Skip to content

Commit 7446433

Browse files
authored
Autoquant v2 initial version (#1240)
* Autoquant v2 initial version Summary: We refactored the v1 to do benchmark for subgraphs of (prev_op -> linear -> post_op) in order to get more accurate estimation of timing. One issue here is now we need to care about batch size of the subgraph, so we'd need the batch size dimension to use symbolic shape, seems that it does not have good support on torch.compile right now More improvements: * current batch size adjustment code is hardcoded to work for llama model, need to think of a way to generalize it * using canonicalized subgraph as key for the cache to reduce the number of times we need to do benchmarking * add accuracy sanity checks Test Plan: Testing with torchao/_models/llama/generate.py ``` python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant_v2-int4 ``` Reviewers: Subscribers: Tasks: Tags: * tested on llama2 and sam * ruff * ruff * import * cleanup * more ruff * ruff * ruff format * rename autoquant v2 * cleanup * ruff * move to prototype folder * remove prototype import * calibration_seq_length
1 parent ca52cdc commit 7446433

File tree

5 files changed

+2086
-18
lines changed

5 files changed

+2086
-18
lines changed

torchao/_models/llama/generate.py

Lines changed: 86 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -205,28 +205,31 @@ def main(
205205

206206

207207
if quantization:
208-
from torchao.quantization.quant_api import (
208+
from torchao.quantization import (
209209
quantize_,
210+
autoquant,
210211
int8_weight_only,
211212
int8_dynamic_activation_int8_weight,
212213
int4_weight_only,
213214
int8_dynamic_activation_int4_weight,
214215
fpx_weight_only,
215216
uintx_weight_only,
216-
autoquant,
217217
float8_weight_only,
218218
float8_dynamic_activation_float8_weight,
219219
)
220+
from torchao.prototype.quantization.autoquant_v2 import autoquant_v2
221+
from torchao.utils import unwrap_tensor_subclass
222+
220223
from torchao.quantization.granularity import PerTensor, PerRow
221224
from torchao.utils import unwrap_tensor_subclass
222225
if "spinquant" in quantization:
223226
from torchao.prototype.spinquant import apply_spinquant
224227
apply_spinquant(model)
225228
if "int8wo" in quantization:
226229
quantize_(model, int8_weight_only())
227-
if "int8dq" in quantization:
230+
elif "int8dq" in quantization:
228231
quantize_(model, int8_dynamic_activation_int8_weight())
229-
if "int4wo" in quantization:
232+
elif "int4wo" in quantization:
230233
if "hqq" in quantization:
231234
use_hqq=True
232235
else:
@@ -246,14 +249,14 @@ def main(
246249
layout=MarlinQQQLayout(),
247250
),
248251
)
249-
else:
252+
else:
250253
from torchao.dtypes import MarlinSparseLayout
251254
quantize_(model, int4_weight_only(layout=MarlinSparseLayout()))
252255
if "fp6" in quantization:
253256
quantize_(model, fpx_weight_only(3, 2))
254-
if "embed-int8wo" in quantization:
257+
elif "embed-int8wo" in quantization:
255258
quantize_(model, int8_weight_only(group_size=64), filter_fn=lambda x, *args: isinstance(x, torch.nn.Embedding))
256-
if quantization.startswith("awq"):
259+
elif quantization.startswith("awq"):
257260
from torchao._models._eval import TransformerEvalWrapper
258261
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3
259262
from torchao.prototype.awq.example import get_calib_dataset
@@ -274,13 +277,13 @@ def main(
274277
input_prep_func=prepare_inputs_for_model,
275278
device=device,
276279
).run_eval(
277-
tasks=['wikitext'],
280+
tasks=['wikitext'],
278281
limit=1,
279282
)
280283
is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear)
281284
use_hqq = "hqq" in quantization
282285
quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size, use_hqq=use_hqq), is_observed_linear)
283-
if "uintx" in quantization:
286+
elif "uintx" in quantization:
284287
# uintx-nbits-group_size, e.g. "uintx-2-64"
285288
if "hqq" in quantization:
286289
# uintx-nbits-group_size-hqq
@@ -294,9 +297,9 @@ def main(
294297
dtype = _NBITS_TO_DTYPE[nbits]
295298
group_size = int(_quant_args[2])
296299
quantize_(model, uintx_weight_only(dtype, group_size, use_hqq=use_hqq))
297-
if "float8wo" in quantization:
300+
elif "float8wo" in quantization:
298301
quantize_(model, float8_weight_only())
299-
if "float8dq" in quantization:
302+
elif "float8dq" in quantization:
300303
granularity = str(quantization.split("-")[-1])
301304
if granularity=="tensor":
302305
granularity = PerTensor()
@@ -305,13 +308,79 @@ def main(
305308
else:
306309
granularity = PerTensor()
307310
quantize_(model, float8_dynamic_activation_float8_weight(granularity=granularity))
308-
if "autoquant" in quantization:
311+
elif "autoquant_v2" in quantization:
312+
from torchao._models._eval import InputRecorder
313+
from torchao._models.llama.model import prepare_inputs_for_model
314+
315+
calibration_seq_length = 256
316+
calibration_limit = 1
317+
inputs = InputRecorder(
318+
tokenizer,
319+
calibration_seq_length,
320+
prepare_inputs_for_model,
321+
False, # pad_calibration_inputs
322+
model.config.vocab_size,
323+
device="cuda"
324+
).record_inputs(
325+
["wikitext"],
326+
1,
327+
).get_inputs()[0].values[0]
328+
inputs = prepare_inputs_for_model(inputs)
329+
with torch.device("cuda"):
330+
model.setup_caches(
331+
max_batch_size=1, max_seq_length=calibration_seq_length
332+
)
333+
334+
if "autoquant_v2-int4" == quantization:
335+
model = autoquant_v2(model, manual=True, qtensor_class_list = torchao.prototype.quantization.autoquant_v2.DEFAULT_INT4_AUTOQUANT_CLASS_LIST, example_input=inputs)
336+
elif "autoquant_v2-float8" == quantization:
337+
model = autoquant_v2(model, manual=True, qtensor_class_list = torchao.prototype.quantization.autoquant_v2.OTHER_AUTOQUANT_CLASS_LIST, example_input=inputs)
338+
else:
339+
model = autoquant_v2(model, manual=True, example_input=inputs)
340+
341+
print("running generate")
342+
generate(
343+
model,
344+
encode_tokens(tokenizer, prompt, bos=True, device=device),
345+
max_new_tokens,
346+
batch_size,
347+
interactive=False,
348+
temperature=temperature,
349+
top_k=top_k,
350+
)
351+
352+
print("running finalize autoquant")
353+
# do autoquantization
354+
model.finalize_autoquant()
355+
elif "autoquant" in quantization:
356+
from torchao._models._eval import InputRecorder
357+
from torchao._models.llama.model import prepare_inputs_for_model
358+
359+
calibration_seq_length = 256
360+
calibration_limit = 1
361+
inputs = InputRecorder(
362+
tokenizer,
363+
calibration_seq_length,
364+
prepare_inputs_for_model,
365+
False, # pad_calibration_inputs
366+
model.config.vocab_size,
367+
device="cuda"
368+
).record_inputs(
369+
["wikitext"],
370+
1,
371+
).get_inputs()[0].values[0]
372+
inputs = prepare_inputs_for_model(inputs)
373+
with torch.device("cuda"):
374+
model.setup_caches(
375+
max_batch_size=1, max_seq_length=calibration_seq_length
376+
)
377+
309378
if "autoquant-int4" == quantization:
310-
model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST)
379+
model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST, example_input=inputs)
311380
elif "autoquant-float8" == quantization:
312-
model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST)
381+
model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST, example_input=inputs)
313382
else:
314-
model = autoquant(model, manual=True)
383+
model = autoquant(model, manual=True, example_input=inputs)
315384

