Skip to content

Commit 2ca3016

Browse files
authored
enable torch.compile for mxfp8_cublas recipe (#1841)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent ce05b3f commit 2ca3016

File tree

2 files changed

+103
-10
lines changed

2 files changed

+103
-10
lines changed

test/prototype/mx_formats/test_mx_linear.py

Lines changed: 70 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch
1212
import torch.nn as nn
1313

14+
from torchao.float8.float8_utils import is_row_major
1415
from torchao.prototype.mx_formats.config import (
1516
MXLinearConfig,
1617
MXLinearRecipeName,
@@ -24,14 +25,14 @@
2425
)
2526
from torchao.quantization.utils import compute_error
2627
from torchao.utils import (
27-
TORCH_VERSION_AT_LEAST_2_4,
28+
TORCH_VERSION_AT_LEAST_2_5,
2829
is_sm_at_least_89,
2930
is_sm_at_least_100,
3031
)
3132

3233
torch.manual_seed(2)
3334

34-
if not TORCH_VERSION_AT_LEAST_2_4:
35+
if not TORCH_VERSION_AT_LEAST_2_5:
3536
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
3637

3738

@@ -169,11 +170,18 @@ def test_activation_checkpointing():
169170

170171
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
171172
@pytest.mark.skipif(
172-
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
173+
is_sm_at_least_100(),
174+
reason="triton does not work yet on CUDA capability 10.0",
173175
)
174176
@pytest.mark.parametrize(
175177
"recipe_name",
176-
["mxfp8_emulated", "mxfp4_emulated", "mxfp8_cutlass", "mxfp4_cutlass"],
178+
[
179+
"mxfp8_emulated",
180+
"mxfp4_emulated",
181+
"mxfp8_cublas",
182+
"mxfp8_cutlass",
183+
"mxfp4_cutlass",
184+
],
177185
)
178186
@pytest.mark.parametrize("bias", [False, True])
179187
# TODO(future PR): figure out why torch.compile does not match eager when
@@ -186,13 +194,13 @@ def test_linear_compile(recipe_name, bias):
186194
if not is_sm_at_least_89():
187195
pytest.skip("CUDA capability >= 8.9 required for float8 in triton")
188196

189-
if recipe_name in ["mxfp8_cutlass", "mxfp4_cutlass"]:
197+
if recipe_name in ["mxfp8_cublas", "mxfp8_cutlass", "mxfp4_cutlass"]:
190198
if not is_sm_at_least_100():
191199
pytest.skip("CUDA capability >= 10.0 required for MX gemms")
192200

193-
if bias and recipe_name in ["mxfp8_cutlass", "mxfp4_cutlass"]:
201+
if bias and recipe_name in ["mxfp8_cublas", "mxfp8_cutlass", "mxfp4_cutlass"]:
194202
# TODO(future PR): fix this, things are clearly broken with bias=True
195-
pytest.skip("this test is broken for cutlass recipes with bias=True")
203+
pytest.skip("this test is broken for non-emulated recipes with bias=True")
196204

197205
M, K, N = 128, 256, 512
198206
input_shape = (M, K)
@@ -285,6 +293,61 @@ def test_inference_compile_simple(elem_dtype):
285293
assert sqnr >= 13.5
286294

287295

296+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
297+
@pytest.mark.skipif(
298+
is_sm_at_least_100(),
299+
reason="triton does not work yet on CUDA capability 10.0",
300+
)
301+
@pytest.mark.skipif(
302+
not is_sm_at_least_100(),
303+
reason="MX gemms require CUDA capability 10.0",
304+
)
305+
def test_scaled_mm_wrapper():
306+
# today, e8m0 isn't supported in torchinductor or triton
307+
# for now, work around this by creating a wrapper around torch._scaled_mm
308+
# which takes uint8 scales, and reinterprets them as e8m0 inside the wrapper
309+
from torchao.prototype.mx_formats.mx_ops import _scaled_mm_with_uint8_scales
310+
311+
M, K, N = 128, 256, 512
312+
BLOCK_SIZE = 32
313+
a = torch.randn(M, K, device="cuda").to(torch.float8_e4m3fn)
314+
b = torch.randn(N, K, device="cuda").to(torch.float8_e4m3fn)
315+
316+
a_scale = torch.ones(M, K // BLOCK_SIZE, device="cuda", dtype=torch.float8_e8m0fnu)
317+
b_scale = torch.ones(N, K // BLOCK_SIZE, device="cuda", dtype=torch.float8_e8m0fnu)
318+
319+
out = torch._scaled_mm(a, b.t(), a_scale, b_scale, out_dtype=torch.bfloat16)
320+
321+
def wrapped(a, b, a_scale, b_scale, out_dtype):
322+
if is_row_major(b.stride()):
323+
b = b.t().contiguous().t()
324+
res = _scaled_mm_with_uint8_scales(a, b, a_scale, b_scale, out_dtype=out_dtype)
325+
return res
326+
327+
wrapped = torch.compile(wrapped)
328+
329+
# correct memory format of `b`
330+
out2 = wrapped(
331+
a,
332+
b.t(),
333+
a_scale.view(torch.uint8),
334+
b_scale.view(torch.uint8),
335+
out_dtype=torch.bfloat16,
336+
)
337+
torch.testing.assert_close(out, out2, atol=0, rtol=0)
338+
339+
# incorrect memory format of `b`
340+
b_col_major = b.t().contiguous().t()
341+
out3 = wrapped(
342+
a,
343+
b_col_major.t(),
344+
a_scale.view(torch.uint8),
345+
b_scale.view(torch.uint8),
346+
out_dtype=torch.bfloat16,
347+
)
348+
torch.testing.assert_close(out, out3, atol=0, rtol=0)
349+
350+
288351
def test_filter_fn():
289352
m1 = nn.Sequential(
290353
nn.Linear(32, 32),

torchao/prototype/mx_formats/mx_ops.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,41 @@
3535
tensor_size_hpx3_to_fp6x4,
3636
)
3737
from torchao.prototype.mx_formats.utils import to_blocked
38+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
3839

3940
aten = torch.ops.aten
4041

4142
MX_OPS_TABLE: Dict[Any, Any] = {}
4243

44+
if TORCH_VERSION_AT_LEAST_2_5:
45+
46+
@torch.library.custom_op("mylib::_scaled_mm_with_uint8_scales", mutates_args=())
47+
def _scaled_mm_with_uint8_scales(
48+
a: torch.Tensor,
49+
b: torch.Tensor,
50+
a_scale: torch.Tensor,
51+
b_scale: torch.Tensor,
52+
out_dtype: torch.dtype,
53+
) -> torch.Tensor:
54+
"""
55+
Until https://github.com/pytorch/pytorch/issues/147873 is done, we need to
56+
work around the lack of support for `torch.float8_e8m0fnu` in
57+
torchinductor. We do so by hiding the cast of scales to e8m0 inside a
58+
custom op.
59+
"""
60+
# cast back to e8m0 where torchinductor can't see it
61+
a_scale = a_scale.view(torch.float8_e8m0fnu)
62+
b_scale = b_scale.view(torch.float8_e8m0fnu)
63+
res = torch._scaled_mm(a, b, a_scale, b_scale, out_dtype=out_dtype)
64+
return res
65+
66+
@_scaled_mm_with_uint8_scales.register_fake
67+
def _(a, b, a_scale, b_scale, out_dtype):
68+
m, k = a.shape
69+
k2, n = b.shape
70+
res = torch.empty(m, n, dtype=out_dtype, device=a.device)
71+
return res
72+
4373

4474
def implements(aten_ops):
4575
"""Register aten ops to the mx op table"""
@@ -89,11 +119,11 @@ def mx_mm(aten_op, args, kwargs=None):
89119
if a._elem_dtype == torch.float8_e4m3fn:
90120
assert b._elem_dtype == torch.float8_e4m3fn
91121
if a._gemm_kernel_choice is MXGemmKernelChoice.CUBLAS:
92-
res = torch._scaled_mm(
122+
res = _scaled_mm_with_uint8_scales(
93123
a._data,
94124
b._data,
95-
a_scale_block.view(torch.float8_e8m0fnu),
96-
b_scale_block.view(torch.float8_e8m0fnu),
125+
a_scale_block,
126+
b_scale_block,
97127
out_dtype=torch.bfloat16,
98128
)
99129
else:

0 commit comments

Comments
 (0)