|
| 1 | +# |
| 2 | +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. |
| 3 | +# Adapted from vllm/model_executor/models/deepseek_mtp.py |
| 4 | +# Copyright 2023 The vLLM team. |
| 5 | +# |
| 6 | +# This file is a part of the vllm-ascend project. |
| 7 | +# |
| 8 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 9 | +# you may not use this file except in compliance with the License. |
| 10 | +# You may obtain a copy of the License at |
| 11 | +# |
| 12 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 13 | +# |
| 14 | +# Unless required by applicable law or agreed to in writing, software |
| 15 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 16 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 17 | +# See the License for the specific language governing permissions and |
| 18 | +# limitations under the License. |
| 19 | + |
| 20 | +from typing import List, Optional |
| 21 | + |
| 22 | +import torch |
| 23 | +import torch.nn as nn |
| 24 | +from transformers import PretrainedConfig |
| 25 | +from vllm.attention.backends.abstract import AttentionMetadata |
| 26 | +from vllm.config import CacheConfig, ModelConfig, VllmConfig |
| 27 | +from vllm.model_executor.layers.layernorm import RMSNorm |
| 28 | +from vllm.model_executor.layers.logits_processor import LogitsProcessor |
| 29 | +from vllm.model_executor.layers.quantization import QuantizationConfig |
| 30 | +from vllm.model_executor.layers.sampler import get_sampler |
| 31 | +from vllm.model_executor.layers.vocab_parallel_embedding import \ |
| 32 | + VocabParallelEmbedding |
| 33 | +from vllm.model_executor.models.deepseek_mtp import ( |
| 34 | + DeepSeekMTP, DeepSeekMultiTokenPredictor, DeepSeekMultiTokenPredictorLayer, |
| 35 | + SharedHead) |
| 36 | +from vllm.model_executor.models.utils import maybe_prefix |
| 37 | +from vllm.model_executor.sampling_metadata import SamplingMetadata |
| 38 | + |
| 39 | +from .deepseek_v2 import CustomDeepseekV2DecoderLayer |
| 40 | + |
| 41 | + |
| 42 | +class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer): |
| 43 | + |
| 44 | + def __init__( |
| 45 | + self, |
| 46 | + config: PretrainedConfig, |
| 47 | + prefix: str, |
| 48 | + model_config: ModelConfig, |
| 49 | + cache_config: Optional[CacheConfig] = None, |
| 50 | + quant_config: Optional[QuantizationConfig] = None, |
| 51 | + ) -> None: |
| 52 | + nn.Module.__init__(self) |
| 53 | + self.embed_tokens = VocabParallelEmbedding( |
| 54 | + config.vocab_size, |
| 55 | + config.hidden_size, |
| 56 | + ) |
| 57 | + |
| 58 | + self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| 59 | + self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| 60 | + self.eh_proj = nn.Linear(config.hidden_size * 2, |
| 61 | + config.hidden_size, |
| 62 | + bias=False) |
| 63 | + self.shared_head = SharedHead(config=config, quant_config=quant_config) |
| 64 | + self.mtp_block = CustomDeepseekV2DecoderLayer(config, prefix, |
| 65 | + model_config, |
| 66 | + cache_config, |
| 67 | + quant_config) |
| 68 | + |
| 69 | + def forward( |
| 70 | + self, |
| 71 | + input_ids: torch.Tensor, |
| 72 | + positions: torch.Tensor, |
| 73 | + kv_cache: torch.Tensor, |
| 74 | + attn_metadata: AttentionMetadata, |
| 75 | + previous_hidden_states: torch.Tensor, |
| 76 | + inputs_embeds: Optional[torch.Tensor] = None, |
| 77 | + spec_step_index: int = 0, |
| 78 | + ) -> torch.Tensor: |
| 79 | + if inputs_embeds is None: |
| 80 | + inputs_embeds = self.embed_tokens(input_ids) |
| 81 | + assert inputs_embeds is not None |
| 82 | + # masking inputs at position 0, as not needed by MTP |
| 83 | + inputs_embeds = torch.where((positions == 0).unsqueeze(-1), |
| 84 | + torch.zeros_like(inputs_embeds), |
| 85 | + inputs_embeds) |
| 86 | + inputs_embeds = self.enorm(inputs_embeds) |
| 87 | + previous_hidden_states = self.hnorm(previous_hidden_states) |
| 88 | + |
| 89 | + hidden_states = self.eh_proj( |
| 90 | + torch.cat([inputs_embeds, previous_hidden_states], dim=-1)) |
| 91 | + |
| 92 | + hidden_states, residual = self.mtp_block(positions=positions, |
| 93 | + hidden_states=hidden_states, |
| 94 | + kv_cache=kv_cache, |
| 95 | + attn_metadata=attn_metadata, |
| 96 | + residual=None) |
| 97 | + hidden_states = residual + hidden_states |
| 98 | + return hidden_states |
| 99 | + |
| 100 | + |
| 101 | +class CustomDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor): |
| 102 | + |
| 103 | + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
| 104 | + nn.Module.__init__(self) |
| 105 | + config = vllm_config.model_config.hf_config |
| 106 | + self.mtp_start_layer_idx = config.num_hidden_layers |
| 107 | + self.num_mtp_layers = config.num_nextn_predict_layers |
| 108 | + # to map the exact layer index from weights |
| 109 | + self.layers = torch.nn.ModuleDict({ |
| 110 | + str(idx): CustomDeepSeekMultiTokenPredictorLayer( |
| 111 | + config, |
| 112 | + f"{prefix}.layers.{idx}", |
| 113 | + model_config=vllm_config.model_config, |
| 114 | + cache_config=vllm_config.cache_config, |
| 115 | + quant_config=vllm_config.quant_config, |
| 116 | + ) |
| 117 | + for idx in range(self.mtp_start_layer_idx, |
| 118 | + self.mtp_start_layer_idx + self.num_mtp_layers) |
| 119 | + }) |
| 120 | + |
| 121 | + # Note: torch._dynamo.exc.Unsupported: builtin: str |
| 122 | + self.layers_list = [ |
| 123 | + self.layers[str(idx)] |
| 124 | + for idx in range(self.mtp_start_layer_idx, |
| 125 | + self.mtp_start_layer_idx + self.num_mtp_layers) |
| 126 | + ] |
| 127 | + self.logits_processor = LogitsProcessor(config.vocab_size) |
| 128 | + |
| 129 | + def forward( |
| 130 | + self, |
| 131 | + input_ids: torch.Tensor, |
| 132 | + positions: torch.Tensor, |
| 133 | + kv_caches: List[torch.Tensor], |
| 134 | + attn_metadata: AttentionMetadata, |
| 135 | + previous_hidden_states: torch.Tensor, |
| 136 | + inputs_embeds: Optional[torch.Tensor] = None, |
| 137 | + spec_step_idx: int = 0, |
| 138 | + ) -> torch.Tensor: |
| 139 | + current_step_idx = (spec_step_idx % self.num_mtp_layers) |
| 140 | + return self.layers_list[current_step_idx]( |
| 141 | + input_ids, |
| 142 | + positions, |
| 143 | + kv_caches[current_step_idx], |
| 144 | + attn_metadata, |
| 145 | + previous_hidden_states, |
| 146 | + inputs_embeds, |
| 147 | + current_step_idx, |
| 148 | + ) |
| 149 | + |
| 150 | + def compute_logits( |
| 151 | + self, |
| 152 | + hidden_states: torch.Tensor, |
| 153 | + sampling_metadata: SamplingMetadata, |
| 154 | + spec_step_idx: int = 0, |
| 155 | + ) -> torch.Tensor: |
| 156 | + current_step_idx = (spec_step_idx % self.num_mtp_layers) |
| 157 | + mtp_layer = self.layers_list[current_step_idx] |
| 158 | + logits = self.logits_processor(mtp_layer.shared_head.head, |
| 159 | + mtp_layer.shared_head(hidden_states), |
| 160 | + sampling_metadata) |
| 161 | + return logits |
| 162 | + |
| 163 | + |
| 164 | +class CustomDeepSeekMTP(DeepSeekMTP): |
| 165 | + # NOTE 1.The quantized MTP layer of deepseek on the NPU is not quantized; |
| 166 | + # NOTE 2.The description file generated by the current msmodelslim tool does not have |
| 167 | + # MTP layer info. Please manually add it and set the value to FLOAT. |
| 168 | + packed_modules_mapping = { |
| 169 | + "gate_up_proj": ["gate_proj", "up_proj"], |
| 170 | + "experts": |
| 171 | + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"] |
| 172 | + } |
| 173 | + |
| 174 | + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
| 175 | + nn.Module.__init__(self) |
| 176 | + self.config = vllm_config.model_config.hf_config |
| 177 | + self.model = CustomDeepSeekMultiTokenPredictor(vllm_config=vllm_config, |
| 178 | + prefix=maybe_prefix( |
| 179 | + prefix, "model")) |
| 180 | + |
| 181 | + self.sampler = get_sampler() |
0 commit comments