Skip to content

Commit 7634409

Browse files
authored
add code-generaion evaluation for woq gptq (#1475)
Signed-off-by: changwangss <chang1.wang@intel.com> Signed-off-by: YIYANGCAI <yiyang.cai@intel.com> Signed-off-by: chensuyue <suyue.chen@intel.com>
1 parent c88d765 commit 7634409

File tree

4 files changed

+100
-25
lines changed

4 files changed

+100
-25
lines changed

examples/.config/model_params_pytorch.json

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,13 @@
471471
"main_script": "run_clm_no_trainer.py",
472472
"batch_size": 8
473473
},
474+
"opt_125m_woq_gptq_debug_int4":{
475+
"model_src_dir": "nlp/huggingface_models/language-modeling/quantization/llm",
476+
"dataset_location": "",
477+
"input_model": "",
478+
"main_script": "run_clm_no_trainer.py",
479+
"batch_size": 8
480+
},
474481
"opt_125m_woq_teq":{
475482
"model_src_dir": "nlp/huggingface_models/language-modeling/quantization/llm",
476483
"dataset_location": "",
@@ -513,7 +520,14 @@
513520
"main_script": "run_clm_no_trainer.py",
514521
"batch_size": 1
515522
},
516-
"gpt_j_woq_rtn":{
523+
"gpt_j_woq_rtn_int4":{
524+
"model_src_dir": "nlp/huggingface_models/language-modeling/quantization/llm",
525+
"dataset_location": "",
526+
"input_model": "",
527+
"main_script": "run_clm_no_trainer.py",
528+
"batch_size": 1
529+
},
530+
"gpt_j_woq_gptq_debug_int4":{
517531
"model_src_dir": "nlp/huggingface_models/language-modeling/quantization/llm",
518532
"dataset_location": "",
519533
"input_model": "",
@@ -527,6 +541,13 @@
527541
"main_script": "run_clm_no_trainer.py",
528542
"batch_size": 1
529543
},
544+
"falcon_7b_woq_gptq_debug_int4":{
545+
"model_src_dir": "nlp/huggingface_models/language-modeling/quantization/llm",
546+
"dataset_location": "",
547+
"input_model": "",
548+
"main_script": "run_clm_no_trainer.py",
549+
"batch_size": 1
550+
},
530551
"xlm-roberta-base_MRPC": {
531552
"model_src_dir": "nlp/huggingface_models/text-classification/quantization/ptq_static/fx",
532553
"dataset_location": "",

examples/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@ accelerate
22
protobuf
33
sentencepiece != 0.1.92
44
datasets >= 1.1.3
5+
peft
56
torch >= 1.10
67
transformers
78
pytest
89
wandb
910
einops
1011
neural-compressor
1112
intel-extension-for-transformers
12-
git+https://github.com/EleutherAI/lm-evaluation-harness.git@83dbfbf6070324f3e5872f63e49d49ff7ef4c9b3
13-
git+https://github.com/huggingface/peft.git@6c44096c7b8d55a2ecf24be9bc68393467e1584a
13+
git+https://github.com/EleutherAI/lm-evaluation-harness.git@cc9778fbe4fa1a709be2abed9deb6180fd40e7e2

examples/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_clm_no_trainer.py

Lines changed: 63 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
help="calibration iters.")
5353
parser.add_argument("--tasks", nargs='+', default=["lambada_openai",
5454
"hellaswag", "winogrande", "piqa", "wikitext"],
55-
type=str, help="tasks list for accuracy validation")
55+
type=str, help="tasks list for accuracy validation, text-generation and code-generation tasks are different.")
5656
parser.add_argument("--peft_model_id", type=str, default=None, help="model_name_or_path of peft model")
5757
# ============SmoothQuant configs==============
5858
parser.add_argument("--sq", action="store_true")
@@ -78,7 +78,40 @@
7878
this should align with your model config, \
7979
and your dataset builder args: args.pad_max_length')
8080
parser.add_argument('--gptq_debug', action='store_true', help='Whether to use debug model ')
81-
# =======================================
81+
# ==============code generation args===========
82+
parser.add_argument("--code_generation", action="store_true")
83+
parser.add_argument("--n_samples", default=200, type=int)
84+
parser.add_argument(
85+
"--limit", default=None, type=int, help="Limit number of samples to eval"
86+
)
87+
parser.add_argument("--allow_code_execution", action="store_true")
88+
parser.add_argument("--prefix", default="")
89+
parser.add_argument("--generation_only", action="store_true")
90+
parser.add_argument("--postprocess", action="store_false")
91+
parser.add_argument("--save_references", action="store_true")
92+
parser.add_argument("--save_generations", action="store_true")
93+
parser.add_argument("--instruction_tokens", default=None)
94+
parser.add_argument("--save_generations_path", default="generations.json")
95+
parser.add_argument("--load_generations_path", default=None)
96+
parser.add_argument("--metric_output_path", default="evaluation_results.json")
97+
parser.add_argument("--max_length_generation", default=512, type=int)
98+
parser.add_argument("--temperature", default=0.8, type=float)
99+
parser.add_argument("--top_p", default=0.8, type=float)
100+
parser.add_argument("--top_k", default=0, type=int)
101+
parser.add_argument("--do_sample", action="store_true")
102+
parser.add_argument("--check_references", action="store_true")
103+
parser.add_argument("--max_memory_per_gpu", type=str, default=None)
104+
parser.add_argument(
105+
"--modeltype",
106+
default="causal",
107+
help="AutoModel to use, it can be causal or seq2seq",
108+
)
109+
parser.add_argument(
110+
"--limit_start",
111+
type=int,
112+
default=0,
113+
help="Optional offset to start from when limiting the number of samples",
114+
)
82115

