25
25
from tests .testing_utils import requires_accelerate , requires_gpu
26
26
27
27
28
- def scheme_kwargs ():
29
- all_types = TransformFactory .registered_names ()
30
- base = [{"type" : type } for type in all_types ]
31
- randomized = [{"type" : type , "randomize" : True } for type in all_types ]
32
- return base + randomized
33
-
34
-
35
- @pytest .mark .parametrize ("scheme_kwargs" , scheme_kwargs ())
36
- def test_correctness_linear (scheme_kwargs ):
28
+ @pytest .mark .parametrize ("type" , TransformFactory .registered_names ())
29
+ @pytest .mark .parametrize ("randomized" , (True , False ))
30
+ def test_correctness_linear (type , randomized ):
37
31
size = (4 , 8 )
38
32
module = torch .nn .Linear (* size , bias = True )
39
- scheme = TransformScheme (** scheme_kwargs )
33
+ scheme = TransformScheme (type = type , randomized = randomized )
40
34
factory = TransformFactory .from_scheme (scheme , name = "" )
41
35
42
36
input_tfm = factory .create_transform (
@@ -60,8 +54,9 @@ def test_correctness_linear(scheme_kwargs):
60
54
assert torch .allclose (true_output , output , atol = 1e-5 , rtol = 0.0 )
61
55
62
56
63
- @pytest .mark .parametrize ("scheme_kwargs" , scheme_kwargs ())
64
- def test_correctness_model (scheme_kwargs , model_apply , offload = False ):
57
+ @pytest .mark .parametrize ("type" , TransformFactory .registered_names ())
58
+ @pytest .mark .parametrize ("randomized" , (True , False ))
59
+ def test_correctness_model (type , randomized , model_apply , offload = False ):
65
60
# load model
66
61
model = model_apply [0 ]
67
62
if offload :
@@ -76,10 +71,7 @@ def test_correctness_model(scheme_kwargs, model_apply, offload=False):
76
71
# apply transforms
77
72
config = TransformConfig (
78
73
config_groups = {
79
- "" : TransformScheme (
80
- ** scheme_kwargs ,
81
- apply = model_apply [1 ],
82
- )
74
+ "" : TransformScheme (type = type , randomized = randomized , apply = model_apply [1 ])
83
75
}
84
76
)
85
77
apply_transform_config (model , config )
@@ -91,6 +83,7 @@ def test_correctness_model(scheme_kwargs, model_apply, offload=False):
91
83
92
84
@requires_gpu
93
85
@requires_accelerate ()
94
- @pytest .mark .parametrize ("scheme_kwargs" , scheme_kwargs ())
95
- def test_correctness_model_offload (scheme_kwargs , model_apply ):
96
- test_correctness_model (scheme_kwargs , model_apply , offload = True )
86
+ @pytest .mark .parametrize ("type" , TransformFactory .registered_names ())
87
+ @pytest .mark .parametrize ("randomized" , (True , False ))
88
+ def test_correctness_model_offload (type , randomized , model_apply ):
89
+ test_correctness_model (type , randomized , model_apply , offload = True )
0 commit comments