diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index 6e6e74b0d1d9..911f0036c2dd 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -52,11 +52,6 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() - self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - ) - self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.eh_proj = nn.Linear(config.hidden_size * 2, @@ -74,8 +69,6 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, spec_step_index: int = 0, ) -> torch.Tensor: - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) assert inputs_embeds is not None # masking inputs at position 0, as not needed by MTP inputs_embeds[positions == 0] = 0 @@ -112,7 +105,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): for idx in range(self.mtp_start_layer_idx, self.mtp_start_layer_idx + self.num_mtp_layers) }) - + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) self.logits_processor = LogitsProcessor(config.vocab_size) def forward( @@ -123,6 +119,8 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, spec_step_idx: int = 0, ) -> torch.Tensor: + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) current_step_idx = (spec_step_idx % self.num_mtp_layers) return self.layers[str(self.mtp_start_layer_idx + current_step_idx)]( input_ids, @@ -242,6 +240,12 @@ def load_weights(self, weights: Iterable[tuple[str, if name.endswith(".bias") and name not in params_dict: continue + # According to DeepSeek-V3 Technical Report, MTP modules + # shares embedding layer. We only load the first weights. + if (spec_layer != self.model.mtp_start_layer_idx + and ".layers" not in name): + continue + param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) @@ -253,17 +257,25 @@ def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str: """ Rewrite the weight name to match the format of the original model. Add .mtp_block for modules in transformer layer block for spec layer + and rename shared layer weights to be top level. """ spec_layer_weight_names = [ "embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head" ] + shared_weight_names = ["embed_tokens"] spec_layer_weight = False + shared_weight = False for weight_name in spec_layer_weight_names: if weight_name in name: spec_layer_weight = True + if weight_name in shared_weight_names: + shared_weight = True break if not spec_layer_weight: # treat rest weights as weights for transformer layer block name = name.replace(f"model.layers.{spec_layer}.", f"model.layers.{spec_layer}.mtp_block.") + elif shared_weight: + # treat shared weights as top level weights + name = name.replace(f"model.layers.{spec_layer}.", "model.") return name diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 153b67fe5714..156f5764e8dc 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -148,7 +148,7 @@ def propose( assert self.runner is not None # FIXME: need to consider multiple kv_cache_groups - attn_metadata = self.runner.attn_metadata_builder.build( + attn_metadata = self.runner.attn_metadata_builders[0].build( common_prefix_len=0, common_attn_metadata=common_attn_metadata, )