Skip to content

Commit b1ea75c

Browse files
authored
mx cast: torch.log2 -> bit shifts (#1908)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent 152c23c commit b1ea75c

File tree

2 files changed

+36
-9
lines changed

2 files changed

+36
-9
lines changed

torchao/prototype/mx_formats/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
E8M0_EXPONENT_NAN_VAL = 255
3636

3737
F32_EXP_BIAS = 127
38+
BF16_EXP_BIAS = 127
3839
F6_E2M3_EXP_BIAS = 1
3940
F6_E3M2_EXP_BIAS = 3
4041
F4_E2M1_EXP_BIAS = 1

torchao/prototype/mx_formats/mx_tensor.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from torchao.prototype.mx_formats.config import MXGemmKernelChoice
2525
from torchao.prototype.mx_formats.constants import (
26+
BF16_EXP_BIAS,
2627
BLOCK_SIZE_DEFAULT,
2728
DTYPE_FP4,
2829
DTYPE_FP6_E2M3,
@@ -39,6 +40,7 @@
3940
F8E4M3_MAX_POW2,
4041
F8E5M2_MAX,
4142
F8E5M2_MAX_POW2,
43+
F32_EXP_BIAS,
4244
F32_MIN_NORMAL,
4345
SUPPORTED_ELEM_DTYPES,
4446
)
@@ -59,6 +61,7 @@
5961

6062
# TODO(later): read from somewhere else?
6163
SBITS, EBITS_F32, MBITS_F32 = 1, 8, 23
64+
EBITS_BF16, MBITS_BF16 = 8, 7
6265
EBITS_F4_E2M1, MBITS_F4_E2M1 = 2, 1
6366
EBITS_F6_E2M3, MBITS_F6_E2M3 = 2, 3
6467
EBITS_F6_E3M2, MBITS_F6_E3M2 = 3, 2
@@ -141,28 +144,51 @@ def to_mx(
141144
else:
142145
raise AssertionError("unsupported element dtype")
143146

147+
if data_hp.dtype is torch.float32:
148+
hp_int_dtype = torch.int32
149+
hp_mbits = MBITS_F32
150+
hp_ebits = EBITS_F32
151+
hp_exp_bias = F32_EXP_BIAS
152+
else:
153+
assert data_hp.dtype is torch.bfloat16
154+
hp_int_dtype = torch.int16
155+
hp_mbits = MBITS_BF16
156+
hp_ebits = EBITS_BF16
157+
hp_exp_bias = BF16_EXP_BIAS
158+
144159
# rounding before calculating the largest power of 2
145160
# X = 2^(floor(log2(rounding(max_abs(v)))-max_exp))
146161
if scaling_mode == ScaleCalculationMode.EVEN:
147162
nan_mask = torch.isnan(max_abs)
148-
max_abs = max_abs.to(torch.float32).view(torch.int32)
149-
val_to_add = 1 << (MBITS_F32 - mbits - 1)
150-
mask = ((1 << (EBITS_F32 + SBITS)) - 1) << MBITS_F32
163+
max_abs = max_abs.view(hp_int_dtype)
164+
val_to_add = 1 << (hp_mbits - mbits - 1)
165+
mask = ((1 << (hp_ebits + SBITS)) - 1) << hp_mbits
151166
max_abs = (max_abs + val_to_add) & mask
152-
max_abs = max_abs.view(torch.float32)
153-
max_abs[nan_mask] = torch.tensor(float("nan"), device=max_abs.device)
167+
max_abs = max_abs.view(data_hp.dtype)
168+
max_abs[nan_mask] = torch.tensor(
169+
float("nan"), device=max_abs.device, dtype=max_abs.dtype
170+
)
154171

155172
# Calculate the scale for different modes
173+
max_abs_int32 = (max_abs + eps).view(hp_int_dtype)
174+
extracted_pow2 = ((max_abs_int32 >> hp_mbits) & 0b11111111) - hp_exp_bias
175+
extracted_pow2 = extracted_pow2.to(data_hp.dtype)
176+
156177
if scaling_mode in (ScaleCalculationMode.FLOOR, ScaleCalculationMode.EVEN):
157-
scale_e8m0_unbiased = torch.floor(torch.log2(max_abs + eps)) - target_max_pow2
178+
scale_e8m0_unbiased = extracted_pow2 - target_max_pow2
158179
elif scaling_mode == ScaleCalculationMode.CEIL:
159-
scale_e8m0_unbiased = torch.ceil(torch.log2(max_abs + eps)) - target_max_pow2
180+
# round up: add one to scale if the mantissa is larger than 0
181+
# 0x7FFFFF is equal to 23 ones
182+
mantissa_gt_one = (max_abs_int32 & 0x7FFFFF) > 0
183+
extracted_pow2 += mantissa_gt_one
184+
scale_e8m0_unbiased = extracted_pow2 - target_max_pow2
160185
else:
161186
raise AssertionError("unsupported scaling calculation mode")
162187

163188
# Clamp to exponents that can be represented in e8m0
189+
# add one to positive range to capture NaNs
164190
scale_e8m0_unbiased = torch.clamp(
165-
scale_e8m0_unbiased, min=-E8M0_EXPONENT_BIAS, max=E8M0_EXPONENT_BIAS
191+
scale_e8m0_unbiased, min=-E8M0_EXPONENT_BIAS, max=E8M0_EXPONENT_BIAS + 1
166192
)
167193

168194
# Create the biased e8m0 representation and cast it to 8 bits
@@ -172,7 +198,7 @@ def to_mx(
172198
# Conversion to torch.uint8 sets NaN values to 0, fix this by
173199
# explicitly setting known NaN values to 255
174200
scale_e8m0_biased = torch.where(
175-
torch.isnan(scale_e8m0_unbiased),
201+
torch.isnan(max_abs),
176202
E8M0_EXPONENT_NAN_VAL,
177203
scale_e8m0_biased,
178204
)

0 commit comments

Comments
 (0)