5
5
# LICENSE file in the root directory of this source tree.
6
6
import types
7
7
from dataclasses import dataclass
8
- from typing import Optional
8
+ from typing import List , Optional
9
9
10
10
import torch
11
11
30
30
ZeroPointDomain ,
31
31
)
32
32
from torchao .quantization .transform_module import (
33
+ _QUANTIZE_CONFIG_HANDLER ,
33
34
register_quantize_module_handler ,
34
35
)
36
+ from torchao .utils import DummyModule
35
37
36
38
from .core import (
37
39
AWQObservedLinear ,
38
40
AWQObserver ,
41
+ AWQObserver2 ,
39
42
)
40
43
41
44
assert len (_DTYPE_TO_BIT_WIDTH ) > 0 , (
@@ -50,6 +53,7 @@ def insert_awq_observer_(
50
53
quant_dtype : torch .dtype = torch .uint4 ,
51
54
scale_search_space_size : int = 20 ,
52
55
group_size : int = 128 ,
56
+ base_config : Optional [AOBaseConfig ] = None ,
53
57
):
54
58
"""
55
59
Inserts AWQObserver into Linear layers of a given model.
@@ -80,22 +84,30 @@ def insert_awq_observer_(
80
84
81
85
def replace_with_observer (layer ):
82
86
# creates observer and replaces linear layers with AWQObservedLinear layers
83
- observer = AWQObserver (
84
- layer .weight ,
85
- layer .bias ,
86
- quantization_granularity ,
87
- mapping_type ,
88
- quant_dtype ,
89
- n_validation_examples ,
90
- validation_sequence_len ,
91
- scale_search_space_size ,
92
- preserve_zero = preserve_zero ,
93
- zero_point_domain = zero_point_domain ,
94
- zero_point_dtype = zero_point_dtype ,
95
- quant_min = quant_min ,
96
- quant_max = quant_max ,
97
- eps = eps ,
98
- )
87
+ if base_config is None :
88
+ observer = AWQObserver (
89
+ layer .weight ,
90
+ layer .bias ,
91
+ quantization_granularity ,
92
+ mapping_type ,
93
+ quant_dtype ,
94
+ n_validation_examples ,
95
+ validation_sequence_len ,
96
+ scale_search_space_size ,
97
+ preserve_zero = preserve_zero ,
98
+ zero_point_domain = zero_point_domain ,
99
+ zero_point_dtype = zero_point_dtype ,
100
+ quant_min = quant_min ,
101
+ quant_max = quant_max ,
102
+ eps = eps ,
103
+ )
104
+ else :
105
+ observer = AWQObserver2 (
106
+ layer .weight ,
107
+ layer .bias ,
108
+ base_config ,
109
+ scale_search_space_size ,
110
+ )
99
111
return AWQObservedLinear .from_float (layer , observer )
100
112
101
113
_replace_with_custom_fn_if_matches_filter (model , replace_with_observer , _is_linear )
@@ -194,3 +206,97 @@ def _awq_uintx_transform(
194
206
linear .extra_repr = types .MethodType (_linear_extra_repr , module )
195
207
linear .bias = observed_linear .bias
196
208
return linear
209
+
210
+
211
+ @dataclass
212
+ class AWQConfig (AOBaseConfig ):
213
+ """
214
+ Configuration for quantizing linear layers when passed into quantize_()
215
+
216
+ Args:
217
+ base_config (AOBaseConfig): The quantization config that we can apply awq on top of, e.g. 8da4w, int4 weight only
218
+ step (str): a string of "prepare", "convert" or "load" indicating the step of AWQ process
219
+ prepare: insert AWQ Observers to linear
220
+ convert: convert the observed linear modules to linear modules with awq quantized weights
221
+ load: convert the floating point model to a dummy awq quantized model
222
+ example_input_shape (Optional[List[int]])): This is used for load step to initialize a random example input
223
+ scale_search_space_size (int): the number of scales to search for
224
+ set_inductor_config: if True, adjusts `torchinductor` settings to recommended values.
225
+ """
226
+
227
+ base_config : AOBaseConfig
228
+ step : str
229
+ example_input_shape : Optional [List [int ]] = None
230
+ scale_search_space_size : int = 20
231
+ set_inductor_config : bool = True
232
+
233
+ def __post_init__ (self ):
234
+ OPTIONS = ["prepare" , "convert" , "load" ]
235
+ assert self .step in OPTIONS , f"Only { OPTIONS } are supported, got { self .step } "
236
+
237
+
238
+ @register_quantize_module_handler (AWQConfig )
239
+ def _awq_transform (
240
+ module : torch .nn .Module ,
241
+ config : AWQUIntXConfig ,
242
+ ) -> torch .nn .Module :
243
+ if config .set_inductor_config :
244
+ torchao .quantization .utils .recommended_inductor_config_setter ()
245
+
246
+ step = config .step
247
+ scale_search_space_size = config .scale_search_space_size
248
+ observed_linear = None
249
+ base_config = config .base_config
250
+
251
+ if step == "prepare" :
252
+ observer = AWQObserver2 (
253
+ module .weight ,
254
+ module .bias ,
255
+ base_config ,
256
+ scale_search_space_size ,
257
+ )
258
+ return AWQObservedLinear .from_float (module , observer )
259
+ elif step == "load" :
260
+ # loading from pre-quantized checkpoint
261
+ observer = AWQObserver2 (
262
+ module .weight ,
263
+ module .bias ,
264
+ base_config ,
265
+ scale_search_space_size ,
266
+ )
267
+ observed_linear = AWQObservedLinear .from_float (module , observer )
268
+ assert config .example_input_shape is not None , (
269
+ "When step is load, we expect example_input_shape to be specified as well"
270
+ )
271
+ example_input = torch .randn (
272
+ config .example_input_shape ,
273
+ device = module .weight .device ,
274
+ dtype = module .weight .dtype ,
275
+ )
276
+ observed_linear (example_input )
277
+ else :
278
+ if not isinstance (module , AWQObservedLinear ):
279
+ print (f"convert: module is not AWQObservedLinear, skipping: { type (module )} " )
280
+ return module
281
+ observed_linear = module
282
+
283
+ assert observed_linear is not None
284
+ equalization_scale = observed_linear .act_obs .calculate_qparams ()
285
+
286
+ base_config_handler = _QUANTIZE_CONFIG_HANDLER [type (config .base_config )]
287
+ dummy_mod = DummyModule (observed_linear .weight * equalization_scale )
288
+ quant_mod = base_config_handler (dummy_mod , config .base_config )
289
+ qw = quant_mod .weight
290
+ qw = to_weight_tensor_with_linear_activation_scale_metadata (qw , equalization_scale )
291
+
292
+ linear = torch .nn .Linear (
293
+ observed_linear .in_features ,
294
+ observed_linear .out_features ,
295
+ observed_linear .bias != None ,
296
+ device = observed_linear .weight .device ,
297
+ dtype = observed_linear .weight .dtype ,
298
+ )
299
+ linear .weight = torch .nn .Parameter (qw , requires_grad = False )
300
+ linear .extra_repr = types .MethodType (_linear_extra_repr , linear )
301
+ linear .bias = observed_linear .bias
302
+ return linear
0 commit comments