|
14 | 14 |
|
15 | 15 | from abc import ABC, abstractmethod
|
16 | 16 | from collections import OrderedDict
|
17 |
| -from typing import Any |
| 17 | +from typing import Any, Optional |
18 | 18 |
|
19 | 19 | import torch
|
20 | 20 |
|
21 | 21 | from neural_compressor.torch.utils import Mode
|
22 | 22 |
|
23 | 23 |
|
24 | 24 | class Quantizer(ABC):
|
25 |
| - """The base quantizer for all algorithm quantizers.""" |
| 25 | + """The base quantizer for all algorithm quantizers. |
26 | 26 |
|
27 |
| - def __init__(self, tune_cfg: OrderedDict = {}): |
28 |
| - """Init a Quantizer object. |
| 27 | + The `Quantizer` unifies the interfaces across various quantization algorithms, including GPTQ, RTN, etc. |
| 28 | + Given a float model, `Quantizer` apply the quantization algorithm to the model according to the `quant_config`. |
29 | 29 |
|
30 |
| - Args: |
31 |
| - tune_cfg (OrderedDict, optional): quantization config for ops. Defaults to {}. |
32 |
| - Take weight-only quantization as an example, |
33 |
| - tune_cfg={ |
34 |
| - 'fc2': |
35 |
| - { |
36 |
| - 'dtype': 'int', |
37 |
| - 'bits': 4, |
38 |
| - 'group_size': 32, |
39 |
| - 'scheme': 'sym' |
40 |
| - } |
41 |
| - } |
42 |
| - """ |
43 |
| - self.tune_cfg = tune_cfg |
| 30 | + To implement a new quantization algorithm,, inherit from `Quantizer` and implement the following methods: |
| 31 | + - `prepare`: prepare a given model for convert. |
| 32 | + - `convert`: convert a prepared model to a quantized model. |
| 33 | + Note: `quantize` and `execute` are optional for new quantization algorithms. |
| 34 | + """ |
44 | 35 |
|
45 |
| - @abstractmethod |
46 |
| - def quantize(self, model: torch.nn.Module, *args: Any, **kwargs: Any): |
47 |
| - """Quantizes a given torch model. |
| 36 | + def __init__(self, quant_config: Optional[Any] = None): |
| 37 | + """Init a Quantizer object. |
48 | 38 |
|
49 | 39 | Args:
|
50 |
| - model (torch.nn.Module): The torch model to be quantized. |
51 |
| -
|
52 |
| - Returns: |
53 |
| - A quantized model. |
| 40 | + quant_config : Specifies how to apply the algorithm on the given model. |
| 41 | + The format of `quant_config` can be defined by `Quantized` itself. |
| 42 | + For example, `quant_config` can be a dictionary as below: |
| 43 | + quant_config={ |
| 44 | + 'fc2':{ |
| 45 | + 'dtype': 'int', |
| 46 | + 'bits': 4, |
| 47 | + 'group_size': 32, |
| 48 | + 'scheme': 'sym' |
| 49 | + }} |
54 | 50 | """
|
55 |
| - raise NotImplementedError("{} doesn't implement `quantize` function.".format(self.__class__.__name__)) |
| 51 | + self.quant_config = quant_config |
56 | 52 |
|
57 | 53 | @abstractmethod
|
58 | 54 | def prepare(self, model: torch.nn.Module, *args: Any, **kwargs: Any):
|
@@ -80,6 +76,30 @@ def convert(self, model: torch.nn.Module, *args: Any, **kwargs: Any):
|
80 | 76 | """
|
81 | 77 | raise NotImplementedError("{} doesn't implement `convert` function. ".format(self.__class__.__name__))
|
82 | 78 |
|
| 79 | + def quantize(self, model: torch.nn.Module, *args: Any, **kwargs: Any): |
| 80 | + """Quantizes a given float model. |
| 81 | +
|
| 82 | + Args: |
| 83 | + model (torch.nn.Module): The float model to be quantized. |
| 84 | +
|
| 85 | + Returns: |
| 86 | + A quantized model. |
| 87 | + """ |
| 88 | + run_fn = kwargs.get("run_fn", None) |
| 89 | + run_args = kwargs.get("run_args", None) |
| 90 | + assert run_fn is not None, ( |
| 91 | + "Can't find run_func. Please provide run_func to quantize API " |
| 92 | + "or overwrite quantize member function in your Quantizer class." |
| 93 | + ) |
| 94 | + |
| 95 | + model = self.prepare(model, *args, **kwargs) |
| 96 | + if run_args: |
| 97 | + run_fn(model, *run_args) |
| 98 | + else: |
| 99 | + run_fn(model) |
| 100 | + model = self.convert(model, *args, **kwargs) |
| 101 | + return model |
| 102 | + |
83 | 103 | def execute(self, model: torch.nn.Module, mode, *args: Any, **kwargs: Any): # pragma: no cover
|
84 | 104 | """Execute according to mode.
|
85 | 105 |
|
|
0 commit comments