Skip to content

Commit 9ef2f06

Browse files
authored
mx cast: remove clamping of output tensor for torch.compile path (#1911)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent 44c5476 commit 9ef2f06

File tree

2 files changed

+64
-3
lines changed

2 files changed

+64
-3
lines changed

test/prototype/mx_formats/test_mx_tensor.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,3 +334,55 @@ def test_to_mx_inductor_single_kernel():
334334
to_mx_c = torch.compile(MXTensor.to_mx, fullgraph=True)
335335
out, code = run_and_get_code(to_mx_c, x, torch.float8_e4m3fn, block_size)
336336
FileCheck().check("def call(").check_count(".run(", 1, exactly=True).run(code[0])
337+
338+
339+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
340+
@pytest.mark.skipif(
341+
not is_sm_at_least_89(),
342+
reason="float8 in triton requires CUDA capability 8.9 or greater",
343+
)
344+
def test_cast_to_float8_e4m3fn_saturation_behavior():
345+
# TODO(#1912): make the saturated cast work in eager mode and remove this
346+
# test
347+
max_val = torch.finfo(torch.float8_e4m3fn).max
348+
349+
# create example data inside the representable range
350+
data_in_range_bf16 = torch.tensor(
351+
[
352+
max_val,
353+
-1 * max_val,
354+
],
355+
dtype=torch.bfloat16,
356+
device="cuda",
357+
)
358+
359+
# create example data outside the representable range
360+
data_out_of_range_bf16 = torch.tensor(
361+
[
362+
max_val * 2,
363+
-1 * (max_val * 2),
364+
],
365+
dtype=torch.bfloat16,
366+
device="cuda",
367+
)
368+
369+
# verify that in eager mode PyTorch casting to float8 is unsaturated
370+
data_in_range_f8 = data_in_range_bf16.to(torch.float8_e4m3fn)
371+
data_out_of_range_f8 = data_out_of_range_bf16.to(torch.float8_e4m3fn)
372+
assert not torch.any(torch.isnan(data_in_range_f8))
373+
assert torch.all(torch.isnan(data_out_of_range_f8))
374+
375+
# verify that in triton, casting to float8 is saturated
376+
# for simplicity, use torch.compile to generate triton code
377+
def to_f8(x):
378+
x = x.to(torch.float8_e4m3fn)
379+
return x
380+
381+
to_f8_c = torch.compile(to_f8)
382+
data_in_range_f8_c = to_f8_c(data_in_range_bf16)
383+
data_out_of_range_f8_c = to_f8_c(data_out_of_range_bf16)
384+
assert not torch.any(torch.isnan(data_in_range_f8_c))
385+
assert not torch.any(torch.isnan(data_out_of_range_f8_c))
386+
torch.testing.assert_close(
387+
data_in_range_f8_c, data_out_of_range_f8_c, atol=0, rtol=0
388+
)

torchao/prototype/mx_formats/mx_tensor.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -228,9 +228,18 @@ def to_mx(
228228
max_pos = F4_E2M1_MAX
229229
else:
230230
raise AssertionError("unsupported")
231-
data_lp = torch.clamp(
232-
data_hp / scale_fp32.unsqueeze(1), min=-1 * max_pos, max=max_pos
233-
)
231+
data_lp = data_hp / scale_fp32.unsqueeze(1)
232+
if (
233+
elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2)
234+
and not torch._dynamo.is_compiling()
235+
):
236+
# As of 20250317, the Pytorch eager mode cast to `torch.float8_e4m3fn`
237+
# is unsaturated. This cast is saturated in triton. If we are compute bound,
238+
# we see a speedup if we remove this redundant clamp if we are compiling
239+
# to triton.
240+
# TODO(#1912): make the saturated cast work in eager mode and remove this
241+
# workaround.
242+
data_lp = torch.clamp(data_lp, min=-1 * max_pos, max=max_pos)
234243

235244
# cast to target dtype
236245
if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):

0 commit comments

Comments
 (0)