Skip to content

Commit a25f1aa

Browse files
authored
Feat Dynamic Quantization for MoE Layers in GPTQ Marlin Backend (vllm-project#19395)
1 parent 76c4877 commit a25f1aa

File tree

1 file changed

+29
-3
lines changed

1 file changed

+29
-3
lines changed

vllm/model_executor/layers/quantization/gptq_marlin.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
from copy import deepcopy
45
from typing import Any, Callable, Optional, Union
56

67
import torch
@@ -9,7 +10,8 @@
910
from vllm import _custom_ops as ops
1011
from vllm.logger import init_logger
1112
from vllm.model_executor.layers.fused_moe.layer import (
12-
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
13+
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
14+
UnquantizedFusedMoEMethod)
1315
from vllm.model_executor.layers.linear import (LinearMethodBase,
1416
set_weight_attrs)
1517
from vllm.model_executor.layers.quantization import QuantizationMethods
@@ -19,7 +21,7 @@
1921
MPLinearLayerConfig, choose_mp_linear_kernel)
2022
from vllm.model_executor.layers.quantization.utils import replace_parameter
2123
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
22-
get_linear_quant_method)
24+
get_dynamic_override, get_linear_quant_method, override_config)
2325
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
2426
check_marlin_supported, check_moe_marlin_supports_layer,
2527
marlin_make_workspace_new, marlin_moe_permute_scales,
@@ -35,6 +37,29 @@
3537
logger = init_logger(__name__)
3638

3739

40+
def get_moe_quant_method(
41+
config: QuantizationConfig,
42+
layer: torch.nn.Module,
43+
prefix: str,
44+
moe_method_cls: type,
45+
):
46+
cloned_config = deepcopy(config)
47+
48+
if isinstance(layer, FusedMoE):
49+
# False = skip module, None = no override, else = Positive match
50+
if get_dynamic_override( # noqa: E712
51+
cloned_config, # noqa: E712
52+
layer_name=prefix) == False: # noqa: E712
53+
return UnquantizedFusedMoEMethod(layer.moe_config)
54+
55+
if prefix:
56+
# Dynamic per module/layer rules may override base config
57+
override_config(cloned_config, prefix=prefix)
58+
59+
return moe_method_cls(cloned_config)
60+
return None
61+
62+
3863
class GPTQMarlinConfig(QuantizationConfig):
3964
"""Config class for GPTQ Marlin"""
4065

@@ -163,7 +188,8 @@ def get_quant_method(self, layer: torch.nn.Module,
163188
"Falling back to Moe WNA16 kernels.")
164189
return MoeWNA16Config.from_config(
165190
self.full_config).get_quant_method(layer, prefix)
166-
return GPTQMarlinMoEMethod(self)
191+
return get_moe_quant_method(self, layer, prefix,
192+
GPTQMarlinMoEMethod)
167193
return get_linear_quant_method(self, layer, prefix,
168194
GPTQMarlinLinearMethod)
169195

0 commit comments

Comments
 (0)