6
6
import torch_npu
7
7
from torch import nn
8
8
from transformers import PretrainedConfig
9
- from vllm .compilation .decorators import support_torch_compile
10
9
11
10
from vllm .model_executor .models .qwen3_moe import Qwen3MoeDecoderLayer , Qwen3MoeModel
12
11
from vllm .config import CacheConfig , VllmConfig
22
21
from vllm .model_executor .models .qwen3_moe import Qwen3MoeForCausalLM
23
22
from vllm .model_executor .layers .vocab_parallel_embedding import ParallelLMHead
24
23
from vllm .model_executor .layers .logits_processor import LogitsProcessor
24
+ from vllm .compilation .decorators import support_torch_compile
25
25
26
26
from vllm_ascend .multistream .context import (
27
27
advance_step_multistream_layer_context , get_multistream_comm_context ,
35
35
from vllm_ascend .ops .fused_moe import AscendFusedMoE , select_experts , apply_mlp
36
36
from vllm_ascend .distributed .tensor_parallel import gather_from_sequence_parallel_region
37
37
import vllm_ascend .envs as envs_ascend
38
+ from vllm_ascend .models .qwen3_moe import CustomQwen3MoeForCausalLM
38
39
39
40
VLLM_ASCEND_ENABLE_DBO : bool = envs_ascend .VLLM_ASCEND_ENABLE_DBO
40
41
@@ -197,7 +198,7 @@ def _forward_op_grouped_mlp(
197
198
self , dispatched_input , tokens_per_expert
198
199
):
199
200
return apply_mlp (
200
- [ dispatched_input ] ,
201
+ dispatched_input ,
201
202
self .mlp .experts .w13_weight ,
202
203
self .mlp .experts .w2_weight ,
203
204
tokens_per_expert
@@ -207,8 +208,7 @@ def _forward_combine_comm(
207
208
self , hidden_states , microbatch_id , num_tokens , chunked_hidden_states_sizes
208
209
):
209
210
token_dispatcher = self .mlp .experts .token_dispatchers [microbatch_id ]
210
- token_dispatcher .combine_alltoall ()
211
- final_hidden_states = token_dispatcher .unpermute2 ()
211
+ final_hidden_states , _ = token_dispatcher .token_unpermutation (hidden_states )
212
212
if hasattr (self .mlp , 'routed_scaling_factor' ):
213
213
final_hidden_states = final_hidden_states * self .mlp .routed_scaling_factor
214
214
@@ -267,13 +267,10 @@ def discard_tensor(tensor):
267
267
# communication in the previous layer, and the attn computation of microbatch 2
268
268
# can be overlapped with the attn communication of microbatch 1
269
269
for i in range (num_micro_batchs ):
270
- # wait last layer moe finishing communication
271
- ms_metadata .try_wait_event (layer_index - 1 , i ,
272
- MSEventKey .MOE_AFTER_COMM )
273
-
274
270
forward_context = get_forward_context ()
275
271
layer_index , ms_metadata , attn_metadata = get_multistream_layer_context (
276
272
)
273
+ ms_metadata .try_wait_event (layer_index - 1 , i , MSEventKey .FFN_AR_FINISH )
277
274
forward_context .attn_metadata = attn_metadata [i ]
278
275
279
276
# input layernorm
@@ -309,36 +306,25 @@ def discard_tensor(tensor):
309
306
with torch .npu .stream (dispatch_context .comm_stream ):
310
307
dispatch_context .comm_stream .wait_event (dispatch_context .before_comm_event )
311
308
token_dispatchers [i ].dispatch_alltoall ()
309
+ dispatched_input [i ], tokens_per_expert [i ] = token_dispatchers [i ].permute2 ()
312
310
dispatch_context .after_comm_event .record ()
313
311
314
- if has_shared_expert :
315
- token_dispatchers [i ].cached_shared_expert_output = tensor_model_parallel_all_reduce (
316
- token_dispatchers [i ].cached_shared_expert_output
317
- )
318
- ms_metadata .ms_events [layer_index ][i ][MSEventKey .MOE_SE_COMM_FINISH ].record ()
319
-
320
312
# print_with_sync('begin experts...', torch.distributed.get_rank())
321
313
# block 4 : Router Experts Computation
322
314
# block 5 : Token Combine Communication
323
315
for i in range (num_micro_batchs ):
324
-
325
316
ms_metadata .try_wait_event (layer_index , i , MSEventKey .MOE_AFTER_COMM )
326
317
discard_tensor (hidden_states [i ])
327
-
328
- dispatched_input [i ], tokens_per_expert [i ] = token_dispatchers [i ].permute2 ()
329
318
router_expert_output [i ] = self ._forward_op_grouped_mlp (dispatched_input [i ], tokens_per_expert [i ])
330
319
discard_tensor (dispatched_input [i ])
331
- token_dispatchers [i ].unpermute1 (router_expert_output [i ])
332
- if router_expert_output [i ].shape [0 ] > 0 and token_dispatchers [i ].num_local_experts > 1 :
333
- discard_tensor (router_expert_output [i ])
334
320
335
321
# Launch Combine Comm in a New Stream.
336
322
combine_context = MultiStreamStepMetadata (
337
323
comm_stream = ms_metadata .communicate_stream ,
338
324
before_comm_event = ms_metadata .ms_events [layer_index ][i ][
339
- MSEventKey .MOE_BEFORE_COMM ],
325
+ MSEventKey .FFN_COM_FINISH ],
340
326
after_comm_event = ms_metadata .ms_events [layer_index ][i ][
341
- MSEventKey .MOE_AFTER_COMM ],
327
+ MSEventKey .FFN_AR_FINISH ],
342
328
)
343
329
combine_context .before_comm_event .record ()
344
330
ms_metadata .try_wait_event (layer_index , i , MSEventKey .MOE_SE_COMM_FINISH )
@@ -347,7 +333,7 @@ def discard_tensor(tensor):
347
333
hidden_states [i ] = self ._forward_combine_comm (
348
334
router_expert_output [i ], i , num_tokens [i ], chunked_hidden_states_sizes [i ]
349
335
)
350
- combine_context .after_comm_event . record ()
336
+ ms_metadata . ms_events [ layer_index ][ i ][ MSEventKey . FFN_AR_FINISH ] = combine_context .comm_stream . record_event ()
351
337
352
338
return hidden_states , residual
353
339
@@ -443,11 +429,10 @@ def forward(
443
429
def can_run_ms (self ):
444
430
attn_metadata = get_forward_context ().attn_metadata
445
431
# enable prefill overlap
446
- with_prefill = getattr ( attn_metadata , "with_prefill_across_dp" , False )
432
+ with_prefill = get_forward_context (). with_prefill
447
433
if attn_metadata is None or not with_prefill or not attn_metadata .enable_dbo_across_dp :
448
434
return False
449
- # if torch.distributed.get_rank() == 0:
450
- # print(attn_metadata)
435
+
451
436
return True
452
437
453
438
def _forward_ms_layers (
@@ -465,9 +450,7 @@ def _forward_ms_layers(
465
450
attn_metadata , [positions , hidden_states ,
466
451
residual ] = self .ms_pre_layer (
467
452
[positions , hidden_states , residual ], )
468
- # if torch.distributed.get_rank() == 0:
469
- # print(attn_metadata[0], attn_metadata[1])
470
- # exit()
453
+ num_micro_batch = len (attn_metadata )
471
454
# the rest layers
472
455
for i in range (moe_start_layer , self .end_layer ):
473
456
layer = self .layers [i ]
@@ -481,6 +464,11 @@ def _forward_ms_layers(
481
464
)
482
465
advance_step_multistream_layer_context ()
483
466
467
+ layer_index , ms_metadata , attn_metadata = get_multistream_layer_context ()
468
+ for i in range (num_micro_batch ):
469
+ ms_metadata .try_wait_event (layer_index - 1 , i , MSEventKey .FFN_AR_FINISH )
470
+
471
+
484
472
[hidden_states ,
485
473
residual ] = self .ms_post_layer ([hidden_states , residual ], )
486
474
return hidden_states , residual
@@ -517,17 +505,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
517
505
self .logits_processor = LogitsProcessor (config .vocab_size )
518
506
self .make_empty_intermediate_tensors = (
519
507
self .model .make_empty_intermediate_tensors )
508
+
509
+ def forward (self , * args , ** kwargs ):
510
+ if "graph_enable" in kwargs :
511
+ kwargs .pop ('graph_enable' )
512
+ return super ().forward (* args , ** kwargs )
520
513
521
- def forward (
522
- self ,
523
- input_ids : torch .Tensor ,
524
- positions : torch .Tensor ,
525
- intermediate_tensors : Optional [IntermediateTensors ] = None ,
526
- inputs_embeds : Optional [torch .Tensor ] = None ,
527
- graph_enable : Optional [bool ] = True
528
- ) -> Union [torch .Tensor , IntermediateTensors ]:
529
- hidden_states = self .model (input_ids , positions , intermediate_tensors ,
530
- inputs_embeds )
531
- return hidden_states
532
514
533
515
0 commit comments