83116
args = parser.parse_args()
84117
if args.ipex:
@@ -262,7 +295,7 @@ def calib_func(prepared_model):
262295
if args.gptq_debug:
263296
from neural_compressor.adaptor.torch_utils.weight_only import gptq_quantize
264297

265-
conf = {
298+
gptq_conf = {
266299
".*": {
267300
'wbits': args.woq_bits, # 1-8 bits
268301
'group_size': args.woq_group_size, # -1 (per-channel)
@@ -272,20 +305,16 @@ def calib_func(prepared_model):
272305
}
273306
q_model_gptq_debug, gptq_config = gptq_quantize(
274307
user_model,
275-
weight_config=conf,
308+
weight_config=gptq_conf,
276309
dataloader=calib_dataloader,
277310
nsamples=args.gptq_nsamples,
278311
use_max_length=args.gptq_use_max_length,
279-
pad_max_length=args.gptq_pad_max_length
312+
pad_max_length=args.gptq_pad_max_length,
280313
)
281-
from intel_extension_for_transformers.llm.evaluation.lm_eval import evaluate
282314

283-
results = evaluate(
284-
model="hf-causal",
285-
model_args='pretrained=' + args.model + ',tokenizer=' + args.model + ',dtype=float32',
286-
user_model=q_model_gptq_debug, tasks=["lambada_openai"],
287-
batch_size=4
288-
)
315+
# save the fake quantized model
316+
os.makedirs(args.output_dir, exist_ok=True)
317+
torch.save(q_model_gptq_debug, os.path.join(args.output_dir, "gptq_best_model.pt"))
289318
exit(0)
290319

291320
else:
@@ -317,7 +346,6 @@ def calib_func(prepared_model):
317346
eval_dataset = load_dataset('lambada', split='validation')
318347
evaluator = Evaluator(eval_dataset, tokenizer)
319348

320-
321349
def eval_func(model):
322350
acc = evaluator.evaluate(model)
323351
return acc
@@ -347,15 +375,29 @@ def eval_func(model):
347375

348376
if args.accuracy:
349377
user_model.eval()
350-
from intel_extension_for_transformers.llm.evaluation.lm_eval import evaluate
378+
if args.gptq_debug:
379+
user_model = torch.load(os.path.join(args.output_dir, "gptq_best_model.pt"))
380+
if args.code_generation:
381+
from intel_extension_for_transformers.llm.evaluation.lm_code_eval import evaluate
382+
from transformers import AutoTokenizer
383+
tokenizer = AutoTokenizer.from_pretrained(args.model)
384+
results = evaluate(
385+
model=user_model,
386+
tokenizer=tokenizer,
387+
tasks=",".join(args.tasks),
388+
batch_size=args.batch_size,
389+
args=args,
390+
)
391+
else:
392+
from intel_extension_for_transformers.llm.evaluation.lm_eval import evaluate
393+
results = evaluate(
394+
model="hf-causal",
395+
model_args='pretrained=' + args.model + ',tokenizer=' + args.model + ',dtype=float32',
396+
user_model=user_model,
397+
batch_size=args.batch_size,
398+
tasks=args.tasks,
399+
)
351400

352-
results = evaluate(
353-
model="hf-causal",
354-
model_args='pretrained=' + args.model + ',tokenizer=' + args.model + ',dtype=float32',
355-
user_model=user_model,
356-
batch_size=args.batch_size,
357-
tasks=args.tasks,
358-
)
359401
dumped = json.dumps(results, indent=2)
360402
if args.save_accuracy_path:
361403
with open(args.save_accuracy_path, "w") as f:

examples/pytorch/nlp/huggingface_models/language-modeling/quantization/llm/run_quant.sh

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ function run_tuning {
5050
model_name_or_path="facebook/opt-125m"
5151
approach="weight_only"
5252
extra_cmd=$extra_cmd" --woq_algo GPTQ"
53+
elif [ "${topology}" = "opt_125m_woq_gptq_debug_int4" ]; then
54+
model_name_or_path="facebook/opt-125m"
55+
approach="weight_only"
56+
extra_cmd=$extra_cmd" --woq_algo GPTQ --woq_bits 4 --woq_scheme asym --woq_group_size 128 --gptq_use_max_length --gptq_debug"
5357
elif [ "${topology}" = "opt_125m_woq_teq" ]; then
5458
model_name_or_path="facebook/opt-125m"
5559
approach="weight_only"
@@ -69,13 +73,21 @@ function run_tuning {
6973
elif [ "${topology}" = "gpt_j_ipex_sq" ]; then
7074
model_name_or_path="EleutherAI/gpt-j-6b"
7175
extra_cmd=$extra_cmd" --ipex --sq --alpha 1.0"
72-
elif [ "${topology}" = "gpt_j_woq_rtn" ]; then
76+
elif [ "${topology}" = "gpt_j_woq_rtn_int4" ]; then
7377
model_name_or_path="EleutherAI/gpt-j-6b"
7478
approach="weight_only"
7579
extra_cmd=$extra_cmd" --woq_algo RTN --woq_bits 4 --woq_group_size 128 --woq_scheme asym --woq_enable_mse_search"
80+
elif [ "${topology}" = "gpt_j_woq_gptq_debug_int4" ]; then
81+
model_name_or_path="EleutherAI/gpt-j-6b"
82+
approach="weight_only"
83+
extra_cmd=$extra_cmd" --woq_algo GPTQ --woq_bits 4 --woq_group_size 128 --woq_scheme asym --gptq_use_max_length --gptq_debug"
7684
elif [ "${topology}" = "falcon_7b_sq" ]; then
7785
model_name_or_path="tiiuae/falcon-7b-instruct"
7886
extra_cmd=$extra_cmd" --sq --alpha 0.5"
87+
elif [ "${topology}" = "falcon_7b_woq_gptq_debug_int4" ]; then
88+
model_name_or_path="tiiuae/falcon-7b-instruct"
89+
approach="weight_only"
90+
extra_cmd=$extra_cmd" --woq_algo GPTQ --woq_bits 4 --woq_group_size 128 --woq_scheme asym --gptq_use_max_length --gptq_debug"
7991
fi
8092

8193
python -u run_clm_no_trainer.py \

0 commit comments

Comments
 (0)