Skip to content

Commit 720a177

Browse files
authored
unbreak CI by fixing MX tests (#2208)
Summary: #2201 broke CI: 1. some MX tests for fp4 are running on A10G instances, with skipping not being properly applied (https://hud.pytorch.org/pytorch/ao/commit/4bfd7c09ef4592eacbbf990aea6d6bda608865c1#42164784332-box) 2. some SQNR thresholds were to tight for fp4 (https://hud.pytorch.org/pytorch/ao/commit/4bfd7c09ef4592eacbbf990aea6d6bda608865c1#42164784332-box) This PR fixes both of these to get CI back to green (I hope). Note that I can't repro 1 locally, so we'll have to land and see if it works. Test Plan: CI Reviewers: Subscribers: Tasks: Tags:
1 parent 58ac6c0 commit 720a177

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

test/prototype/mx_formats/test_mx_linear.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,6 @@ def test_inference_print_str():
396396
@pytest.mark.skipif(
397397
not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+"
398398
)
399-
@pytest.mark.skipif(not is_sm_at_least_100, reason="Reqs sm100")
400399
@pytest.mark.parametrize("elem_dtype", [torch.float8_e4m3fn, torch.float4_e2m1fn_x2])
401400
@pytest.mark.parametrize("bias", [True, False])
402401
@pytest.mark.parametrize("compile", [True, False])
@@ -405,9 +404,14 @@ def test_inference_subclass(elem_dtype, bias: bool, compile: bool):
405404
"""
406405
Smoke test for inference compile
407406
"""
407+
# TODO(future): figure out why these CUDA capability conditions are not properly
408+
# applied when inside `pytest.mark.skipif` for this test
408409
if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
409410
if not is_sm_at_least_89():
410411
pytest.skip("CUDA capability >= 8.9 required for float8 in triton")
412+
elif elem_dtype == torch.float4_e2m1fn_x2:
413+
if not is_sm_at_least_100():
414+
pytest.skip("CUDA capability >= 10.0 required for float4 gemm")
411415

412416
m = nn.Linear(32, 128, bias=bias, dtype=torch.bfloat16, device="cuda")
413417
m_mx = copy.deepcopy(m)

test/prototype/mx_formats/test_mx_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def assert_sqnr_gt_threshold(orig, new, threshold):
6666
if elem_dtype is torch.float8_e4m3fn:
6767
assert_sqnr_gt_threshold(data_hp, data_mx_dq, 18.0)
6868
else:
69-
assert_sqnr_gt_threshold(data_hp, data_mx_dq, 14.0)
69+
assert_sqnr_gt_threshold(data_hp, data_mx_dq, 13.0)
7070

7171

7272
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")

0 commit comments

Comments
 (0)