23
23
24
24
from torchao .prototype .mx_formats .config import MXGemmKernelChoice
25
25
from torchao .prototype .mx_formats .constants import (
26
+ BF16_EXP_BIAS ,
26
27
BLOCK_SIZE_DEFAULT ,
27
28
DTYPE_FP4 ,
28
29
DTYPE_FP6_E2M3 ,
39
40
F8E4M3_MAX_POW2 ,
40
41
F8E5M2_MAX ,
41
42
F8E5M2_MAX_POW2 ,
43
+ F32_EXP_BIAS ,
42
44
F32_MIN_NORMAL ,
43
45
SUPPORTED_ELEM_DTYPES ,
44
46
)
59
61
60
62
# TODO(later): read from somewhere else?
61
63
SBITS , EBITS_F32 , MBITS_F32 = 1 , 8 , 23
64
+ EBITS_BF16 , MBITS_BF16 = 8 , 7
62
65
EBITS_F4_E2M1 , MBITS_F4_E2M1 = 2 , 1
63
66
EBITS_F6_E2M3 , MBITS_F6_E2M3 = 2 , 3
64
67
EBITS_F6_E3M2 , MBITS_F6_E3M2 = 3 , 2
@@ -141,28 +144,51 @@ def to_mx(
141
144
else :
142
145
raise AssertionError ("unsupported element dtype" )
143
146
147
+ if data_hp .dtype is torch .float32 :
148
+ hp_int_dtype = torch .int32
149
+ hp_mbits = MBITS_F32
150
+ hp_ebits = EBITS_F32
151
+ hp_exp_bias = F32_EXP_BIAS
152
+ else :
153
+ assert data_hp .dtype is torch .bfloat16
154
+ hp_int_dtype = torch .int16
155
+ hp_mbits = MBITS_BF16
156
+ hp_ebits = EBITS_BF16
157
+ hp_exp_bias = BF16_EXP_BIAS
158
+
144
159
# rounding before calculating the largest power of 2
145
160
# X = 2^(floor(log2(rounding(max_abs(v)))-max_exp))
146
161
if scaling_mode == ScaleCalculationMode .EVEN :
147
162
nan_mask = torch .isnan (max_abs )
148
- max_abs = max_abs .to ( torch . float32 ). view (torch . int32 )
149
- val_to_add = 1 << (MBITS_F32 - mbits - 1 )
150
- mask = ((1 << (EBITS_F32 + SBITS )) - 1 ) << MBITS_F32
163
+ max_abs = max_abs .view (hp_int_dtype )
164
+ val_to_add = 1 << (hp_mbits - mbits - 1 )
165
+ mask = ((1 << (hp_ebits + SBITS )) - 1 ) << hp_mbits
151
166
max_abs = (max_abs + val_to_add ) & mask
152
- max_abs = max_abs .view (torch .float32 )
153
- max_abs [nan_mask ] = torch .tensor (float ("nan" ), device = max_abs .device )
167
+ max_abs = max_abs .view (data_hp .dtype )
168
+ max_abs [nan_mask ] = torch .tensor (
169
+ float ("nan" ), device = max_abs .device , dtype = max_abs .dtype
170
+ )
154
171
155
172
# Calculate the scale for different modes
173
+ max_abs_int32 = (max_abs + eps ).view (hp_int_dtype )
174
+ extracted_pow2 = ((max_abs_int32 >> hp_mbits ) & 0b11111111 ) - hp_exp_bias
175
+ extracted_pow2 = extracted_pow2 .to (data_hp .dtype )
176
+
156
177
if scaling_mode in (ScaleCalculationMode .FLOOR , ScaleCalculationMode .EVEN ):
157
- scale_e8m0_unbiased = torch . floor ( torch . log2 ( max_abs + eps )) - target_max_pow2
178
+ scale_e8m0_unbiased = extracted_pow2 - target_max_pow2
158
179
elif scaling_mode == ScaleCalculationMode .CEIL :
159
- scale_e8m0_unbiased = torch .ceil (torch .log2 (max_abs + eps )) - target_max_pow2
180
+ # round up: add one to scale if the mantissa is larger than 0
181
+ # 0x7FFFFF is equal to 23 ones
182
+ mantissa_gt_one = (max_abs_int32 & 0x7FFFFF ) > 0
183
+ extracted_pow2 += mantissa_gt_one
184
+ scale_e8m0_unbiased = extracted_pow2 - target_max_pow2
160
185
else :
161
186
raise AssertionError ("unsupported scaling calculation mode" )
162
187
163
188
# Clamp to exponents that can be represented in e8m0
189
+ # add one to positive range to capture NaNs
164
190
scale_e8m0_unbiased = torch .clamp (
165
- scale_e8m0_unbiased , min = - E8M0_EXPONENT_BIAS , max = E8M0_EXPONENT_BIAS
191
+ scale_e8m0_unbiased , min = - E8M0_EXPONENT_BIAS , max = E8M0_EXPONENT_BIAS + 1
166
192
)
167
193
168
194
# Create the biased e8m0 representation and cast it to 8 bits
@@ -172,7 +198,7 @@ def to_mx(
172
198
# Conversion to torch.uint8 sets NaN values to 0, fix this by
173
199
# explicitly setting known NaN values to 255
174
200
scale_e8m0_biased = torch .where (
175
- torch .isnan (scale_e8m0_unbiased ),
201
+ torch .isnan (max_abs ),
176
202
E8M0_EXPONENT_NAN_VAL ,
177
203
scale_e8m0_biased ,
178
204
)
0 commit comments