20
20
import torch
21
21
import torch .distributed as dist
22
22
import torch_npu
23
+ import torchair as tng # type: ignore
23
24
from vllm .config import get_current_vllm_config
24
25
from vllm .distributed import (GroupCoordinator ,
25
26
get_tensor_model_parallel_world_size ,
@@ -47,10 +48,11 @@ def fused_experts_with_mc2(
47
48
top_k : int ,
48
49
expert_map : torch .Tensor = None ,
49
50
moe_all_to_all_group_name : Optional [str ] = None ,
51
+ ** kwargs
50
52
) -> torch .Tensor :
51
53
global_bs = 0
52
54
moe_expert_num = len (expert_map )
53
- kwargs = {
55
+ kwargs_mc2 = {
54
56
"x" : hidden_states ,
55
57
"expert_ids" : topk_ids ,
56
58
"expert_shard_type" : 0 ,
@@ -81,13 +83,20 @@ def fused_experts_with_mc2(
81
83
"tp_world_size" : tp_size ,
82
84
"tp_rank_id" : tp_rank ,
83
85
}
84
- kwargs .update (stage1_kwargs )
86
+ kwargs_mc2 .update (stage1_kwargs )
85
87
86
- output = torch_npu .npu_moe_distribute_dispatch (** kwargs )
88
+ output = torch_npu .npu_moe_distribute_dispatch (** kwargs_mc2 )
87
89
# comm_stream.wait_stream(torch.npu.current_stream())
88
90
expand_x , dynamic_scale , expand_idx , expert_token_nums , ep_recv_counts = output [
89
91
0 :5 ]
90
92
93
+ shared_experts = kwargs .get ('shared_experts' , None )
94
+ if shared_experts :
95
+ shared_gate_up = kwargs .get ('shared_gate_up' , None )
96
+ with tng .scope .npu_stream_switch ('cv' ):
97
+ tng .scope .npu_wait_tensor (shared_gate_up , expand_x )
98
+ shared_x = shared_experts .act_fn (shared_gate_up )
99
+
91
100
w1 = w1 .transpose (1 , 2 )
92
101
expert_token_nums = torch .cumsum (expert_token_nums ,
93
102
dim = 0 ,
@@ -116,10 +125,15 @@ def fused_experts_with_mc2(
116
125
group_list = group_list ,
117
126
)
118
127
128
+ if shared_experts :
129
+ with tng .scope .npu_stream_switch ('cv' ):
130
+ tng .scope .npu_wait_tensor (shared_x , down_out_list )
131
+ shared_output = shared_experts .down_proj (shared_x )
132
+
119
133
down_out_list = torch .cat (down_out_list , dim = 0 )
120
134
121
135
# moeCombine
122
- kwargs = {
136
+ kwargs_mc2 = {
123
137
"expand_x" : down_out_list ,
124
138
"expert_ids" : topk_ids ,
125
139
"expand_idx" : expand_idx ,
@@ -141,10 +155,12 @@ def fused_experts_with_mc2(
141
155
"tp_world_size" : tp_size ,
142
156
"tp_rank_id" : tp_rank ,
143
157
}
144
- kwargs .update (stage3_kwargs )
158
+ kwargs_mc2 .update (stage3_kwargs )
145
159
146
- hidden_states = torch_npu .npu_moe_distribute_combine (** kwargs )
160
+ hidden_states = torch_npu .npu_moe_distribute_combine (** kwargs_mc2 )
147
161
162
+ if shared_experts :
163
+ return hidden_states , shared_output
148
164
return hidden_states
149
165
150
166
@@ -664,7 +680,8 @@ def apply(
664
680
topk_ids = topk_ids ,
665
681
top_k = top_k ,
666
682
expert_map = expert_map ,
667
- moe_all_to_all_group_name = self .moe_all_to_all_group_name )
683
+ moe_all_to_all_group_name = self .moe_all_to_all_group_name ,
684
+ ** kwargs )
668
685
elif get_ep_group ().world_size == 1 :
669
686
return fused_experts (hidden_states = x ,
670
687
w1 = layer .w13_weight ,
0 commit comments