Skip to content

Commit 111b3ce

Browse files
add llama2 examples for smoothquant (#1470)
Signed-off-by: chensuyue <suyue.chen@intel.com> Co-authored-by: Lu, Yintong <yintong.lu@intel.com>
1 parent 18dd8f8 commit 111b3ce

File tree

4 files changed

+288
-1
lines changed

4 files changed

+288
-1
lines changed

docs/source/smooth_quant.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ IPEX (Intel Extension for PyTorch): 2.0/2.1
324324

325325
Dataset: lambada_openai
326326

327-
Task: text-generation
327+
Task: text-generation provided by [ITREX](https://github.com/intel/intel-extension-for-transformers/tree/main/examples/huggingface/pytorch/text-generation/quantization)
328328

329329
alpha [0.4, 0.6] is sweet spot region in SmoothQuant paper.
330330

@@ -370,6 +370,13 @@ A list of models that achieved a <1% accuracy drop is shown below.
370370
| databricks/dolly-v2-3b* | 0.6297 | 0.6247 | alpha=0.5, Ipex 2.1 |
371371
| tiiuae/falcon-7b-instruct | 0.6437 | 0.6392 | alpha=0.7, Pytorch |
372372

373+
The results listed below are achieved using IPEX optimize_transformers in model initialization for better performance. Please refer to the step-by-step [instruction](../../examples/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/ipex/README.md) for details.
374+
| Model/Last token accuracy | FP32 Accuracy | INT8 (w/ SmoothQuant) | Notes |
375+
|:----------:|:------:|:------:|-----------------------------------|
376+
| LLaMa-2-7b-hf* | 0.7392 | 0.7332 | alpha=Auto, Ipex 2.1 |
377+
| LLaMa-2-13b-hf* | 0.7677 | 0.7632 | alpha=Auto, Ipex 2.1 |
378+
379+
373380
Please note that for models with asterisk(*), we have set all add ops to FP32 during quantization step to achieve desirable results.
374381
## Example
375382

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
Step-by-Step
2+
============
3+
This document describes the step-by-step instructions to run llama2 SmoothQuant with Intel® Neural Compressor and Intel® Extension for PyTorch.
4+
5+
# Prerequisite
6+
```
7+
# Installation dependencies
8+
pip install -r requirements.txt
9+
```
10+
11+
# Run Quantization
12+
13+
## Llama-2-7b
14+
```bash
15+
python run_llama2_sq.py \
16+
--model-id meta-llama/Llama-2-7b-hf \
17+
--batch-size 56 \
18+
--sq-recipes "llama2-7b"
19+
```
20+
## Llama-2-13b
21+
```bash
22+
python run_llama2_sq.py \
23+
--model-id meta-llama/Llama-2-13b-hf \
24+
--batch-size 56 \
25+
--sq-recipes "llama2-13b" \
26+
--padding
27+
```
28+
> Notes:
29+
> - INT8 model will be saved into "./saved_results" including "./saved_results/best_configure.json" and "./saved_results/best_model.pt", which can be loaded and evaluated by IPEX.
30+
> - Parameter "--sq-recipes" decides the recipes used to do quantize, details can be found in scripts.
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
neural-compressor==2.4
2+
transformers==4.32.0
3+
datasets
4+
accelerate
5+
sentencepiece
6+
protobuf
7+
--extra-index-url https://download.pytorch.org/whl/cpu
8+
torch==2.1.0+cpu
9+
intel-extension-for-pytorch==2.1.0
10+
Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
import argparse
2+
3+
from datasets import load_dataset
4+
from transformers import LlamaForCausalLM, LlamaTokenizer, AutoConfig
5+
6+
import torch
7+
from torch.nn.functional import pad
8+
from torch.utils.data import DataLoader
9+
10+
import intel_extension_for_pytorch as ipex
11+
12+
parser = argparse.ArgumentParser('LLaMA generation script (int8 path)', add_help=False)
13+
14+
parser.add_argument(
15+
"-m", "--model-id", default=None, type=str, required=True, help="your llama model"
16+
)
17+
parser.add_argument(
18+
"--sq-recipes", default=None, type=str, required=True, help="llama2-7b or llama2-13b"
19+
)
20+
parser.add_argument(
21+
"--max-new-tokens", default=32, type=int, help="output max new tokens"
22+
)
23+
parser.add_argument("--dataset", nargs="?", default="NeelNanda/pile-10k")
24+
parser.add_argument("--output-dir", nargs="?", default="./saved_results")
25+
26+
parser.add_argument(
27+
"--int8-bf16-mixed",
28+
action="store_true",
29+
help="by default it is int8-fp32 mixed, to enable int8 mixed amp bf16 (work on platforms like SPR)",
30+
)
31+
parser.add_argument("--input-tokens", default="32", type=str)
32+
parser.add_argument("--prompt", default=None, type=str)
33+
parser.add_argument("--padding", action="store_true", help="whether do padding in calib_dataloader")
34+
parser.add_argument("--batch-size", default=1, type=int, help="batch size")
35+
parser.add_argument("--alpha", default=0.8, type=float, help="alpha value for smoothquant")
36+
parser.add_argument("--greedy", action="store_true")
37+
38+
args = parser.parse_args()
39+
40+
try:
41+
ipex._C.disable_jit_linear_repack()
42+
except Exception:
43+
pass
44+
45+
# amp autocast
46+
if args.int8_bf16_mixed:
47+
amp_enabled = True
48+
amp_dtype = torch.bfloat16
49+
else:
50+
amp_enabled = False
51+
amp_dtype = torch.float32
52+
53+
num_beams = 1 if args.greedy else 4
54+
55+
# load model
56+
config = AutoConfig.from_pretrained(args.model_id, torchscript=True)
57+
if not hasattr(config, "text_max_length") and args.prompt is None:
58+
config.text_max_length = int(args.input_tokens) + int(args.max_new_tokens)
59+
60+
user_model = LlamaForCausalLM.from_pretrained(
61+
args.model_id, config=config, low_cpu_mem_usage=True, torch_dtype=torch.float
62+
)
63+
64+
tokenizer = LlamaTokenizer.from_pretrained(args.model_id)
65+
print("Data type of the model:", user_model.dtype)
66+
67+
# dummy past key value
68+
beam_idx_tmp = torch.zeros(
69+
(2048, int(args.batch_size * num_beams)), dtype=torch.long
70+
).contiguous()
71+
global_past_key_value = [
72+
(
73+
torch.zeros(1, 0, 0, 1, dtype=torch.long).contiguous(),
74+
torch.zeros(
75+
[
76+
1,
77+
user_model.config.num_attention_heads,
78+
1,
79+
int(
80+
user_model.config.hidden_size
81+
/ user_model.config.num_attention_heads
82+
),
83+
]
84+
).contiguous(),
85+
torch.zeros(
86+
[
87+
1,
88+
user_model.config.num_attention_heads,
89+
1,
90+
int(
91+
user_model.config.hidden_size
92+
/ user_model.config.num_attention_heads
93+
),
94+
]
95+
).contiguous(),
96+
beam_idx_tmp,
97+
)
98+
for i in range(user_model.config.num_hidden_layers)
99+
]
100+
101+
102+
class Evaluator:
103+
104+
def __init__(self, dataset, tokenizer, batch_size=1, pad_val=1, pad_max=512):
105+
self.dataset = dataset
106+
self.tokenizer = tokenizer
107+
self.batch_size = batch_size
108+
self.pad_val = pad_val
109+
self.pad_max = pad_max
110+
111+
# tokenize the dataset
112+
self.dataset = self.dataset.map(self.tokenize_function, batched=True)
113+
self.dataset.set_format(type="torch", columns=["input_ids"])
114+
115+
@torch.no_grad()
116+
def tokenize_function(self, examples):
117+
if "prompt" in examples:
118+
example = self.tokenizer(examples["prompt"])
119+
elif "text" in examples:
120+
example = self.tokenizer(examples["text"])
121+
elif "code" in examples:
122+
example = self.tokenizer(examples["code"])
123+
return example
124+
125+
@torch.no_grad()
126+
def collate_batch(self, batch):
127+
position_ids_padded = []
128+
input_ids_padded = []
129+
last_ind = []
130+
attention_mask_padded = []
131+
for text in batch:
132+
input_ids = text["input_ids"]
133+
if not args.padding:
134+
input_ids = (
135+
input_ids[: int(self.pad_max)]
136+
if len(input_ids) > int(self.pad_max)
137+
else input_ids
138+
) #no_padding
139+
else:
140+
pad_len = self.pad_max - input_ids.shape[0]
141+
input_ids = pad(input_ids, (0, pad_len), value=self.pad_val)
142+
last_ind.append(input_ids.shape[0] - 1)
143+
attention_mask = torch.ones(len(input_ids))
144+
position_ids = torch.arange(len(input_ids))
145+
input_ids_padded.append(input_ids)
146+
attention_mask_padded.append(attention_mask)
147+
position_ids_padded.append(position_ids)
148+
return (
149+
(
150+
torch.vstack(input_ids_padded),
151+
torch.vstack(attention_mask_padded),
152+
torch.vstack(position_ids_padded),
153+
tuple(global_past_key_value),
154+
),
155+
torch.tensor(last_ind),
156+
)
157+
158+
159+
calib_dataset = load_dataset(args.dataset, split="train")
160+
user_model.eval()
161+
if args.sq_recipes == "llama2-7b":
162+
pad_max = 2048
163+
elif args.sq_recipes == "llama2-13b":
164+
pad_max = 1024
165+
else:
166+
pad_max = 512
167+
calib_evaluator = Evaluator(calib_dataset, tokenizer, args.batch_size, pad_max=pad_max)
168+
calib_dataloader = DataLoader(
169+
calib_evaluator.dataset,
170+
batch_size=1,
171+
shuffle=False,
172+
collate_fn=calib_evaluator.collate_batch,
173+
)
174+
175+
176+
def calib_func(prepared_model):
177+
for i, (
178+
(input_ids, attention_mask, position_ids, past_key_values),
179+
last_ind,
180+
) in enumerate(calib_dataloader):
181+
if i == 512:
182+
break
183+
prepared_model(
184+
input_ids,
185+
attention_mask=attention_mask,
186+
position_ids=position_ids,
187+
past_key_values=past_key_values,
188+
)
189+
190+
191+
example_inputs = None
192+
for i, (
193+
(input_ids, attention_mask, position_ids, past_key_values),
194+
last_ind,
195+
) in enumerate(calib_dataloader):
196+
example_inputs = (input_ids, attention_mask, position_ids, past_key_values)
197+
break
198+
199+
qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=args.alpha)
200+
user_model = ipex.optimize_transformers(
201+
user_model.eval(),
202+
dtype=amp_dtype,
203+
quantization_config=qconfig,
204+
inplace=True,
205+
deployment_mode=False,
206+
)
207+
208+
# steps for SmoothQuant with Intel® Neural Compressor
209+
from neural_compressor import PostTrainingQuantConfig, quantization
210+
211+
# quantization recipes
212+
excluded_precisions = [] if args.int8_bf16_mixed else ["bf16"]
213+
op_type_dict = {"add": {"weight": {"dtype": ["fp32"]}, "activation": {"dtype": ["fp32"]}}}
214+
recipes = {}
215+
if args.sq_recipes == "llama2-7b":
216+
recipes = {"smooth_quant": True, "smooth_quant_args": {'alpha': 'auto', 'folding': False, 'default_alpha': 0.8,
217+
'auto_alpha_args': {"alpha_min": 0.8, "alpha_max": 0.99,
218+
"alpha_step": 0.01,
219+
"shared_criterion": "mean"}}}
220+
elif args.sq_recipes == "llama2-13b":
221+
recipes = {"smooth_quant": True, "smooth_quant_args": {'alpha': 'auto', 'folding': False, 'default_alpha': 0.8,
222+
'auto_alpha_args': {"alpha_min": 0.75, "alpha_max": 0.99,
223+
"alpha_step": 0.01,
224+
"shared_criterion": "max"}}}
225+
226+
227+
conf = PostTrainingQuantConfig(
228+
backend="ipex",
229+
excluded_precisions=excluded_precisions,
230+
op_type_dict=op_type_dict,
231+
recipes=recipes,
232+
example_inputs=example_inputs,
233+
)
234+
q_model = quantization.fit(
235+
user_model,
236+
conf,
237+
calib_dataloader=calib_dataloader,
238+
calib_func=calib_func,
239+
)
240+
q_model.save(args.output_dir)

0 commit comments

Comments
 (0)