Skip to content

Commit ed361ff

Browse files
authored
[Reland] ROCm CI (Infra + Skips) (#1581)
This PR to skip the unit test failures for ROCm + infra changes to enable ROCm CI. **NOTE:** This PR aims to enable the ROCm CI testing for torchao _only for pushes to main branch_. The ROCm tests should start showing up here once this PR is merged: https://hud.pytorch.org/hud/pytorch/ao/main/1?per_page=50&name_filter=regression Torchao PRs can also trigger the ROCm CI runs using the `ciflow/rocm` PR label (#1749). Enabling ROCm CI testing on *all* torchao PRs will be done in a follow-up PR. This pull request introduces the `skip_if_rocm` decorator across various test files to skip tests that are not yet supported on ROCm. The changes ensure that tests are conditionally skipped if ROCm is detected, improving the test suite's compatibility with different environments. # Key changes include: ### Cherry-pick ROCm CI infra changes from #999 ### Configure workflow to trigger ROCm CI only for pushes to main branch, OR on PRs with the `ciflow/rocm` label ### Introduction of `skip_if_rocm` decorator: * Added `skip_if_rocm` import in multiple test files to conditionally skip tests not supported on ROCm. (`test/dtypes/test_affine_quantized.py`, `test/dtypes/test_floatx.py`, `test/float8/test_base.py`, `test/hqq/test_hqq_affine.py`, `test/integration/test_integration.py`, `test/kernel/test_galore_downproj.py`, `test/prototype/test_awq.py`, `test/prototype/test_low_bit_optim.py`, `test/prototype/test_splitk.py`, `test/quantization/test_galore_quant.py`, `test/quantization/test_marlin_qqq.py`, `test/sparsity/test_marlin.py`, `test/test_ops.py`, `test/test_s8s4_linear_cutlass.py`, `torchao/utils.py`) [[1]](diffhunk://#diff-31b1ffcd78674b79cc65749176354ea4743683070120034709c1da7a3eac31f6R24) [[2]](diffhunk://#diff-0e811fa3416cd87d9a25b4fb680890098c69aa33ca4db4d347d4a10cc41e0eb3L30-R30) [[3]](diffhunk://#diff-05925b4469eb63ab854cc9891f088f570fa3822cdaeb4de109e0b1b9ab5038a7R21) [[4]](diffhunk://#diff-a9708dc28f15bb9cf665417e6c66601f9e8e2f1f672d1858603b74fa879a3357R13) [[5]](diffhunk://#diff-a977c33299f20a626cf650b2b6f0a49ef8fad7c97be21a5618e600b588b14b15R83) [[6]](diffhunk://#diff-4b0ddf8d1e85f4b4f1067f8d1d3e6b4d48785b3675c7202bf49bfbb1079d682fR14) [[7]](diffhunk://#diff-66249d5a8ed995b0a8e22c6354d6b270c5feeb982cb79a28f7c1b929700e89f4L8-R12) [[8]](diffhunk://#diff-244d33d1e8c30e765556011a4d3b76509f61433a346ba12ffc3115144e895aedR33) [[9]](diffhunk://#diff-2bcf3336ff64bfef786e6126813db46040b93628cab5faff3f0f5ed2cb077bf2L16-R24) [[10]](diffhunk://#diff-51ddab022797064be44ca38c87a56c6e87cd69444f4c6151a11b7f0141aef2b9R21) [[11]](diffhunk://#diff-133d8c7492ee2e7536328c8391545610750774e43d128d258380cb6787bb9e93L22-R22) [[12]](diffhunk://#diff-a58427e02fb5b05d26e03e8c2d216e5ae379d82084fd14bf77ea127b5505a43cL18-R18) [[13]](diffhunk://#diff-d183f2afc51d6a59bc70094e8f476d2468c45e415500f6eb60abad955e065156R22-R24) [[14]](diffhunk://#diff-85cc98d31eb8056e082ebdfbf2979aaa046ffc08bbacd4a65a31795b51998645R10-R12) [[15]](diffhunk://#diff-d2a11602a79e83305208472f1abe6a4106f02ce62a7f9524007181813863fcf6R10) ### Application of `skip_if_rocm` decorator: * Applied `@skip_if_rocm("ROCm development in progress")` to multiple test functions to skip them when running on ROCm. (`test/dtypes/test_affine_quantized.py`, `test/dtypes/test_floatx.py`, `test/float8/test_base.py`, `test/hqq/test_hqq_affine.py`, `test/integration/test_integration.py`, `test/kernel/test_galore_downproj.py`, `test/prototype/test_awq.py`, `test/prototype/test_low_bit_optim.py`, `test/prototype/test_splitk.py`, `test/quantization/test_galore_quant.py`, `test/quantization/test_marlin_qqq.py`, `test/sparsity/test_marlin.py`) [[1]](diffhunk://#diff-31b1ffcd78674b79cc65749176354ea4743683070120034709c1da7a3eac31f6R93) [[2]](diffhunk://#diff-31b1ffcd78674b79cc65749176354ea4743683070120034709c1da7a3eac31f6R173) [[3]](diffhunk://#diff-31b1ffcd78674b79cc65749176354ea4743683070120034709c1da7a3eac31f6R186) [[4]](diffhunk://#diff-0e811fa3416cd87d9a25b4fb680890098c69aa33ca4db4d347d4a10cc41e0eb3R111) [[5]](diffhunk://#diff-05925b4469eb63ab854cc9891f088f570fa3822cdaeb4de109e0b1b9ab5038a7R427) [[6]](diffhunk://#diff-a9708dc28f15bb9cf665417e6c66601f9e8e2f1f672d1858603b74fa879a3357R114) [[7]](diffhunk://#diff-a977c33299f20a626cf650b2b6f0a49ef8fad7c97be21a5618e600b588b14b15R571) [[8]](diffhunk://#diff-a977c33299f20a626cf650b2b6f0a49ef8fad7c97be21a5618e600b588b14b15R690) [[9]](diffhunk://#diff-a977c33299f20a626cf650b2b6f0a49ef8fad7c97be21a5618e600b588b14b15R710) [[10]](diffhunk://#diff-a977c33299f20a626cf650b2b6f0a49ef8fad7c97be21a5618e600b588b14b15R904) [[11]](diffhunk://#diff-a977c33299f20a626cf650b2b6f0a49ef8fad7c97be21a5618e600b588b14b15R924) [[12]](diffhunk://#diff-4b0ddf8d1e85f4b4f1067f8d1d3e6b4d48785b3675c7202bf49bfbb1079d682fR33) [[13]](diffhunk://#diff-66249d5a8ed995b0a8e22c6354d6b270c5feeb982cb79a28f7c1b929700e89f4R120) [[14]](diffhunk://#diff-244d33d1e8c30e765556011a4d3b76509f61433a346ba12ffc3115144e895aedR116) [[15]](diffhunk://#diff-2bcf3336ff64bfef786e6126813db46040b93628cab5faff3f0f5ed2cb077bf2L16-R24) [[16]](diffhunk://#diff-51ddab022797064be44ca38c87a56c6e87cd69444f4c6151a11b7f0141aef2b9R86) [[17]](diffhunk://#diff-133d8c7492ee2e7536328c8391545610750774e43d128d258380cb6787bb9e93R48) [[18]](diffhunk://#diff-133d8c7492ee2e7536328c8391545610750774e43d128d258380cb6787bb9e93R70) [[19]](diffhunk://#diff-a58427e02fb5b05d26e03e8c2d216e5ae379d82084fd14bf77ea127b5505a43cR40) [[20]](diffhunk://#diff-a58427e02fb5b05d26e03e8c2d216e5ae379d82084fd14bf77ea127b5505a43cL51-R58) ### Module-level skips for ROCm: * Added module-level skips for ROCm in specific test files to skip all tests within the module if ROCm is detected. (`test/test_ops.py`, `test/test_s8s4_linear_cutlass.py`) [[1]](diffhunk://#diff-d183f2afc51d6a59bc70094e8f476d2468c45e415500f6eb60abad955e065156R22-R24) [[2]](diffhunk://#diff-85cc98d31eb8056e082ebdfbf2979aaa046ffc08bbacd4a65a31795b51998645R10-R12)
1 parent 878ec7a commit ed361ff

25 files changed

+153
-13
lines changed
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
name: Run Regression Tests on ROCm
2+
3+
on:
4+
push:
5+
branches:
6+
- main
7+
tags:
8+
- ciflow/rocm/*
9+
10+
concurrency:
11+
group: regression_test-${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }}
12+
cancel-in-progress: true
13+
14+
env:
15+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
16+
17+
jobs:
18+
test-nightly:
19+
strategy:
20+
fail-fast: false
21+
matrix:
22+
include:
23+
- name: ROCM Nightly
24+
runs-on: linux.rocm.gpu.torchao
25+
torch-spec: '--pre torch==2.7.0.dev20250122 --index-url https://download.pytorch.org/whl/nightly/rocm6.3'
26+
gpu-arch-type: "rocm"
27+
gpu-arch-version: "6.3"
28+
29+
permissions:
30+
id-token: write
31+
contents: read
32+
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
33+
with:
34+
timeout: 120
35+
no-sudo: ${{ matrix.gpu-arch-type == 'rocm' }}
36+
runner: ${{ matrix.runs-on }}
37+
gpu-arch-type: ${{ matrix.gpu-arch-type }}
38+
gpu-arch-version: ${{ matrix.gpu-arch-version }}
39+
submodules: recursive
40+
script: |
41+
conda create -n venv python=3.9 -y
42+
conda activate venv
43+
python -m pip install --upgrade pip
44+
pip install ${{ matrix.torch-spec }}
45+
pip install -r dev-requirements.txt
46+
pip install .
47+
export CONDA=$(dirname $(dirname $(which conda)))
48+
export LD_LIBRARY_PATH=$CONDA/lib/:$LD_LIBRARY_PATH
49+
pytest test --verbose -s

test/dtypes/test_affine_quantized.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
TORCH_VERSION_AT_LEAST_2_6,
2626
is_fbcode,
2727
is_sm_at_least_89,
28+
skip_if_rocm,
2829
)
2930

3031
is_cusparselt_available = (
@@ -104,6 +105,7 @@ def test_tensor_core_layout_transpose(self):
104105
"apply_quant",
105106
get_quantization_functions(is_cusparselt_available, True, "cuda", True),
106107
)
108+
@skip_if_rocm("ROCm enablement in progress")
107109
def test_weights_only(self, apply_quant):
108110
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
109111
if isinstance(apply_quant, AOBaseConfig):
@@ -196,6 +198,7 @@ def apply_uint6_weight_only_quant(linear):
196198
"apply_quant", get_quantization_functions(is_cusparselt_available, True)
197199
)
198200
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
201+
@skip_if_rocm("ROCm enablement in progress")
199202
def test_print_quantized_module(self, apply_quant):
200203
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
201204
if isinstance(apply_quant, AOBaseConfig):
@@ -213,6 +216,7 @@ class TestAffineQuantizedBasic(TestCase):
213216

214217
@common_utils.parametrize("device", COMMON_DEVICES)
215218
@common_utils.parametrize("dtype", COMMON_DTYPES)
219+
@skip_if_rocm("ROCm enablement in progress")
216220
def test_flatten_unflatten(self, device, dtype):
217221
if device == "cuda" and dtype == torch.bfloat16 and is_fbcode():
218222
raise unittest.SkipTest("TODO: Failing for cuda + bfloat16 in fbcode")

test/dtypes/test_affine_quantized_tensor_parallel.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import unittest
22

3+
import pytest
34
import torch
45
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard
56
from torch.testing._internal import common_utils
@@ -27,6 +28,9 @@
2728
except ModuleNotFoundError:
2829
has_gemlite = False
2930

31+
if torch.version.hip is not None:
32+
pytest.skip("Skipping the test in ROCm", allow_module_level=True)
33+
3034

3135
class TestAffineQuantizedTensorParallel(DTensorTestBase):
3236
"""Basic test case for tensor subclasses"""

test/dtypes/test_floatx.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
fpx_weight_only,
2828
quantize_,
2929
)
30-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_fbcode
30+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_fbcode, skip_if_rocm
3131

3232
_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
3333
_Floatx_DTYPES = [(3, 2), (2, 2)]
@@ -109,6 +109,7 @@ def test_to_copy_device(self, ebits, mbits):
109109
@parametrize("bias", [False, True])
110110
@parametrize("dtype", [torch.half, torch.bfloat16])
111111
@unittest.skipIf(is_fbcode(), reason="broken in fbcode")
112+
@skip_if_rocm("ROCm enablement in progress")
112113
def test_fpx_weight_only(self, ebits, mbits, bias, dtype):
113114
N, OC, IC = 4, 256, 64
114115
device = "cuda"

test/dtypes/test_nf4.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
nf4_weight_only,
3434
to_nf4,
3535
)
36+
from torchao.utils import skip_if_rocm
3637

3738
bnb_available = False
3839

@@ -111,6 +112,7 @@ def test_backward_dtype_match(self, dtype: torch.dtype):
111112

112113
@unittest.skipIf(not bnb_available, "Need bnb availble")
113114
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
115+
@skip_if_rocm("ROCm enablement in progress")
114116
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
115117
def test_reconstruction_qlora_vs_bnb(self, dtype: torch.dtype):
116118
# From https://github.com/drisspg/transformer_nuggets/blob/f05afad68ad9086d342268f46a7f344617a02314/test/test_qlora.py#L65C1-L81C47
@@ -133,6 +135,7 @@ def test_reconstruction_qlora_vs_bnb(self, dtype: torch.dtype):
133135

134136
@unittest.skipIf(not bnb_available, "Need bnb availble")
135137
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
138+
@skip_if_rocm("ROCm enablement in progress")
136139
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
137140
def test_nf4_bnb_linear(self, dtype: torch.dtype):
138141
"""

test/dtypes/test_uint4.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from torchao.quantization.quant_api import (
2929
_replace_with_custom_fn_if_matches_filter,
3030
)
31-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
31+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, skip_if_rocm
3232

3333

3434
def _apply_weight_only_uint4_quant(model):
@@ -92,6 +92,7 @@ def test_basic_tensor_ops(self):
9292
# only test locally
9393
# print("x:", x[0])
9494

95+
@skip_if_rocm("ROCm enablement in progress")
9596
def test_gpu_quant(self):
9697
for x_shape in [[2, 4], [5, 5, 5, 4], [1, 4, 4]]:
9798
x = torch.randn(*x_shape)
@@ -104,6 +105,7 @@ def test_gpu_quant(self):
104105
# make sure it runs
105106
opt(x)
106107

108+
@skip_if_rocm("ROCm enablement in progress")
107109
def test_pt2e_quant(self):
108110
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
109111
QuantizationConfig,

test/float8/test_base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
TORCH_VERSION_AT_LEAST_2_5,
1919
is_sm_at_least_89,
2020
is_sm_at_least_90,
21+
skip_if_rocm,
2122
)
2223

2324
if not TORCH_VERSION_AT_LEAST_2_5:
@@ -426,6 +427,7 @@ def test_linear_from_config_params(
426427
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
427428
@pytest.mark.parametrize("linear_bias", [True, False])
428429
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
430+
@skip_if_rocm("ROCm enablement in progress")
429431
def test_linear_from_recipe(
430432
self,
431433
recipe_name,

test/float8/test_float8_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55

66
from torchao.float8.float8_utils import _round_scale_down_to_power_of_2
7-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
7+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, skip_if_rocm
88

99
if not TORCH_VERSION_AT_LEAST_2_5:
1010
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
@@ -30,6 +30,7 @@
3030
# ("largest subnormal number", [2**-126 * (1 - 2**-23), 1.1754943508222875e-38]),
3131
],
3232
)
33+
@skip_if_rocm("ROCm enablement in progress")
3334
def test_round_scale_down_to_power_of_2_valid_inputs(
3435
test_case: dict,
3536
):

test/float8/test_fsdp2/test_fsdp2.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@
4343
if not is_sm_at_least_89():
4444
pytest.skip("Unsupported CUDA device capability version", allow_module_level=True)
4545

46+
if torch.version.hip is not None:
47+
pytest.skip("ROCm enablement in progress", allow_module_level=True)
48+
4649

4750
class TestFloat8Common:
4851
def broadcast_module(self, module: nn.Module) -> None:

test/hqq/test_hqq_affine.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
)
1212
from torchao.utils import (
1313
TORCH_VERSION_AT_LEAST_2_3,
14+
skip_if_rocm,
1415
)
1516

1617
cuda_available = torch.cuda.is_available()
@@ -109,6 +110,7 @@ def test_hqq_plain_5bit(self):
109110
ref_dot_product_error=0.000704,
110111
)
111112

113+
@skip_if_rocm("ROCm enablement in progress")
112114
def test_hqq_plain_4bit(self):
113115
self._test_hqq(
114116
dtype=torch.uint4,

0 commit comments

Comments
 (0)