35
35
"change_linear_weights_to_int8_dqtensors" ,
36
36
"change_linear_weights_to_int8_woqtensors" ,
37
37
"change_linear_weights_to_int4_woqtensors" ,
38
+ "swap_conv2d_1x1_to_linear"
38
39
]
39
40
40
41
@@ -45,19 +46,17 @@ def _replace_with_custom_fn_if_matches_filter(
45
46
For each `child` in `model`, replaces it with `replacement_fn(child)`
46
47
if `filter_fn(child)` is `True`
47
48
"""
48
- name_to_child = dict (model .named_children ())
49
- for name , child in name_to_child .items ():
50
- if cur_fqn == "" :
51
- new_fqn = name
52
- else :
53
- new_fqn = f"{ cur_fqn } .{ name } "
54
- if filter_fn (child , new_fqn ):
55
- new_child = replacement_fn (child )
56
- setattr (model , name , new_child )
57
- else :
58
- _replace_with_custom_fn_if_matches_filter (
59
- child , replacement_fn , filter_fn , new_fqn
49
+ if filter_fn (model , cur_fqn [:- 1 ]):
50
+ model = replacement_fn (model )
51
+ return model
52
+ else :
53
+ for name , child in model .named_children ():
54
+ new_child = _replace_with_custom_fn_if_matches_filter (
55
+ child , replacement_fn , filter_fn , f"{ cur_fqn } { name } ."
60
56
)
57
+ setattr (model , name , new_child )
58
+ return model
59
+
61
60
62
61
def _is_linear (mod , * args ):
63
62
return (
@@ -81,7 +80,7 @@ def apply_weight_only_int8_quant(model):
81
80
)
82
81
83
82
84
- def apply_dynamic_quant (model ):
83
+ def apply_dynamic_quant (model , filter_fn = None ):
85
84
"""
86
85
Applies dynamic symmetric per-token activation and per-channel weight
87
86
quantization to all linear layers in the given model using
@@ -90,7 +89,7 @@ def apply_dynamic_quant(model):
90
89
_replace_with_custom_fn_if_matches_filter (
91
90
model ,
92
91
lambda mod : DynamicallyPerAxisQuantizedLinear .from_float (mod ),
93
- _is_linear ,
92
+ _is_linear if filter_fn is None else filter_fn ,
94
93
)
95
94
96
95
@@ -104,18 +103,23 @@ def insert_subclass(lin):
104
103
return insert_subclass
105
104
106
105
107
- def change_linear_weights_to_int8_dqtensors (model ):
106
+ def change_linear_weights_to_int8_dqtensors (model , filter_fn = None ):
108
107
"""
109
108
Converts all linear weight tensors to the `Int8DynamicallyQuantizedLinearWeight`
110
109
Tensor subclass, effectively applying the same form of quantization
111
110
as apply_dynamic_quant while not modifying the linear modules.
112
111
"""
112
+ if filter_fn is None :
113
+ filter_fn = (
114
+ lambda * args :
115
+ _is_linear (* args ) and
116
+ _in_features_greater_than_16 (* args )
117
+ )
118
+
113
119
_replace_with_custom_fn_if_matches_filter (
114
120
model ,
115
121
_get_subclass_inserter (Int8DynamicallyQuantizedLinearWeight ),
116
- lambda * args :
117
- _is_linear (* args ) and
118
- _in_features_greater_than_16 (* args )
122
+ filter_fn
119
123
)
120
124
121
125
@@ -140,8 +144,36 @@ def change_linear_weights_to_int4_woqtensors(model, **kwargs):
140
144
effectively applying the same form of quantization
141
145
as apply_dynamic_quant while not modifying the linear modules.
142
146
"""
147
+ filter_fn = kwargs .pop ("filter_fn" , _is_linear )
148
+
143
149
_replace_with_custom_fn_if_matches_filter (
144
150
model ,
145
151
_get_subclass_inserter (Int4WeightOnlyQuantizedLinearWeight , ** kwargs ),
146
- _is_linear ,
152
+ filter_fn ,
153
+ )
154
+
155
+ def swap_conv2d_1x1_to_linear (model ):
156
+ """
157
+ Changes all conv2d 1x1 modules to equivalent linear modules so that they can then be quantized.
158
+ """
159
+ class PermuteSandwich (torch .nn .Module ):
160
+ def __init__ (self , mod ):
161
+ super ().__init__ ()
162
+ self .mod = mod
163
+
164
+ def forward (self , * args ):
165
+ return self .mod (args [0 ].permute (0 , 2 , 3 , 1 )).permute (- 0 ,3 ,1 ,2 )
166
+
167
+
168
+ def replace_conv2d_1x1 (conv ):
169
+ assert conv .kernel_size == (1 , 1 )
170
+ lin = torch .nn .Linear (conv .in_channels , conv .out_channels , bias = (conv .bias is None ))
171
+ lin .weight = torch .nn .Parameter (conv .weight .squeeze (- 1 ,- 2 ))
172
+ lin .bias = conv .bias
173
+ return PermuteSandwich (lin )
174
+
175
+ _replace_with_custom_fn_if_matches_filter (
176
+ model ,
177
+ replace_conv2d_1x1 ,
178
+ filter_fn = lambda mod , * args : isinstance (mod , torch .nn .Conv2d ) and mod .kernel_size == (1 ,1 )
147
179
)
0 commit comments