@@ -175,10 +175,7 @@ def to_mx(
175
175
176
176
# For now, calculate the scale in floating point.
177
177
# TODO(future) audit if there is a need to bit shift exponents instead.
178
- scale_fp = torch .pow (
179
- torch .full (max_abs .size (), 2.0 , device = scale_e8m0_biased .device ),
180
- scale_e8m0_unbiased ,
181
- )
178
+ scale_fp = torch .exp2 (scale_e8m0_unbiased ).to (torch .float32 )
182
179
183
180
# Today, 2**-127 returns 0 in compile+inductor+triton because it is in the
184
181
# float32 denormal range. For now, manually adjust the fp scale. This is
@@ -233,14 +230,10 @@ def to_mx(
233
230
234
231
235
232
def get_fp_scale (scale_e8m0 ):
236
- s_offset = scale_e8m0 .to (torch .int16 ) - E8M0_EXPONENT_BIAS
237
- # TODO(later): it would be nice if there was a way to do the 2^x operation
238
- # in PyTorch without creating a tensor of twos
239
- two = torch .full (s_offset .size (), 2.0 , device = scale_e8m0 .device )
240
- # pow(two, s_offset) can be out of range of floating point formats.
241
233
# TODO(later): handle this for float16 if we decide to support float16
242
234
# scales.
243
- s_fp = torch .pow (two , s_offset )
235
+ s_offset = scale_e8m0 .to (torch .int16 ) - E8M0_EXPONENT_BIAS
236
+ s_fp = torch .exp2 (s_offset )
244
237
245
238
# If a block exponent was 255, set values of that block to NaN
246
239
s_fp = torch .where (scale_e8m0 != E8M0_EXPONENT_NAN_VAL , s_fp , float ("nan" ))
0 commit comments