20
20
import torch
21
21
import torch .distributed as dist
22
22
import torch_npu
23
- from vllm .distributed import GroupCoordinator
23
+ import torchair as tng
24
+ from vllm .distributed import GroupCoordinator , tensor_model_parallel_all_reduce
24
25
25
26
import vllm_ascend .envs as envs_ascend
26
27
from vllm_ascend .distributed .parallel_state import get_ep_group
@@ -36,7 +37,8 @@ def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
36
37
w2_scale : torch .Tensor ,
37
38
group_list : torch .Tensor ,
38
39
dynamic_scale : torch .Tensor = None ,
39
- group_list_type : int = 1 ) -> torch .Tensor :
40
+ group_list_type : int = 1 ,
41
+ ** kwargs ) -> torch .Tensor :
40
42
"""
41
43
apply MLP: gate_up_proj -> swiglu -> down_proj
42
44
@@ -68,6 +70,23 @@ def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
68
70
else :
69
71
pertoken_scale = dynamic_scale
70
72
73
+ shared_experts = kwargs .get ('shared_experts' , None )
74
+ if shared_experts :
75
+ shared_gate_up = kwargs .get ('shared_gate_up' , None )
76
+ shared_dynamic_scale = kwargs .get ('shared_dynamic_scale' , None )
77
+ with tng .scope .npu_stream_switch ('1' ):
78
+ tng .scope .npu_wait_tensor (shared_gate_up , hidden_states )
79
+ shared_x , shared_dynamic_scale = torch_npu .npu_dequant_swiglu_quant (
80
+ x = shared_gate_up ,
81
+ weight_scale = shared_experts .gate_up_proj .weight_scale_fp32 ,
82
+ activation_scale = shared_dynamic_scale ,
83
+ bias = None ,
84
+ quant_scale = None ,
85
+ quant_offset = None ,
86
+ group_index = None ,
87
+ activate_left = True ,
88
+ quant_mode = 1 )
89
+
71
90
# gmm1: gate_up_proj
72
91
hidden_states = torch_npu .npu_grouped_matmul (
73
92
x = [hidden_states ],
@@ -96,6 +115,21 @@ def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
96
115
group_type = 0 ,
97
116
group_list = group_list ,
98
117
output_dtype = w2_scale .dtype )[0 ]
118
+
119
+ if shared_experts :
120
+ with tng .scope .npu_stream_switch ('1' ):
121
+ tng .scope .npu_wait_tensor (shared_x , hidden_states )
122
+ shared_output = torch_npu .npu_quant_matmul (
123
+ shared_x ,
124
+ shared_experts .down_proj .weight ,
125
+ shared_experts .down_proj .weight_scale ,
126
+ pertoken_scale = shared_dynamic_scale ,
127
+ output_dtype = torch .bfloat16 ,
128
+ )
129
+ if shared_experts .down_proj .reduce_results and shared_experts .down_proj .tp_size > 1 :
130
+ shared_output = tensor_model_parallel_all_reduce (x )
131
+ if shared_experts :
132
+ return hidden_states , shared_output
99
133
return hidden_states
100
134
101
135
@@ -110,11 +144,12 @@ def fused_experts_with_mc2(
110
144
top_k : int ,
111
145
expert_map : torch .Tensor = None ,
112
146
moe_all_to_all_group_name : str = "" ,
147
+ ** kwargs
113
148
) -> torch .Tensor :
114
149
global_bs = 0
115
150
moe_expert_num = len (expert_map )
116
151
# hidden_states = hidden_states.bfloat16()
117
- kwargs = {
152
+ kwargs1 = {
118
153
"x" : hidden_states ,
119
154
"expert_ids" : topk_ids ,
120
155
"expert_shard_type" : 0 ,
@@ -145,9 +180,9 @@ def fused_experts_with_mc2(
145
180
"tp_world_size" : tp_size ,
146
181
"tp_rank_id" : tp_rank ,
147
182
}
148
- kwargs .update (stage1_kwargs )
183
+ kwargs1 .update (stage1_kwargs )
149
184
150
- output = torch_npu .npu_moe_distribute_dispatch (** kwargs )
185
+ output = torch_npu .npu_moe_distribute_dispatch (** kwargs1 )
151
186
# comm_stream.wait_stream(torch.npu.current_stream())
152
187
expand_x , dynamic_scale , expand_idx , expert_token_nums , ep_recv_counts = output [
153
188
0 :5 ]
@@ -165,10 +200,15 @@ def fused_experts_with_mc2(
165
200
w2 ,
166
201
w2_scale ,
167
202
expert_token_nums ,
168
- dynamic_scale = dynamic_scale )
203
+ dynamic_scale = dynamic_scale ,
204
+ ** kwargs )
205
+
206
+ multi_stream = isinstance (down_out_list , tuple )
207
+ if multi_stream :
208
+ down_out_list , shared_output = down_out_list
169
209
170
210
# moeCombine
171
- kwargs = {
211
+ kwargs2 = {
172
212
"expand_x" : down_out_list ,
173
213
"expert_ids" : topk_ids ,
174
214
"expand_idx" : expand_idx ,
@@ -192,10 +232,12 @@ def fused_experts_with_mc2(
192
232
"tp_world_size" : tp_size ,
193
233
"tp_rank_id" : tp_rank ,
194
234
}
195
- kwargs .update (stage3_kwargs )
235
+ kwargs2 .update (stage3_kwargs )
196
236
197
- hidden_states = torch_npu .npu_moe_distribute_combine (** kwargs )
237
+ hidden_states = torch_npu .npu_moe_distribute_combine (** kwargs2 )
198
238
239
+ if multi_stream :
240
+ return hidden_states , shared_output
199
241
return hidden_states
200
242
201
243
@@ -634,7 +676,8 @@ def apply(
634
676
topk_ids = topk_ids ,
635
677
top_k = top_k ,
636
678
expert_map = expert_map ,
637
- moe_all_to_all_group_name = self .moe_all_to_all_group_name )
679
+ moe_all_to_all_group_name = self .moe_all_to_all_group_name ,
680
+ ** kwargs )
638
681
elif dp_size == 1 :
639
682
return fused_experts (hidden_states = x ,
640
683
w1 = layer .w13_weight ,
0 commit comments