|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | + |
| 3 | +from fractions import Fraction |
| 4 | +from typing import Any, Optional, Union |
| 5 | + |
| 6 | +import torch |
| 7 | + |
| 8 | +from vllm.logger import init_logger |
| 9 | +from vllm.model_executor.layers.linear import (LinearBase, |
| 10 | + UnquantizedLinearMethod) |
| 11 | +from vllm.model_executor.layers.quantization.base_config import ( |
| 12 | + QuantizationConfig) |
| 13 | +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead |
| 14 | +from vllm.platforms import current_platform |
| 15 | +from vllm.scalar_type import scalar_types |
| 16 | + |
| 17 | +logger = init_logger(__name__) |
| 18 | + |
| 19 | + |
| 20 | +class AutoRoundConfig(QuantizationConfig): |
| 21 | + """Config class for AutoRound. |
| 22 | + Reference: https://arxiv.org/pdf/2309.05516 |
| 23 | + """ |
| 24 | + |
| 25 | + SUPPORTED_BITS = {2, 3, 4, 8} |
| 26 | + SUPPORTED_DTYPES = {"int"} |
| 27 | + SUPPORTED_FORMATS = {"auto_round:auto_gptq", "auto_round:auto_awq"} |
| 28 | + SUPPORTED_BACKENDS = { |
| 29 | + "auto", "gptq", "gptq:marlin", "awq", "awq:marlin", "marlin", "ipex" |
| 30 | + } |
| 31 | + |
| 32 | + def __init__( |
| 33 | + self, |
| 34 | + weight_bits: int, |
| 35 | + group_size: int, |
| 36 | + sym: bool = True, |
| 37 | + packing_format: str = "auto_round:auto_gptq", |
| 38 | + block_name_to_quantize: Optional[Union[str, list[str]]] = None, |
| 39 | + extra_config: Optional[dict[str, Any]] = None, |
| 40 | + data_type: str = "int", |
| 41 | + backend: str = "auto", |
| 42 | + ) -> None: |
| 43 | + super().__init__() |
| 44 | + if weight_bits not in self.SUPPORTED_BITS: |
| 45 | + raise ValueError(f"Unsupported weight_bits: {weight_bits}, " |
| 46 | + f"currently only support {self.SUPPORTED_BITS}") |
| 47 | + if data_type not in self.SUPPORTED_DTYPES: |
| 48 | + raise ValueError( |
| 49 | + f"Unsupported data_type: {data_type}," |
| 50 | + f" currently only support {self.SUPPORTED_DTYPES}") |
| 51 | + if packing_format not in self.SUPPORTED_FORMATS: |
| 52 | + raise ValueError( |
| 53 | + f"Unsupported packing_format: {packing_format}, " |
| 54 | + f"currently only support {self.SUPPORTED_FORMATS}") |
| 55 | + if backend not in self.SUPPORTED_BACKENDS: |
| 56 | + raise ValueError( |
| 57 | + f"Unsupported backend: {backend}, " |
| 58 | + f"currently only support {self.SUPPORTED_BACKENDS}") |
| 59 | + |
| 60 | + self.weight_bits = weight_bits |
| 61 | + self.group_size = group_size |
| 62 | + self.sym = sym |
| 63 | + self.packing_format = packing_format |
| 64 | + self.block_name_to_quantize = (block_name_to_quantize.split(",") if |
| 65 | + isinstance(block_name_to_quantize, str) |
| 66 | + else block_name_to_quantize) |
| 67 | + self.extra_config = extra_config |
| 68 | + self.data_type = data_type |
| 69 | + self.backend = backend |
| 70 | + self.pack_factor = Fraction(32, weight_bits) |
| 71 | + |
| 72 | + def __repr__(self) -> str: |
| 73 | + return (f"AutoRoundConfig(weight_bits={self.weight_bits}, " |
| 74 | + f"group_size={self.group_size}, sym={self.sym})") |
| 75 | + |
| 76 | + @classmethod |
| 77 | + def get_name(cls): ## use str will trigger preci issue |
| 78 | + return "auto-round" |
| 79 | + |
| 80 | + @classmethod |
| 81 | + def get_supported_act_dtypes(cls) -> list[torch.dtype]: |
| 82 | + return [torch.half, torch.bfloat16] |
| 83 | + |
| 84 | + @classmethod |
| 85 | + def get_min_capability(cls) -> int: |
| 86 | + return 60 |
| 87 | + |
| 88 | + @classmethod |
| 89 | + def get_config_filenames(cls) -> list[str]: |
| 90 | + return ["quantization_config.json"] |
| 91 | + |
| 92 | + @classmethod |
| 93 | + def from_config(cls, config: dict[str, Any]) -> "AutoRoundConfig": |
| 94 | + return cls( |
| 95 | + weight_bits=cls.get_from_keys(config, ["bits"]), |
| 96 | + group_size=cls.get_from_keys(config, ["group_size"]), |
| 97 | + sym=cls.get_from_keys(config, ["sym"]), |
| 98 | + packing_format=cls.get_from_keys_or(config, ["packing_format"], |
| 99 | + "auto_round:auto_gptq"), |
| 100 | + block_name_to_quantize=cls.get_from_keys_or( |
| 101 | + config, ["block_name_to_quantize", "to_quant_block_names"], |
| 102 | + None), |
| 103 | + extra_config=cls.get_from_keys_or(config, ["extra_config"], None), |
| 104 | + data_type=cls.get_from_keys_or(config, ["data_type"], "int"), |
| 105 | + backend=cls.get_from_keys_or(config, ["backend", "vllm_backend"], |
| 106 | + "auto"), |
| 107 | + ) |
| 108 | + |
| 109 | + def get_layer_config(self, layer, layer_name: str): |
| 110 | + # Priority: extra_config > block_name_to_quantize > type fallback |
| 111 | + if self.extra_config and layer_name in self.extra_config: |
| 112 | + cfg = self.extra_config[layer_name] |
| 113 | + return cfg.get("bits", self.weight_bits), cfg.get( |
| 114 | + "group_size", self.group_size), cfg.get("sym", self.sym) |
| 115 | + |
| 116 | + quantized = True |
| 117 | + if self.block_name_to_quantize: |
| 118 | + quantized = any(name in layer_name |
| 119 | + for name in self.block_name_to_quantize) |
| 120 | + elif isinstance(layer, ParallelLMHead): |
| 121 | + quantized = False |
| 122 | + |
| 123 | + return (self.weight_bits, self.group_size, |
| 124 | + self.sym) if quantized else (16, -1, True) |
| 125 | + |
| 126 | + def check_quantized(self, weight_bits: int) -> bool: |
| 127 | + return weight_bits < 16 |
| 128 | + |
| 129 | + def apply_awq_quant_layer(self, layer, prefix: str, backend: str = "auto"): |
| 130 | + from vllm.model_executor.layers.fused_moe import FusedMoE |
| 131 | + from vllm.model_executor.layers.quantization.utils.marlin_utils import ( |
| 132 | + check_marlin_supported, check_moe_marlin_supports_layer) |
| 133 | + |
| 134 | + weight_bits, group_size, sym = self.get_layer_config(layer, prefix) |
| 135 | + if not self.check_quantized(weight_bits): |
| 136 | + if isinstance(layer, (LinearBase, ParallelLMHead)): |
| 137 | + return UnquantizedLinearMethod() |
| 138 | + else: |
| 139 | + return None |
| 140 | + |
| 141 | + logger.debug("[%s] Type: %s, Bits: %s, Group Size: %s, Sym: %s", |
| 142 | + prefix, layer.__class__.__name__, weight_bits, group_size, |
| 143 | + sym) |
| 144 | + if backend == "auto" or "marlin" in backend: |
| 145 | + if isinstance(layer, FusedMoE): |
| 146 | + use_marlin = check_moe_marlin_supports_layer(layer, group_size) |
| 147 | + else: |
| 148 | + |
| 149 | + AWQ_TYPE_MAP = { |
| 150 | + 4: scalar_types.uint4, |
| 151 | + 8: scalar_types.uint8, |
| 152 | + } |
| 153 | + use_marlin = ((weight_bits, sym) in AWQ_TYPE_MAP |
| 154 | + and check_marlin_supported( |
| 155 | + AWQ_TYPE_MAP[(weight_bits)], group_size, |
| 156 | + not sym)) |
| 157 | + else: |
| 158 | + use_marlin = False |
| 159 | + if use_marlin: |
| 160 | + from vllm.model_executor.layers.quantization.awq_marlin import ( |
| 161 | + AWQMarlinConfig, AWQMarlinLinearMethod, AWQMoEMethod) |
| 162 | + quant_args_marlin = AWQMarlinConfig(weight_bits=weight_bits, |
| 163 | + group_size=group_size, |
| 164 | + zero_point=not sym, |
| 165 | + lm_head_quantized=False, |
| 166 | + full_config={}, |
| 167 | + modules_to_not_convert=[]) |
| 168 | + else: |
| 169 | + from vllm.model_executor.layers.quantization.awq import ( |
| 170 | + AWQConfig, AWQLinearMethod) |
| 171 | + quant_args = AWQConfig( |
| 172 | + weight_bits=weight_bits, |
| 173 | + group_size=group_size, |
| 174 | + zero_point=not sym, |
| 175 | + ) |
| 176 | + |
| 177 | + if isinstance(layer, FusedMoE): |
| 178 | + if use_marlin: |
| 179 | + return AWQMoEMethod(quant_args_marlin) |
| 180 | + from vllm.model_executor.layers.quantization.moe_wna16 import ( |
| 181 | + MoeWNA16Config) |
| 182 | + config = { |
| 183 | + "linear_quant_method": "awq", |
| 184 | + "weight_bits": weight_bits, |
| 185 | + "group_size": group_size, |
| 186 | + "zero_point": not sym, |
| 187 | + } |
| 188 | + return MoeWNA16Config.from_config(config).get_quant_method( |
| 189 | + layer, prefix) |
| 190 | + |
| 191 | + if isinstance(layer, (LinearBase, ParallelLMHead)): |
| 192 | + if use_marlin: |
| 193 | + return AWQMarlinLinearMethod(quant_args_marlin) |
| 194 | + else: |
| 195 | + return AWQLinearMethod(quant_args) |
| 196 | + return None |
| 197 | + |
| 198 | + def apply_gptq_quant_layer(self, |
| 199 | + layer, |
| 200 | + prefix: str, |
| 201 | + backend: str = "auto"): |
| 202 | + from vllm.model_executor.layers.fused_moe import FusedMoE |
| 203 | + from vllm.model_executor.layers.quantization.utils.marlin_utils import ( |
| 204 | + check_marlin_supported, check_moe_marlin_supports_layer) |
| 205 | + weight_bits, group_size, sym = self.get_layer_config(layer, prefix) |
| 206 | + if not self.check_quantized(weight_bits): |
| 207 | + if isinstance(layer, (LinearBase, ParallelLMHead)): |
| 208 | + return UnquantizedLinearMethod() |
| 209 | + else: |
| 210 | + return None |
| 211 | + |
| 212 | + logger.debug("[%s] Type: %s, Bits: %s, Group Size: %s, Sym: %s", |
| 213 | + prefix, layer.__class__.__name__, weight_bits, group_size, |
| 214 | + sym) |
| 215 | + if backend == "auto" or "marlin" in backend: |
| 216 | + if isinstance(layer, FusedMoE): |
| 217 | + use_marlin = check_moe_marlin_supports_layer(layer, group_size) |
| 218 | + else: |
| 219 | + GPTQ_TYPE_MAP = { |
| 220 | + (4, True): scalar_types.uint4b8, |
| 221 | + (8, True): scalar_types.uint8b128, |
| 222 | + } |
| 223 | + use_marlin = ((weight_bits, sym) in GPTQ_TYPE_MAP |
| 224 | + and check_marlin_supported( |
| 225 | + GPTQ_TYPE_MAP[(weight_bits, sym)], |
| 226 | + group_size, |
| 227 | + has_zp=not sym)) |
| 228 | + else: |
| 229 | + use_marlin = False |
| 230 | + if use_marlin: |
| 231 | + from vllm.model_executor.layers.quantization.gptq_marlin import ( |
| 232 | + GPTQMarlinConfig, GPTQMarlinLinearMethod, GPTQMarlinMoEMethod) |
| 233 | + quant_args_marlin = GPTQMarlinConfig(weight_bits=weight_bits, |
| 234 | + group_size=group_size, |
| 235 | + is_sym=sym, |
| 236 | + lm_head_quantized=False, |
| 237 | + desc_act=False, |
| 238 | + dynamic={}, |
| 239 | + full_config={}) |
| 240 | + else: |
| 241 | + from vllm.model_executor.layers.quantization.gptq import ( |
| 242 | + GPTQConfig, GPTQLinearMethod) |
| 243 | + quant_args = GPTQConfig(weight_bits=weight_bits, |
| 244 | + group_size=group_size, |
| 245 | + lm_head_quantized=False, |
| 246 | + desc_act=False, |
| 247 | + dynamic={}) |
| 248 | + |
| 249 | + if isinstance(layer, FusedMoE): |
| 250 | + if use_marlin: |
| 251 | + from vllm.model_executor.layers.quantization.moe_wna16 import ( |
| 252 | + MoeWNA16Config) |
| 253 | + config = { |
| 254 | + "linear_quant_method": "gptq", |
| 255 | + "weight_bits": weight_bits, |
| 256 | + "group_size": group_size, |
| 257 | + "sym": sym, |
| 258 | + "lm_head_quantized": False, |
| 259 | + } |
| 260 | + return MoeWNA16Config.from_config(config).get_quant_method( |
| 261 | + layer, prefix) |
| 262 | + return GPTQMarlinMoEMethod(quant_args_marlin) |
| 263 | + |
| 264 | + if isinstance(layer, (LinearBase, ParallelLMHead)): |
| 265 | + if use_marlin: |
| 266 | + return GPTQMarlinLinearMethod(quant_args_marlin) |
| 267 | + else: |
| 268 | + return GPTQLinearMethod(quant_args) |
| 269 | + |
| 270 | + return None |
| 271 | + |
| 272 | + def apply_ipex_quant_layer(self, layer, prefix: str): |
| 273 | + weight_bits, group_size, sym = self.get_layer_config(layer, prefix) |
| 274 | + if not self.check_quantized(weight_bits): |
| 275 | + if isinstance(layer, (LinearBase, ParallelLMHead)): |
| 276 | + return UnquantizedLinearMethod() |
| 277 | + else: |
| 278 | + return None |
| 279 | + from vllm.model_executor.layers.quantization.ipex_quant import ( |
| 280 | + IPEXAWQLinearMethod, IPEXConfig, IPEXGPTQLinearMethod) |
| 281 | + if isinstance(layer, (LinearBase, ParallelLMHead)): |
| 282 | + if "awq" in self.packing_format: |
| 283 | + config = IPEXConfig(method="awq", |
| 284 | + weight_bits=weight_bits, |
| 285 | + group_size=group_size) |
| 286 | + return IPEXAWQLinearMethod(config) |
| 287 | + elif "gptq" in self.packing_format: |
| 288 | + config = IPEXConfig(method="gptq", |
| 289 | + weight_bits=weight_bits, |
| 290 | + group_size=group_size) |
| 291 | + return IPEXGPTQLinearMethod(config) |
| 292 | + else: |
| 293 | + raise ValueError( |
| 294 | + f"ipex backend only supports awq " |
| 295 | + f"and gtpq format,but got {self.packing_format}") |
| 296 | + else: |
| 297 | + return None |
| 298 | + |
| 299 | + def get_quant_method(self, layer: torch.nn.Module, prefix: str): |
| 300 | + if (current_platform.is_cpu() or current_platform.is_xpu() |
| 301 | + or self.backend == "ipex"): |
| 302 | + return self.apply_ipex_quant_layer(layer, prefix) |
| 303 | + if "gptq" in self.packing_format or "gptq" in self.backend: |
| 304 | + return self.apply_gptq_quant_layer(layer, prefix) |
| 305 | + if "awq" in self.packing_format or "awq" in self.backend: |
| 306 | + return self.apply_awq_quant_layer(layer, prefix) |
0 commit comments