Skip to content

Commit 14eff10

Browse files
authored
mx: support inference_mode and rank 3+ (#3238)
Update [ghstack-poisoned]
1 parent f303f4c commit 14eff10

File tree

2 files changed

+39
-2
lines changed

2 files changed

+39
-2
lines changed

test/prototype/mx_formats/test_inference_workflow.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,20 @@ def cuda_kernel_profiler(kernel_pattern):
7171
@pytest.mark.parametrize("bias", [True, False])
7272
@pytest.mark.parametrize("compile", [True, False])
7373
@pytest.mark.parametrize("emulate", [True, False])
74+
@pytest.mark.parametrize("use_inference_mode", [True, False])
75+
@pytest.mark.parametrize("x_rank", [2, 3])
7476
@torch.no_grad()
7577
@skip_if_rocm(
7678
"ROCm float4 gemm require gfx950"
7779
) # TODO(future): deploy gfx950 in ROCM CI
78-
def test_inference_workflow_mx(elem_dtype, bias: bool, compile: bool, emulate: bool):
80+
def test_inference_workflow_mx(
81+
elem_dtype,
82+
bias: bool,
83+
compile: bool,
84+
emulate: bool,
85+
use_inference_mode: bool,
86+
x_rank: int,
87+
):
7988
"""
8089
Smoke test for inference compile
8190
"""
@@ -112,8 +121,15 @@ def test_inference_workflow_mx(elem_dtype, bias: bool, compile: bool, emulate: b
112121
m_mx = torch.compile(m_mx, fullgraph=True)
113122

114123
x = torch.randn(128, 32, device="cuda", dtype=torch.bfloat16)
124+
if x_rank == 3:
125+
x = x.unsqueeze(0)
126+
115127
y_ref = m(x)
116-
y_mx = m_mx(x)
128+
if use_inference_mode:
129+
with torch.inference_mode():
130+
y_mx = m_mx(x)
131+
else:
132+
y_mx = m_mx(x)
117133
sqnr = compute_error(y_ref, y_mx)
118134
SQNR_THRESHOLD = 25.0 if elem_dtype == torch.float8_e4m3fn else 15.0
119135
assert sqnr >= SQNR_THRESHOLD, (

torchao/prototype/mx_formats/mx_tensor.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -775,6 +775,27 @@ def mx_addmm(func, types, args, kwargs):
775775
return _addmm_mx_dispatch(a, b, func, bias=bias)
776776

777777

778+
@implements([aten.linear.default])
779+
def mx_linear(func, types, args, kwargs):
780+
assert isinstance(args[0], torch.Tensor) and isinstance(args[1], MXTensor)
781+
a = args[0]
782+
783+
# make a 2d
784+
orig_a_shape = a.shape
785+
a_2d = a.view(-1, orig_a_shape[-1])
786+
787+
b = args[1].t()
788+
if len(args) > 2:
789+
bias = args[2]
790+
res = _addmm_mx_dispatch(a_2d, b, aten.addmm.default, bias)
791+
else:
792+
res = _addmm_mx_dispatch(a_2d, b, aten.mm.default)
793+
794+
# reshape back to original shape
795+
res = res.view(*orig_a_shape[:-1], res.shape[-1])
796+
return res
797+
798+
778799
@implements([aten.t.default])
779800
def mx_t(func, types, args, kwargs):
780801
# For now, only transpose(input, 0, 1) is supported.

0 commit comments

Comments
 (0)