20
20
import torch
21
21
import torch .distributed as dist
22
22
import torch_npu
23
+ import torchair as tng # type: ignore
23
24
from vllm .config import get_current_vllm_config
24
25
from vllm .distributed import (GroupCoordinator ,
25
26
get_tensor_model_parallel_world_size ,
38
39
USING_LCCL_COM : bool = envs_ascend .USING_LCCL_COM
39
40
40
41
41
- def fused_experts_with_mc2 (
42
- hidden_states : torch .Tensor ,
43
- w1 : torch .Tensor ,
44
- w2 : torch .Tensor ,
45
- topk_weights : torch .Tensor ,
46
- topk_ids : torch .Tensor ,
47
- top_k : int ,
48
- expert_map : torch .Tensor = None ,
49
- moe_all_to_all_group_name : Optional [str ] = None ,
50
- ) -> torch .Tensor :
42
+ def fused_experts_with_mc2 (hidden_states : torch .Tensor ,
43
+ w1 : torch .Tensor ,
44
+ w2 : torch .Tensor ,
45
+ topk_weights : torch .Tensor ,
46
+ topk_ids : torch .Tensor ,
47
+ top_k : int ,
48
+ expert_map : torch .Tensor = None ,
49
+ moe_all_to_all_group_name : Optional [str ] = None ,
50
+ ** kwargs ) -> torch .Tensor :
51
51
global_bs = 0
52
52
moe_expert_num = len (expert_map )
53
- kwargs = {
53
+ kwargs_mc2 = {
54
54
"x" : hidden_states ,
55
55
"expert_ids" : topk_ids ,
56
56
"expert_shard_type" : 0 ,
@@ -81,13 +81,20 @@ def fused_experts_with_mc2(
81
81
"tp_world_size" : tp_size ,
82
82
"tp_rank_id" : tp_rank ,
83
83
}
84
- kwargs .update (stage1_kwargs )
84
+ kwargs_mc2 .update (stage1_kwargs )
85
85
86
- output = torch_npu .npu_moe_distribute_dispatch (** kwargs )
86
+ output = torch_npu .npu_moe_distribute_dispatch (** kwargs_mc2 )
87
87
# comm_stream.wait_stream(torch.npu.current_stream())
88
88
expand_x , dynamic_scale , expand_idx , expert_token_nums , ep_recv_counts = output [
89
89
0 :5 ]
90
90
91
+ shared_experts = kwargs .get ('shared_experts' , None )
92
+ if shared_experts :
93
+ shared_gate_up = kwargs .get ('shared_gate_up' , None )
94
+ with tng .scope .npu_stream_switch ('cv' ):
95
+ tng .scope .npu_wait_tensor (shared_gate_up , expand_x )
96
+ shared_x = shared_experts .act_fn (shared_gate_up )
97
+
91
98
w1 = w1 .transpose (1 , 2 )
92
99
expert_token_nums = torch .cumsum (expert_token_nums ,
93
100
dim = 0 ,
@@ -116,10 +123,15 @@ def fused_experts_with_mc2(
116
123
group_list = group_list ,
117
124
)
118
125
126
+ if shared_experts :
127
+ with tng .scope .npu_stream_switch ('cv' ):
128
+ tng .scope .npu_wait_tensor (shared_x , down_out_list )
129
+ shared_output = shared_experts .down_proj (shared_x )
130
+
119
131
down_out_list = torch .cat (down_out_list , dim = 0 )
120
132
121
133
# moeCombine
122
- kwargs = {
134
+ kwargs_mc2 = {
123
135
"expand_x" : down_out_list ,
124
136
"expert_ids" : topk_ids ,
125
137
"expand_idx" : expand_idx ,
@@ -141,10 +153,12 @@ def fused_experts_with_mc2(
141
153
"tp_world_size" : tp_size ,
142
154
"tp_rank_id" : tp_rank ,
143
155
}
144
- kwargs .update (stage3_kwargs )
156
+ kwargs_mc2 .update (stage3_kwargs )
145
157
146
- hidden_states = torch_npu .npu_moe_distribute_combine (** kwargs )
158
+ hidden_states = torch_npu .npu_moe_distribute_combine (** kwargs_mc2 )
147
159
160
+ if shared_experts :
161
+ return hidden_states , shared_output
148
162
return hidden_states
149
163
150
164
@@ -664,7 +678,8 @@ def apply(
664
678
topk_ids = topk_ids ,
665
679
top_k = top_k ,
666
680
expert_map = expert_map ,
667
- moe_all_to_all_group_name = self .moe_all_to_all_group_name )
681
+ moe_all_to_all_group_name = self .moe_all_to_all_group_name ,
682
+ ** kwargs )
668
683
elif get_ep_group ().world_size == 1 :
669
684
return fused_experts (hidden_states = x ,
670
685
w1 = layer .w13_weight ,
0 commit comments