diff --git a/neural_compressor/torch/algorithms/weight_only/gptq.py b/neural_compressor/torch/algorithms/weight_only/gptq.py index 53bee017076..c8c8d9a4ec6 100644 --- a/neural_compressor/torch/algorithms/weight_only/gptq.py +++ b/neural_compressor/torch/algorithms/weight_only/gptq.py @@ -183,7 +183,7 @@ def quantize(x, scale, zero, maxq): return scale * (q - zero) -class GPTQuantizer(object): +class RAWGPTQuantizer(object): """Main API for GPTQ algorithm. Please refer to: @@ -195,7 +195,6 @@ def __init__( self, model, weight_config={}, - dataloader=None, nsamples=128, use_max_length=True, max_seq_length=2048, @@ -203,7 +202,7 @@ def __init__( export_compressed_model=False, use_layer_wise=False, model_path="", - run_fn=None, + dataloader=None, *args, **kwargs, ): @@ -226,7 +225,6 @@ def __init__( export_compressed_model (bool, optional): Choose return fp32 or int32 model. Defaults to False. use_layer_wise (bool): Enables quantize model per layer. Defaults to False. model_path (str): Model path that is used to load state_dict per layer. - run_fn: a function to run model inference for collecting input information. device: cpu or cuda """ # model @@ -271,9 +269,7 @@ def __init__( self.dataloader_original = dataloader self.dataloader = [] self.nsamples = nsamples - self.run_fn = run_fn - self.run_args = kwargs.get("run_args", None) - if run_fn is None: + if dataloader is not None: self.prepare_dataloader() def prepare_dataloader(self): @@ -489,7 +485,7 @@ def track_hidden_states(self, data): return data[0] @torch.no_grad() - def pre_quantization(self): + def prepare_for_calibration(self): """Prepare input calibration data and other attributes which are critical for gptq execution.""" try: self.cache_key_arguments = { @@ -532,34 +528,13 @@ def forward(layer, *args, **kwargs): # Step2: modify the first transformer block's forward function to obtain inputs for calibration if not self.use_layer_wise: self.gptq_related_blocks["transformers"][0] = self.gptq_related_blocks["transformers"][0].to(self.device) - forward_cache = self.gptq_related_blocks["transformers"][0].forward + self.forward_cache = self.gptq_related_blocks["transformers"][0].forward self.gptq_related_blocks["transformers"][0].forward = partial( forward, self.gptq_related_blocks["transformers"][0] ) - # Step3: run forward to obtain calibration datasets - logger.info("Collecting calibration inputs...") - logger.info("Collecting calibration inputs by running the run_fn provided by user.") - if self.run_fn: - if self.run_args: - self.run_fn(self.model, *self.run_args) - accelerator.mark_step() - else: - self.run_fn(self.model) - accelerator.mark_step() - else: - for batch in tqdm(self.dataloader): - if not self.use_layer_wise: - batch = move_input_to_device(batch, self.device) - try: - if isinstance(batch, tuple) or isinstance(batch, list): - self.model(batch[0]) - elif isinstance(batch, dict): - self.model(**batch) - else: - self.model(batch) - except ValueError: - pass + @torch.no_grad() + def remove_prepare_for_calibration(self): # output inp data shape logger.info("All calibration data's shape =>") # check all hidden_states shape @@ -571,7 +546,7 @@ def forward(layer, *args, **kwargs): logger.info("Done.") # Step 4: restore original forward function, relocate layers back to cpu. - self.gptq_related_blocks["transformers"][0].forward = forward_cache + self.gptq_related_blocks["transformers"][0].forward = self.forward_cache if not self.use_layer_wise: self.gptq_related_blocks["transformers"][0] = self.gptq_related_blocks["transformers"][0].cpu() for embedding_name, embedding_layer in self.gptq_related_blocks["embeddings"].items(): @@ -606,7 +581,6 @@ def execute_quantization(self, means=None, stds=None): # Step1: prepare quantization (calibration datasets) logger.info("Begin ====>") - self.pre_quantization() model_path = self.model_path # Step2: run gptq quantization in a transformer block-wise manner. @@ -1144,41 +1118,57 @@ def ready(self): return torch.all(self.scale != 0) -def gptq_quantize( - model, - weight_config={}, - dataloader=None, - nsamples=128, - max_seq_length=2048, - use_max_length=True, - device=None, - export_compressed_model=False, - use_layer_wise=False, - model_path=None, - run_fn=None, - run_args=None, -): - """Run weight-only quantization with.""" - # TODO: unify weight_config keys, add docstring, and support default config - assert isinstance(model, torch.nn.Module), "only support torch module" - if use_layer_wise: - assert model_path is not None, "model_path should not be None when use layer wise mode" - from .gptq import GPTQuantizer - - gptq_quantizer = GPTQuantizer( +from neural_compressor.torch.algorithms import Quantizer as INCQuantizer + + +class GPTQuantizer(INCQuantizer): + def __init__(self, quant_config={}): + """Init a RTNQuantizer object. + + Args: + quant_config (OrderedDict, optional): quantization config for ops. Defaults to {}. + """ + super().__init__(quant_config) + + @torch.no_grad() + def prepare( + self, model, - weight_config, - dataloader, - nsamples, - use_max_length, - max_seq_length, - device, - export_compressed_model=export_compressed_model, - use_layer_wise=use_layer_wise, - model_path=model_path, - run_fn=run_fn, - run_args=run_args, - ) - fp32_modified_model, gptq_config = gptq_quantizer.execute_quantization() - logger.info("GPTQ quantizing done.") - return fp32_modified_model, gptq_config + nsamples=128, + max_seq_length=2048, + use_max_length=True, + device=None, + export_compressed_model=False, + use_layer_wise=False, + model_path=None, + *args, + **kwargs, + ): + """Run weight-only quantization with.""" + # TODO: unify weight_config keys, add docstring, and support default config + assert isinstance(model, torch.nn.Module), "only support torch module" + if use_layer_wise: + assert model_path is not None, "model_path should not be None when use layer wise mode" + + self.gptq_quantizer = RAWGPTQuantizer( + model, + weight_config=self.quant_config, + nsamples=nsamples, + use_max_length=use_max_length, + max_seq_length=max_seq_length, + device=device, + export_compressed_model=export_compressed_model, + use_layer_wise=use_layer_wise, + model_path=model_path, + ) + self.gptq_quantizer.prepare_for_calibration() + return self.gptq_quantizer.model + + @torch.no_grad() + def convert(self, model, *args, **kwargs): + self.gptq_quantizer.model = model + self.gptq_quantizer.remove_prepare_for_calibration() + q_model, gptq_config = self.gptq_quantizer.execute_quantization() + q_model.gptq_config = gptq_config + logger.info("GPTQ quantizing done.") + return q_model diff --git a/neural_compressor/torch/quantization/algorithm_entry.py b/neural_compressor/torch/quantization/algorithm_entry.py index e9c17474f8a..3e4a75393ca 100644 --- a/neural_compressor/torch/quantization/algorithm_entry.py +++ b/neural_compressor/torch/quantization/algorithm_entry.py @@ -72,10 +72,14 @@ def rtn_entry( @register_algo(GPTQ) @torch.no_grad() def gptq_entry( - model: torch.nn.Module, configs_mapping: Dict[Tuple[str, callable], GPTQConfig], *args, **kwargs + model: torch.nn.Module, + configs_mapping: Dict[Tuple[str, callable], GPTQConfig], + mode: Mode = Mode.QUANTIZE, + *args, + **kwargs, ) -> torch.nn.Module: logger.info("Quantize model with the GPTQ algorithm.") - from neural_compressor.torch.algorithms.weight_only.gptq import gptq_quantize + from neural_compressor.torch.algorithms.weight_only.gptq import GPTQuantizer # rebuild weight_config for gptq_quantize function weight_config = {} @@ -106,12 +110,16 @@ def gptq_entry( } ) kwargs.pop("example_inputs") - kwargs.pop("mode") # TODO: will be removed after GPTQ refactoring - logger.warning("lm_head in transformer model is skipped by GPTQ") - model, quantization_perm = gptq_quantize(model=model, weight_config=weight_config, *args, **kwargs) - # Assign the gptq config as an attribute of model - model._gptq_quantization_perm = quantization_perm + if getattr(model, "quantizer", False): + quantizer = model.quantizer + else: + quantizer = GPTQuantizer(quant_config=weight_config) + model = quantizer.execute(model, mode=mode, *args, **kwargs) + if getattr(model, "quantizer", False): + del model.quantizer + else: + model.quantizer = quantizer return model @@ -123,7 +131,7 @@ def static_quant_entry( configs_mapping: Dict[Tuple[str, callable], StaticQuantConfig], mode: Mode = Mode.QUANTIZE, *args, - **kwargs + **kwargs, ) -> torch.nn.Module: logger.info("Quantize model with the static quant algorithm.") from neural_compressor.torch.algorithms.static_quant import StaticQuantQuantizer @@ -333,7 +341,7 @@ def autoround_quantize_entry( configs_mapping: Dict[Tuple[str, callable], AutoRoundConfig], mode: Mode = Mode.QUANTIZE, *args, - **kwargs + **kwargs, ) -> torch.nn.Module: from neural_compressor.torch.algorithms.weight_only.autoround import AutoRoundQuantizer diff --git a/test/3x/torch/quantization/weight_only/test_gptq.py b/test/3x/torch/quantization/weight_only/test_gptq.py index 6325269c2c2..ea1ce5a8878 100644 --- a/test/3x/torch/quantization/weight_only/test_gptq.py +++ b/test/3x/torch/quantization/weight_only/test_gptq.py @@ -5,7 +5,19 @@ import transformers from neural_compressor.torch.algorithms.weight_only.modules import WeightOnlyLinear -from neural_compressor.torch.quantization import GPTQConfig, get_default_gptq_config, get_default_rtn_config, quantize +from neural_compressor.torch.quantization import ( + GPTQConfig, + convert, + get_default_gptq_config, + get_default_rtn_config, + prepare, + quantize, +) + + +def run_fn_for_rtn(model): + model(torch.tensor([[10, 20, 30]], dtype=torch.long)) + model(torch.tensor([[40, 50, 60]], dtype=torch.long)) def run_fn(model): @@ -32,18 +44,40 @@ def test_accuracy_improvement(self): # test_default_rtn_config model = copy.deepcopy(self.tiny_gptj) quant_config = get_default_rtn_config() - model = quantize(model, quant_config, run_fn=run_fn) + model = prepare(model, quant_config) + run_fn_for_rtn(model) + model = convert(model) rtn_label = model(self.example_inputs)[0] rtn_atol = (rtn_label - self.label).amax() # test_default_gptq_config model = copy.deepcopy(self.tiny_gptj) quant_config = get_default_gptq_config() - model = quantize(model, quant_config, run_fn=run_fn) + model = prepare(model, quant_config) + run_fn(model) + model = convert(model) gptq_label = model(self.example_inputs)[0] gptq_atol = (gptq_label - self.label).amax() # 0.05 VS 0.08 assert gptq_atol < rtn_atol, "GPTQ should have lower atol than RTN, please double check." + def test_quantize_API(self): + # test_default_gptq_config + model = copy.deepcopy(self.tiny_gptj) + quant_config = get_default_gptq_config() + model = prepare(model, quant_config) + run_fn(model) + model = convert(model) + gptq_label = model(self.example_inputs)[0] + gptq_atol_1 = (gptq_label - self.label).amax() + # quantize API + model = copy.deepcopy(self.tiny_gptj) + quant_config = get_default_gptq_config() + model = quantize(model, quant_config, run_fn=run_fn) + gptq_label = model(self.example_inputs)[0] + gptq_atol_2 = (gptq_label - self.label).amax() + # 0.05 VS 0.08 + assert gptq_atol_1 == gptq_atol_2, "GPTQ should have lower atol than RTN, please double check." + @pytest.mark.parametrize( "bits, use_sym, group_size", [ @@ -62,7 +96,9 @@ def test_int_params(self, bits, use_sym, group_size): use_sym=use_sym, group_size=group_size, ) - model = quantize(model, quant_config, run_fn=run_fn) + model = prepare(model, quant_config) + run_fn(model) + model = convert(model) out = model(self.example_inputs)[0] assert (out != self.label).all(), "WOQ output should be different with raw output" if (bits, use_sym, group_size) == (8, True, 128): @@ -78,7 +114,9 @@ def test_mse_search(self): quant_config = GPTQConfig( use_mse_search=False, ) - model = quantize(model, quant_config, run_fn=run_fn) + model = prepare(model, quant_config) + run_fn(model) + model = convert(model) out = model(self.example_inputs)[0] atol_false = (out - self.label).amax() # use_mse_search=True @@ -86,7 +124,9 @@ def test_mse_search(self): quant_config = GPTQConfig( use_mse_search=True, ) - model = quantize(model, quant_config, run_fn=run_fn) + model = prepare(model, quant_config) + run_fn(model) + model = convert(model) out = model(self.example_inputs)[0] atol_true = (out - self.label).amax() # compare atol, this case is an ideal case. @@ -110,7 +150,9 @@ def test_export_compressed_model(self, dtype): dtype=dtype, export_compressed_model=False, ) - model = quantize(model, quant_config, run_fn=run_fn) + model = prepare(model, quant_config) + run_fn(model) + model = convert(model) out1 = model(self.example_inputs)[0] # export_compressed_model = True model = copy.deepcopy(self.tiny_gptj) @@ -118,7 +160,9 @@ def test_export_compressed_model(self, dtype): dtype=dtype, export_compressed_model=True, ) - model = quantize(model, quant_config, run_fn=run_fn) + model = prepare(model, quant_config) + run_fn(model) + model = convert(model) out2 = model(self.example_inputs)[0] assert isinstance(model.transformer.h[0].attn.k_proj, WeightOnlyLinear), "Exporting compressed model failed." @@ -139,7 +183,9 @@ def test_dtype_params(self, dtype): quant_config = GPTQConfig( dtype=dtype, ) - model = quantize(model, quant_config, run_fn=run_fn) + model = prepare(model, quant_config) + run_fn(model) + model = convert(model) out = model(self.example_inputs)[0] atol = (out - self.label).amax() assert atol < 0.12, "Accuracy gap atol > 0.12 is unexpected. Please double check." @@ -159,7 +205,9 @@ def test_double_quant_params(self, dtype, double_quant_bits, double_quant_group_ double_quant_use_sym=False, double_quant_group_size=double_quant_group_size, ) - model = quantize(model, quant_config, run_fn=run_fn) + model = prepare(model, quant_config) + run_fn(model) + model = convert(model) out = model(self.example_inputs)[0] atol_false = (out - self.label).amax() model = copy.deepcopy(self.tiny_gptj) @@ -171,7 +219,9 @@ def test_double_quant_params(self, dtype, double_quant_bits, double_quant_group_ double_quant_use_sym=True, double_quant_group_size=double_quant_group_size, ) - model = quantize(model, quant_config, run_fn=run_fn) + model = prepare(model, quant_config) + run_fn(model) + model = convert(model) out = model(self.example_inputs)[0] atol_true = (out - self.label).amax() # compare atol, this case is not an ideal case.