Skip to content

Commit fd2adde

Browse files
committed
update quantizer and model relationship
Signed-off-by: xin3he <xin3.he@intel.com>
1 parent 9bba089 commit fd2adde

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

neural_compressor/torch/algorithms/weight_only/gptq.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def quantize(x, scale, zero, maxq):
183183
return scale * (q - zero)
184184

185185

186-
class GPTQuantizer(object):
186+
class RAWGPTQuantizer(object):
187187
"""Main API for GPTQ algorithm.
188188
189189
Please refer to:
@@ -1121,7 +1121,7 @@ def ready(self):
11211121
from neural_compressor.torch.algorithms import Quantizer as INCQuantizer
11221122

11231123

1124-
class INCGPTQQuantizer(INCQuantizer):
1124+
class GPTQuantizer(INCQuantizer):
11251125
def __init__(self, quant_config={}):
11261126
"""Init a RTNQuantizer object.
11271127
@@ -1149,9 +1149,8 @@ def prepare(
11491149
assert isinstance(model, torch.nn.Module), "only support torch module"
11501150
if use_layer_wise:
11511151
assert model_path is not None, "model_path should not be None when use layer wise mode"
1152-
from .gptq import GPTQuantizer
11531152

1154-
self.gptq_quantizer = GPTQuantizer(
1153+
self.gptq_quantizer = RAWGPTQuantizer(
11551154
model,
11561155
weight_config=self.quant_config,
11571156
nsamples=nsamples,

neural_compressor/torch/quantization/algorithm_entry.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def gptq_entry(
8787
**kwargs,
8888
) -> torch.nn.Module:
8989
logger.info("Quantize model with the GPTQ algorithm.")
90-
from neural_compressor.torch.algorithms.weight_only.gptq import INCGPTQQuantizer
90+
from neural_compressor.torch.algorithms.weight_only.gptq import GPTQuantizer
9191

9292
# rebuild weight_config for gptq_quantize function
9393
weight_config = {}
@@ -119,10 +119,15 @@ def gptq_entry(
119119
)
120120
kwargs.pop("example_inputs")
121121
logger.warning("lm_head in transformer model is skipped by GPTQ")
122-
123-
if CurrentQuantizer.quantizer is None or mode in [Mode.PREPARE, Mode.QUANTIZE]:
124-
CurrentQuantizer.quantizer = INCGPTQQuantizer(quant_config=weight_config)
125-
model = CurrentQuantizer.quantizer.execute(model, mode=mode, *args, **kwargs)
122+
if getattr(model, "quantizer", False):
123+
quantizer = model.quantizer
124+
else:
125+
quantizer = GPTQuantizer(quant_config=weight_config)
126+
model = quantizer.execute(model, mode=mode, *args, **kwargs)
127+
if getattr(model, "quantizer", False):
128+
del model.quantizer
129+
else:
130+
model.quantizer = quantizer
126131
return model
127132

128133

0 commit comments

Comments
 (0)