Skip to content

Commit 3ea2410

Browse files
authored
[Feature] Enable inference support for Deepseekr1-w8a8-MTP (#1584)
### What this PR does / why we need it? 1. Support the inference of the Deepseekr1-w8a8-mtp model with statically-quantized shared_head in MTP layers. Signed-off-by: curryliu <120010041@link.cuhk.edu.cn> --------- Signed-off-by: curryliu <120010041@link.cuhk.edu.cn>
1 parent 77ff27b commit 3ea2410

File tree

3 files changed

+48
-4
lines changed

3 files changed

+48
-4
lines changed

vllm_ascend/models/deepseek_mtp.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828
from vllm.model_executor.layers.logits_processor import LogitsProcessor
2929
from vllm.model_executor.layers.quantization import QuantizationConfig
3030
from vllm.model_executor.layers.sampler import get_sampler
31-
from vllm.model_executor.layers.vocab_parallel_embedding import \
32-
VocabParallelEmbedding
31+
from vllm.model_executor.layers.vocab_parallel_embedding import (
32+
ParallelLMHead, VocabParallelEmbedding)
3333
from vllm.model_executor.models.deepseek_mtp import (
3434
DeepSeekMTP, DeepSeekMultiTokenPredictor, DeepSeekMultiTokenPredictorLayer,
3535
SharedHead)
@@ -40,6 +40,20 @@
4040
from .deepseek_v2 import CustomDeepseekV2DecoderLayer
4141

4242

43+
class CustomDeepSeekShareHead(SharedHead):
44+
45+
def __init__(self,
46+
config: PretrainedConfig,
47+
quant_config: Optional[QuantizationConfig] = None,
48+
prefix: str = "") -> None:
49+
nn.Module.__init__(self)
50+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
51+
self.head = ParallelLMHead(config.vocab_size,
52+
config.hidden_size,
53+
quant_config=quant_config,
54+
prefix=maybe_prefix(prefix, "head"))
55+
56+
4357
class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer):
4458

4559
def __init__(
@@ -61,7 +75,10 @@ def __init__(
6175
self.eh_proj = nn.Linear(config.hidden_size * 2,
6276
config.hidden_size,
6377
bias=False)
64-
self.shared_head = SharedHead(config=config, quant_config=quant_config)
78+
self.shared_head = CustomDeepSeekShareHead(config=config,
79+
quant_config=quant_config,
80+
prefix=maybe_prefix(
81+
prefix, "shared_head"))
6582
self.mtp_block = CustomDeepseekV2DecoderLayer(config, prefix,
6683
model_config,
6784
cache_config,

vllm_ascend/models/deepseek_v2.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -738,7 +738,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
738738
if get_pp_group().is_last_rank:
739739
self.lm_head = ParallelLMHead(config.vocab_size,
740740
config.hidden_size,
741-
quant_config=quant_config)
741+
quant_config=quant_config,
742+
prefix=maybe_prefix(
743+
prefix, "lm_head"))
742744
else:
743745
self.lm_head = PPMissingLayer()
744746
self.logits_processor = LogitsProcessor(config.vocab_size)

vllm_ascend/quantization/quant_config.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
from vllm.model_executor.layers.quantization.base_config import (
3535
QuantizationConfig, QuantizeMethodBase)
3636
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
37+
from vllm.model_executor.layers.vocab_parallel_embedding import (
38+
UnquantizedEmbeddingMethod, VocabParallelEmbedding)
3739
from vllm.model_executor.parameter import PerTensorScaleParameter
3840
from vllm.model_executor.utils import set_weight_attrs
3941

@@ -104,6 +106,12 @@ def get_quant_method(self, layer: torch.nn.Module,
104106
return AscendUnquantizedFusedMoEMethod()
105107
return AscendFusedMoEMethod(self, prefix,
106108
self.packed_modules_mapping)
109+
elif isinstance(layer, VocabParallelEmbedding):
110+
if self.is_layer_skipped_ascend(prefix,
111+
self.packed_modules_mapping):
112+
return UnquantizedEmbeddingMethod()
113+
return AscendEmbeddingMethod(self, prefix,
114+
self.packed_modules_mapping)
107115
return None
108116

109117
def is_layer_skipped_ascend(
@@ -352,3 +360,20 @@ def apply(
352360
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
353361
if hasattr(self.quant_method, "process_weights_after_loading"):
354362
self.quant_method.process_weights_after_loading(layer)
363+
364+
365+
class AscendEmbeddingMethod(AscendLinearMethod):
366+
"""Embedding method for Ascend quantization.
367+
368+
This class calls AscendQuantizer to search a specific quantization
369+
implementations supported on ascend hardware for Embedding methods.
370+
371+
Args:
372+
quant_config: The Ascend quantization config.
373+
"""
374+
375+
def __init__(self, quant_config: AscendQuantConfig, prefix: str,
376+
packed_modules_mapping: Dict[str, Any]) -> None:
377+
self.quantizer = AscendQuantizer.get_quantizer(
378+
quant_config.quant_description, prefix, packed_modules_mapping)
379+
self.quant_method = self.quantizer.build_linear_method()

0 commit comments

Comments
 (0)