@@ -204,11 +204,7 @@ def to_mx(
204
204
)
205
205
206
206
# For now, calculate the scale in floating point.
207
- # TODO(future) audit if there is a need to bit shift exponents instead.
208
- scale_fp = torch .pow (
209
- torch .full (max_abs .size (), 2.0 , device = scale_e8m0_biased .device ),
210
- scale_e8m0_unbiased ,
211
- )
207
+ scale_fp32 = (scale_e8m0_biased .to (torch .int32 ) << MBITS_F32 ).view (torch .float32 )
212
208
213
209
# Today, 2**-127 returns 0 in compile+inductor+triton because it is in the
214
210
# float32 denormal range. For now, manually adjust the fp scale. This is
@@ -217,7 +213,7 @@ def to_mx(
217
213
# Note: it would be more correct to set the minimum to 2**-127, but this
218
214
# does not work in triton either as it looks like subnormal value handling
219
215
# has some gaps. So, for now just set to the minimum normal value.
220
- scale_fp = torch .clamp (scale_fp , min = F32_MIN_NORMAL )
216
+ scale_fp32 = torch .clamp (scale_fp32 , min = F32_MIN_NORMAL )
221
217
222
218
# scale and saturated cast the data elements to max of target dtype
223
219
if elem_dtype == torch .float8_e4m3fn :
@@ -233,7 +229,7 @@ def to_mx(
233
229
else :
234
230
raise AssertionError ("unsupported" )
235
231
data_lp = torch .clamp (
236
- data_hp / scale_fp .unsqueeze (1 ), min = - 1 * max_pos , max = max_pos
232
+ data_hp / scale_fp32 .unsqueeze (1 ), min = - 1 * max_pos , max = max_pos
237
233
)
238
234
239
235
# cast to target dtype
0 commit comments