316385
generate(
317386
model,
@@ -325,6 +394,7 @@ def main(
325394

326395
# do autoquantization
327396
model.finalize_autoquant()
397+
328398
else:
329399
if not TORCH_VERSION_AT_LEAST_2_5:
330400
unwrap_tensor_subclass(model)
@@ -489,7 +559,7 @@ def callback(x):
489559
parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.')
490560
parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.')
491561
parser.add_argument('--checkpoint_path', type=Path, default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Model checkpoint path.')
492-
parser.add_argument('-q', '--quantization', type=str,
562+
parser.add_argument('-q', '--quantization', type=str,
493563
help=(
494564
'Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-hqq, autoquant, '
495565
+'autoquant-int4, autoquant-float8, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin, spinquant, '

torchao/_models/sam/README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,7 @@ sh setup.sh
1717

1818
Finally, you can run benchmarks with
1919
```
20-
sh benchmark_sam.sh
20+
sh benchmark.sh
2121
```
22+
23+
You can check out the result in results.csv

torchao/_models/sam/eval_combo.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,14 @@
99
import time
1010
import resource
1111

12-
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight, int4_weight_only
12+
import torchao
13+
from torchao.quantization import (
14+
quantize_,
15+
int8_dynamic_activation_int8_weight,
16+
int4_weight_only,
17+
autoquant,
18+
)
19+
from torchao.prototype.quantization.autoquant_v2 import autoquant_v2
1320
from torchao.sparsity import sparsify_, apply_fake_sparsity, semi_sparse_weight
1421
from torchao.dtypes import SemiSparseLayout, MarlinSparseLayout
1522
from torchao.utils import unwrap_tensor_subclass
@@ -336,6 +343,29 @@ def mlp_only(mod, name):
336343
mlp_lin2_only)
337344
if not TORCH_VERSION_AT_LEAST_2_5:
338345
predictor.model.image_encoder = unwrap_tensor_subclass(predictor.model.image_encoder)
346+
347+
elif compress is not None and "autoquant_v2" in compress:
348+
example_input = torch.randn(1, 3, 1024, 1024, dtype=torch.bfloat16, device=device)
349+
if "autoquant_v2-int4" == compress:
350+
autoquant_v2(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.prototype.quantization.autoquant_v2.DEFAULT_INT4_AUTOQUANT_CLASS_LIST)
351+
elif "autoquant_v2-float8" == compress:
352+
autoquant_v2(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.prototype.quantization.autoquant_v2.OTHER_AUTOQUANT_CLASS_LIST)
353+
else:
354+
autoquant_v2(predictor.model.image_encoder, example_input=example_input, manual=True)
355+
356+
predictor.model.image_encoder(example_input)
357+
predictor.model.image_encoder.finalize_autoquant()
358+
359+
elif compress is not None and "autoquant" in compress:
360+
example_input = torch.randn(1, 3, 1024, 1024, dtype=torch.bfloat16, device=device)
361+
if "autoquant-int4" == compress:
362+
autoquant(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST)
363+
elif "autoquant-float8" == compress:
364+
autoquant(predictor.model.image_encoder, example_input=example_input, manual=True, qtensor_class_list=torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST)
365+
else:
366+
autoquant(predictor.model.image_encoder, example_input=example_input, manual=True)
367+
predictor.model.image_encoder(example_input)
368+
predictor.model.image_encoder.finalize_autoquant()
339369
else:
340370
assert compress is None, f"Unsupported compress mode {compress}"
341371

0 commit comments

Comments
 (0)