20
20
annotate_matmul_16a8w ,
21
21
)
22
22
23
+ from executorch .backends .qualcomm .quantizer .observers .per_channel_param_observer import (
24
+ PerChannelParamObserver ,
25
+ )
26
+ from executorch .backends .qualcomm .quantizer .qconfig import (
27
+ _derived_bias_quant_spec ,
28
+ QuantizationConfig ,
29
+ )
30
+
23
31
from executorch .backends .qualcomm .quantizer .quantizer import QuantDtype
24
32
from executorch .backends .qualcomm .utils .utils import convert_linear_to_conv2d
25
33
47
55
48
56
from torchao .quantization .pt2e import MinMaxObserver
49
57
from torchao .quantization .pt2e .quantize_pt2e import convert_pt2e , prepare_pt2e
58
+ from torchao .quantization .pt2e .quantizer import QuantizationSpec
59
+
50
60
51
61
sys .setrecursionlimit (4096 )
52
62
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
@@ -78,6 +88,33 @@ def forward(
78
88
return self .model .forward (tokens , self .atten_mask )
79
89
80
90
91
+ def add_mse_weight_observer (quant_dtype , quantizer ):
92
+ weight_dtype = (
93
+ torch .int4
94
+ if quant_dtype in (QuantDtype .use_16a4w , QuantDtype .use_16a4w_block )
95
+ else torch .int8
96
+ )
97
+ per_channel_q_config = quantizer .default_quant_config .quant_config
98
+ weight_qspec = QuantizationSpec (
99
+ dtype = torch .int8 if weight_dtype == torch .int4 else weight_dtype ,
100
+ quant_min = (
101
+ - 7 if weight_dtype == torch .int4 else torch .iinfo (weight_dtype ).min + 1
102
+ ),
103
+ quant_max = (7 if weight_dtype == torch .int4 else torch .iinfo (weight_dtype ).max ),
104
+ qscheme = torch .per_channel_symmetric ,
105
+ ch_axis = 0 ,
106
+ observer_or_fake_quant_ctr = PerChannelParamObserver .with_args (
107
+ ** {"steps" : 200 , "use_mse" : True }
108
+ ),
109
+ )
110
+ quantizer .default_quant_config .per_channel_quant_config = QuantizationConfig (
111
+ input_activation = per_channel_q_config .input_activation ,
112
+ output_activation = per_channel_q_config .output_activation ,
113
+ weight = weight_qspec ,
114
+ bias = _derived_bias_quant_spec ,
115
+ )
116
+
117
+
81
118
def gen_eval_wrapper (model_name , args ):
82
119
tokenizer = get_tokenizer (args .tokenizer_path )
83
120
with open (args .params ) as f :
@@ -142,13 +179,13 @@ def permute(w, heads):
142
179
if getattr (layer .feed_forward , "prepare_feedfoward_conv" , None ):
143
180
layer .feed_forward .prepare_feedfoward_conv ()
144
181
145
- model .to (dtype = torch .bfloat16 )
182
+ model .to (dtype = torch .float )
146
183
model .to (device = args .device )
147
184
148
185
tokens , atten_mask = model .get_example_inputs (use_kv_cache = False )
149
186
tokens = tokens .to (device = args .device )
150
187
atten_mask = atten_mask .to (device = args .device )
151
- atten_mask = atten_mask .to (dtype = torch .bfloat16 )
188
+ atten_mask = atten_mask .to (dtype = torch .float )
152
189
inputs = (tokens , atten_mask )
153
190
154
191
if args .embedding_quantize :
@@ -174,7 +211,8 @@ def permute(w, heads):
174
211
)
175
212
quantizer .add_custom_quant_annotations (custom_annotations )
176
213
177
- model .has_quant_io = True
214
+ if args .range_setting == "mse_weight" :
215
+ add_mse_weight_observer (quant_dtype , quantizer )
178
216
179
217
with torch .no_grad ():
180
218
model = torch .export .export (model , inputs , strict = True ).module ()
@@ -245,6 +283,23 @@ def main() -> None:
245
283
torch .manual_seed (seed )
246
284
modelname = "llama2"
247
285
parser = build_args_parser ()
286
+ parser .add_argument (
287
+ "-P" ,
288
+ "--ptq" ,
289
+ help = "If specified, will do PTQ quantization. default is 16bits activation and 4bits weight. Support 8a8w, 16a4w and 16a4w_block." ,
290
+ type = str ,
291
+ )
292
+ parser .add_argument (
293
+ "--range_setting" ,
294
+ help = "Choose which range setting method (e.g. mse_weight). If not specified, will do minmax for weights and activations" ,
295
+ type = str ,
296
+ )
297
+ parser .add_argument (
298
+ "--limit" ,
299
+ help = "the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples" ,
300
+ type = str ,
301
+ )
302
+
248
303
args = parser .parse_args ()
249
304
args .llama_model = "llama3_2"
250
305
# Overrides this arg, because evaluation requires full logits.
@@ -257,15 +312,9 @@ def main() -> None:
257
312
args .use_kv_cache = False
258
313
args .prefill_ar_len = args .max_seq_length
259
314
260
- # To do fewer samples for faster evaluation
261
- args .limit = 0.1
262
- # args.samples = {'wikitext': list(range(1))}
263
-
264
315
args .device = "cuda" if torch .cuda .is_available () else "cpu"
265
316
torch .set_default_device (args .device )
266
317
267
- args .ptq = "8a8w"
268
-
269
318
eval_llama (modelname , args )
270
319
271
320
0 commit comments