24
24
# See the License for the specific language governing permissions and
25
25
# limitations under the License.
26
26
"""Inference-only DeepseekV2/DeepseekV3 model."""
27
- from typing import List , Optional , Union
27
+ from typing import Optional , Union
28
28
29
29
import torch
30
30
from torch import nn
31
31
from transformers import PretrainedConfig
32
- from vllm .attention import AttentionMetadata
33
32
from vllm .config import CacheConfig , ModelConfig , VllmConfig
34
33
from vllm .distributed import get_pp_group , get_tensor_model_parallel_world_size
35
34
from vllm .model_executor .layers .fused_moe import FusedMoE
@@ -61,11 +60,6 @@ def __init__(
61
60
self .tp_size = get_tensor_model_parallel_world_size ()
62
61
self .routed_scaling_factor = config .routed_scaling_factor
63
62
self .n_shared_experts = config .n_shared_experts
64
- self .routed_scaling_factor = config .routed_scaling_factor
65
- if self .tp_size > config .n_routed_experts :
66
- raise ValueError (
67
- f"Tensor parallel size { self .tp_size } is greater than "
68
- f"the number of experts { config .n_routed_experts } ." )
69
63
70
64
if config .hidden_act != "silu" :
71
65
raise ValueError (f"Unsupported activation: { config .hidden_act } . "
@@ -129,6 +123,7 @@ def __init__(
129
123
# DecoderLayers are created with `make_layers` which passes the prefix
130
124
# with the layer's index.
131
125
layer_idx = int (prefix .split (sep = '.' )[- 1 ])
126
+ self .layer_idx = layer_idx
132
127
if model_config .use_mla :
133
128
attn_cls = DeepseekV2MLAAttention
134
129
else :
@@ -171,6 +166,7 @@ def __init__(
171
166
eps = config .rms_norm_eps )
172
167
self .post_attention_layernorm = RMSNorm (config .hidden_size ,
173
168
eps = config .rms_norm_eps )
169
+ self .routed_scaling_factor = config .routed_scaling_factor
174
170
175
171
176
172
class CustomDeepseekV2Model (nn .Module ):
@@ -184,8 +180,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
184
180
model_config = vllm_config .model_config
185
181
cache_config = vllm_config .cache_config
186
182
quant_config = vllm_config .quant_config
183
+ self .config = config
187
184
188
- self .padding_idx = config .pad_token_id
189
185
self .vocab_size = config .vocab_size
190
186
191
187
if get_pp_group ().is_first_rank :
@@ -223,8 +219,6 @@ def forward(
223
219
self ,
224
220
input_ids : torch .Tensor ,
225
221
positions : torch .Tensor ,
226
- kv_caches : List [torch .Tensor ],
227
- attn_metadata : AttentionMetadata ,
228
222
intermediate_tensors : Optional [IntermediateTensors ],
229
223
inputs_embeds : Optional [torch .Tensor ] = None ,
230
224
) -> Union [torch .Tensor , IntermediateTensors ]:
@@ -239,11 +233,8 @@ def forward(
239
233
hidden_states = intermediate_tensors ["hidden_states" ]
240
234
residual = intermediate_tensors ["residual" ]
241
235
242
- for i in range (self .start_layer , self .end_layer ):
243
- layer = self .layers [i ]
244
- hidden_states , residual = layer (positions , hidden_states ,
245
- kv_caches [i - self .start_layer ],
246
- attn_metadata , residual )
236
+ for layer in self .layers [self .start_layer :self .end_layer ]:
237
+ hidden_states , residual = layer (positions , hidden_states , residual )
247
238
248
239
if not get_pp_group ().is_last_rank :
249
240
return IntermediateTensors ({
@@ -272,9 +263,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
272
263
self .model = CustomDeepseekV2Model (vllm_config = vllm_config ,
273
264
prefix = maybe_prefix (
274
265
prefix , "model" ))
275
- self .lm_head = ParallelLMHead (config .vocab_size ,
276
- config .hidden_size ,
277
- quant_config = quant_config )
266
+ if get_pp_group ().is_last_rank :
267
+ self .lm_head = ParallelLMHead (config .vocab_size ,
268
+ config .hidden_size ,
269
+ quant_config = quant_config )
270
+ else :
271
+ self .lm_head = PPMissingLayer ()
278
272
self .logits_processor = LogitsProcessor (config .vocab_size )
279
273
self .sampler = get_sampler ()
280
274
self .make_empty_intermediate_tensors = (
0 commit comments