Skip to content

Commit b4d0768

Browse files
authored
README and benchmark improvements (#867)
README improvements Summary: quantization README: 1) added fp6 to benchmarks 2) rewrote autoquant section to give a higher level explanation before diving into the details 3) reordered affine quantization section to first show techniques then dive into details 4) added fp6 section 5) moved kv cache stuff to new section 6) added sparse-marlin section and removed sparse-marlin benchmark from top of README since we don't have a reasonable flow for users to use to apply it to their model without a pre-sparsified checkpoint. 7) added uintx section Benchmarks Changes: 1) added instructions for adding things to benchmarks so everything stays consistent (in llama benchmark README) 2) organized/ran benchmarks for uintx and fp6 and sparse-marlin 3) added evaluations.sh to mirror benchmarks.sh 4) added sparse-marlin to eval.py 5) fixed some generate.py logging bugs 6) improved generate help quantization help text 7) fixed some eval.py bugs with uintx 8) added marlin to eval 9) fixed eval help text sparsity readme: 1) added some details to sparsity Test Plan: benchmarks.sh evaluations.sh Reviewers: Subscribers: Tasks: Tags:
1 parent e283743 commit b4d0768

File tree

7 files changed

+188
-106
lines changed

7 files changed

+188
-106
lines changed

torchao/_models/llama/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,7 @@ To see how these techniques scale generally we've run `generate.py` with subsets
2727
| 32768 | 23.83 | 21.72 | 20.64 |
2828
| 65536 | 33.5 | 29.54 | 25.24 |
2929
| 131072 | 59.27 | 52.62 | 34.18 |
30+
31+
## Adding Benchmarks For New Techniques
32+
33+
If you want to add benchmarks that you think should be kept up to date, please try to keep the format consistent. For performance focused techniques (e.g. if they require fine-tuning or something else) add an option to run them in generate.py and an execution command in benchmarks.sh in the relevant section. If its a technique that's still in development, add it in the section for `OTHER BENCHMARKS` if there's a finalized api and you want those numbers in the main quantization README, add them in the `README BENCHMARKS` section. For accuracy focused techniques, add them in eval.py and evaluations.sh in a similar vein. Ideally techniques in the main readme will have both benchmarks and evaluations set up here so they can be monitored and reproduced easily.

torchao/_models/llama/benchmark_results.txt

Lines changed: 21 additions & 10 deletions
Large diffs are not rendered by default.

torchao/_models/llama/benchmarks.sh

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,28 @@
11
export CHECKPOINT_PATH=../../../checkpoints # path to checkpoints folder
22

3-
3+
# README BENCHMARKS
44
export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf
5-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt
6-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt
7-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt
8-
# in readme
95
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt
106
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt
117
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt
8+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization fp6 --write_result benchmark_results.txt --precision float16
129
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt
13-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt
1410
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant-int4 --write_result benchmark_results.txt
1511

16-
# auto-round w/ quant_lm_head
17-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoround
18-
# auto-round w/o quant_lm_head
19-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoround-cuda-0
2012

2113

2214
export MODEL_REPO=meta-llama/Meta-Llama-3-8B
23-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt
24-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt
25-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt
26-
# in readme
2715
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt
2816
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt
2917
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt
18+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization fp6 --write_result benchmark_results.txt --precision float16
3019
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt
31-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt
3220
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant-int4 --write_result benchmark_results.txt
33-
# sparse marlin (NOTE: float16)
34-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --precision float16 --write_result benchmark_results.txt
35-
# auto-round w/ quant_lm_head
36-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoround
37-
# auto-round w/o quant_lm_head
38-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoround-cuda-0
3921

4022

23+
# OTHER BENCHMARKS
4124

