36
36
from vllm .attention import Attention , AttentionMetadata
37
37
from vllm .config import (CacheConfig , ModelConfig , VllmConfig ,
38
38
get_current_vllm_config )
39
- from vllm .distributed import (get_dp_group , get_pp_group ,
39
+ from vllm .distributed import (get_pp_group ,
40
40
get_tensor_model_parallel_world_size ,
41
41
get_tp_group , tensor_model_parallel_all_reduce )
42
42
from vllm .forward_context import get_forward_context
@@ -205,17 +205,16 @@ def __init__(
205
205
)
206
206
CustomDeepseekV2MoE .top_k = config .num_experts_per_tok
207
207
208
- vllm_config = get_current_vllm_config ()
209
- self .dp_size = get_dp_group ().world_size
210
- batch_size = vllm_config .scheduler_config .max_num_seqs
211
-
212
- params_dtype = torch .get_default_dtype ()
213
- self .final_hidden_states = torch .zeros (
214
- [batch_size , config .hidden_size ], dtype = params_dtype , device = "npu" )
208
+ self .params_dtype = torch .get_default_dtype ()
209
+ self .tp_rank_in_group = get_tp_group ().rank_in_group
215
210
self .tp_group = get_tp_group ().device_group
216
211
217
- def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
218
- attn_metadata = get_forward_context ().attn_metadata
212
+ def forward (
213
+ self ,
214
+ hidden_states : torch .Tensor ,
215
+ attn_metadata : Optional [AttentionMetadata ] = None ) -> torch .Tensor :
216
+ if attn_metadata is None :
217
+ attn_metadata = get_forward_context ().attn_metadata
219
218
if attn_metadata is None :
220
219
# for profile run
221
220
is_prefill = True
@@ -224,34 +223,36 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
224
223
num_tokens , hidden_dim = hidden_states .shape
225
224
hidden_states = hidden_states .view (- 1 , hidden_dim )
226
225
226
+ if self .n_shared_experts is not None :
227
+ shared_output = self .shared_experts (hidden_states )
228
+
227
229
if (self .tp_size > 1 and VLLM_ENABLE_MC2 and not is_prefill ):
228
- chunks = torch .chunk (hidden_states ,
229
- get_tp_group ().world_size ,
230
- dim = 0 )
231
- hidden_states = chunks [get_tp_group ().rank_in_group ]
230
+ chunks = torch .chunk (hidden_states , self .tp_size , dim = 0 )
231
+ hidden_states = chunks [self .tp_rank_in_group ]
232
232
233
233
# router_logits: (num_tokens, n_experts)
234
234
router_logits , _ = self .gate (hidden_states )
235
235
236
- final_hidden_states = self .experts (
236
+ hidden_states = self .experts (
237
237
hidden_states = hidden_states ,
238
238
router_logits = router_logits ,
239
239
is_prefill = is_prefill ,
240
240
top_k = CustomDeepseekV2MoE .top_k ) * self .routed_scaling_factor
241
241
242
242
if self .tp_size > 1 :
243
243
if VLLM_ENABLE_MC2 and not is_prefill :
244
- dist .all_gather_into_tensor (self .final_hidden_states ,
245
- final_hidden_states , self .tp_group )
246
- final_hidden_states = self .final_hidden_states
244
+ final_hidden_states = torch .zeros ([num_tokens , hidden_dim ],
245
+ dtype = self .params_dtype ,
246
+ device = "npu" )
247
+ dist .all_gather_into_tensor (final_hidden_states , hidden_states ,
248
+ self .tp_group )
249
+ hidden_states = final_hidden_states
247
250
else :
248
- final_hidden_states = tensor_model_parallel_all_reduce (
249
- final_hidden_states )
251
+ hidden_states = tensor_model_parallel_all_reduce (hidden_states )
250
252
if self .n_shared_experts is not None :
251
- shared_output = self .shared_experts (hidden_states )
252
- final_hidden_states = final_hidden_states + shared_output
253
+ hidden_states = hidden_states + shared_output
253
254
254
- return final_hidden_states .view (num_tokens , hidden_dim )
255
+ return hidden_states .view (num_tokens , hidden_dim )
255
256
256
257
257
258
class CustomDeepseekV2MLAAttention (DeepseekV2MLAAttention ):
@@ -524,7 +525,11 @@ def forward(
524
525
# Fully Connected
525
526
hidden_states , residual = self .post_attention_layernorm (
526
527
hidden_states , residual )
527
- hidden_states = self .mlp (hidden_states )
528
+
529
+ if isinstance (self .mlp , CustomDeepseekV2MoE ):
530
+ hidden_states = self .mlp (hidden_states , attn_metadata )
531
+ else :
532
+ hidden_states = self .mlp (hidden_states )
528
533
529
534
if isinstance (
530
535
self .mlp ,
0 commit comments