File tree 1 file changed +4
-3
lines changed
src/compressed_tensors/quantization/utils
1 file changed +4
-3
lines changed Original file line number Diff line number Diff line change @@ -88,14 +88,15 @@ def calculate_qparams(
88
88
):
89
89
assert global_scale is not None
90
90
scales = global_scale * (max_val_pos / FP4_E2M1_DATA .max ) # Not needed
91
- # scales = scales / global_scale
92
91
scales = scales .to (FP8_E4M3_DATA .dtype )
93
92
else :
94
93
# Divide over bit range over max value?
95
94
scales = max_val_pos / (float (bit_range ) / 2 )
96
95
97
- # TODO: clamp not implemented for FP8 '
98
- # scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
96
+ # TODO: clamp not implemented for FP8 - we shouldn't need to clamp this anyway as we're
97
+ # casting to FP8 on line 92?
98
+ if scales .dtype != FP8_E4M3_DATA .dtype :
99
+ scales = torch .clamp (scales , min = torch .finfo (torch .float32 ).eps )
99
100
zero_points = torch .zeros (scales .shape , device = device , dtype = min_vals .dtype )
100
101
else :
101
102
scales = (max_vals - min_vals ) / float (bit_range )
You can’t perform that action at this time.
0 commit comments