25+
# kv cache quantization
4226
export MODEL_REPO=meta-llama/Meta-Llama-3.1-8B
4327
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 8192
4428
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 8192 --kv_cache_quantization
@@ -55,3 +39,29 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --wr
5539
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 131072
5640
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 131072 --kv_cache_quantization
5741
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt --cache_size 131072 --kv_cache_quantization --linear_causal_mask
42+
43+
export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf
44+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt
45+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt
46+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt
47+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt
48+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization fp6 --write_result benchmark_results.txt
49+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --precision float16 --write_result benchmark_results.txt
50+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-4-64 --write_result benchmark_results.txt
51+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-2-8 --write_result benchmark_results.txt
52+
# TODO: this is an accuracy technique with same perf as int4, should be in evaluations instead of generate.py
53+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoround # auto-round w/o quant_lm_head
54+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoround-cuda-0 # auto-round w/o quant_lm_head
55+
56+
export MODEL_REPO=meta-llama/Meta-Llama-3-8B
57+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt
58+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt
59+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt
60+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt
61+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization fp6 --write_result benchmark_results.txt --precision float16
62+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --precision float16 --write_result benchmark_results.txt
63+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-4-64 --write_result benchmark_results.txt
64+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-2-8 --write_result benchmark_results.txt
65+
# TODO: this is an accuracy technique with same perf as int4, should be in evaluations instead of generate.py
66+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoround # auto-round w/o quant_lm_head
67+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoround-cuda-0 # auto-round w/o quant_lm_head

torchao/_models/llama/eval.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,11 @@ def run_evaluation(
4444
pad_calibration_inputs: Optional[bool] = False,
4545
):
4646
"""Runs the evaluation of a model using LM Eval."""
47-
47+
print(
48+
f"\nEvaluating model {checkpoint_path} on tasks: {tasks}, limit: {limit}, device: {device}, precision: {precision}, "
49+
+f"quantization: {quantization}, compile: {compile}, max_length: {max_length}, calibration_tasks: {calibration_tasks}, "
50+
+f"calibration_seq_length: {calibration_seq_length}, pad_calibration_inputs: {pad_calibration_inputs}\n"
51+
)
4852
torchao.quantization.utils.recommended_inductor_config_setter()
4953

