Skip to content

Commit d7ce8ec

Browse files
authored
[NVFP4] Small Nits (#341)
* small nits * remove import; remove comment * fix test
1 parent e554fba commit d7ce8ec

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

src/compressed_tensors/quantization/lifecycle/forward.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
DynamicType,
2222
QuantizationArgs,
2323
QuantizationStrategy,
24-
QuantizationType,
2524
round_to_quantized_type,
2625
)
2726
from compressed_tensors.quantization.quant_config import QuantizationStatus
@@ -405,7 +404,7 @@ def _quantize(
405404

406405
# if a global scale is optionally provided, use it
407406
# to further scale the local `scale` parameter
408-
if global_scale:
407+
if global_scale is not None:
409408
scale = scale.to(global_scale.dtype) / global_scale
410409

411410
scaled = x / scale
@@ -438,7 +437,7 @@ def _dequantize(
438437

439438
# if a global scale is optionally provided, use it
440439
# to further scale the local `scale` parameter
441-
if global_scale:
440+
if global_scale is not None:
442441
scale = scale.to(global_scale.dtype) / global_scale
443442

444443
dequant_value = x_q.to(scale.dtype)

src/compressed_tensors/quantization/utils/helpers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ def calculate_qparams(
110110
else:
111111
scales = max_val_pos / (float(bit_range) / 2)
112112

113+
# TODO: in the case of MoEs, the global_scale may also be 0/need to be clamped
113114
if scales.dtype == FP8_E4M3_DATA.dtype:
114115
# torch.clamp not supported for FP8
115116
# use the next largest fp8 value from 0
@@ -495,4 +496,4 @@ def generate_gparam(
495496
max_vals = torch.max(updated_max_val, torch.zeros_like(updated_max_val))
496497
max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals))
497498
global_scale = scale_data.max * quant_data.max / max_val_pos
498-
return global_scale.to(dtype)
499+
return global_scale.to(dtype).reshape([1])

tests/test_quantization/test_utils/test_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,6 @@ def test_fused_global_scales():
7070
min_val, max_val = torch.aminmax(layer.weight)
7171
global_scale = generate_gparam(min_val.data, max_val.data)
7272
# max value should be = (448 * 6) / global_scale
73-
assert max_tensor_value == pytest.approx(
73+
assert max_tensor_value.item() == pytest.approx(
7474
FP4_E2M1_DATA.max * FP8_E4M3_DATA.max / global_scale, abs=0.001
7575
)

0 commit comments

Comments
 (0)