Skip to content

Commit 0be96f6

Browse files
authored
[v0.7.3]support MTP in deepseek w8a8 quant model (#502)
### What this PR does / why we need it? Add support MTP in deepseek w8a8 quant model. ### Does this PR introduce _any_ user-facing change? 1. The quantized MTP layer of deepseek on the current NPU msmodelslim is not quantized, So the MTP layer in deepseek w8a8 quantization weight is still in bflaot16 format; 2. The description file generated by the current msmodelslim tool does not have MTP layer information. Please manually add it to `quantization_config` in `config.json` and set the value to `FLOAT`. ### How was this patch tested? local tested Signed-off-by: mengwei805 <mengwei25@huawei.com>
1 parent 5299829 commit 0be96f6

File tree

2 files changed

+185
-0
lines changed

2 files changed

+185
-0
lines changed

vllm_ascend/models/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,7 @@ def register_model():
1313
ModelRegistry.register_model(
1414
"DeepseekV3ForCausalLM",
1515
"vllm_ascend.models.deepseek_v2:CustomDeepseekV3ForCausalLM")
16+
17+
ModelRegistry.register_model(
18+
"DeepSeekMTPModel",
19+
"vllm_ascend.models.deepseek_mtp:CustomDeepSeekMTP")

vllm_ascend/models/deepseek_mtp.py

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# Adapted from vllm/model_executor/models/deepseek_mtp.py
4+
# Copyright 2023 The vLLM team.
5+
#
6+
# This file is a part of the vllm-ascend project.
7+
#
8+
# Licensed under the Apache License, Version 2.0 (the "License");
9+
# you may not use this file except in compliance with the License.
10+
# You may obtain a copy of the License at
11+
#
12+
# http://www.apache.org/licenses/LICENSE-2.0
13+
#
14+
# Unless required by applicable law or agreed to in writing, software
15+
# distributed under the License is distributed on an "AS IS" BASIS,
16+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+
# See the License for the specific language governing permissions and
18+
# limitations under the License.
19+
20+
from typing import List, Optional
21+
22+
import torch
23+
import torch.nn as nn
24+
from transformers import PretrainedConfig
25+
from vllm.attention.backends.abstract import AttentionMetadata
26+
from vllm.config import CacheConfig, ModelConfig, VllmConfig
27+
from vllm.model_executor.layers.layernorm import RMSNorm
28+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
29+
from vllm.model_executor.layers.quantization import QuantizationConfig
30+
from vllm.model_executor.layers.sampler import get_sampler
31+
from vllm.model_executor.layers.vocab_parallel_embedding import \
32+
VocabParallelEmbedding
33+
from vllm.model_executor.models.deepseek_mtp import (
34+
DeepSeekMTP, DeepSeekMultiTokenPredictor, DeepSeekMultiTokenPredictorLayer,
35+
SharedHead)
36+
from vllm.model_executor.models.utils import maybe_prefix
37+
from vllm.model_executor.sampling_metadata import SamplingMetadata
38+
39+
from .deepseek_v2 import CustomDeepseekV2DecoderLayer
40+
41+
42+
class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer):
43+
44+
def __init__(
45+
self,
46+
config: PretrainedConfig,
47+
prefix: str,
48+
model_config: ModelConfig,
49+
cache_config: Optional[CacheConfig] = None,
50+
quant_config: Optional[QuantizationConfig] = None,
51+
) -> None:
52+
nn.Module.__init__(self)
53+
self.embed_tokens = VocabParallelEmbedding(
54+
config.vocab_size,
55+
config.hidden_size,
56+
)
57+
58+
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
59+
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
60+
self.eh_proj = nn.Linear(config.hidden_size * 2,
61+
config.hidden_size,
62+
bias=False)
63+
self.shared_head = SharedHead(config=config, quant_config=quant_config)
64+
self.mtp_block = CustomDeepseekV2DecoderLayer(config, prefix,
65+
model_config,
66+
cache_config,
67+
quant_config)
68+
69+
def forward(
70+
self,
71+
input_ids: torch.Tensor,
72+
positions: torch.Tensor,
73+
kv_cache: torch.Tensor,
74+
attn_metadata: AttentionMetadata,
75+
previous_hidden_states: torch.Tensor,
76+
inputs_embeds: Optional[torch.Tensor] = None,
77+
spec_step_index: int = 0,
78+
) -> torch.Tensor:
79+
if inputs_embeds is None:
80+
inputs_embeds = self.embed_tokens(input_ids)
81+
assert inputs_embeds is not None
82+
# masking inputs at position 0, as not needed by MTP
83+
inputs_embeds = torch.where((positions == 0).unsqueeze(-1),
84+
torch.zeros_like(inputs_embeds),
85+
inputs_embeds)
86+
inputs_embeds = self.enorm(inputs_embeds)
87+
previous_hidden_states = self.hnorm(previous_hidden_states)
88+
89+
hidden_states = self.eh_proj(
90+
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
91+
92+
hidden_states, residual = self.mtp_block(positions=positions,
93+
hidden_states=hidden_states,
94+
kv_cache=kv_cache,
95+
attn_metadata=attn_metadata,
96+
residual=None)
97+
hidden_states = residual + hidden_states
98+
return hidden_states
99+
100+
101+
class CustomDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor):
102+
103+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
104+
nn.Module.__init__(self)
105+
config = vllm_config.model_config.hf_config
106+
self.mtp_start_layer_idx = config.num_hidden_layers
107+
self.num_mtp_layers = config.num_nextn_predict_layers
108+
# to map the exact layer index from weights
109+
self.layers = torch.nn.ModuleDict({
110+
str(idx): CustomDeepSeekMultiTokenPredictorLayer(
111+
config,
112+
f"{prefix}.layers.{idx}",
113+
model_config=vllm_config.model_config,
114+
cache_config=vllm_config.cache_config,
115+
quant_config=vllm_config.quant_config,
116+
)
117+
for idx in range(self.mtp_start_layer_idx,
118+
self.mtp_start_layer_idx + self.num_mtp_layers)
119+
})
120+
121+
# Note: torch._dynamo.exc.Unsupported: builtin: str
122+
self.layers_list = [
123+
self.layers[str(idx)]
124+
for idx in range(self.mtp_start_layer_idx,
125+
self.mtp_start_layer_idx + self.num_mtp_layers)
126+
]
127+
self.logits_processor = LogitsProcessor(config.vocab_size)
128+
129+
def forward(
130+
self,
131+
input_ids: torch.Tensor,
132+
positions: torch.Tensor,
133+
kv_caches: List[torch.Tensor],
134+
attn_metadata: AttentionMetadata,
135+
previous_hidden_states: torch.Tensor,
136+
inputs_embeds: Optional[torch.Tensor] = None,
137+
spec_step_idx: int = 0,
138+
) -> torch.Tensor:
139+
current_step_idx = (spec_step_idx % self.num_mtp_layers)
140+
return self.layers_list[current_step_idx](
141+
input_ids,
142+
positions,
143+
kv_caches[current_step_idx],
144+
attn_metadata,
145+
previous_hidden_states,
146+
inputs_embeds,
147+
current_step_idx,
148+
)
149+
150+
def compute_logits(
151+
self,
152+
hidden_states: torch.Tensor,
153+
sampling_metadata: SamplingMetadata,
154+
spec_step_idx: int = 0,
155+
) -> torch.Tensor:
156+
current_step_idx = (spec_step_idx % self.num_mtp_layers)
157+
mtp_layer = self.layers_list[current_step_idx]
158+
logits = self.logits_processor(mtp_layer.shared_head.head,
159+
mtp_layer.shared_head(hidden_states),
160+
sampling_metadata)
161+
return logits
162+
163+
164+
class CustomDeepSeekMTP(DeepSeekMTP):
165+
# NOTE 1.The quantized MTP layer of deepseek on the NPU is not quantized;
166+
# NOTE 2.The description file generated by the current msmodelslim tool does not have
167+
# MTP layer info. Please manually add it and set the value to FLOAT.
168+
packed_modules_mapping = {
169+
"gate_up_proj": ["gate_proj", "up_proj"],
170+
"experts":
171+
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
172+
}
173+
174+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
175+
nn.Module.__init__(self)
176+
self.config = vllm_config.model_config.hf_config
177+
self.model = CustomDeepSeekMultiTokenPredictor(vllm_config=vllm_config,
178+
prefix=maybe_prefix(
179+
prefix, "model"))
180+
181+
self.sampler = get_sampler()

0 commit comments

Comments
 (0)