Skip to content

Commit 8359f4c

Browse files
authored
[V1][Speculative Decoding] Fix DeepSeek MTP (#20022)
Signed-off-by: cjackal <44624812+cjackal@users.noreply.github.com>
1 parent bf51815 commit 8359f4c

File tree

2 files changed

+21
-9
lines changed

2 files changed

+21
-9
lines changed

vllm/model_executor/models/deepseek_mtp.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,6 @@ def __init__(
5252
quant_config: Optional[QuantizationConfig] = None,
5353
) -> None:
5454
super().__init__()
55-
self.embed_tokens = VocabParallelEmbedding(
56-
config.vocab_size,
57-
config.hidden_size,
58-
)
59-
6055
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
6156
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
6257
self.eh_proj = nn.Linear(config.hidden_size * 2,
@@ -74,8 +69,6 @@ def forward(
7469
inputs_embeds: Optional[torch.Tensor] = None,
7570
spec_step_index: int = 0,
7671
) -> torch.Tensor:
77-
if inputs_embeds is None:
78-
inputs_embeds = self.embed_tokens(input_ids)
7972
assert inputs_embeds is not None
8073
# masking inputs at position 0, as not needed by MTP
8174
inputs_embeds[positions == 0] = 0
@@ -112,7 +105,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
112105
for idx in range(self.mtp_start_layer_idx,
113106
self.mtp_start_layer_idx + self.num_mtp_layers)
114107
})
115-
108+
self.embed_tokens = VocabParallelEmbedding(
109+
config.vocab_size,
110+
config.hidden_size,
111+
)
116112
self.logits_processor = LogitsProcessor(config.vocab_size)
117113

118114
def forward(
@@ -123,6 +119,8 @@ def forward(
123119
inputs_embeds: Optional[torch.Tensor] = None,
124120
spec_step_idx: int = 0,
125121
) -> torch.Tensor:
122+
if inputs_embeds is None:
123+
inputs_embeds = self.embed_tokens(input_ids)
126124
current_step_idx = (spec_step_idx % self.num_mtp_layers)
127125
return self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
128126
input_ids,
@@ -242,6 +240,12 @@ def load_weights(self, weights: Iterable[tuple[str,
242240
if name.endswith(".bias") and name not in params_dict:
243241
continue
244242

243+
# According to DeepSeek-V3 Technical Report, MTP modules
244+
# shares embedding layer. We only load the first weights.
245+
if (spec_layer != self.model.mtp_start_layer_idx
246+
and ".layers" not in name):
247+
continue
248+
245249
param = params_dict[name]
246250
weight_loader = getattr(param, "weight_loader",
247251
default_weight_loader)
@@ -253,17 +257,25 @@ def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
253257
"""
254258
Rewrite the weight name to match the format of the original model.
255259
Add .mtp_block for modules in transformer layer block for spec layer
260+
and rename shared layer weights to be top level.
256261
"""
257262
spec_layer_weight_names = [
258263
"embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head"
259264
]
265+
shared_weight_names = ["embed_tokens"]
260266
spec_layer_weight = False
267+
shared_weight = False
261268
for weight_name in spec_layer_weight_names:
262269
if weight_name in name:
263270
spec_layer_weight = True
271+
if weight_name in shared_weight_names:
272+
shared_weight = True
264273
break
265274
if not spec_layer_weight:
266275
# treat rest weights as weights for transformer layer block
267276
name = name.replace(f"model.layers.{spec_layer}.",
268277
f"model.layers.{spec_layer}.mtp_block.")
278+
elif shared_weight:
279+
# treat shared weights as top level weights
280+
name = name.replace(f"model.layers.{spec_layer}.", "model.")
269281
return name

vllm/v1/spec_decode/eagle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def propose(
148148
assert self.runner is not None
149149

150150
# FIXME: need to consider multiple kv_cache_groups
151-
attn_metadata = self.runner.attn_metadata_builder.build(
151+
attn_metadata = self.runner.attn_metadata_builders[0].build(
152152
common_prefix_len=0,
153153
common_attn_metadata=common_attn_metadata,
154154
)

0 commit comments

Comments
 (0)