Skip to content

Commit 92ddea9

Browse files
committed
cleanup
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent f74fe3e commit 92ddea9

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

tests/test_transform/factory/test_correctness.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,13 @@
2323
from tests.testing_utils import requires_accelerate, requires_gpu
2424

2525

26-
_test_schemes = [
27-
TransformScheme(type=name) for name in TransformFactory.registered_names()
28-
] + [
29-
TransformScheme(type=name, randomize=True)
30-
for name in TransformFactory.registered_names()
31-
]
26+
def all_schemes():
27+
base = [TransformScheme(type=name) for name in TransformFactory.registered_names()]
28+
randomized = [
29+
TransformScheme(type=name, randomize=True)
30+
for name in TransformFactory.registered_names()
31+
]
32+
return base + randomized
3233

3334

3435
class TransformableModel(torch.nn.Module):
@@ -45,7 +46,7 @@ def forward(self, x):
4546
return x
4647

4748

48-
@pytest.mark.parametrize("scheme", _test_schemes)
49+
@pytest.mark.parametrize("scheme", all_schemes())
4950
def test_correctness_linear(scheme):
5051
size = (4, 8)
5152
module = torch.nn.Linear(*size, bias=True)
@@ -72,7 +73,7 @@ def test_correctness_linear(scheme):
7273
assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0)
7374

7475

75-
@pytest.mark.parametrize("scheme", _test_schemes)
76+
@pytest.mark.parametrize("scheme", all_schemes())
7677
def test_correctness_model(scheme, offload=False):
7778
# load model
7879
model = TransformableModel(2, 4, 8, 16, 32, 64)
@@ -110,6 +111,6 @@ def test_correctness_model(scheme, offload=False):
110111

111112
@requires_gpu
112113
@requires_accelerate()
113-
@pytest.mark.parametrize("scheme", _test_schemes)
114+
@pytest.mark.parametrize("scheme", all_schemes())
114115
def test_correctness_model_offload(scheme):
115116
test_correctness_model(scheme, offload=True)

0 commit comments

Comments
 (0)