From d8a6af759aef80c755929407ac4a3f6861ef23ca Mon Sep 17 00:00:00 2001 From: cjackal <44624812+cjackal@users.noreply.github.com> Date: Tue, 24 Jun 2025 13:53:40 +0000 Subject: [PATCH 1/5] add missing token embedding in `DeepSeekMultiTokenPredictor` Signed-off-by: cjackal <44624812+cjackal@users.noreply.github.com> --- vllm/model_executor/models/deepseek_mtp.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index 6e6e74b0d1d9..60fe3f637462 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -112,7 +112,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): for idx in range(self.mtp_start_layer_idx, self.mtp_start_layer_idx + self.num_mtp_layers) }) - + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) self.logits_processor = LogitsProcessor(config.vocab_size) def forward( From e54a85a6d9d3f8219b7aec900817e0b89f810972 Mon Sep 17 00:00:00 2001 From: cjackal <44624812+cjackal@users.noreply.github.com> Date: Tue, 24 Jun 2025 13:56:25 +0000 Subject: [PATCH 2/5] fix `EagleProposer` for deepseek MTP single kv cache group case Signed-off-by: cjackal <44624812+cjackal@users.noreply.github.com> --- vllm/v1/spec_decode/eagle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 153b67fe5714..156f5764e8dc 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -148,7 +148,7 @@ def propose( assert self.runner is not None # FIXME: need to consider multiple kv_cache_groups - attn_metadata = self.runner.attn_metadata_builder.build( + attn_metadata = self.runner.attn_metadata_builders[0].build( common_prefix_len=0, common_attn_metadata=common_attn_metadata, ) From 07e68f4dd8b7b1c94a333550ca2005c5a3093a0d Mon Sep 17 00:00:00 2001 From: cjackal <44624812+cjackal@users.noreply.github.com> Date: Wed, 25 Jun 2025 00:54:52 +0000 Subject: [PATCH 3/5] mv vocab embedding from MTP layers to MTP module Signed-off-by: cjackal <44624812+cjackal@users.noreply.github.com> --- vllm/model_executor/models/deepseek_mtp.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index 60fe3f637462..51de15aecb30 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -52,11 +52,6 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() - self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - ) - self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.eh_proj = nn.Linear(config.hidden_size * 2, @@ -74,8 +69,6 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, spec_step_index: int = 0, ) -> torch.Tensor: - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) assert inputs_embeds is not None # masking inputs at position 0, as not needed by MTP inputs_embeds[positions == 0] = 0 @@ -126,6 +119,8 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, spec_step_idx: int = 0, ) -> torch.Tensor: + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) current_step_idx = (spec_step_idx % self.num_mtp_layers) return self.layers[str(self.mtp_start_layer_idx + current_step_idx)]( input_ids, @@ -258,7 +253,7 @@ def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str: Add .mtp_block for modules in transformer layer block for spec layer """ spec_layer_weight_names = [ - "embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head" + "enorm", "hnorm", "eh_proj", "shared_head" ] spec_layer_weight = False for weight_name in spec_layer_weight_names: From ba01c4ac269928dd1150caa7f1d584b5b2151930 Mon Sep 17 00:00:00 2001 From: cjackal <44624812+cjackal@users.noreply.github.com> Date: Wed, 25 Jun 2025 14:46:40 +0900 Subject: [PATCH 4/5] refactor MTP weight loading logic Signed-off-by: cjackal <44624812+cjackal@users.noreply.github.com> --- vllm/model_executor/models/deepseek_mtp.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index 51de15aecb30..c1bd3fcac048 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -240,6 +240,12 @@ def load_weights(self, weights: Iterable[tuple[str, if name.endswith(".bias") and name not in params_dict: continue + # According to DeepSeek-V3 Technical Report, MTP modules + # shares embedding layer. We only load the first weights. + if (spec_layer != self.model.mtp_start_layer_idx and + ".layers" not in name): + continue + param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) @@ -251,17 +257,25 @@ def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str: """ Rewrite the weight name to match the format of the original model. Add .mtp_block for modules in transformer layer block for spec layer + and rename shared layer weights to be top level. """ spec_layer_weight_names = [ - "enorm", "hnorm", "eh_proj", "shared_head" + "embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head" ] + shared_weight_names = ["embed_tokens"] spec_layer_weight = False + shared_weight = False for weight_name in spec_layer_weight_names: if weight_name in name: spec_layer_weight = True + if weight_name in shared_weight_names: + shared_weight = True break if not spec_layer_weight: # treat rest weights as weights for transformer layer block name = name.replace(f"model.layers.{spec_layer}.", f"model.layers.{spec_layer}.mtp_block.") + elif shared_weight: + # treat shared weights as top level weights + name = name.replace(f"model.layers.{spec_layer}.", "model.") return name From 949f1bc7261f9b9d53094edaf9b95ba9325ff107 Mon Sep 17 00:00:00 2001 From: cjackal <44624812+cjackal@users.noreply.github.com> Date: Wed, 25 Jun 2025 15:00:27 +0900 Subject: [PATCH 5/5] make ruff happy Signed-off-by: cjackal <44624812+cjackal@users.noreply.github.com> --- vllm/model_executor/models/deepseek_mtp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index c1bd3fcac048..911f0036c2dd 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -242,8 +242,8 @@ def load_weights(self, weights: Iterable[tuple[str, # According to DeepSeek-V3 Technical Report, MTP modules # shares embedding layer. We only load the first weights. - if (spec_layer != self.model.mtp_start_layer_idx and - ".layers" not in name): + if (spec_layer != self.model.mtp_start_layer_idx + and ".layers" not in name): continue param = params_dict[name]