Skip to content

Commit 6dfba04

Browse files
authored
Eval hf models using lm_eval (#2179)
1 parent 396a567 commit 6dfba04

File tree

6 files changed

+274
-2
lines changed

6 files changed

+274
-2
lines changed

benchmarks/_models/eval_hf_models.py

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import argparse
8+
import itertools
9+
import subprocess
10+
11+
import torch
12+
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
13+
14+
from benchmarks.microbenchmarks.utils import string_to_config
15+
from torchao.quantization import * # noqa: F401, F403
16+
from torchao.quantization.utils import _lm_eval_available
17+
18+
19+
def quantize_model_and_save(model_id, quant_config, output_dir="results"):
20+
"""Quantize the model and save it to the output directory."""
21+
print("Quantizing model with config: ", quant_config)
22+
if quant_config is None:
23+
quantization_config = None
24+
else:
25+
quantization_config = TorchAoConfig(quant_type=quant_config)
26+
quantized_model = AutoModelForCausalLM.from_pretrained(
27+
model_id,
28+
device_map="auto",
29+
torch_dtype=torch.bfloat16,
30+
quantization_config=quantization_config,
31+
)
32+
tokenizer = AutoTokenizer.from_pretrained(model_id)
33+
quantized_model.save_pretrained(output_dir, safe_serialization=False)
34+
tokenizer.save_pretrained(output_dir, safe_serialization=False)
35+
return quantized_model, tokenizer
36+
37+
38+
def run_lm_eval(model_dir, tasks_list=["hellaswag"], device="cuda:0", batch_size=8):
39+
"""Run the lm_eval command using subprocess."""
40+
tasks_str = ",".join(tasks_list)
41+
command = [
42+
"lm_eval",
43+
"--model",
44+
"hf",
45+
"--model_args",
46+
f"pretrained={model_dir}",
47+
"--tasks",
48+
f"{tasks_str}",
49+
"--device",
50+
f"{device}",
51+
"--batch_size",
52+
f"{batch_size}",
53+
]
54+
subprocess.run(command, check=True)
55+
56+
57+
def get_model_size_in_bytes(model, ignore_embeddings=False):
58+
"""
59+
Returns the model size in bytes. The option to ignore embeddings
60+
is useful for models with disproportionately large embeddings compared
61+
to other model parameters that get quantized/sparsified.
62+
"""
63+
64+
def flat_size(tensor):
65+
if hasattr(tensor, "__tensor_flatten__"):
66+
size = 0
67+
# 0th element is a list of attributes that
68+
# hold tensors
69+
for attr_name in tensor.__tensor_flatten__()[0]:
70+
sub_tensor = getattr(tensor, attr_name)
71+
size += flat_size(sub_tensor)
72+
return size
73+
else:
74+
return tensor.numel() * tensor.element_size()
75+
76+
model_size = 0
77+
for _, child in model.named_children():
78+
if not (isinstance(child, torch.nn.Embedding) and ignore_embeddings):
79+
for p in itertools.chain(
80+
child.parameters(recurse=False), child.buffers(recurse=False)
81+
):
82+
model_size += flat_size(p)
83+
model_size += get_model_size_in_bytes(child, ignore_embeddings)
84+
return model_size
85+
86+
87+
def run(
88+
model_id,
89+
quantization,
90+
tasks,
91+
device,
92+
batch_size,
93+
model_output_dir,
94+
):
95+
print(f"Running model {model_id} with quantization {quantization}")
96+
model_name = model_id.split("/")[-1]
97+
model_output_dir = f"quantized_model/{model_name}-{quantization}"
98+
quant_config = string_to_config(quantization, None)
99+
quantized_model, tokenizer = quantize_model_and_save(
100+
model_id, quant_config=quant_config, output_dir=model_output_dir
101+
)
102+
print("Compiling model ....")
103+
quantized_model = torch.compile(
104+
quantized_model,
105+
mode="reduce-overhead",
106+
fullgraph=True,
107+
)
108+
run_lm_eval(
109+
model_output_dir, tasks_list=tasks, device=device, batch_size=batch_size
110+
)
111+
model_size = get_model_size_in_bytes(quantized_model, ignore_embeddings=True) / 1e9
112+
print(f"Model size: {model_size:.2f} GB")
113+
114+
115+
if __name__ == "__main__":
116+
if not _lm_eval_available:
117+
print(
118+
"lm_eval is required to run this script. Please install it using pip install lm-eval."
119+
)
120+
exit(0)
121+
122+
# Set up argument parser
123+
parser = argparse.ArgumentParser(
124+
description="Quantize a model and evaluate its throughput."
125+
)
126+
parser.add_argument(
127+
"--model_id",
128+
type=str,
129+
default="meta-llama/Llama-3.1-8B",
130+
help="The model ID to use.",
131+
)
132+
parser.add_argument(
133+
"--quantization",
134+
type=str,
135+
default=None,
136+
help="The quantization method to use.",
137+
)
138+
parser.add_argument(
139+
"--tasks",
140+
nargs="+",
141+
type=str,
142+
default=["wikitext"],
143+
help="List of lm-eluther tasks to evaluate usage: --tasks task1 task2",
144+
)
145+
parser.add_argument(
146+
"--device", type=str, default="cuda:0", help="Device to run the model on."
147+
)
148+
parser.add_argument(
149+
"--batch_size", type=int, default=1, help="Batch size for lm_eval."
150+
)
151+
parser.add_argument(
152+
"--prompt",
153+
type=str,
154+
default="What are we having for dinner?",
155+
help="Prompt for model throughput evaluation.",
156+
)
157+
parser.add_argument(
158+
"--max_new_tokens",
159+
type=int,
160+
default=10,
161+
help="Max new tokens to generate for throughput evaluation.",
162+
)
163+
parser.add_argument(
164+
"--num_runs",
165+
type=int,
166+
default=5,
167+
help="Number of runs to average over for throughput evaluation.",
168+
)
169+
parser.add_argument(
170+
"--output_dir",
171+
type=str,
172+
default="quantized_models",
173+
help="Output directory for quantized model.",
174+
)
175+
args = parser.parse_args()
176+
177+
# Use parsed arguments
178+
run(
179+
model_id=args.model_id,
180+
quantization=args.quantization,
181+
tasks=args.tasks,
182+
device=args.device,
183+
batch_size=args.batch_size,
184+
model_output_dir=args.output_dir,
185+
)

benchmarks/_models/eval_hf_models.sh

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
# For llama3.1-8B
9+
python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --tasks wikitext hellaswag
10+
python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization float8dq-row --tasks wikitext hellaswag
11+
python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization float8dq-tensor --tasks wikitext hellaswag
12+
python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization float8wo --tasks wikitext hellaswag
13+
python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization int4wo-128 --tasks wikitext hellaswag
14+
python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization int8wo --tasks wikitext hellaswag
15+
python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization int8dq --tasks wikitext hellaswag
16+
python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization gemlitewo-128-4 --tasks wikitext hellaswag
17+
python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization gemlitewo-128-8 --tasks wikitext hellaswag
18+
19+
20+
# For llama3.2-3B
21+
python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --tasks wikitext hellaswag
22+
python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization float8dq-row --tasks wikitext hellaswag
23+
python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization float8dq-tensor --tasks wikitext hellaswag
24+
python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization float8wo --tasks wikitext hellaswag
25+
python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization int4wo-128 --tasks wikitext hellaswag
26+
python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization int8wo --tasks wikitext hellaswag
27+
python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization int8dq --tasks wikitext hellaswag
28+
python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization gemlitewo-128-4 --tasks wikitext hellaswag
29+
python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.2-3B --quantization gemlitewo-128-8 --tasks wikitext hellaswag

benchmarks/microbenchmarks/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ Currently, quantization string is in same format as the one being passed in llam
7171
- `int8wo`: 8-bit weight-only quantization
7272
- `int4wo-{group_size}`: 4-bit weight-only quantization with specified group size
7373
- `int4wo-{group_size}-hqq`: 4-bit weight-only quantization with HQQ
74+
- `gemlitewo-{bit_width}-{group_size}`: 4 or 8 bit integer quantization and utilizes the gemlite triton kernel
7475

7576
### Model Types
7677
- `linear`: Simple linear layer

benchmarks/microbenchmarks/utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
Float8DynamicActivationFloat8WeightConfig,
1919
Float8WeightOnlyConfig,
2020
FPXWeightOnlyConfig,
21+
GemliteUIntXWeightOnlyConfig,
2122
Int4WeightOnlyConfig,
2223
Int8DynamicActivationInt4WeightConfig,
2324
Int8DynamicActivationInt8WeightConfig,
@@ -291,6 +292,23 @@ def string_to_config(
291292
else:
292293
granularity = PerTensor()
293294
return Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
295+
if "gemlitewo" in quantization:
296+
params = quantization.split("-")
297+
bit_width = int(params[1]) if len(params) > 1 else 4
298+
group_size = (
299+
int(params[2])
300+
if len(params) > 2 and bit_width == 4
301+
else None
302+
if bit_width == 8
303+
else 64
304+
)
305+
assert group_size in [
306+
32,
307+
64,
308+
128,
309+
256,
310+
], f"int4wo group_size needs to be one of [32,64,128,256] but got {group_size}"
311+
return GemliteUIntXWeightOnlyConfig(group_size=group_size, bit_width=bit_width)
294312
return None
295313

296314

third_party/cutlass

Submodule cutlass updated 41 files

torchao/_models/README.md

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,43 @@
1-
## SAM2
1+
# LLAMA
2+
3+
## Eval on Llama 3.1 8B and Llama 3.2 3B
4+
5+
We use lm-eval tasks for evaluating TorchAO Quantization APIs on HuggingFace models. The results are in the table below:
6+
7+
| Model Name | Quantization Technique | Acc |Acc Norm| Word perplexity| Model Size (GB) |
8+
|------------|---------------------------|-------|--------|----------------|-------------------|
9+
| Llama 3.1 8B | None | 60.01 | 78.84 | 7.33 | 15.01 |
10+
| Llama 3.1 8B | int4wo-128 | 58.10 | 77.06 | 8.25 | 4.76 |
11+
| Llama 3.1 8B | int8wo | 59.92 | 78.95 | 7.34 | 8.04 |
12+
| Llama 3.1 8B | int8dq | 60.01 | 78.82 | 7.45 | 8.03 |
13+
| Llama 3.1 8B | float8wo | 59.83 | 78.61 | 7.37 | 8.03 |
14+
| Llama 3.1 8B | float8dq (PerRow) | 59.86 | 78.57 | 7.41 | 8.04 |
15+
| Llama 3.1 8B | float8dq (PerTensor) | 59.95 | 78.66 | 7.42 | 8.03 |
16+
| Llama 3.1 8B | gemlite (gp=128) | 58.48 | 77.34 | 8.07 | 4.76 |
17+
18+
| Model Name | Quantization Technique | Acc |Acc Norm| Word perplexity| Model Size (GB) |
19+
|------------|---------------------------|-------|--------|----------------|-------------------|
20+
| Llama 3.2 3B | None | 55.27 | 73.70 | 9.26 | 6.43 |
21+
| Llama 3.2 3B | int4wo-128 | 53.13 | 71.31 | 10.36 | 2.29 |
22+
| Llama 3.2 3B | int8wo | 55.15 | 73.44 | 9.28 | 3.61 |
23+
| Llama 3.2 3B | int8dq | 55.00 | 73.29 | 9.43 | 3.61 |
24+
| Llama 3.2 3B | float8wo | 55.18 | 73.58 | 9.31 | 3.61 |
25+
| Llama 3.2 3B | float8dq (PerRow) | 55.18 | 73.37 | 9.33 | 3.61 |
26+
| Llama 3.2 3B | float8dq (PerTensor) | 55.16 | 73.53 | 9.35 | 3.61 |
27+
| Llama 3.2 3B | gemlite (gp=128) | 53.71 | 71.99 | 10.05 | 2.29 |
28+
29+
To generate the above results run:
30+
```
31+
sh benchmarks/_models/eval_hf_models.sh
32+
```
33+
34+
To run lm-eval for a different hf-model with AO quantization technique, run:
35+
```
36+
python benchmarks/_models/eval_hf_models.py --model_id meta-llama/Llama-3.1-8B --quantization float8dq-row --tasks wikitext hellaswag
37+
```
38+
Replace model id, quantization and tasks with your desired values Please refer to ([HuggingFace <-> TorchAO](https://huggingface.co/docs/transformers/main/en//quantization/torchao)) integration docs for more details about the supported quantization techniques.
39+
40+
# SAM2
241
sam2 is a fork of https://github.com/facebookresearch/sam2 at commit c2ec8e14a185632b0a5d8b161928ceb50197eddc
342

443
It includes

0 commit comments

Comments
 (0)