|
1 | 1 | # SPDX-License-Identifier: Apache-2.0
|
2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
3 | 3 |
|
| 4 | +from copy import deepcopy |
4 | 5 | from typing import Any, Callable, Optional, Union
|
5 | 6 |
|
6 | 7 | import torch
|
|
9 | 10 | from vllm import _custom_ops as ops
|
10 | 11 | from vllm.logger import init_logger
|
11 | 12 | from vllm.model_executor.layers.fused_moe.layer import (
|
12 |
| - FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) |
| 13 | + FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported, |
| 14 | + UnquantizedFusedMoEMethod) |
13 | 15 | from vllm.model_executor.layers.linear import (LinearMethodBase,
|
14 | 16 | set_weight_attrs)
|
15 | 17 | from vllm.model_executor.layers.quantization import QuantizationMethods
|
|
19 | 21 | MPLinearLayerConfig, choose_mp_linear_kernel)
|
20 | 22 | from vllm.model_executor.layers.quantization.utils import replace_parameter
|
21 | 23 | from vllm.model_executor.layers.quantization.utils.gptq_utils import (
|
22 |
| - get_linear_quant_method) |
| 24 | + get_dynamic_override, get_linear_quant_method, override_config) |
23 | 25 | from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
24 | 26 | check_marlin_supported, check_moe_marlin_supports_layer,
|
25 | 27 | marlin_make_workspace_new, marlin_moe_permute_scales,
|
|
35 | 37 | logger = init_logger(__name__)
|
36 | 38 |
|
37 | 39 |
|
| 40 | +def get_moe_quant_method( |
| 41 | + config: QuantizationConfig, |
| 42 | + layer: torch.nn.Module, |
| 43 | + prefix: str, |
| 44 | + moe_method_cls: type, |
| 45 | +): |
| 46 | + cloned_config = deepcopy(config) |
| 47 | + |
| 48 | + if isinstance(layer, FusedMoE): |
| 49 | + # False = skip module, None = no override, else = Positive match |
| 50 | + if get_dynamic_override( # noqa: E712 |
| 51 | + cloned_config, # noqa: E712 |
| 52 | + layer_name=prefix) == False: # noqa: E712 |
| 53 | + return UnquantizedFusedMoEMethod(layer.moe_config) |
| 54 | + |
| 55 | + if prefix: |
| 56 | + # Dynamic per module/layer rules may override base config |
| 57 | + override_config(cloned_config, prefix=prefix) |
| 58 | + |
| 59 | + return moe_method_cls(cloned_config) |
| 60 | + return None |
| 61 | + |
| 62 | + |
38 | 63 | class GPTQMarlinConfig(QuantizationConfig):
|
39 | 64 | """Config class for GPTQ Marlin"""
|
40 | 65 |
|
@@ -163,7 +188,8 @@ def get_quant_method(self, layer: torch.nn.Module,
|
163 | 188 | "Falling back to Moe WNA16 kernels.")
|
164 | 189 | return MoeWNA16Config.from_config(
|
165 | 190 | self.full_config).get_quant_method(layer, prefix)
|
166 |
| - return GPTQMarlinMoEMethod(self) |
| 191 | + return get_moe_quant_method(self, layer, prefix, |
| 192 | + GPTQMarlinMoEMethod) |
167 | 193 | return get_linear_quant_method(self, layer, prefix,
|
168 | 194 | GPTQMarlinLinearMethod)
|
169 | 195 |
|
|
0 commit comments