Skip to content

[V1][Speculative Decoding] Fix DeepSeek MTP #20022

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions vllm/model_executor/models/deepseek_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,6 @@ def __init__(
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)

self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.eh_proj = nn.Linear(config.hidden_size * 2,
Expand All @@ -74,8 +69,6 @@ def forward(
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_index: int = 0,
) -> torch.Tensor:
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
assert inputs_embeds is not None
# masking inputs at position 0, as not needed by MTP
inputs_embeds[positions == 0] = 0
Expand Down Expand Up @@ -112,7 +105,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
for idx in range(self.mtp_start_layer_idx,
self.mtp_start_layer_idx + self.num_mtp_layers)
})

self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
Comment on lines +108 to +111
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Adding self.embed_tokens here seems redundant, as it's already defined in the DeepSeekMultiTokenPredictorLayer class. Consider if this is truly necessary or if it can be removed to avoid duplication.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this related to the bug? Does the MTP module have a separate vocab embedding?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The lack of vocab embedding module raises the 'DeepSeekMultiTokenPredictor' object has no attribute 'embed_tokens' attribute error, the first traceback in the linked issue.

Architecture-wise all the vocab embeddings are of the same shape as the target model, but we do need to keep vocab embeddings for each mtp layers if the target model has been trained with multiple mtp layers (not the case for official deepseek R1/V3 families though) and the user launches server with PP > 1. There is a similar condition check in EAGLE weight loading step.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cjackal Thanks for the explanation! Can we use the target model's embedding when PP=1 and only allocate the weights when PP > 1?

Copy link
Contributor Author

@cjackal cjackal Jun 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WoosukKwon Another thought after the discussion: the sole purpose of speculative decoding is leveraging the small draft model to get faster generation speed, so allowing pipeline parallelism on draft model weights is rare and kind of contradictory.

We may simply assume that all the MTP layers are on the same (last) pipeline component and always share the vocab embedding of the MTP layers with that of the target model. NVM, even if MTP module is not split, there is no guarantee that target model's embedding is on the same component. Let me just move the vocab embeddings from DeepSeekMultiTokenPredictorLayer to DeepSeekMultiTokenPredictor to share them among MTP layers and leave the sharing between target and draft embedding to the EAGLE draft model loading stage that I linked before.

Copy link
Collaborator

@WoosukKwon WoosukKwon Jun 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cjackal Got it. Thanks! Could you please re-run the test locally again?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need more refactor on draft weight loading part; let me ping again when ready.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cjackal Got it. Thanks!

self.logits_processor = LogitsProcessor(config.vocab_size)

def forward(
Expand All @@ -123,6 +119,8 @@ def forward(
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
current_step_idx = (spec_step_idx % self.num_mtp_layers)
return self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
input_ids,
Expand Down Expand Up @@ -242,6 +240,12 @@ def load_weights(self, weights: Iterable[tuple[str,
if name.endswith(".bias") and name not in params_dict:
continue

# According to DeepSeek-V3 Technical Report, MTP modules
# shares embedding layer. We only load the first weights.
if (spec_layer != self.model.mtp_start_layer_idx
and ".layers" not in name):
continue

param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
Expand All @@ -253,17 +257,25 @@ def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
"""
Rewrite the weight name to match the format of the original model.
Add .mtp_block for modules in transformer layer block for spec layer
and rename shared layer weights to be top level.
"""
spec_layer_weight_names = [
"embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head"
]
shared_weight_names = ["embed_tokens"]
spec_layer_weight = False
shared_weight = False
for weight_name in spec_layer_weight_names:
if weight_name in name:
spec_layer_weight = True
if weight_name in shared_weight_names:
shared_weight = True
break
if not spec_layer_weight:
# treat rest weights as weights for transformer layer block
name = name.replace(f"model.layers.{spec_layer}.",
f"model.layers.{spec_layer}.mtp_block.")
elif shared_weight:
# treat shared weights as top level weights
name = name.replace(f"model.layers.{spec_layer}.", "model.")
return name
2 changes: 1 addition & 1 deletion vllm/v1/spec_decode/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def propose(
assert self.runner is not None

# FIXME: need to consider multiple kv_cache_groups
attn_metadata = self.runner.attn_metadata_builder.build(
attn_metadata = self.runner.attn_metadata_builders[0].build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
Expand Down