Skip to content

Gptq refactor #1770

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 61 additions & 71 deletions neural_compressor/torch/algorithms/weight_only/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def quantize(x, scale, zero, maxq):
return scale * (q - zero)


class GPTQuantizer(object):
class RAWGPTQuantizer(object):
"""Main API for GPTQ algorithm.

Please refer to:
Expand All @@ -195,15 +195,14 @@ def __init__(
self,
model,
weight_config={},
dataloader=None,
nsamples=128,
use_max_length=True,
max_seq_length=2048,
device=None,
export_compressed_model=False,
use_layer_wise=False,
model_path="",
run_fn=None,
dataloader=None,
*args,
**kwargs,
):
Expand All @@ -226,7 +225,6 @@ def __init__(
export_compressed_model (bool, optional): Choose return fp32 or int32 model. Defaults to False.
use_layer_wise (bool): Enables quantize model per layer. Defaults to False.
model_path (str): Model path that is used to load state_dict per layer.
run_fn: a function to run model inference for collecting input information.
device: cpu or cuda
"""
# model
Expand Down Expand Up @@ -271,9 +269,7 @@ def __init__(
self.dataloader_original = dataloader
self.dataloader = []
self.nsamples = nsamples
self.run_fn = run_fn
self.run_args = kwargs.get("run_args", None)
if run_fn is None:
if dataloader is not None:
self.prepare_dataloader()

def prepare_dataloader(self):
Expand Down Expand Up @@ -489,7 +485,7 @@ def track_hidden_states(self, data):
return data[0]

@torch.no_grad()
def pre_quantization(self):
def prepare_for_calibration(self):
"""Prepare input calibration data and other attributes which are critical for gptq execution."""
try:
self.cache_key_arguments = {
Expand Down Expand Up @@ -532,34 +528,13 @@ def forward(layer, *args, **kwargs):
# Step2: modify the first transformer block's forward function to obtain inputs for calibration
if not self.use_layer_wise:
self.gptq_related_blocks["transformers"][0] = self.gptq_related_blocks["transformers"][0].to(self.device)
forward_cache = self.gptq_related_blocks["transformers"][0].forward
self.forward_cache = self.gptq_related_blocks["transformers"][0].forward
self.gptq_related_blocks["transformers"][0].forward = partial(
forward, self.gptq_related_blocks["transformers"][0]
)

# Step3: run forward to obtain calibration datasets
logger.info("Collecting calibration inputs...")
logger.info("Collecting calibration inputs by running the run_fn provided by user.")
if self.run_fn:
if self.run_args:
self.run_fn(self.model, *self.run_args)
accelerator.mark_step()
else:
self.run_fn(self.model)
accelerator.mark_step()
else:
for batch in tqdm(self.dataloader):
if not self.use_layer_wise:
batch = move_input_to_device(batch, self.device)
try:
if isinstance(batch, tuple) or isinstance(batch, list):
self.model(batch[0])
elif isinstance(batch, dict):
self.model(**batch)
else:
self.model(batch)
except ValueError:
pass
@torch.no_grad()
def remove_prepare_for_calibration(self):
# output inp data shape
logger.info("All calibration data's shape =>")
# check all hidden_states shape
Expand All @@ -571,7 +546,7 @@ def forward(layer, *args, **kwargs):
logger.info("Done.")

# Step 4: restore original forward function, relocate layers back to cpu.
self.gptq_related_blocks["transformers"][0].forward = forward_cache
self.gptq_related_blocks["transformers"][0].forward = self.forward_cache
if not self.use_layer_wise:
self.gptq_related_blocks["transformers"][0] = self.gptq_related_blocks["transformers"][0].cpu()
for embedding_name, embedding_layer in self.gptq_related_blocks["embeddings"].items():
Expand Down Expand Up @@ -606,7 +581,6 @@ def execute_quantization(self, means=None, stds=None):
# Step1: prepare quantization (calibration datasets)

logger.info("Begin ====>")
self.pre_quantization()
model_path = self.model_path

# Step2: run gptq quantization in a transformer block-wise manner.
Expand Down Expand Up @@ -1144,41 +1118,57 @@ def ready(self):
return torch.all(self.scale != 0)


def gptq_quantize(
model,
weight_config={},
dataloader=None,
nsamples=128,
max_seq_length=2048,
use_max_length=True,
device=None,
export_compressed_model=False,
use_layer_wise=False,
model_path=None,
run_fn=None,
run_args=None,
):
"""Run weight-only quantization with."""
# TODO: unify weight_config keys, add docstring, and support default config
assert isinstance(model, torch.nn.Module), "only support torch module"
if use_layer_wise:
assert model_path is not None, "model_path should not be None when use layer wise mode"
from .gptq import GPTQuantizer

gptq_quantizer = GPTQuantizer(
from neural_compressor.torch.algorithms import Quantizer as INCQuantizer


class GPTQuantizer(INCQuantizer):
def __init__(self, quant_config={}):
"""Init a RTNQuantizer object.

Args:
quant_config (OrderedDict, optional): quantization config for ops. Defaults to {}.
"""
super().__init__(quant_config)

@torch.no_grad()
def prepare(
self,
model,
weight_config,
dataloader,
nsamples,
use_max_length,
max_seq_length,
device,
export_compressed_model=export_compressed_model,
use_layer_wise=use_layer_wise,
model_path=model_path,
run_fn=run_fn,
run_args=run_args,
)
fp32_modified_model, gptq_config = gptq_quantizer.execute_quantization()
logger.info("GPTQ quantizing done.")
return fp32_modified_model, gptq_config
nsamples=128,
max_seq_length=2048,
use_max_length=True,
device=None,
export_compressed_model=False,
use_layer_wise=False,
model_path=None,
*args,
**kwargs,
):
"""Run weight-only quantization with."""
# TODO: unify weight_config keys, add docstring, and support default config
assert isinstance(model, torch.nn.Module), "only support torch module"
if use_layer_wise:
assert model_path is not None, "model_path should not be None when use layer wise mode"

self.gptq_quantizer = RAWGPTQuantizer(
model,
weight_config=self.quant_config,
nsamples=nsamples,
use_max_length=use_max_length,
max_seq_length=max_seq_length,
device=device,
export_compressed_model=export_compressed_model,
use_layer_wise=use_layer_wise,
model_path=model_path,
)
self.gptq_quantizer.prepare_for_calibration()
return self.gptq_quantizer.model

@torch.no_grad()
def convert(self, model, *args, **kwargs):
self.gptq_quantizer.model = model
self.gptq_quantizer.remove_prepare_for_calibration()
q_model, gptq_config = self.gptq_quantizer.execute_quantization()
q_model.gptq_config = gptq_config
logger.info("GPTQ quantizing done.")
return q_model
26 changes: 17 additions & 9 deletions neural_compressor/torch/quantization/algorithm_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,14 @@ def rtn_entry(
@register_algo(GPTQ)
@torch.no_grad()
def gptq_entry(
model: torch.nn.Module, configs_mapping: Dict[Tuple[str, callable], GPTQConfig], *args, **kwargs
model: torch.nn.Module,
configs_mapping: Dict[Tuple[str, callable], GPTQConfig],
mode: Mode = Mode.QUANTIZE,
*args,
**kwargs,
) -> torch.nn.Module:
logger.info("Quantize model with the GPTQ algorithm.")
from neural_compressor.torch.algorithms.weight_only.gptq import gptq_quantize
from neural_compressor.torch.algorithms.weight_only.gptq import GPTQuantizer

# rebuild weight_config for gptq_quantize function
weight_config = {}
Expand Down Expand Up @@ -106,12 +110,16 @@ def gptq_entry(
}
)
kwargs.pop("example_inputs")
kwargs.pop("mode") # TODO: will be removed after GPTQ refactoring

logger.warning("lm_head in transformer model is skipped by GPTQ")
model, quantization_perm = gptq_quantize(model=model, weight_config=weight_config, *args, **kwargs)
# Assign the gptq config as an attribute of model
model._gptq_quantization_perm = quantization_perm
if getattr(model, "quantizer", False):
quantizer = model.quantizer
else:
quantizer = GPTQuantizer(quant_config=weight_config)
model = quantizer.execute(model, mode=mode, *args, **kwargs)
if getattr(model, "quantizer", False):
del model.quantizer
else:
model.quantizer = quantizer
return model


Expand All @@ -123,7 +131,7 @@ def static_quant_entry(
configs_mapping: Dict[Tuple[str, callable], StaticQuantConfig],
mode: Mode = Mode.QUANTIZE,
*args,
**kwargs
**kwargs,
) -> torch.nn.Module:
logger.info("Quantize model with the static quant algorithm.")
from neural_compressor.torch.algorithms.static_quant import StaticQuantQuantizer
Expand Down Expand Up @@ -333,7 +341,7 @@ def autoround_quantize_entry(
configs_mapping: Dict[Tuple[str, callable], AutoRoundConfig],
mode: Mode = Mode.QUANTIZE,
*args,
**kwargs
**kwargs,
) -> torch.nn.Module:
from neural_compressor.torch.algorithms.weight_only.autoround import AutoRoundQuantizer

Expand Down
Loading
Loading