Skip to content

Commit 18dd8f8

Browse files
authored
format fix for llm example scripts (#1474)
Signed-off-by: chensuyue <suyue.chen@intel.com>
1 parent eb615ed commit 18dd8f8

File tree

1 file changed

+52
-37
lines changed

1 file changed

+52
-37
lines changed

examples/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_clm_no_trainer.py

Lines changed: 52 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import argparse
22
import os
33
import sys
4+
45
sys.path.append('./')
56
import time
67
import json
@@ -33,7 +34,7 @@
3334
'--seed',
3435
type=int, default=42, help='Seed for sampling the calibration data.'
3536
)
36-
parser.add_argument("--approach", type=str, default='static',
37+
parser.add_argument("--approach", type=str, default='static',
3738
help="Select from ['dynamic', 'static', 'weight-only']")
3839
parser.add_argument("--int8", action="store_true")
3940
parser.add_argument("--ipex", action="store_true", help="Use intel extension for pytorch.")
@@ -50,38 +51,41 @@
5051
parser.add_argument("--calib_iters", default=512, type=int,
5152
help="calibration iters.")
5253
parser.add_argument("--tasks", nargs='+', default=["lambada_openai",
53-
"hellaswag","winogrande","piqa","wikitext"],
54-
type=str, help="tasks list for accuracy validation")
54+
"hellaswag", "winogrande", "piqa", "wikitext"],
55+
type=str, help="tasks list for accuracy validation")
5556
parser.add_argument("--peft_model_id", type=str, default=None, help="model_name_or_path of peft model")
5657
# ============SmoothQuant configs==============
5758
parser.add_argument("--sq", action="store_true")
5859
parser.add_argument("--alpha", default="auto", help="Smooth quant parameter.")
5960
# ============WeightOnly configs===============
60-
parser.add_argument("--woq_algo", default="RTN", choices=['RTN', 'AWQ', 'TEQ', 'GPTQ'],
61+
parser.add_argument("--woq_algo", default="RTN", choices=['RTN', 'AWQ', 'TEQ', 'GPTQ'],
6162
help="Weight-only parameter.")
6263
parser.add_argument("--woq_bits", type=int, default=8)
6364
parser.add_argument("--woq_group_size", type=int, default=-1)
6465
parser.add_argument("--woq_scheme", default="sym")
6566
parser.add_argument("--woq_enable_mse_search", action="store_true")
6667
parser.add_argument("--woq_enable_full_range", action="store_true")
6768
# =============GPTQ configs====================
68-
parser.add_argument("--gptq_actorder", action="store_true", help="Whether to apply the activation order GPTQ heuristic.")
69-
parser.add_argument('--gptq_percdamp', type=float, default=.01, help='Percent of the average Hessian diagonal to use for dampening.')
69+
parser.add_argument("--gptq_actorder", action="store_true",
70+
help="Whether to apply the activation order GPTQ heuristic.")
71+
parser.add_argument('--gptq_percdamp', type=float, default=.01,
72+
help='Percent of the average Hessian diagonal to use for dampening.')
7073
parser.add_argument('--gptq_block_size', type=int, default=128, help='Block size. sub weight matrix size to run GPTQ.')
7174
parser.add_argument('--gptq_nsamples', type=int, default=128, help='Number of calibration data samples.')
72-
parser.add_argument('--gptq_use_max_length', action="store_true", help='Set all sequence length to be same length of args.gptq_pad_max_length')
75+
parser.add_argument('--gptq_use_max_length', action="store_true",
76+
help='Set all sequence length to be same length of args.gptq_pad_max_length')
7377
parser.add_argument('--gptq_pad_max_length', type=int, default=2048, help='Calibration dataset sequence max length, \
7478
this should align with your model config, \
7579
and your dataset builder args: args.pad_max_length')
7680
parser.add_argument('--gptq_debug', action='store_true', help='Whether to use debug model ')
77-
parser.add_argument('--gptq_gpu', action='store_true', help='Whether to use gpu')
7881
# =======================================
7982

8083
args = parser.parse_args()
8184
if args.ipex:
8285
import intel_extension_for_pytorch as ipex
8386
calib_size = 1
8487

88+
8589
class Evaluator:
8690
def __init__(self, dataset, tokenizer, batch_size=8, pad_val=1, pad_max=196, is_calib=False):
8791
self.dataset = dataset
@@ -149,7 +153,7 @@ def evaluate(self, model):
149153
pred = last_token_logits.argmax(dim=-1)
150154
total += label.size(0)
151155
hit += (pred == label).sum().item()
152-
if (i+1) % 50 == 0:
156+
if (i + 1) % 50 == 0:
153157
print(hit / total)
154158
print("Processed minibatch:", i)
155159

@@ -187,6 +191,7 @@ def get_user_model():
187191
user_model.eval()
188192
return user_model, tokenizer
189193

194+
190195
if args.quantize:
191196
# dataset
192197
user_model, tokenizer = get_user_model()
@@ -201,43 +206,46 @@ def get_user_model():
201206
collate_fn=calib_evaluator.collate_batch,
202207
)
203208

209+
204210
def calib_func(prepared_model):
205211
for i, calib_input in enumerate(calib_dataloader):
206212
if i > args.calib_iters:
207213
break
208214
prepared_model(calib_input[0])
209215

216+
210217
recipes = {}
211218
eval_func = None
212219
from neural_compressor import PostTrainingQuantConfig, quantization
220+
213221
# specify the op_type_dict and op_name_dict
214222
if args.approach == 'weight_only':
215223
op_type_dict = {
216-
'.*':{ # re.match
224+
'.*': { # re.match
217225
"weight": {
218-
'bits': args.woq_bits, # 1-8 bits
226+
'bits': args.woq_bits, # 1-8 bits
219227
'group_size': args.woq_group_size, # -1 (per-channel)
220-
'scheme': args.woq_scheme, # sym/asym
221-
'algorithm': args.woq_algo, # RTN/AWQ/TEQ
228+
'scheme': args.woq_scheme, # sym/asym
229+
'algorithm': args.woq_algo, # RTN/AWQ/TEQ
222230
},
223231
},
224232
}
225-
op_name_dict={
226-
'lm_head':{"weight": {'dtype': 'fp32'},},
227-
'embed_out':{"weight": {'dtype': 'fp32'},}, # for dolly_v2
233+
op_name_dict = {
234+
'lm_head': {"weight": {'dtype': 'fp32'}, },
235+
'embed_out': {"weight": {'dtype': 'fp32'}, }, # for dolly_v2
228236
}
229237
recipes["rtn_args"] = {
230238
"enable_mse_search": args.woq_enable_mse_search,
231239
"enable_full_range": args.woq_enable_full_range,
232240
}
233241
recipes['gptq_args'] = {
234-
'percdamp': args.gptq_percdamp,
235-
'act_order':args.gptq_actorder,
236-
'block_size': args.gptq_block_size,
237-
'nsamples': args.gptq_nsamples,
238-
'use_max_length': args.gptq_use_max_length,
239-
'pad_max_length': args.gptq_pad_max_length
240-
}
242+
'percdamp': args.gptq_percdamp,
243+
'act_order': args.gptq_actorder,
244+
'block_size': args.gptq_block_size,
245+
'nsamples': args.gptq_nsamples,
246+
'use_max_length': args.gptq_use_max_length,
247+
'pad_max_length': args.gptq_pad_max_length
248+
}
241249
# GPTQ: use assistive functions to modify calib_dataloader and calib_func
242250
# TEQ: set calib_func=None, use default training func as calib_func
243251
if args.woq_algo in ["GPTQ", "TEQ"]:
@@ -253,30 +261,32 @@ def calib_func(prepared_model):
253261
# for test on various models, keep the code of directly call gptq_quantize
254262
if args.gptq_debug:
255263
from neural_compressor.adaptor.torch_utils.weight_only import gptq_quantize
264+
256265
conf = {
257-
".*":{
258-
'wbits': args.woq_bits, # 1-8 bits
266+
".*": {
267+
'wbits': args.woq_bits, # 1-8 bits
259268
'group_size': args.woq_group_size, # -1 (per-channel)
260269
'sym': (args.woq_scheme == "sym"),
261270
'act_order': args.gptq_actorder,
262271
}
263-
}
272+
}
264273
q_model_gptq_debug, gptq_config = gptq_quantize(
265-
user_model,
266-
weight_config=conf,
267-
dataloader=calib_dataloader,
268-
nsamples = args.gptq_nsamples,
269-
use_max_length = args.gptq_use_max_length,
270-
pad_max_length = args.gptq_pad_max_length
274+
user_model,
275+
weight_config=conf,
276+
dataloader=calib_dataloader,
277+
nsamples=args.gptq_nsamples,
278+
use_max_length=args.gptq_use_max_length,
279+
pad_max_length=args.gptq_pad_max_length
271280
)
272281
from intel_extension_for_transformers.llm.evaluation.lm_eval import evaluate
282+
273283
results = evaluate(
274284
model="hf-causal",
275-
model_args='pretrained='+args.model+',tokenizer='+args.model+',dtype=float32',
285+
model_args='pretrained=' + args.model + ',tokenizer=' + args.model + ',dtype=float32',
276286
user_model=q_model_gptq_debug, tasks=["lambada_openai"],
277-
device=DEV.type,
278287
batch_size=4
279288
)
289+
exit(0)
280290

281291
else:
282292
if re.search("gpt", user_model.config.model_type):
@@ -306,6 +316,8 @@ def calib_func(prepared_model):
306316
if isinstance(args.alpha, list):
307317
eval_dataset = load_dataset('lambada', split='validation')
308318
evaluator = Evaluator(eval_dataset, tokenizer)
319+
320+
309321
def eval_func(model):
310322
acc = evaluator.evaluate(model)
311323
return acc
@@ -323,6 +335,7 @@ def eval_func(model):
323335
if args.int8 or args.int8_bf16_mixed:
324336
print("load int8 model")
325337
from neural_compressor.utils.pytorch import load
338+
326339
if args.ipex:
327340
user_model = load(os.path.abspath(os.path.expanduser(args.output_dir)))
328341
else:
@@ -335,9 +348,10 @@ def eval_func(model):
335348
if args.accuracy:
336349
user_model.eval()
337350
from intel_extension_for_transformers.llm.evaluation.lm_eval import evaluate
351+
338352
results = evaluate(
339353
model="hf-causal",
340-
model_args='pretrained='+args.model+',tokenizer='+args.model+',dtype=float32',
354+
model_args='pretrained=' + args.model + ',tokenizer=' + args.model + ',dtype=float32',
341355
user_model=user_model,
342356
batch_size=args.batch_size,
343357
tasks=args.tasks,
@@ -358,11 +372,12 @@ def eval_func(model):
358372
user_model.eval()
359373
from intel_extension_for_transformers.llm.evaluation.lm_eval import evaluate
360374
import time
375+
361376
samples = args.iters * args.batch_size
362377
start = time.time()
363378
results = evaluate(
364379
model="hf-causal",
365-
model_args='pretrained='+args.model+',tokenizer='+args.model+',dtype=float32',
380+
model_args='pretrained=' + args.model + ',tokenizer=' + args.model + ',dtype=float32',
366381
user_model=user_model,
367382
batch_size=args.batch_size,
368383
tasks=args.tasks,
@@ -376,5 +391,5 @@ def eval_func(model):
376391
acc = results["results"][task_name]["acc"]
377392
print("Accuracy: %.5f" % acc)
378393
print('Throughput: %.3f samples/sec' % (samples / (end - start)))
379-
print('Latency: %.3f ms' % ((end - start)*1000 / samples))
394+
print('Latency: %.3f ms' % ((end - start) * 1000 / samples))
380395
print('Batch size = %d' % args.batch_size)

0 commit comments

Comments
 (0)