5054
assert checkpoint_path.is_file(), checkpoint_path
@@ -73,27 +77,28 @@ def run_evaluation(
7377
quantize_(model, fpx_weight_only(3, 2))
7478
if "int4wo" in quantization and not "gptq" in quantization:
7579
if "hqq" in quantization:
76-
quantization = quantization[:-4]
7780
use_hqq = True
7881
else:
7982
use_hqq = False
80-
groupsize=int(quantization.split("-")[-1])
83+
groupsize=int(quantization.split("-")[1])
8184
assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}"
8285
quantize_(model.to(device), int4_weight_only(group_size=groupsize, use_hqq=use_hqq))
8386
if "uintx" in quantization:
8487
# uintx-nbits-groupsize
8588
# "uintx-2-64"
8689
if "hqq" in quantization:
8790
use_hqq = True
88-
quantization = quantization[:-4]
8991
else:
9092
use_hqq = False
9193
_quant_args = quantization.split("-")
92-
nbits = int(_quant_args[0])
94+
nbits = int(_quant_args[1])
9395
_NBITS_TO_DTYPE = {1: torch.uint1, 2: torch.uint2, 3: torch.uint3, 4: torch.uint4, 5: torch.uint5, 6: torch.uint6, 7: torch.uint7, 8: torch.uint8}
9496
dtype = _NBITS_TO_DTYPE[nbits]
95-
group_size = int(_quant_args[1])
97+
group_size = int(_quant_args[2])
9698
quantize_(model, uintx_weight_only(dtype, group_size, use_hqq=use_hqq))
99+
if "marlin" in quantization:
100+
from torchao.dtypes import MarlinSparseLayoutType
101+
quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
97102
if "int4wo" in quantization and "gptq" in quantization:
98103
groupsize=int(quantization.split("-")[-2])
99104
assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}"
@@ -140,7 +145,12 @@ def run_evaluation(
140145
parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate')
141146
parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use')
142147
parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation')
143-
parser.add_argument("-q", "--quantization", type=str, help="Which quantization techniques to apply: int8dq, int8wo, int4wo-<groupsize>, int4wo-<groupsize>-gptq, int4wo-<groupsize>-hqq, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq")
148+
parser.add_argument('-q', '--quantization', type=str,
149+
help=(
150+
'Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-gptq, autoquant, autoquant-int4, '+
151+
'int4wo-<groupsize>-hqq, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin'
152+
)
153+
)
144154
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
145155
parser.add_argument('--max_length', type=int, default=None, help='Length of text to process at one time')
146156
parser.add_argument('--calibration_tasks', type=str, nargs='+', default=['wikitext'], help='tasks to do gptq calibration on, if doing gptq')

torchao/_models/llama/generate.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -219,10 +219,9 @@ def main(
219219
if "int4wo" in quantization:
220220
if "hqq" in quantization:
221221
use_hqq=True
222-
quantization = quantization[:-4]
223222
else:
224223
use_hqq=False
225-
groupsize=int(quantization.split("-")[-1])
224+
groupsize=int(quantization.split("-")[1])
226225
assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}"
227226
quantize_(model, int4_weight_only(group_size=groupsize))
228227
if "marlin" in quantization:
@@ -273,22 +272,22 @@ def main(
273272
)
274273
model.to(device)
275274
model.reset_caches()
276-
if "fp6" in quantization:
275+
# TODO this needs to be expanded to all of fpx so they can
276+
if "fp6" in quantization:
277277
quantize_(model, fpx_weight_only(3, 2))
278278
if "uintx" in quantization:
279279
# uintx-nbits-groupsize, e.g. "uintx-2-64"
280280
if "hqq" in quantization:
281281
# uintx-nbits-groupsize-hqq
282-
quantization = quantization[:-4]
283282
use_hqq = True
284283
else:
285284
use_hqq = False
286285
_quant_args = quantization.split("-")
287-
nbits = int(_quant_args[0])
286+
nbits = int(_quant_args[1])
288287
assert nbits >= 1 and nbits <= 8, "nbits must be 1 to 8"
289288
_NBITS_TO_DTYPE = {1: torch.uint1, 2: torch.uint2, 3: torch.uint3, 4: torch.uint4, 5: torch.uint5, 6: torch.uint6, 7: torch.uint7, 8: torch.uint8}
290289
dtype = _NBITS_TO_DTYPE[nbits]
291-
group_size = int(_quant_args[1])
290+
group_size = int(_quant_args[2])
292291
quantize_(model, uintx_weight_only(dtype, group_size, use_hqq=use_hqq))
293292
if "autoquant" in quantization:
294293
if "autoquant-int4" == quantization:
@@ -459,7 +458,13 @@ def callback(x):
459458
parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.')
460459
parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.')
461460
parser.add_argument('--checkpoint_path', type=Path, default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Model checkpoint path.')
462-
parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-<groupsize>, autoquant, autoquant-int4, int4wo-<groupsize>-hqq, autoround-<model_device>-<quant_lm_head>-<iters>-<groupsize>-<batch_size>-<seqlen>-<nsamples>, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq')
461+
parser.add_argument('-q', '--quantization', type=str,
462+
help=(
463+
'Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-hqq, autoquant, '
464+
+'autoquant-int4, autoround-<model_device>-<quant_lm_head>-<iters>-<groupsize>-<batch_size>-<seqlen>-<nsamples>, '
465+
+'uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin'
466+
)
467+
)
463468
parser.add_argument('--kv_cache_quantization', action='store_true', help='Whether to quantize the KV cache')
464469
parser.add_argument('--cache_size', type=int, default=None, help='Force size of cache to be a certain number of tokens, if not set, will use max_new_tokens+prompt_size')
465470
parser.add_argument('--linear_causal_mask', action='store_true', help='Whether to use the memory efficient, but slightly less fast, linear causal mask (important for long context lengths)')

0 commit comments

Comments
 (0)