Skip to content

Commit e2ee1e8

Browse files
authored
[Feature]Add support for models quantized with AutoRound (vllm-project#17850)
Signed-off-by: wenhuach21 <wenhua.cheng@intel.com>
1 parent 20d8ce8 commit e2ee1e8

File tree

3 files changed

+339
-0
lines changed

3 files changed

+339
-0
lines changed

tests/quantization/test_auto_round.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""Test model set-up and inference for quantized HF models supported
3+
on the AutoRound.
4+
5+
Validating the configuration and printing results for manual checking.
6+
7+
Run `pytest tests/quantization/test_auto_round.py`.
8+
"""
9+
10+
import pytest
11+
12+
from vllm.platforms import current_platform
13+
14+
MODELS = [
15+
"OPEA/Qwen2.5-0.5B-Instruct-int4-sym-inc", ##auto_round:auto_gptq
16+
"Intel/Qwen2-0.5B-Instruct-int4-sym-AutoRound" ##auto_round:auto_awq
17+
]
18+
19+
20+
@pytest.mark.skipif(not current_platform.is_cpu()
21+
and not current_platform.is_xpu()
22+
and not current_platform.is_cuda(),
23+
reason="only supports CPU/XPU/CUDA backend.")
24+
@pytest.mark.parametrize("model", MODELS)
25+
def test_auto_round(vllm_runner, model):
26+
with vllm_runner(model) as llm:
27+
output = llm.generate_greedy(["The capital of France is"],
28+
max_tokens=8)
29+
assert output
30+
print(f"{output[0][1]}")

vllm/model_executor/layers/quantization/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
"quark",
3434
"moe_wna16",
3535
"torchao",
36+
"auto-round",
3637
]
3738
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
3839

@@ -84,6 +85,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
8485
from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig
8586

8687
from .aqlm import AQLMConfig
88+
from .auto_round import AutoRoundConfig
8789
from .awq import AWQConfig
8890
from .awq_marlin import AWQMarlinConfig
8991
from .bitblas import BitBLASConfig
@@ -138,6 +140,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
138140
"quark": QuarkConfig,
139141
"moe_wna16": MoeWNA16Config,
140142
"torchao": TorchAOConfig,
143+
"auto-round": AutoRoundConfig,
141144
}
142145
# Update the `method_to_config` with customized quantization methods.
143146
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)
Lines changed: 306 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,306 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
from fractions import Fraction
4+
from typing import Any, Optional, Union
5+
6+
import torch
7+
8+
from vllm.logger import init_logger
9+
from vllm.model_executor.layers.linear import (LinearBase,
10+
UnquantizedLinearMethod)
11+
from vllm.model_executor.layers.quantization.base_config import (
12+
QuantizationConfig)
13+
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
14+
from vllm.platforms import current_platform
15+
from vllm.scalar_type import scalar_types
16+
17+
logger = init_logger(__name__)
18+
19+
20+
class AutoRoundConfig(QuantizationConfig):
21+
"""Config class for AutoRound.
22+
Reference: https://arxiv.org/pdf/2309.05516
23+
"""
24+
25+
SUPPORTED_BITS = {2, 3, 4, 8}
26+
SUPPORTED_DTYPES = {"int"}
27+
SUPPORTED_FORMATS = {"auto_round:auto_gptq", "auto_round:auto_awq"}
28+
SUPPORTED_BACKENDS = {
29+
"auto", "gptq", "gptq:marlin", "awq", "awq:marlin", "marlin", "ipex"
30+
}
31+
32+
def __init__(
33+
self,
34+
weight_bits: int,
35+
group_size: int,
36+
sym: bool = True,
37+
packing_format: str = "auto_round:auto_gptq",
38+
block_name_to_quantize: Optional[Union[str, list[str]]] = None,
39+
extra_config: Optional[dict[str, Any]] = None,
40+
data_type: str = "int",
41+
backend: str = "auto",
42+
) -> None:
43+
super().__init__()
44+
if weight_bits not in self.SUPPORTED_BITS:
45+
raise ValueError(f"Unsupported weight_bits: {weight_bits}, "
46+
f"currently only support {self.SUPPORTED_BITS}")
47+
if data_type not in self.SUPPORTED_DTYPES:
48+
raise ValueError(
49+
f"Unsupported data_type: {data_type},"
50+
f" currently only support {self.SUPPORTED_DTYPES}")
51+
if packing_format not in self.SUPPORTED_FORMATS:
52+
raise ValueError(
53+
f"Unsupported packing_format: {packing_format}, "
54+
f"currently only support {self.SUPPORTED_FORMATS}")
55+
if backend not in self.SUPPORTED_BACKENDS:
56+
raise ValueError(
57+
f"Unsupported backend: {backend}, "
58+
f"currently only support {self.SUPPORTED_BACKENDS}")
59+
60+
self.weight_bits = weight_bits
61+
self.group_size = group_size
62+
self.sym = sym
63+
self.packing_format = packing_format
64+
self.block_name_to_quantize = (block_name_to_quantize.split(",") if
65+
isinstance(block_name_to_quantize, str)
66+
else block_name_to_quantize)
67+
self.extra_config = extra_config
68+
self.data_type = data_type
69+
self.backend = backend
70+
self.pack_factor = Fraction(32, weight_bits)
71+
72+
def __repr__(self) -> str:
73+
return (f"AutoRoundConfig(weight_bits={self.weight_bits}, "
74+
f"group_size={self.group_size}, sym={self.sym})")
75+
76+
@classmethod
77+
def get_name(cls): ## use str will trigger preci issue
78+
return "auto-round"
79+
80+
@classmethod
81+
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
82+
return [torch.half, torch.bfloat16]
83+
84+
@classmethod
85+
def get_min_capability(cls) -> int:
86+
return 60
87+
88+
@classmethod
89+
def get_config_filenames(cls) -> list[str]:
90+
return ["quantization_config.json"]
91+
92+
@classmethod
93+
def from_config(cls, config: dict[str, Any]) -> "AutoRoundConfig":
94+
return cls(
95+
weight_bits=cls.get_from_keys(config, ["bits"]),
96+
group_size=cls.get_from_keys(config, ["group_size"]),
97+
sym=cls.get_from_keys(config, ["sym"]),
98+
packing_format=cls.get_from_keys_or(config, ["packing_format"],
99+
"auto_round:auto_gptq"),
100+
block_name_to_quantize=cls.get_from_keys_or(
101+
config, ["block_name_to_quantize", "to_quant_block_names"],
102+
None),
103+
extra_config=cls.get_from_keys_or(config, ["extra_config"], None),
104+
data_type=cls.get_from_keys_or(config, ["data_type"], "int"),
105+
backend=cls.get_from_keys_or(config, ["backend", "vllm_backend"],
106+
"auto"),
107+
)
108+
109+
def get_layer_config(self, layer, layer_name: str):
110+
# Priority: extra_config > block_name_to_quantize > type fallback
111+
if self.extra_config and layer_name in self.extra_config:
112+
cfg = self.extra_config[layer_name]
113+
return cfg.get("bits", self.weight_bits), cfg.get(
114+
"group_size", self.group_size), cfg.get("sym", self.sym)
115+
116+
quantized = True
117+
if self.block_name_to_quantize:
118+
quantized = any(name in layer_name
119+
for name in self.block_name_to_quantize)
120+
elif isinstance(layer, ParallelLMHead):
121+
quantized = False
122+
123+
return (self.weight_bits, self.group_size,
124+
self.sym) if quantized else (16, -1, True)
125+
126+
def check_quantized(self, weight_bits: int) -> bool:
127+
return weight_bits < 16
128+
129+
def apply_awq_quant_layer(self, layer, prefix: str, backend: str = "auto"):
130+
from vllm.model_executor.layers.fused_moe import FusedMoE
131+
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
132+
check_marlin_supported, check_moe_marlin_supports_layer)
133+
134+
weight_bits, group_size, sym = self.get_layer_config(layer, prefix)
135+
if not self.check_quantized(weight_bits):
136+
if isinstance(layer, (LinearBase, ParallelLMHead)):
137+
return UnquantizedLinearMethod()
138+
else:
139+
return None
140+
141+
logger.debug("[%s] Type: %s, Bits: %s, Group Size: %s, Sym: %s",
142+
prefix, layer.__class__.__name__, weight_bits, group_size,
143+
sym)
144+
if backend == "auto" or "marlin" in backend:
145+
if isinstance(layer, FusedMoE):
146+
use_marlin = check_moe_marlin_supports_layer(layer, group_size)
147+
else:
148+
149+
AWQ_TYPE_MAP = {
150+
4: scalar_types.uint4,
151+
8: scalar_types.uint8,
152+
}
153+
use_marlin = ((weight_bits, sym) in AWQ_TYPE_MAP
154+
and check_marlin_supported(
155+
AWQ_TYPE_MAP[(weight_bits)], group_size,
156+
not sym))
157+
else:
158+
use_marlin = False
159+
if use_marlin:
160+
from vllm.model_executor.layers.quantization.awq_marlin import (
161+
AWQMarlinConfig, AWQMarlinLinearMethod, AWQMoEMethod)
162+
quant_args_marlin = AWQMarlinConfig(weight_bits=weight_bits,
163+
group_size=group_size,
164+
zero_point=not sym,
165+
lm_head_quantized=False,
166+
full_config={},
167+
modules_to_not_convert=[])
168+
else:
169+
from vllm.model_executor.layers.quantization.awq import (
170+
AWQConfig, AWQLinearMethod)
171+
quant_args = AWQConfig(
172+
weight_bits=weight_bits,
173+
group_size=group_size,
174+
zero_point=not sym,
175+
)
176+
177+
if isinstance(layer, FusedMoE):
178+
if use_marlin:
179+
return AWQMoEMethod(quant_args_marlin)
180+
from vllm.model_executor.layers.quantization.moe_wna16 import (
181+
MoeWNA16Config)
182+
config = {
183+
"linear_quant_method": "awq",
184+
"weight_bits": weight_bits,
185+
"group_size": group_size,
186+
"zero_point": not sym,
187+
}
188+
return MoeWNA16Config.from_config(config).get_quant_method(
189+
layer, prefix)
190+
191+
if isinstance(layer, (LinearBase, ParallelLMHead)):
192+
if use_marlin:
193+
return AWQMarlinLinearMethod(quant_args_marlin)
194+
else:
195+
return AWQLinearMethod(quant_args)
196+
return None
197+
198+
def apply_gptq_quant_layer(self,
199+
layer,
200+
prefix: str,
201+
backend: str = "auto"):
202+
from vllm.model_executor.layers.fused_moe import FusedMoE
203+
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
204+
check_marlin_supported, check_moe_marlin_supports_layer)
205+
weight_bits, group_size, sym = self.get_layer_config(layer, prefix)
206+
if not self.check_quantized(weight_bits):
207+
if isinstance(layer, (LinearBase, ParallelLMHead)):
208+
return UnquantizedLinearMethod()
209+
else:
210+
return None
211+
212+
logger.debug("[%s] Type: %s, Bits: %s, Group Size: %s, Sym: %s",
213+
prefix, layer.__class__.__name__, weight_bits, group_size,
214+
sym)
215+
if backend == "auto" or "marlin" in backend:
216+
if isinstance(layer, FusedMoE):
217+
use_marlin = check_moe_marlin_supports_layer(layer, group_size)
218+
else:
219+
GPTQ_TYPE_MAP = {
220+
(4, True): scalar_types.uint4b8,
221+
(8, True): scalar_types.uint8b128,
222+
}
223+
use_marlin = ((weight_bits, sym) in GPTQ_TYPE_MAP
224+
and check_marlin_supported(
225+
GPTQ_TYPE_MAP[(weight_bits, sym)],
226+
group_size,
227+
has_zp=not sym))
228+
else:
229+
use_marlin = False
230+
if use_marlin:
231+
from vllm.model_executor.layers.quantization.gptq_marlin import (
232+
GPTQMarlinConfig, GPTQMarlinLinearMethod, GPTQMarlinMoEMethod)
233+
quant_args_marlin = GPTQMarlinConfig(weight_bits=weight_bits,
234+
group_size=group_size,
235+
is_sym=sym,
236+
lm_head_quantized=False,
237+
desc_act=False,
238+
dynamic={},
239+
full_config={})
240+
else:
241+
from vllm.model_executor.layers.quantization.gptq import (
242+
GPTQConfig, GPTQLinearMethod)
243+
quant_args = GPTQConfig(weight_bits=weight_bits,
244+
group_size=group_size,
245+
lm_head_quantized=False,
246+
desc_act=False,
247+
dynamic={})
248+
249+
if isinstance(layer, FusedMoE):
250+
if use_marlin:
251+
from vllm.model_executor.layers.quantization.moe_wna16 import (
252+
MoeWNA16Config)
253+
config = {
254+
"linear_quant_method": "gptq",
255+
"weight_bits": weight_bits,
256+
"group_size": group_size,
257+
"sym": sym,
258+
"lm_head_quantized": False,
259+
}
260+
return MoeWNA16Config.from_config(config).get_quant_method(
261+
layer, prefix)
262+
return GPTQMarlinMoEMethod(quant_args_marlin)
263+
264+
if isinstance(layer, (LinearBase, ParallelLMHead)):
265+
if use_marlin:
266+
return GPTQMarlinLinearMethod(quant_args_marlin)
267+
else:
268+
return GPTQLinearMethod(quant_args)
269+
270+
return None
271+
272+
def apply_ipex_quant_layer(self, layer, prefix: str):
273+
weight_bits, group_size, sym = self.get_layer_config(layer, prefix)
274+
if not self.check_quantized(weight_bits):
275+
if isinstance(layer, (LinearBase, ParallelLMHead)):
276+
return UnquantizedLinearMethod()
277+
else:
278+
return None
279+
from vllm.model_executor.layers.quantization.ipex_quant import (
280+
IPEXAWQLinearMethod, IPEXConfig, IPEXGPTQLinearMethod)
281+
if isinstance(layer, (LinearBase, ParallelLMHead)):
282+
if "awq" in self.packing_format:
283+
config = IPEXConfig(method="awq",
284+
weight_bits=weight_bits,
285+
group_size=group_size)
286+
return IPEXAWQLinearMethod(config)
287+
elif "gptq" in self.packing_format:
288+
config = IPEXConfig(method="gptq",
289+
weight_bits=weight_bits,
290+
group_size=group_size)
291+
return IPEXGPTQLinearMethod(config)
292+
else:
293+
raise ValueError(
294+
f"ipex backend only supports awq "
295+
f"and gtpq format,but got {self.packing_format}")
296+
else:
297+
return None
298+
299+
def get_quant_method(self, layer: torch.nn.Module, prefix: str):
300+
if (current_platform.is_cpu() or current_platform.is_xpu()
301+
or self.backend == "ipex"):
302+
return self.apply_ipex_quant_layer(layer, prefix)
303+
if "gptq" in self.packing_format or "gptq" in self.backend:
304+
return self.apply_gptq_quant_layer(layer, prefix)
305+
if "awq" in self.packing_format or "awq" in self.backend:
306+
return self.apply_awq_quant_layer(layer, prefix)

0 commit comments

Comments
 (0)