Skip to content

Commit 44c5476

Browse files
authored
mx cast: torch.pow -> bit shifts (#1910)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent b1ea75c commit 44c5476

File tree

1 file changed

+3
-7
lines changed

1 file changed

+3
-7
lines changed

torchao/prototype/mx_formats/mx_tensor.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -204,11 +204,7 @@ def to_mx(
204204
)
205205

206206
# For now, calculate the scale in floating point.
207-
# TODO(future) audit if there is a need to bit shift exponents instead.
208-
scale_fp = torch.pow(
209-
torch.full(max_abs.size(), 2.0, device=scale_e8m0_biased.device),
210-
scale_e8m0_unbiased,
211-
)
207+
scale_fp32 = (scale_e8m0_biased.to(torch.int32) << MBITS_F32).view(torch.float32)
212208

213209
# Today, 2**-127 returns 0 in compile+inductor+triton because it is in the
214210
# float32 denormal range. For now, manually adjust the fp scale. This is
@@ -217,7 +213,7 @@ def to_mx(
217213
# Note: it would be more correct to set the minimum to 2**-127, but this
218214
# does not work in triton either as it looks like subnormal value handling
219215
# has some gaps. So, for now just set to the minimum normal value.
220-
scale_fp = torch.clamp(scale_fp, min=F32_MIN_NORMAL)
216+
scale_fp32 = torch.clamp(scale_fp32, min=F32_MIN_NORMAL)
221217

222218
# scale and saturated cast the data elements to max of target dtype
223219
if elem_dtype == torch.float8_e4m3fn:
@@ -233,7 +229,7 @@ def to_mx(
233229
else:
234230
raise AssertionError("unsupported")
235231
data_lp = torch.clamp(
236-
data_hp / scale_fp.unsqueeze(1), min=-1 * max_pos, max=max_pos
232+
data_hp / scale_fp32.unsqueeze(1), min=-1 * max_pos, max=max_pos
237233
)
238234

239235
# cast to target dtype

0 commit comments

Comments
 (0)