Skip to content

Commit c8f7ac7

Browse files
committed
refactor gptq with prepare and convert API
Signed-off-by: xin3he <xin3.he@intel.com>
1 parent 5f3f388 commit c8f7ac7

File tree

3 files changed

+135
-94
lines changed

3 files changed

+135
-94
lines changed

neural_compressor/torch/algorithms/weight_only/gptq.py

Lines changed: 61 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -195,15 +195,14 @@ def __init__(
195195
self,
196196
model,
197197
weight_config={},
198-
dataloader=None,
199198
nsamples=128,
200199
use_max_length=True,
201200
max_seq_length=2048,
202201
device=None,
203202
export_compressed_model=False,
204203
use_layer_wise=False,
205204
model_path="",
206-
run_fn=None,
205+
dataloader=None,
207206
*args,
208207
**kwargs,
209208
):
@@ -226,7 +225,6 @@ def __init__(
226225
export_compressed_model (bool, optional): Choose return fp32 or int32 model. Defaults to False.
227226
use_layer_wise (bool): Enables quantize model per layer. Defaults to False.
228227
model_path (str): Model path that is used to load state_dict per layer.
229-
run_fn: a function to run model inference for collecting input information.
230228
device: cpu or cuda
231229
"""
232230
# model
@@ -271,9 +269,7 @@ def __init__(
271269
self.dataloader_original = dataloader
272270
self.dataloader = []
273271
self.nsamples = nsamples
274-
self.run_fn = run_fn
275-
self.run_args = kwargs.get("run_args", None)
276-
if run_fn is None:
272+
if dataloader is not None:
277273
self.prepare_dataloader()
278274

279275
def prepare_dataloader(self):
@@ -489,7 +485,7 @@ def track_hidden_states(self, data):
489485
return data[0]
490486

491487
@torch.no_grad()
492-
def pre_quantization(self):
488+
def prepare_for_calibration(self):
493489
"""Prepare input calibration data and other attributes which are critical for gptq execution."""
494490
try:
495491
self.cache_key_arguments = {
@@ -532,34 +528,13 @@ def forward(layer, *args, **kwargs):
532528
# Step2: modify the first transformer block's forward function to obtain inputs for calibration
533529
if not self.use_layer_wise:
534530
self.gptq_related_blocks["transformers"][0] = self.gptq_related_blocks["transformers"][0].to(self.device)
535-
forward_cache = self.gptq_related_blocks["transformers"][0].forward
531+
self.forward_cache = self.gptq_related_blocks["transformers"][0].forward
536532
self.gptq_related_blocks["transformers"][0].forward = partial(
537533
forward, self.gptq_related_blocks["transformers"][0]
538534
)
539535

540-
# Step3: run forward to obtain calibration datasets
541-
logger.info("Collecting calibration inputs...")
542-
logger.info("Collecting calibration inputs by running the run_fn provided by user.")
543-
if self.run_fn:
544-
if self.run_args:
545-
self.run_fn(self.model, *self.run_args)
546-
accelerator.mark_step()
547-
else:
548-
self.run_fn(self.model)
549-
accelerator.mark_step()
550-
else:
551-
for batch in tqdm(self.dataloader):
552-
if not self.use_layer_wise:
553-
batch = move_input_to_device(batch, self.device)
554-
try:
555-
if isinstance(batch, tuple) or isinstance(batch, list):
556-
self.model(batch[0])
557-
elif isinstance(batch, dict):
558-
self.model(**batch)
559-
else:
560-
self.model(batch)
561-
except ValueError:
562-
pass
536+
@torch.no_grad()
537+
def remove_prepare_for_calibration(self):
563538
# output inp data shape
564539
logger.info("All calibration data's shape =>")
565540
# check all hidden_states shape
@@ -571,7 +546,7 @@ def forward(layer, *args, **kwargs):
571546
logger.info("Done.")
572547

573548
# Step 4: restore original forward function, relocate layers back to cpu.
574-
self.gptq_related_blocks["transformers"][0].forward = forward_cache
549+
self.gptq_related_blocks["transformers"][0].forward = self.forward_cache
575550
if not self.use_layer_wise:
576551
self.gptq_related_blocks["transformers"][0] = self.gptq_related_blocks["transformers"][0].cpu()
577552
for embedding_name, embedding_layer in self.gptq_related_blocks["embeddings"].items():
@@ -606,7 +581,6 @@ def execute_quantization(self, means=None, stds=None):
606581
# Step1: prepare quantization (calibration datasets)
607582

608583
logger.info("Begin ====>")
609-
self.pre_quantization()
610584
model_path = self.model_path
611585

612586
# Step2: run gptq quantization in a transformer block-wise manner.
@@ -1144,41 +1118,58 @@ def ready(self):
11441118
return torch.all(self.scale != 0)
11451119

11461120

1147-
def gptq_quantize(
1148-
model,
1149-
weight_config={},
1150-
dataloader=None,
1151-
nsamples=128,
1152-
max_seq_length=2048,
1153-
use_max_length=True,
1154-
device=None,
1155-
export_compressed_model=False,
1156-
use_layer_wise=False,
1157-
model_path=None,
1158-
run_fn=None,
1159-
run_args=None,
1160-
):
1161-
"""Run weight-only quantization with."""
1162-
# TODO: unify weight_config keys, add docstring, and support default config
1163-
assert isinstance(model, torch.nn.Module), "only support torch module"
1164-
if use_layer_wise:
1165-
assert model_path is not None, "model_path should not be None when use layer wise mode"
1166-
from .gptq import GPTQuantizer
1167-
1168-
gptq_quantizer = GPTQuantizer(
1121+
from neural_compressor.torch.algorithms import Quantizer as INCQuantizer
1122+
1123+
1124+
class INCGPTQQuantizer(INCQuantizer):
1125+
def __init__(self, quant_config={}):
1126+
"""Init a RTNQuantizer object.
1127+
1128+
Args:
1129+
quant_config (OrderedDict, optional): quantization config for ops. Defaults to {}.
1130+
"""
1131+
super().__init__(quant_config)
1132+
1133+
@torch.no_grad()
1134+
def prepare(
1135+
self,
11691136
model,
1170-
weight_config,
1171-
dataloader,
1172-
nsamples,
1173-
use_max_length,
1174-
max_seq_length,
1175-
device,
1176-
export_compressed_model=export_compressed_model,
1177-
use_layer_wise=use_layer_wise,
1178-
model_path=model_path,
1179-
run_fn=run_fn,
1180-
run_args=run_args,
1181-
)
1182-
fp32_modified_model, gptq_config = gptq_quantizer.execute_quantization()
1183-
logger.info("GPTQ quantizing done.")
1184-
return fp32_modified_model, gptq_config
1137+
nsamples=128,
1138+
max_seq_length=2048,
1139+
use_max_length=True,
1140+
device=None,
1141+
export_compressed_model=False,
1142+
use_layer_wise=False,
1143+
model_path=None,
1144+
*args,
1145+
**kwargs,
1146+
):
1147+
"""Run weight-only quantization with."""
1148+
# TODO: unify weight_config keys, add docstring, and support default config
1149+
assert isinstance(model, torch.nn.Module), "only support torch module"
1150+
if use_layer_wise:
1151+
assert model_path is not None, "model_path should not be None when use layer wise mode"
1152+
from .gptq import GPTQuantizer
1153+
1154+
self.gptq_quantizer = GPTQuantizer(
1155+
model,
1156+
weight_config=self.quant_config,
1157+
nsamples=nsamples,
1158+
use_max_length=use_max_length,
1159+
max_seq_length=max_seq_length,
1160+
device=device,
1161+
export_compressed_model=export_compressed_model,
1162+
use_layer_wise=use_layer_wise,
1163+
model_path=model_path,
1164+
)
1165+
self.gptq_quantizer.prepare_for_calibration()
1166+
return self.gptq_quantizer.model
1167+
1168+
@torch.no_grad()
1169+
def convert(self, model, *args, **kwargs):
1170+
self.gptq_quantizer.model = model
1171+
self.gptq_quantizer.remove_prepare_for_calibration()
1172+
q_model, gptq_config = self.gptq_quantizer.execute_quantization()
1173+
q_model.gptq_config = gptq_config
1174+
logger.info("GPTQ quantizing done.")
1175+
return q_model

neural_compressor/torch/quantization/algorithm_entry.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,19 @@
3333
from neural_compressor.torch.utils import Mode, logger, register_algo
3434

3535

36+
class CurrentQuantizer:
37+
quantizer = None
38+
39+
3640
###################### RTN Algo Entry ##################################
3741
@register_algo(RTN)
3842
@torch.no_grad()
3943
def rtn_entry(
40-
model: torch.nn.Module, configs_mapping: Dict[Tuple[str, callable], RTNConfig], *args, **kwargs
44+
model: torch.nn.Module,
45+
configs_mapping: Dict[Tuple[str, callable], RTNConfig],
46+
mode: Mode = Mode.QUANTIZE,
47+
*args,
48+
**kwargs,
4149
) -> torch.nn.Module:
4250
"""The main entry to apply rtn quantization."""
4351
from neural_compressor.torch.algorithms.weight_only.rtn import rtn_quantize
@@ -72,10 +80,14 @@ def rtn_entry(
7280
@register_algo(GPTQ)
7381
@torch.no_grad()
7482
def gptq_entry(
75-
model: torch.nn.Module, configs_mapping: Dict[Tuple[str, callable], GPTQConfig], *args, **kwargs
83+
model: torch.nn.Module,
84+
configs_mapping: Dict[Tuple[str, callable], GPTQConfig],
85+
mode: Mode = Mode.QUANTIZE,
86+
*args,
87+
**kwargs,
7688
) -> torch.nn.Module:
7789
logger.info("Quantize model with the GPTQ algorithm.")
78-
from neural_compressor.torch.algorithms.weight_only.gptq import gptq_quantize
90+
from neural_compressor.torch.algorithms.weight_only.gptq import INCGPTQQuantizer
7991

8092
# rebuild weight_config for gptq_quantize function
8193
weight_config = {}
@@ -106,12 +118,11 @@ def gptq_entry(
106118
}
107119
)
108120
kwargs.pop("example_inputs")
109-
kwargs.pop("mode") # TODO: will be removed after GPTQ refactoring
110-
111121
logger.warning("lm_head in transformer model is skipped by GPTQ")
112-
model, quantization_perm = gptq_quantize(model=model, weight_config=weight_config, *args, **kwargs)
113-
# Assign the gptq config as an attribute of model
114-
model._gptq_quantization_perm = quantization_perm
122+
123+
if CurrentQuantizer.quantizer is None or mode == [Mode.PREPARE, Mode.QUANTIZE]:
124+
CurrentQuantizer.quantizer = INCGPTQQuantizer(quant_config=weight_config)
125+
model = CurrentQuantizer.quantizer.execute(model, mode=mode, *args, **kwargs)
115126
return model
116127

117128

@@ -123,7 +134,7 @@ def static_quant_entry(
123134
configs_mapping: Dict[Tuple[str, callable], StaticQuantConfig],
124135
mode: Mode = Mode.QUANTIZE,
125136
*args,
126-
**kwargs
137+
**kwargs,
127138
) -> torch.nn.Module:
128139
logger.info("Quantize model with the static quant algorithm.")
129140
from neural_compressor.torch.algorithms.static_quant import StaticQuantQuantizer
@@ -223,7 +234,11 @@ def smooth_quant_entry(
223234
@register_algo(name=AWQ)
224235
@torch.no_grad()
225236
def awq_quantize_entry(
226-
model: torch.nn.Module, configs_mapping: Dict[Tuple[str, callable], AWQConfig], *args, **kwargs
237+
model: torch.nn.Module,
238+
configs_mapping: Dict[Tuple[str, callable], AWQConfig],
239+
mode: Mode = Mode.QUANTIZE,
240+
*args,
241+
**kwargs,
227242
) -> torch.nn.Module:
228243
logger.info("Quantize model with the AWQ algorithm.")
229244
from neural_compressor.torch.algorithms.weight_only.awq import awq_quantize
@@ -333,7 +348,7 @@ def autoround_quantize_entry(
333348
configs_mapping: Dict[Tuple[str, callable], AutoRoundConfig],
334349
mode: Mode = Mode.QUANTIZE,
335350
*args,
336-
**kwargs
351+
**kwargs,
337352
) -> torch.nn.Module:
338353
from neural_compressor.torch.algorithms.weight_only.autoround import AutoRoundQuantizer
339354

@@ -406,7 +421,11 @@ def autoround_quantize_entry(
406421
@register_algo(name=HQQ)
407422
@torch.no_grad()
408423
def hqq_entry(
409-
model: torch.nn.Module, configs_mapping: Dict[Tuple[str, Callable], HQQConfig], *args, **kwargs
424+
model: torch.nn.Module,
425+
configs_mapping: Dict[Tuple[str, Callable], HQQConfig],
426+
mode: Mode = Mode.QUANTIZE,
427+
*args,
428+
**kwargs,
410429
) -> torch.nn.Module:
411430
from neural_compressor.torch.algorithms.weight_only.hqq import hqq_quantize
412431

0 commit comments

Comments
 (0)