Skip to content

Commit 4a5ab2d

Browse files
authored
enable compile with mxfp8 and mxfp4 cutlass gemm (#1838)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent fa5f45e commit 4a5ab2d

File tree

3 files changed

+31
-21
lines changed

3 files changed

+31
-21
lines changed

test/prototype/mx_formats/test_mx_linear.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -171,30 +171,36 @@ def test_activation_checkpointing():
171171
@pytest.mark.skipif(
172172
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
173173
)
174-
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
174+
@pytest.mark.parametrize(
175+
"recipe_name",
176+
["mxfp8_emulated", "mxfp4_emulated", "mxfp8_cutlass", "mxfp4_cutlass"],
177+
)
175178
@pytest.mark.parametrize("bias", [False, True])
176179
# TODO(future PR): figure out why torch.compile does not match eager when
177180
# autocast is on
178-
@pytest.mark.parametrize(
179-
"use_autocast",
180-
[
181-
False,
182-
],
183-
)
184-
def test_linear_compile(elem_dtype, bias, use_autocast):
181+
def test_linear_compile(recipe_name, bias):
185182
"""
186183
Verify that compile does not change numerics of MX linear fw + bw
187184
"""
188-
if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
185+
if recipe_name in ["mxfp8_emulated", "mxfp8_cutlass"]:
189186
if not is_sm_at_least_89():
190187
pytest.skip("CUDA capability >= 8.9 required for float8 in triton")
191-
M, K, N = 4, 8, 6
188+
189+
if recipe_name in ["mxfp8_cutlass", "mxfp4_cutlass"]:
190+
if not is_sm_at_least_100():
191+
pytest.skip("CUDA capability >= 10.0 required for MX gemms")
192+
193+
if bias and recipe_name in ["mxfp8_cutlass", "mxfp4_cutlass"]:
194+
# TODO(future PR): fix this, things are clearly broken with bias=True
195+
pytest.skip("this test is broken for cutlass recipes with bias=True")
196+
197+
M, K, N = 128, 256, 512
192198
input_shape = (M, K)
193199
grad_shape = (M, N)
194200
m_mx = nn.Sequential(
195201
nn.Linear(K, N, bias=bias, device="cuda"),
196202
)
197-
config = MXLinearConfig(block_size=2, elem_dtype=elem_dtype)
203+
config = MXLinearConfig.from_recipe_name(recipe_name)
198204
swap_linear_with_mx_linear(m_mx, config=config)
199205
m_mx_c = copy.deepcopy(m_mx)
200206
m_mx_c = torch.compile(m_mx_c, fullgraph=True, backend="inductor")
@@ -203,13 +209,8 @@ def test_linear_compile(elem_dtype, bias, use_autocast):
203209
x = copy.deepcopy(x_ref)
204210
g = torch.randn(*grad_shape, device="cuda")
205211

206-
if use_autocast:
207-
with torch.autocast("cuda", dtype=torch.bfloat16):
208-
y_ref = m_mx(x_ref)
209-
y = m_mx_c(x)
210-
else:
211-
y_ref = m_mx(x_ref)
212-
y = m_mx_c(x)
212+
y_ref = m_mx(x_ref)
213+
y = m_mx_c(x)
213214
torch.testing.assert_close(y_ref, y, atol=0, rtol=0)
214215

215216
y_ref.backward(g)

torchao/ops.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,16 @@
2727
lib.define(
2828
"rowwise_scaled_linear_cutlass_s8s4(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_scale, Tensor bias) -> Tensor"
2929
)
30-
lib.define("mx_fp8_bf16(Tensor a, Tensor b, Tensor a_scale, Tensor b_scale) -> Tensor")
31-
lib.define("mx_fp4_bf16(Tensor a, Tensor b, Tensor a_scale, Tensor b_scale) -> Tensor")
30+
# Note: we need to add the `torch._C.Tag.needs_fixed_stride_order` tag in order for inductor
31+
# to honor the layout constraints for `b` in the two ops below.
32+
lib.define(
33+
"mx_fp8_bf16(Tensor a, Tensor b, Tensor a_scale, Tensor b_scale) -> Tensor",
34+
tags=[torch._C.Tag.needs_fixed_stride_order],
35+
)
36+
lib.define(
37+
"mx_fp4_bf16(Tensor a, Tensor b, Tensor a_scale, Tensor b_scale) -> Tensor",
38+
tags=[torch._C.Tag.needs_fixed_stride_order],
39+
)
3240

3341

3442
def register_custom_op(name):

torchao/prototype/mx_formats/mx_ops.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import torch
2323
from torch.utils._pytree import tree_map
2424

25-
# from torchao.ops import mx_fp4_bf16, mx_fp8_bf16
2625
import torchao.ops
2726
from torchao.prototype.mx_formats.config import MXGemmKernelChoice
2827
from torchao.prototype.mx_formats.constants import DTYPE_FP4
@@ -73,7 +72,9 @@ def mx_mm(aten_op, args, kwargs=None):
7372
if a._gemm_kernel_choice in (MXGemmKernelChoice.CUBLAS, MXGemmKernelChoice.CUTLASS):
7473
# real MX gemm backed by torchao's CUTLASS kernels
7574
M, K, N = a.shape[0], a.shape[1], b.shape[1]
75+
assert a._data.is_contiguous()
7676
assert b._data.t().is_contiguous()
77+
7778
# TODO(future PR): use block_size instead of hardcoding 32
7879
a_scale = a._scale_e8m0.view(M, K // 32)
7980
b_scale = b._scale_e8m0.view(N, K // 32)

0 commit comments

Comments
 (0)