26
26
from tests .testing_utils import requires_accelerate , requires_gpu
27
27
28
28
29
- _test_schemes = [
30
- TransformScheme (type = name ) for name in TransformFactory .registered_names ()
31
- ] + [
32
- TransformScheme (type = name , randomize = True )
33
- for name in TransformFactory .registered_names ()
34
- ]
29
+ def all_schemes ():
30
+ base = [TransformScheme (type = name ) for name in TransformFactory .registered_names ()]
31
+ randomized = [
32
+ TransformScheme (type = name , randomize = True )
33
+ for name in TransformFactory .registered_names ()
34
+ ]
35
+ return base + randomized
35
36
36
37
37
38
class TransformableModel (torch .nn .Module ):
@@ -48,7 +49,7 @@ def forward(self, x):
48
49
return x
49
50
50
51
51
- @pytest .mark .parametrize ("scheme" , _test_schemes )
52
+ @pytest .mark .parametrize ("scheme" , all_schemes () )
52
53
def test_memory_sharing (scheme , offload = False ):
53
54
# load scheme and factory
54
55
scheme = TransformScheme (
@@ -98,12 +99,12 @@ def test_memory_sharing(scheme, offload=False):
98
99
99
100
@requires_gpu
100
101
@requires_accelerate ()
101
- @pytest .mark .parametrize ("scheme" , _test_schemes )
102
+ @pytest .mark .parametrize ("scheme" , all_schemes () )
102
103
def test_memory_sharing_offload (scheme ):
103
104
test_memory_sharing (scheme , offload = True )
104
105
105
106
106
- @pytest .mark .parametrize ("scheme" , _test_schemes )
107
+ @pytest .mark .parametrize ("scheme" , all_schemes () )
107
108
def test_memory_sharing_training (scheme ):
108
109
scheme .requires_grad = True
109
110
test_memory_sharing (scheme , offload = False )
0 commit comments