23
23
from tests .testing_utils import requires_accelerate , requires_gpu
24
24
25
25
26
- _test_schemes = [
27
- TransformScheme (type = name ) for name in TransformFactory .registered_names ()
28
- ] + [
29
- TransformScheme (type = name , randomize = True )
30
- for name in TransformFactory .registered_names ()
31
- ]
26
+ def all_schemes ():
27
+ base = [TransformScheme (type = name ) for name in TransformFactory .registered_names ()]
28
+ randomized = [
29
+ TransformScheme (type = name , randomize = True )
30
+ for name in TransformFactory .registered_names ()
31
+ ]
32
+ return base + randomized
32
33
33
34
34
35
class TransformableModel (torch .nn .Module ):
@@ -45,7 +46,7 @@ def forward(self, x):
45
46
return x
46
47
47
48
48
- @pytest .mark .parametrize ("scheme" , _test_schemes )
49
+ @pytest .mark .parametrize ("scheme" , all_schemes () )
49
50
def test_correctness_linear (scheme ):
50
51
size = (4 , 8 )
51
52
module = torch .nn .Linear (* size , bias = True )
@@ -72,7 +73,7 @@ def test_correctness_linear(scheme):
72
73
assert torch .allclose (true_output , output , atol = 1e-5 , rtol = 0.0 )
73
74
74
75
75
- @pytest .mark .parametrize ("scheme" , _test_schemes )
76
+ @pytest .mark .parametrize ("scheme" , all_schemes () )
76
77
def test_correctness_model (scheme , offload = False ):
77
78
# load model
78
79
model = TransformableModel (2 , 4 , 8 , 16 , 32 , 64 )
@@ -110,6 +111,6 @@ def test_correctness_model(scheme, offload=False):
110
111
111
112
@requires_gpu
112
113
@requires_accelerate ()
113
- @pytest .mark .parametrize ("scheme" , _test_schemes )
114
+ @pytest .mark .parametrize ("scheme" , all_schemes () )
114
115
def test_correctness_model_offload (scheme ):
115
116
test_correctness_model (scheme , offload = True )
0 commit comments