16
16
import torch
17
17
from compressed_tensors .transform import (
18
18
TransformArgs ,
19
+ TransformConfig ,
19
20
TransformFactory ,
20
21
TransformScheme ,
21
22
)
22
23
from compressed_tensors .utils import offloaded_dispatch
23
24
from tests .testing_utils import requires_accelerate , requires_gpu
24
25
25
26
26
- class TransformableModel (torch .nn .Module ):
27
- def __init__ (self , * sizes ):
28
- super ().__init__ ()
29
- self .fcs = torch .nn .ModuleList ([])
30
- self .fcs .append (torch .nn .Linear (sizes [0 ], sizes [1 ], bias = False ))
31
- for index in range (1 , len (sizes ) - 1 ):
32
- self .fcs .append (torch .nn .Linear (sizes [index ], sizes [index + 1 ], bias = False ))
27
+ def scheme_kwargs ():
28
+ all_types = TransformFactory .registered_names ()
29
+ base = [{"type" : type } for type in all_types ]
30
+ randomized = [{"type" : type , "randomize" : True } for type in all_types ]
31
+ return base + randomized
33
32
34
- def forward (self , x ):
35
- for layer in self .fcs :
36
- x = layer (x )
37
- return x
38
33
39
-
40
- @pytest .mark .parametrize (
41
- "scheme" ,
42
- [TransformScheme (type = name ) for name in TransformFactory .registered_names ()],
43
- )
44
- def test_correctness_linear (scheme ):
34
+ @pytest .mark .parametrize ("scheme_kwargs" , scheme_kwargs ())
35
+ def test_correctness_linear (scheme_kwargs ):
45
36
size = (4 , 8 )
46
37
module = torch .nn .Linear (* size , bias = True )
38
+ scheme = TransformScheme (** scheme_kwargs )
47
39
factory = TransformFactory .from_scheme (scheme , name = "" )
48
40
49
41
input_tfm = factory .create_transform (
@@ -67,50 +59,39 @@ def test_correctness_linear(scheme):
67
59
assert torch .allclose (true_output , output , atol = 1e-5 , rtol = 0.0 )
68
60
69
61
70
- @pytest .mark .parametrize (
71
- "scheme" ,
72
- [TransformScheme (type = name ) for name in TransformFactory .registered_names ()],
73
- )
74
- def test_correctness_model (scheme , offload = False ):
62
+ @pytest .mark .parametrize ("scheme_kwargs" , scheme_kwargs ())
63
+ def test_correctness_model (scheme_kwargs , model_apply , offload = False ):
75
64
# load model
76
- model = TransformableModel ( 2 , 4 , 8 , 16 , 32 , 64 )
65
+ model = model_apply [ 0 ]
77
66
if offload :
78
67
model = offloaded_dispatch (model , torch .device ("cuda" ))
79
68
80
- # create factory
81
- scheme .apply = [
82
- # weight output -> input
83
- TransformArgs (targets = "fcs.0" , location = "weight_output" ),
84
- TransformArgs (targets = "fcs.1" , location = "input" , inverse = True ),
85
- # output -> weight input
86
- TransformArgs (targets = "fcs.1" , location = "output" ),
87
- TransformArgs (targets = "fcs.2" , location = "weight_input" , inverse = True ),
88
- # output -> input
89
- TransformArgs (targets = "fcs.2" , location = "output" ),
90
- TransformArgs (targets = "fcs.3" , location = "input" , inverse = True ),
91
- # weight output -> weight input
92
- TransformArgs (targets = "fcs.3" , location = "weight_output" ),
93
- TransformArgs (targets = "fcs.4" , location = "weight_input" , inverse = True ),
94
- ]
95
- factory = TransformFactory .from_scheme (scheme , name = "" )
96
-
97
- # create inputs
69
+ # get output
98
70
input = torch .rand ((17 , model .fcs [0 ].in_features ))
99
71
if offload :
100
72
input = input .to (torch .device ("cuda" ))
73
+ true_output = model (input )
74
+
75
+ # apply transforms
76
+ config = TransformConfig (
77
+ config_groups = {
78
+ "" : TransformScheme (
79
+ ** scheme_kwargs ,
80
+ apply = model_apply [1 ],
81
+ )
82
+ }
83
+ )
84
+ for name , scheme in config .config_groups .items ():
85
+ factory = TransformFactory .from_scheme (scheme , name = name )
86
+ factory .apply_to_model (model )
101
87
102
88
# compare outputs
103
- true_output = model (input )
104
- factory .apply_to_model (model )
105
89
output = model (input )
106
90
assert torch .allclose (true_output , output , atol = 1e-5 , rtol = 0.0 )
107
91
108
92
109
93
@requires_gpu
110
94
@requires_accelerate ()
111
- @pytest .mark .parametrize (
112
- "scheme" ,
113
- [TransformScheme (type = name ) for name in TransformFactory .registered_names ()],
114
- )
115
- def test_correctness_model_offload (scheme ):
116
- test_correctness_model (scheme , offload = True )
95
+ @pytest .mark .parametrize ("scheme_kwargs" , scheme_kwargs ())
96
+ def test_correctness_model_offload (scheme_kwargs , model_apply ):
97
+ test_correctness_model (scheme_kwargs , model_apply , offload = True )
0 commit comments