Skip to content

Commit b11b96a

Browse files
committed
swap operations
1 parent eec7bd3 commit b11b96a

File tree

2 files changed

+6
-9
lines changed

2 files changed

+6
-9
lines changed

src/compressed_tensors/quantization/lifecycle/forward.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ def _quantize(
380380
) -> torch.Tensor:
381381

382382
if global_scale:
383-
scale = scale.to(global_scale.dtype) * global_scale
383+
scale = scale.to(global_scale.dtype) / global_scale
384384

385385
scaled = x / scale
386386
if zero_point is not None:
@@ -409,7 +409,7 @@ def _dequantize(
409409
) -> torch.Tensor:
410410

411411
if global_scale:
412-
scale = scale.to(global_scale.dtype) * global_scale
412+
scale = scale.to(global_scale.dtype) / global_scale
413413

414414
dequant_value = x_q.to(scale.dtype)
415415

src/compressed_tensors/quantization/utils/helpers.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -87,17 +87,14 @@ def calculate_qparams(
8787
and quantization_args.type == QuantizationType.FLOAT
8888
):
8989
assert global_scale is not None
90-
breakpoint()
91-
scales = max_val_pos / FP4_E2M1_DATA.max # Not needed
92-
scales = scales / global_scale
93-
scales = scales.to(FP8_E4M3_DATA.dtype) # .to(torch.float32)
94-
90+
scales = global_scale * (max_val_pos / FP4_E2M1_DATA.max) # Not needed
91+
# scales = scales / global_scale
92+
scales = scales.to(FP8_E4M3_DATA.dtype)
9593
else:
9694
# Divide over bit range over max value?
97-
scales = max_val_pos / (float(bit_range) / 2)
95+
scales = max_val_pos / (float(bit_radnge) / 2)
9896

9997
# TODO: clamp not implemented for FP8 '
100-
breakpoint()
10198
# scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
10299
zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype)
103100
else:

0 commit comments

Comments
 (0)