31
31
compute_max_diff,
32
32
)
33
33
34
- if torch.version.hip is not None:
35
- pytest.skip("Skipping the test in ROCm", allow_module_level=True)
34
+ IS_CUDA = torch.cuda.is_available() and torch.version.cuda
35
+ IS_ROCM = torch.cuda.is_available() and torch.version.hip
36
36
37
37
try:
38
38
import torchao.ops
@@ -58,7 +58,7 @@ def _create_floatx_inputs(
58
58
fp16_act = torch.rand(BS, IC).to(dtype) + 0.5
59
59
return floatx_weight.to(device), scale.to(device), fp16_act.to(device)
60
60
61
- @pytest.mark.skipif(not torch.cuda.is_available() , reason="CUDA not available")
61
+ @pytest.mark.skipif(not IS_CUDA , reason="CUDA not available")
62
62
@parametrize("ebits,mbits", [(3, 2), (2, 2)])
63
63
@parametrize("dtype", [torch.half, torch.bfloat16])
64
64
def test_quant_llm_linear(self, ebits, mbits, dtype):
@@ -88,7 +88,7 @@ def test_quant_llm_linear(self, ebits, mbits, dtype):
88
88
test_utils=test_utils,
89
89
)
90
90
91
- @pytest.mark.skipif(not torch.cuda.is_available() , reason="CUDA not available")
91
+ @pytest.mark.skipif(not IS_CUDA , reason="CUDA not available")
92
92
@parametrize("BS,OC,IC,splitK", [(1, 2048, 4096, 5), (2, 8192, 8192, 6)])
93
93
@parametrize("ebits,mbits", [(3, 2), (2, 2)])
94
94
@parametrize("dtype", [torch.half, torch.bfloat16])
@@ -278,7 +278,7 @@ def make_test_id(param):
278
278
return f"tiles_{param}"
279
279
280
280
281
- @pytest.mark.skipif(not torch.cuda.is_available() , reason="CUDA not available")
281
+ @pytest.mark.skipif(not IS_CUDA , reason="CUDA not available")
282
282
# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+")
283
283
@pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK, ids=make_test_id)
284
284
def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles):
@@ -296,7 +296,7 @@ def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles):
296
296
297
297
298
298
# TODO: Fix "test_aot_dispatch_dynamic" test failure
299
- @pytest.mark.skipif(not torch.cuda.is_available() , reason="CUDA not available")
299
+ @pytest.mark.skipif(not IS_CUDA , reason="CUDA not available")
300
300
# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+")
301
301
@pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK, ids=make_test_id)
302
302
def test_unpack_tensor_core_tiled_layout_op(shape, inner_k_tiles):
@@ -342,7 +342,7 @@ def dequant_ref(q, scales, zeros, group_size, nbits=4, dtype=torch.bfloat16):
342
342
return dq.reshape(n, k)
343
343
344
344
345
- @pytest.mark.skipif(not torch.cuda.is_available() , reason="CUDA not available")
345
+ @pytest.mark.skipif(not IS_CUDA , reason="CUDA not available")
346
346
# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+")
347
347
@pytest.mark.parametrize(
348
348
"shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str
@@ -410,7 +410,7 @@ def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant(
410
410
411
411
412
412
# This test differs from one above in that it uses `unpack_tensor_core_tiled_layout` to unpack then dequantize
413
- @pytest.mark.skipif(not torch.cuda.is_available() , reason="CUDA not available")
413
+ @pytest.mark.skipif(not IS_CUDA , reason="CUDA not available")
414
414
# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+")
415
415
@pytest.mark.parametrize(
416
416
"shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str
@@ -476,7 +476,7 @@ def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant(
476
476
assert diff_op_ao < 1e-1
477
477
478
478
479
- @pytest.mark.skipif(not torch.cuda.is_available() , reason="CUDA not available")
479
+ @pytest.mark.skipif(not IS_CUDA , reason="CUDA not available")
480
480
# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+")
481
481
@pytest.mark.parametrize(
482
482
"shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str
@@ -587,7 +587,7 @@ def reshape_w(w):
587
587
)
588
588
589
589
590
- @pytest.mark.skipif(not torch.cuda.is_available() , reason="CUDA not available")
590
+ @pytest.mark.skipif(not IS_CUDA , reason="CUDA not available")
591
591
@pytest.mark.parametrize(
592
592
"batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors",
593
593
MARLIN_TEST_PARAMS,
@@ -677,7 +677,7 @@ def test_marlin_24(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_facto
677
677
)
678
678
679
679
680
- @pytest.mark.skipif(not torch.cuda.is_available() , reason="CUDA not available")
680
+ @pytest.mark.skipif(not IS_CUDA , reason="CUDA not available")
681
681
@pytest.mark.parametrize(
682
682
"batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors",
683
683
MARLIN_TEST_PARAMS,
@@ -756,5 +756,27 @@ def test_marlin_qqq(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_fact
756
756
)
757
757
758
758
759
+ @pytest.mark.skipif(not IS_ROCM, reason="ROCm not available")
760
+ def test_swizzle_mm():
761
+ test_utils = [
762
+ "test_schema",
763
+ "test_autograd_registration",
764
+ "test_faketensor",
765
+ ]
766
+
767
+ # TODO: Figure out why test fails unless torch >= 2.5
768
+ if TORCH_VERSION_AT_LEAST_2_5:
769
+ test_utils.append("test_aot_dispatch_dynamic")
770
+
771
+ mat1 = torch.randint(0, 16, dtype=torch.float, size=(16, 32), device="cuda")
772
+ mat2 = torch.randint(0, 16, dtype=torch.float, size=(32, 16), device="cuda")
773
+
774
+ opcheck(
775
+ torch.ops.torchao.swizzle_mm,
776
+ (mat1, mat2, False, False),
777
+ test_utils=test_utils,
778
+ )
779
+
780
+
759
781
if __name__ == "__main__":
760
782
pytest.main(sys.argv)
0 commit comments