33
33
34
34
import vllm_ascend .envs as envs_ascend
35
35
from vllm_ascend .distributed .parallel_state import get_ep_group , get_etp_group
36
+ from vllm_ascend .utils import npu_stream_switch , npu_wait_tensor
36
37
37
38
VLLM_ENABLE_MC2 : bool = envs_ascend .VLLM_ENABLE_MC2
38
39
USING_LCCL_COM : bool = envs_ascend .USING_LCCL_COM
@@ -47,6 +48,7 @@ def fused_experts_with_mc2(
47
48
top_k : int ,
48
49
expert_map : torch .Tensor = None ,
49
50
moe_all_to_all_group_name : Optional [str ] = None ,
51
+ shared_experts : Optional [torch .nn .Module ] = None ,
50
52
) -> torch .Tensor :
51
53
global_bs = 0
52
54
moe_expert_num = len (expert_map )
@@ -83,11 +85,20 @@ def fused_experts_with_mc2(
83
85
}
84
86
kwargs .update (stage1_kwargs )
85
87
88
+ if shared_experts is not None :
89
+ with npu_stream_switch ("expert_secondary" ):
90
+ shared_gate_up , _ = shared_experts .gate_up_proj (hidden_states )
91
+
86
92
output = torch_npu .npu_moe_distribute_dispatch (** kwargs )
87
93
# comm_stream.wait_stream(torch.npu.current_stream())
88
94
expand_x , dynamic_scale , expand_idx , expert_token_nums , ep_recv_counts = output [
89
95
0 :5 ]
90
96
97
+ if shared_experts is not None :
98
+ with npu_stream_switch ("expert_secondary" ):
99
+ npu_wait_tensor (shared_gate_up , expand_x )
100
+ shared_act = shared_experts .act_fn (shared_gate_up )
101
+
91
102
w1 = w1 .transpose (1 , 2 )
92
103
expert_token_nums = torch .cumsum (expert_token_nums ,
93
104
dim = 0 ,
@@ -118,6 +129,11 @@ def fused_experts_with_mc2(
118
129
119
130
down_out_list = torch .cat (down_out_list , dim = 0 )
120
131
132
+ if shared_experts is not None :
133
+ with npu_stream_switch ("expert_secondary" ):
134
+ npu_wait_tensor (shared_act , down_out_list )
135
+ shared_hidden_states , _ = shared_experts .down_proj (shared_act )
136
+
121
137
# moeCombine
122
138
kwargs = {
123
139
"expand_x" : down_out_list ,
@@ -145,7 +161,7 @@ def fused_experts_with_mc2(
145
161
146
162
hidden_states = torch_npu .npu_moe_distribute_combine (** kwargs )
147
163
148
- return hidden_states
164
+ return hidden_states , shared_hidden_states if shared_experts is not None else None
149
165
150
166
151
167
# currently expert parallelism implemented with all2all
@@ -624,6 +640,7 @@ def apply(
624
640
scoring_func : str = "softmax" ,
625
641
e_score_correction_bias : Optional [torch .Tensor ] = None ,
626
642
is_prefill : bool = False ,
643
+ shared_experts : Optional [torch .nn .Module ] = None ,
627
644
** kwargs ,
628
645
):
629
646
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
@@ -664,28 +681,35 @@ def apply(
664
681
topk_ids = topk_ids ,
665
682
top_k = top_k ,
666
683
expert_map = expert_map ,
667
- moe_all_to_all_group_name = self .moe_all_to_all_group_name )
684
+ moe_all_to_all_group_name = self .moe_all_to_all_group_name ,
685
+ shared_experts = shared_experts ,
686
+ )
668
687
elif get_ep_group ().world_size == 1 :
669
- return fused_experts (hidden_states = x ,
670
- w1 = layer .w13_weight ,
671
- w2 = layer .w2_weight ,
672
- topk_weights = topk_weights ,
673
- topk_ids = topk_ids ,
674
- top_k = top_k ,
675
- expert_map = expert_map )
688
+ router_hidden_states = fused_experts (hidden_states = x ,
689
+ w1 = layer .w13_weight ,
690
+ w2 = layer .w2_weight ,
691
+ topk_weights = topk_weights ,
692
+ topk_ids = topk_ids ,
693
+ top_k = top_k ,
694
+ expert_map = expert_map )
676
695
else :
677
696
# The current implementation of deepseek moe splits hidden_states
678
697
# according to tp_size before they are feed into fused_moe module.
679
698
# Therefore, all2all is needed no matter how dp/tp is set so as to
680
699
# dispatch/combine tokens.
681
- return fused_experts_with_all2all (hidden_states = x ,
682
- w1 = layer .w13_weight ,
683
- w2 = layer .w2_weight ,
684
- topk_weights = topk_weights ,
685
- topk_ids = topk_ids ,
686
- top_k = top_k ,
687
- expert_map = expert_map ,
688
- ep_group = get_ep_group ())
700
+ router_hidden_states = fused_experts_with_all2all (hidden_states = x ,
701
+ w1 = layer .w13_weight ,
702
+ w2 = layer .w2_weight ,
703
+ topk_weights = topk_weights ,
704
+ topk_ids = topk_ids ,
705
+ top_k = top_k ,
706
+ expert_map = expert_map ,
707
+ ep_group = get_ep_group ())
708
+
709
+ if shared_experts is None :
710
+ return router_hidden_states
711
+ else :
712
+ return router_hidden_states , shared_experts (x )
689
713
690
714
691
715
class AscendFusedMoE (FusedMoE ):
@@ -815,7 +839,8 @@ def forward(self,
815
839
router_logits : torch .Tensor ,
816
840
is_prefill : bool ,
817
841
enable_force_load_balance : bool = False ,
818
- top_k = None ):
842
+ top_k : Optional [int ] = None ,
843
+ shared_experts : Optional [torch .nn .Module ] = None ):
819
844
assert self .quant_method is not None
820
845
821
846
if top_k :
@@ -842,7 +867,9 @@ def forward(self,
842
867
scoring_func = self .scoring_func ,
843
868
e_score_correction_bias = self .e_score_correction_bias ,
844
869
is_prefill = is_prefill ,
845
- enable_force_load_balance = enable_force_load_balance )
870
+ enable_force_load_balance = enable_force_load_balance ,
871
+ shared_experts = shared_experts ,
872
+ )
846
873
847
874
if VLLM_ENABLE_MC2 and not is_prefill :
848
875
...
0 commit comments