File tree 2 files changed +6
-9
lines changed
src/compressed_tensors/quantization
2 files changed +6
-9
lines changed Original file line number Diff line number Diff line change @@ -380,7 +380,7 @@ def _quantize(
380
380
) -> torch .Tensor :
381
381
382
382
if global_scale :
383
- scale = scale .to (global_scale .dtype ) * global_scale
383
+ scale = scale .to (global_scale .dtype ) / global_scale
384
384
385
385
scaled = x / scale
386
386
if zero_point is not None :
@@ -409,7 +409,7 @@ def _dequantize(
409
409
) -> torch .Tensor :
410
410
411
411
if global_scale :
412
- scale = scale .to (global_scale .dtype ) * global_scale
412
+ scale = scale .to (global_scale .dtype ) / global_scale
413
413
414
414
dequant_value = x_q .to (scale .dtype )
415
415
Original file line number Diff line number Diff line change @@ -87,17 +87,14 @@ def calculate_qparams(
87
87
and quantization_args .type == QuantizationType .FLOAT
88
88
):
89
89
assert global_scale is not None
90
- breakpoint ()
91
- scales = max_val_pos / FP4_E2M1_DATA .max # Not needed
92
- scales = scales / global_scale
93
- scales = scales .to (FP8_E4M3_DATA .dtype ) # .to(torch.float32)
94
-
90
+ scales = global_scale * (max_val_pos / FP4_E2M1_DATA .max ) # Not needed
91
+ # scales = scales / global_scale
92
+ scales = scales .to (FP8_E4M3_DATA .dtype )
95
93
else :
96
94
# Divide over bit range over max value?
97
- scales = max_val_pos / (float (bit_range ) / 2 )
95
+ scales = max_val_pos / (float (bit_radnge ) / 2 )
98
96
99
97
# TODO: clamp not implemented for FP8 '
100
- breakpoint ()
101
98
# scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
102
99
zero_points = torch .zeros (scales .shape , device = device , dtype = min_vals .dtype )
103
100
else :
You can’t perform that action at this time.
0 commit comments