Skip to content

Commit 779956f

Browse files
committed
cleanup 2
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 92ddea9 commit 779956f

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

tests/test_transform/factory/test_memory.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,13 @@
2626
from tests.testing_utils import requires_accelerate, requires_gpu
2727

2828

29-
_test_schemes = [
30-
TransformScheme(type=name) for name in TransformFactory.registered_names()
31-
] + [
32-
TransformScheme(type=name, randomize=True)
33-
for name in TransformFactory.registered_names()
34-
]
29+
def all_schemes():
30+
base = [TransformScheme(type=name) for name in TransformFactory.registered_names()]
31+
randomized = [
32+
TransformScheme(type=name, randomize=True)
33+
for name in TransformFactory.registered_names()
34+
]
35+
return base + randomized
3536

3637

3738
class TransformableModel(torch.nn.Module):
@@ -48,7 +49,7 @@ def forward(self, x):
4849
return x
4950

5051

51-
@pytest.mark.parametrize("scheme", _test_schemes)
52+
@pytest.mark.parametrize("scheme", all_schemes())
5253
def test_memory_sharing(scheme, offload=False):
5354
# load scheme and factory
5455
scheme = TransformScheme(
@@ -98,12 +99,12 @@ def test_memory_sharing(scheme, offload=False):
9899

99100
@requires_gpu
100101
@requires_accelerate()
101-
@pytest.mark.parametrize("scheme", _test_schemes)
102+
@pytest.mark.parametrize("scheme", all_schemes())
102103
def test_memory_sharing_offload(scheme):
103104
test_memory_sharing(scheme, offload=True)
104105

105106

106-
@pytest.mark.parametrize("scheme", _test_schemes)
107+
@pytest.mark.parametrize("scheme", all_schemes())
107108
def test_memory_sharing_training(scheme):
108109
scheme.requires_grad = True
109110
test_memory_sharing(scheme, offload=False)

0 commit comments

Comments
 (0)