@@ -205,28 +205,31 @@ def main(
205
205
206
206
207
207
if quantization :
208
- from torchao .quantization . quant_api import (
208
+ from torchao .quantization import (
209
209
quantize_ ,
210
+ autoquant ,
210
211
int8_weight_only ,
211
212
int8_dynamic_activation_int8_weight ,
212
213
int4_weight_only ,
213
214
int8_dynamic_activation_int4_weight ,
214
215
fpx_weight_only ,
215
216
uintx_weight_only ,
216
- autoquant ,
217
217
float8_weight_only ,
218
218
float8_dynamic_activation_float8_weight ,
219
219
)
220
+ from torchao .prototype .quantization .autoquant_v2 import autoquant_v2
221
+ from torchao .utils import unwrap_tensor_subclass
222
+
220
223
from torchao .quantization .granularity import PerTensor , PerRow
221
224
from torchao .utils import unwrap_tensor_subclass
222
225
if "spinquant" in quantization :
223
226
from torchao .prototype .spinquant import apply_spinquant
224
227
apply_spinquant (model )
225
228
if "int8wo" in quantization :
226
229
quantize_ (model , int8_weight_only ())
227
- if "int8dq" in quantization :
230
+ elif "int8dq" in quantization :
228
231
quantize_ (model , int8_dynamic_activation_int8_weight ())
229
- if "int4wo" in quantization :
232
+ elif "int4wo" in quantization :
230
233
if "hqq" in quantization :
231
234
use_hqq = True
232
235
else :
@@ -246,14 +249,14 @@ def main(
246
249
layout = MarlinQQQLayout (),
247
250
),
248
251
)
249
- else :
252
+ else :
250
253
from torchao .dtypes import MarlinSparseLayout
251
254
quantize_ (model , int4_weight_only (layout = MarlinSparseLayout ()))
252
255
if "fp6" in quantization :
253
256
quantize_ (model , fpx_weight_only (3 , 2 ))
254
- if "embed-int8wo" in quantization :
257
+ elif "embed-int8wo" in quantization :
255
258
quantize_ (model , int8_weight_only (group_size = 64 ), filter_fn = lambda x , * args : isinstance (x , torch .nn .Embedding ))
256
- if quantization .startswith ("awq" ):
259
+ elif quantization .startswith ("awq" ):
257
260
from torchao ._models ._eval import TransformerEvalWrapper
258
261
from torchao .utils import TORCH_VERSION_AT_LEAST_2_3
259
262
from torchao .prototype .awq .example import get_calib_dataset
@@ -274,13 +277,13 @@ def main(
274
277
input_prep_func = prepare_inputs_for_model ,
275
278
device = device ,
276
279
).run_eval (
277
- tasks = ['wikitext' ],
280
+ tasks = ['wikitext' ],
278
281
limit = 1 ,
279
282
)
280
283
is_observed_linear = lambda m , fqn : isinstance (m , AWQObservedLinear )
281
284
use_hqq = "hqq" in quantization
282
285
quantize_ (model , awq_uintx (quant_dtype = quant_dtype , group_size = group_size , use_hqq = use_hqq ), is_observed_linear )
283
- if "uintx" in quantization :
286
+ elif "uintx" in quantization :
284
287
# uintx-nbits-group_size, e.g. "uintx-2-64"
285
288
if "hqq" in quantization :
286
289
# uintx-nbits-group_size-hqq
@@ -294,9 +297,9 @@ def main(
294
297
dtype = _NBITS_TO_DTYPE [nbits ]
295
298
group_size = int (_quant_args [2 ])
296
299
quantize_ (model , uintx_weight_only (dtype , group_size , use_hqq = use_hqq ))
297
- if "float8wo" in quantization :
300
+ elif "float8wo" in quantization :
298
301
quantize_ (model , float8_weight_only ())
299
- if "float8dq" in quantization :
302
+ elif "float8dq" in quantization :
300
303
granularity = str (quantization .split ("-" )[- 1 ])
301
304
if granularity == "tensor" :
302
305
granularity = PerTensor ()
@@ -305,13 +308,79 @@ def main(
305
308
else :
306
309
granularity = PerTensor ()
307
310
quantize_ (model , float8_dynamic_activation_float8_weight (granularity = granularity ))
308
- if "autoquant" in quantization :
311
+ elif "autoquant_v2" in quantization :
312
+ from torchao ._models ._eval import InputRecorder
313
+ from torchao ._models .llama .model import prepare_inputs_for_model
314
+
315
+ calibration_seq_length = 256
316
+ calibration_limit = 1
317
+ inputs = InputRecorder (
318
+ tokenizer ,
319
+ calibration_seq_length ,
320
+ prepare_inputs_for_model ,
321
+ False , # pad_calibration_inputs
322
+ model .config .vocab_size ,
323
+ device = "cuda"
324
+ ).record_inputs (
325
+ ["wikitext" ],
326
+ 1 ,
327
+ ).get_inputs ()[0 ].values [0 ]
328
+ inputs = prepare_inputs_for_model (inputs )
329
+ with torch .device ("cuda" ):
330
+ model .setup_caches (
331
+ max_batch_size = 1 , max_seq_length = calibration_seq_length
332
+ )
333
+
334
+ if "autoquant_v2-int4" == quantization :
335
+ model = autoquant_v2 (model , manual = True , qtensor_class_list = torchao .prototype .quantization .autoquant_v2 .DEFAULT_INT4_AUTOQUANT_CLASS_LIST , example_input = inputs )
336
+ elif "autoquant_v2-float8" == quantization :
337
+ model = autoquant_v2 (model , manual = True , qtensor_class_list = torchao .prototype .quantization .autoquant_v2 .OTHER_AUTOQUANT_CLASS_LIST , example_input = inputs )
338
+ else :
339
+ model = autoquant_v2 (model , manual = True , example_input = inputs )
340
+
341
+ print ("running generate" )
342
+ generate (
343
+ model ,
344
+ encode_tokens (tokenizer , prompt , bos = True , device = device ),
345
+ max_new_tokens ,
346
+ batch_size ,
347
+ interactive = False ,
348
+ temperature = temperature ,
349
+ top_k = top_k ,
350
+ )
351
+
352
+ print ("running finalize autoquant" )
353
+ # do autoquantization
354
+ model .finalize_autoquant ()
355
+ elif "autoquant" in quantization :
356
+ from torchao ._models ._eval import InputRecorder
357
+ from torchao ._models .llama .model import prepare_inputs_for_model
358
+
359
+ calibration_seq_length = 256
360
+ calibration_limit = 1
361
+ inputs = InputRecorder (
362
+ tokenizer ,
363
+ calibration_seq_length ,
364
+ prepare_inputs_for_model ,
365
+ False , # pad_calibration_inputs
366
+ model .config .vocab_size ,
367
+ device = "cuda"
368
+ ).record_inputs (
369
+ ["wikitext" ],
370
+ 1 ,
371
+ ).get_inputs ()[0 ].values [0 ]
372
+ inputs = prepare_inputs_for_model (inputs )
373
+ with torch .device ("cuda" ):
374
+ model .setup_caches (
375
+ max_batch_size = 1 , max_seq_length = calibration_seq_length
376
+ )
377
+
309
378
if "autoquant-int4" == quantization :
310
- model = autoquant (model , manual = True , qtensor_class_list = torchao .quantization .DEFAULT_INT4_AUTOQUANT_CLASS_LIST )
379
+ model = autoquant (model , manual = True , qtensor_class_list = torchao .quantization .DEFAULT_INT4_AUTOQUANT_CLASS_LIST , example_input = inputs )
311
380
elif "autoquant-float8" == quantization :
312
- model = autoquant (model , manual = True , qtensor_class_list = torchao .quantization .OTHER_AUTOQUANT_CLASS_LIST )
381
+ model = autoquant (model , manual = True , qtensor_class_list = torchao .quantization .OTHER_AUTOQUANT_CLASS_LIST , example_input = inputs )
313
382
else :
314
- model = autoquant (model , manual = True )
383
+ model = autoquant (model , manual = True , example_input = inputs )
315
384
316
385
generate (
317
386
model ,
@@ -325,6 +394,7 @@ def main(
325
394
326
395
# do autoquantization
327
396
model .finalize_autoquant ()
397
+
328
398
else :
329
399
if not TORCH_VERSION_AT_LEAST_2_5 :
330
400
unwrap_tensor_subclass (model )
@@ -489,7 +559,7 @@ def callback(x):
489
559
parser .add_argument ('--top_k' , type = int , default = 200 , help = 'Top-k for sampling.' )
490
560
parser .add_argument ('--temperature' , type = float , default = 0.8 , help = 'Temperature for sampling.' )
491
561
parser .add_argument ('--checkpoint_path' , type = Path , default = Path ("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth" ), help = 'Model checkpoint path.' )
492
- parser .add_argument ('-q' , '--quantization' , type = str ,
562
+ parser .add_argument ('-q' , '--quantization' , type = str ,
493
563
help = (
494
564
'Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-hqq, autoquant, '
495
565
+ 'autoquant-int4, autoquant-float8, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin, spinquant, '
0 commit comments