Skip to content

Commit be30822

Browse files
committed
fix condition
1 parent e8c6c8f commit be30822

File tree

1 file changed

+4
-3
lines changed
  • src/compressed_tensors/quantization/utils

1 file changed

+4
-3
lines changed

src/compressed_tensors/quantization/utils/helpers.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -88,14 +88,15 @@ def calculate_qparams(
8888
):
8989
assert global_scale is not None
9090
scales = global_scale * (max_val_pos / FP4_E2M1_DATA.max) # Not needed
91-
# scales = scales / global_scale
9291
scales = scales.to(FP8_E4M3_DATA.dtype)
9392
else:
9493
# Divide over bit range over max value?
9594
scales = max_val_pos / (float(bit_range) / 2)
9695

97-
# TODO: clamp not implemented for FP8 '
98-
# scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
96+
# TODO: clamp not implemented for FP8 - we shouldn't need to clamp this anyway as we're
97+
# casting to FP8 on line 92?
98+
if scales.dtype != FP8_E4M3_DATA.dtype:
99+
scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
99100
zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype)
100101
else:
101102
scales = (max_vals - min_vals) / float(bit_range)

0 commit comments

Comments
 (0)