Skip to content

Commit c72ebc6

Browse files
authored
move decorators to testing/utils.py (#1761)
* move decorators to testing/utils.py * add import * fix import * fix ruff formatting error * ruff fixes * ruff format * compute_capability test * update * update rest of tests * fix ruff
1 parent ed361ff commit c72ebc6

19 files changed

+67
-63
lines changed

test/dtypes/test_affine_quantized.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,12 @@
2020
quantize_,
2121
)
2222
from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain
23+
from torchao.testing.utils import skip_if_rocm
2324
from torchao.utils import (
2425
TORCH_VERSION_AT_LEAST_2_5,
2526
TORCH_VERSION_AT_LEAST_2_6,
2627
is_fbcode,
2728
is_sm_at_least_89,
28-
skip_if_rocm,
2929
)
3030

3131
is_cusparselt_available = (

test/dtypes/test_floatx.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727
fpx_weight_only,
2828
quantize_,
2929
)
30-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_fbcode, skip_if_rocm
30+
from torchao.testing.utils import skip_if_rocm
31+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_fbcode
3132

3233
_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
3334
_Floatx_DTYPES = [(3, 2), (2, 2)]

test/dtypes/test_nf4.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
nf4_weight_only,
3434
to_nf4,
3535
)
36-
from torchao.utils import skip_if_rocm
36+
from torchao.testing.utils import skip_if_rocm
3737

3838
bnb_available = False
3939

test/dtypes/test_uint4.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828
from torchao.quantization.quant_api import (
2929
_replace_with_custom_fn_if_matches_filter,
3030
)
31-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, skip_if_rocm
31+
from torchao.testing.utils import skip_if_rocm
32+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
3233

3334

3435
def _apply_weight_only_uint4_quant(model):

test/float8/test_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@
1414
import torch
1515
import torch.nn as nn
1616

17+
from torchao.testing.utils import skip_if_rocm
1718
from torchao.utils import (
1819
TORCH_VERSION_AT_LEAST_2_5,
1920
is_sm_at_least_89,
2021
is_sm_at_least_90,
21-
skip_if_rocm,
2222
)
2323

2424
if not TORCH_VERSION_AT_LEAST_2_5:

test/float8/test_float8_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
import torch
55

66
from torchao.float8.float8_utils import _round_scale_down_to_power_of_2
7-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, skip_if_rocm
7+
from torchao.testing.utils import skip_if_rocm
8+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
89

910
if not TORCH_VERSION_AT_LEAST_2_5:
1011
pytest.skip("Unsupported PyTorch version", allow_module_level=True)

test/hqq/test_hqq_affine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
quantize_,
1010
uintx_weight_only,
1111
)
12+
from torchao.testing.utils import skip_if_rocm
1213
from torchao.utils import (
1314
TORCH_VERSION_AT_LEAST_2_3,
14-
skip_if_rocm,
1515
)
1616

1717
cuda_available = torch.cuda.is_available()

test/integration/test_integration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
from torchao.quantization.utils import (
7777
compute_error as SQNR,
7878
)
79+
from torchao.testing.utils import skip_if_rocm
7980
from torchao.utils import (
8081
TORCH_VERSION_AT_LEAST_2_3,
8182
TORCH_VERSION_AT_LEAST_2_4,
@@ -85,7 +86,6 @@
8586
benchmark_model,
8687
is_fbcode,
8788
is_sm_at_least_90,
88-
skip_if_rocm,
8989
unwrap_tensor_subclass,
9090
)
9191

test/kernel/test_fused_kernels.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import torch
1212
from galore_test_utils import get_kernel, make_copy, make_data
1313

14-
from torchao.utils import skip_if_rocm
14+
from torchao.testing.utils import skip_if_rocm
1515

1616
torch.manual_seed(0)
1717
MAX_DIFF_no_tf32 = 1e-5

test/kernel/test_galore_downproj.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from torchao.prototype.galore.kernels.matmul import set_tuner_top_k as matmul_tuner_topk
1313
from torchao.prototype.galore.kernels.matmul import triton_mm_launcher
14-
from torchao.utils import skip_if_rocm
14+
from torchao.testing.utils import skip_if_rocm
1515

1616
torch.manual_seed(0)
1717

0 commit comments

Comments
 (0)