File tree Expand file tree Collapse file tree 3 files changed +5
-5
lines changed
src/compressed_tensors/quantization
tests/test_quantization/test_utils Expand file tree Collapse file tree 3 files changed +5
-5
lines changed Original file line number Diff line number Diff line change 21
21
DynamicType ,
22
22
QuantizationArgs ,
23
23
QuantizationStrategy ,
24
- QuantizationType ,
25
24
round_to_quantized_type ,
26
25
)
27
26
from compressed_tensors .quantization .quant_config import QuantizationStatus
@@ -405,7 +404,7 @@ def _quantize(
405
404
406
405
# if a global scale is optionally provided, use it
407
406
# to further scale the local `scale` parameter
408
- if global_scale :
407
+ if global_scale is not None :
409
408
scale = scale .to (global_scale .dtype ) / global_scale
410
409
411
410
scaled = x / scale
@@ -438,7 +437,7 @@ def _dequantize(
438
437
439
438
# if a global scale is optionally provided, use it
440
439
# to further scale the local `scale` parameter
441
- if global_scale :
440
+ if global_scale is not None :
442
441
scale = scale .to (global_scale .dtype ) / global_scale
443
442
444
443
dequant_value = x_q .to (scale .dtype )
Original file line number Diff line number Diff line change @@ -110,6 +110,7 @@ def calculate_qparams(
110
110
else :
111
111
scales = max_val_pos / (float (bit_range ) / 2 )
112
112
113
+ # TODO: in the case of MoEs, the global_scale may also be 0/need to be clamped
113
114
if scales .dtype == FP8_E4M3_DATA .dtype :
114
115
# torch.clamp not supported for FP8
115
116
# use the next largest fp8 value from 0
@@ -495,4 +496,4 @@ def generate_gparam(
495
496
max_vals = torch .max (updated_max_val , torch .zeros_like (updated_max_val ))
496
497
max_val_pos = torch .max (torch .abs (min_vals ), torch .abs (max_vals ))
497
498
global_scale = scale_data .max * quant_data .max / max_val_pos
498
- return global_scale .to (dtype )
499
+ return global_scale .to (dtype ). reshape ([ 1 ])
Original file line number Diff line number Diff line change @@ -70,6 +70,6 @@ def test_fused_global_scales():
70
70
min_val , max_val = torch .aminmax (layer .weight )
71
71
global_scale = generate_gparam (min_val .data , max_val .data )
72
72
# max value should be = (448 * 6) / global_scale
73
- assert max_tensor_value == pytest .approx (
73
+ assert max_tensor_value . item () == pytest .approx (
74
74
FP4_E2M1_DATA .max * FP8_E4M3_DATA .max / global_scale , abs = 0.001
75
75
)
You can’t perform that action at this time.
0 commit comments