33
33
34
34
import vllm_ascend .envs as envs_ascend
35
35
from vllm_ascend .distributed .parallel_state import get_ep_group , get_etp_group
36
+ from vllm_ascend .utils import npu_stream_switch , npu_wait_tensor
36
37
37
38
VLLM_ENABLE_MC2 : bool = envs_ascend .VLLM_ENABLE_MC2
38
39
USING_LCCL_COM : bool = envs_ascend .USING_LCCL_COM
@@ -47,6 +48,8 @@ def fused_experts_with_mc2(
47
48
top_k : int ,
48
49
expert_map : torch .Tensor = None ,
49
50
moe_all_to_all_group_name : Optional [str ] = None ,
51
+ shared_experts : Optional [torch .nn .Module ] = None ,
52
+ graph_mode : bool = False ,
50
53
) -> torch .Tensor :
51
54
global_bs = 0
52
55
moe_expert_num = len (expert_map )
@@ -88,6 +91,10 @@ def fused_experts_with_mc2(
88
91
expand_x , dynamic_scale , expand_idx , expert_token_nums , ep_recv_counts = output [
89
92
0 :5 ]
90
93
94
+ if shared_experts is not None :
95
+ with npu_stream_switch ("expert_secondary" , 0 , enabled = graph_mode ):
96
+ shared_gate_up , _ = shared_experts .gate_up_proj (hidden_states )
97
+
91
98
w1 = w1 .transpose (1 , 2 )
92
99
expert_token_nums = torch .cumsum (expert_token_nums ,
93
100
dim = 0 ,
@@ -102,6 +109,11 @@ def fused_experts_with_mc2(
102
109
group_list = group_list ,
103
110
)
104
111
112
+ if shared_experts is not None :
113
+ with npu_stream_switch ("expert_secondary" , 0 , enabled = graph_mode ):
114
+ npu_wait_tensor (shared_gate_up , expand_x , enabled = graph_mode )
115
+ shared_act = shared_experts .act_fn (shared_gate_up )
116
+
105
117
# TODO: Remove this in the future.
106
118
gate_up_out = torch .cat (gate_up_out_list , dim = 0 )
107
119
gate_up_out = torch_npu .npu_swiglu (gate_up_out )
@@ -145,7 +157,15 @@ def fused_experts_with_mc2(
145
157
146
158
hidden_states = torch_npu .npu_moe_distribute_combine (** kwargs )
147
159
148
- return hidden_states
160
+ if shared_experts is not None :
161
+ with npu_stream_switch ("expert_secondary" , 0 , enabled = graph_mode ):
162
+ npu_wait_tensor (shared_act , down_out_list , enabled = graph_mode )
163
+ shared_hidden_states , _ = shared_experts .down_proj (shared_act )
164
+
165
+ if shared_experts is None :
166
+ return hidden_states
167
+ else :
168
+ return hidden_states , shared_hidden_states
149
169
150
170
151
171
# currently expert parallelism implemented with all2all
@@ -587,6 +607,8 @@ def __init__(self, moe: MoEConfig = None):
587
607
self .ep_size = ep_group .world_size
588
608
self .global_batch_size = vllm_config .scheduler_config .max_num_seqs
589
609
self .local_batch_size = self .global_batch_size // self .ep_size
610
+ self .graph_mode = vllm_config .get ("additional_config" ,
611
+ {}).get ("enable_graph_mode" , False )
590
612
591
613
try :
592
614
device_group = ep_group .device_group
@@ -624,6 +646,7 @@ def apply(
624
646
scoring_func : str = "softmax" ,
625
647
e_score_correction_bias : Optional [torch .Tensor ] = None ,
626
648
is_prefill : bool = False ,
649
+ shared_experts : Optional [torch .nn .Module ] = None ,
627
650
** kwargs ,
628
651
):
629
652
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
@@ -664,28 +687,37 @@ def apply(
664
687
topk_ids = topk_ids ,
665
688
top_k = top_k ,
666
689
expert_map = expert_map ,
667
- moe_all_to_all_group_name = self .moe_all_to_all_group_name )
690
+ moe_all_to_all_group_name = self .moe_all_to_all_group_name ,
691
+ shared_experts = shared_experts ,
692
+ graph_mode = self .graph_mode ,
693
+ )
668
694
elif get_ep_group ().world_size == 1 :
669
- return fused_experts (hidden_states = x ,
670
- w1 = layer .w13_weight ,
671
- w2 = layer .w2_weight ,
672
- topk_weights = topk_weights ,
673
- topk_ids = topk_ids ,
674
- top_k = top_k ,
675
- expert_map = expert_map )
695
+ router_hidden_states = fused_experts (hidden_states = x ,
696
+ w1 = layer .w13_weight ,
697
+ w2 = layer .w2_weight ,
698
+ topk_weights = topk_weights ,
699
+ topk_ids = topk_ids ,
700
+ top_k = top_k ,
701
+ expert_map = expert_map )
676
702
else :
677
703
# The current implementation of deepseek moe splits hidden_states
678
704
# according to tp_size before they are feed into fused_moe module.
679
705
# Therefore, all2all is needed no matter how dp/tp is set so as to
680
706
# dispatch/combine tokens.
681
- return fused_experts_with_all2all (hidden_states = x ,
682
- w1 = layer .w13_weight ,
683
- w2 = layer .w2_weight ,
684
- topk_weights = topk_weights ,
685
- topk_ids = topk_ids ,
686
- top_k = top_k ,
687
- expert_map = expert_map ,
688
- ep_group = get_ep_group ())
707
+ router_hidden_states = fused_experts_with_all2all (
708
+ hidden_states = x ,
709
+ w1 = layer .w13_weight ,
710
+ w2 = layer .w2_weight ,
711
+ topk_weights = topk_weights ,
712
+ topk_ids = topk_ids ,
713
+ top_k = top_k ,
714
+ expert_map = expert_map ,
715
+ ep_group = get_ep_group ())
716
+
717
+ if shared_experts is None :
718
+ return router_hidden_states
719
+ else :
720
+ return router_hidden_states , shared_experts (x )
689
721
690
722
691
723
class AscendFusedMoE (FusedMoE ):
@@ -815,7 +847,8 @@ def forward(self,
815
847
router_logits : torch .Tensor ,
816
848
is_prefill : bool ,
817
849
enable_force_load_balance : bool = False ,
818
- top_k = None ):
850
+ top_k : Optional [int ] = None ,
851
+ shared_experts : Optional [torch .nn .Module ] = None ):
819
852
assert self .quant_method is not None
820
853
821
854
if top_k :
@@ -842,7 +875,9 @@ def forward(self,
842
875
scoring_func = self .scoring_func ,
843
876
e_score_correction_bias = self .e_score_correction_bias ,
844
877
is_prefill = is_prefill ,
845
- enable_force_load_balance = enable_force_load_balance )
878
+ enable_force_load_balance = enable_force_load_balance ,
879
+ shared_experts = shared_experts ,
880
+ )
846
881
847
882
if VLLM_ENABLE_MC2 and not is_prefill :
848
883
...
0 commit comments