Skip to content

Commit 044e6db

Browse files
yiliu30yuwenzhopre-commit-ci[bot]
authored
Refine base Quantizer (#1760)
Refine base Quantizer class --------- Signed-off-by: yuwenzho <yuwen.zhou@intel.com> Signed-off-by: yiliu30 <yi4.liu@intel.com> Co-authored-by: yuwenzho <yuwen.zhou@intel.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 95e67ea commit 044e6db

File tree

3 files changed

+51
-48
lines changed

3 files changed

+51
-48
lines changed

neural_compressor/torch/algorithms/base_algorithm.py

Lines changed: 46 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,45 +14,41 @@
1414

1515
from abc import ABC, abstractmethod
1616
from collections import OrderedDict
17-
from typing import Any
17+
from typing import Any, Optional
1818

1919
import torch
2020

2121
from neural_compressor.torch.utils import Mode
2222

2323

2424
class Quantizer(ABC):
25-
"""The base quantizer for all algorithm quantizers."""
25+
"""The base quantizer for all algorithm quantizers.
2626
27-
def __init__(self, tune_cfg: OrderedDict = {}):
28-
"""Init a Quantizer object.
27+
The `Quantizer` unifies the interfaces across various quantization algorithms, including GPTQ, RTN, etc.
28+
Given a float model, `Quantizer` apply the quantization algorithm to the model according to the `quant_config`.
2929
30-
Args:
31-
tune_cfg (OrderedDict, optional): quantization config for ops. Defaults to {}.
32-
Take weight-only quantization as an example,
33-
tune_cfg={
34-
'fc2':
35-
{
36-
'dtype': 'int',
37-
'bits': 4,
38-
'group_size': 32,
39-
'scheme': 'sym'
40-
}
41-
}
42-
"""
43-
self.tune_cfg = tune_cfg
30+
To implement a new quantization algorithm,, inherit from `Quantizer` and implement the following methods:
31+
- `prepare`: prepare a given model for convert.
32+
- `convert`: convert a prepared model to a quantized model.
33+
Note: `quantize` and `execute` are optional for new quantization algorithms.
34+
"""
4435

45-
@abstractmethod
46-
def quantize(self, model: torch.nn.Module, *args: Any, **kwargs: Any):
47-
"""Quantizes a given torch model.
36+
def __init__(self, quant_config: Optional[Any] = None):
37+
"""Init a Quantizer object.
4838
4939
Args:
50-
model (torch.nn.Module): The torch model to be quantized.
51-
52-
Returns:
53-
A quantized model.
40+
quant_config : Specifies how to apply the algorithm on the given model.
41+
The format of `quant_config` can be defined by `Quantized` itself.
42+
For example, `quant_config` can be a dictionary as below:
43+
quant_config={
44+
'fc2':{
45+
'dtype': 'int',
46+
'bits': 4,
47+
'group_size': 32,
48+
'scheme': 'sym'
49+
}}
5450
"""
55-
raise NotImplementedError("{} doesn't implement `quantize` function.".format(self.__class__.__name__))
51+
self.quant_config = quant_config
5652

5753
@abstractmethod
5854
def prepare(self, model: torch.nn.Module, *args: Any, **kwargs: Any):
@@ -80,6 +76,30 @@ def convert(self, model: torch.nn.Module, *args: Any, **kwargs: Any):
8076
"""
8177
raise NotImplementedError("{} doesn't implement `convert` function. ".format(self.__class__.__name__))
8278

79+
def quantize(self, model: torch.nn.Module, *args: Any, **kwargs: Any):
80+
"""Quantizes a given float model.
81+
82+
Args:
83+
model (torch.nn.Module): The float model to be quantized.
84+
85+
Returns:
86+
A quantized model.
87+
"""
88+
run_fn = kwargs.get("run_fn", None)
89+
run_args = kwargs.get("run_args", None)
90+
assert run_fn is not None, (
91+
"Can't find run_func. Please provide run_func to quantize API "
92+
"or overwrite quantize member function in your Quantizer class."
93+
)
94+
95+
model = self.prepare(model, *args, **kwargs)
96+
if run_args:
97+
run_fn(model, *run_args)
98+
else:
99+
run_fn(model)
100+
model = self.convert(model, *args, **kwargs)
101+
return model
102+
83103
def execute(self, model: torch.nn.Module, mode, *args: Any, **kwargs: Any): # pragma: no cover
84104
"""Execute according to mode.
85105

neural_compressor/torch/algorithms/static_quant/static_quant.py

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,13 @@
4646

4747

4848
class StaticQuantQuantizer(Quantizer):
49-
def __init__(self, tune_cfg: OrderedDict = {}):
49+
def __init__(self, quant_config: OrderedDict = {}):
5050
"""Init a StaticQuantQuantizer object.
5151
5252
Args:
53-
tune_cfg (OrderedDict, optional): quantization config for ops. Defaults to {}.
53+
quant_config (OrderedDict, optional): quantization config for ops. Defaults to {}.
5454
"""
55-
super().__init__(tune_cfg)
55+
super().__init__(quant_config)
5656

5757
def prepare(self, model, example_inputs, inplace=True, *args, **kwargs):
5858
"""Prepares a given model for quantization.
@@ -71,7 +71,7 @@ def prepare(self, model, example_inputs, inplace=True, *args, **kwargs):
7171
model, example_inputs
7272
)
7373
# update json file in ipex_config_path; map ipex op_name to pt op_name
74-
user_cfg = cfg_to_qconfig(self.tune_cfg, cfgs, op_infos_from_cfgs, output_tensor_id_op_name)
74+
user_cfg = cfg_to_qconfig(self.quant_config, cfgs, op_infos_from_cfgs, output_tensor_id_op_name)
7575
model.eval()
7676

7777
# Check save_qconf_summary part is a workaround for IPEX bug.
@@ -126,23 +126,6 @@ def convert(self, model, example_inputs, inplace=True, *args, **kwargs):
126126
model.save = MethodType(save, model)
127127
return model
128128

129-
def quantize(self, model, example_inputs, run_fn, inplace=True, *args, **kwargs):
130-
"""Quantizes a given torch model.
131-
132-
Args:
133-
model: A float model to be quantized.
134-
example_inputs: Used to trace torch model.
135-
run_fn: a calibration function for calibrating the model.
136-
inplace: Whether to carry out model transformations in-place. Defaults to True.
137-
138-
Returns:
139-
A quantized model.
140-
"""
141-
model = self.prepare(model, example_inputs=example_inputs, inplace=inplace)
142-
run_fn(model)
143-
model = self.convert(model, example_inputs=example_inputs, inplace=inplace)
144-
return model
145-
146129

147130
def _ipex_post_quant_process(model, example_inputs, inplace=False):
148131
"""Convert to a jit model.

neural_compressor/torch/quantization/algorithm_entry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def static_quant_entry(
155155
inplace = kwargs.get("inplace", True)
156156
assert example_inputs is not None, "Please provide example_inputs for static quantization."
157157

158-
quantizer = StaticQuantQuantizer(tune_cfg=quant_config_mapping)
158+
quantizer = StaticQuantQuantizer(quant_config=quant_config_mapping)
159159
model = quantizer.execute(model, mode=mode, run_fn=run_fn, example_inputs=example_inputs, inplace=inplace)
160160
return model
161161

0 commit comments

Comments
 (0)