20
20
import torch
21
21
import torch .distributed as dist
22
22
import torch_npu
23
- from vllm .distributed import GroupCoordinator
23
+ import torchair as tng # type: ignore
24
+ from vllm .distributed import GroupCoordinator , tensor_model_parallel_all_reduce
24
25
25
26
import vllm_ascend .envs as envs_ascend
26
27
from vllm_ascend .distributed .parallel_state import get_ep_group
@@ -36,7 +37,8 @@ def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
36
37
w2_scale : torch .Tensor ,
37
38
group_list : torch .Tensor ,
38
39
dynamic_scale : torch .Tensor = None ,
39
- group_list_type : int = 1 ) -> torch .Tensor :
40
+ group_list_type : int = 1 ,
41
+ ** kwargs ) -> torch .Tensor :
40
42
"""
41
43
apply MLP: gate_up_proj -> swiglu -> down_proj
42
44
@@ -68,6 +70,23 @@ def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
68
70
else :
69
71
pertoken_scale = dynamic_scale
70
72
73
+ shared_experts = kwargs .get ('shared_experts' , None )
74
+ if shared_experts :
75
+ shared_gate_up = kwargs .get ('shared_gate_up' , None )
76
+ shared_dynamic_scale = kwargs .get ('shared_dynamic_scale' , None )
77
+ with tng .scope .npu_stream_switch ('1' ):
78
+ tng .scope .npu_wait_tensor (shared_gate_up , hidden_states )
79
+ shared_x , shared_dynamic_scale = torch_npu .npu_dequant_swiglu_quant (
80
+ x = shared_gate_up ,
81
+ weight_scale = shared_experts .gate_up_proj .weight_scale_fp32 ,
82
+ activation_scale = shared_dynamic_scale ,
83
+ bias = None ,
84
+ quant_scale = None ,
85
+ quant_offset = None ,
86
+ group_index = None ,
87
+ activate_left = True ,
88
+ quant_mode = 1 )
89
+
71
90
# gmm1: gate_up_proj
72
91
hidden_states = torch_npu .npu_grouped_matmul (
73
92
x = [hidden_states ],
@@ -96,25 +115,39 @@ def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
96
115
group_type = 0 ,
97
116
group_list = group_list ,
98
117
output_dtype = w2_scale .dtype )[0 ]
118
+
119
+ if shared_experts :
120
+ with tng .scope .npu_stream_switch ('1' ):
121
+ tng .scope .npu_wait_tensor (shared_x , hidden_states )
122
+ shared_output = torch_npu .npu_quant_matmul (
123
+ shared_x ,
124
+ shared_experts .down_proj .weight ,
125
+ shared_experts .down_proj .weight_scale ,
126
+ pertoken_scale = shared_dynamic_scale ,
127
+ output_dtype = torch .bfloat16 ,
128
+ )
129
+ if shared_experts .down_proj .reduce_results and shared_experts .down_proj .tp_size > 1 :
130
+ shared_output = tensor_model_parallel_all_reduce (shared_output )
131
+ if shared_experts :
132
+ return hidden_states , shared_output
99
133
return hidden_states
100
134
101
135
102
- def fused_experts_with_mc2 (
103
- hidden_states : torch .Tensor ,
104
- w1 : torch .Tensor ,
105
- w2 : torch .Tensor ,
106
- w1_scale : torch .Tensor ,
107
- w2_scale : torch .Tensor ,
108
- topk_weights : torch .Tensor ,
109
- topk_ids : torch .Tensor ,
110
- top_k : int ,
111
- expert_map : torch .Tensor = None ,
112
- moe_all_to_all_group_name : str = "" ,
113
- ) -> torch .Tensor :
136
+ def fused_experts_with_mc2 (hidden_states : torch .Tensor ,
137
+ w1 : torch .Tensor ,
138
+ w2 : torch .Tensor ,
139
+ w1_scale : torch .Tensor ,
140
+ w2_scale : torch .Tensor ,
141
+ topk_weights : torch .Tensor ,
142
+ topk_ids : torch .Tensor ,
143
+ top_k : int ,
144
+ expert_map : torch .Tensor = None ,
145
+ moe_all_to_all_group_name : str = "" ,
146
+ ** kwargs ) -> torch .Tensor :
114
147
global_bs = 0
115
148
moe_expert_num = len (expert_map )
116
149
# hidden_states = hidden_states.bfloat16()
117
- kwargs = {
150
+ kwargs_mc2 = {
118
151
"x" : hidden_states ,
119
152
"expert_ids" : topk_ids ,
120
153
"expert_shard_type" : 0 ,
@@ -145,9 +178,9 @@ def fused_experts_with_mc2(
145
178
"tp_world_size" : tp_size ,
146
179
"tp_rank_id" : tp_rank ,
147
180
}
148
- kwargs .update (stage1_kwargs )
181
+ kwargs_mc2 .update (stage1_kwargs )
149
182
150
- output = torch_npu .npu_moe_distribute_dispatch (** kwargs )
183
+ output = torch_npu .npu_moe_distribute_dispatch (** kwargs_mc2 )
151
184
# comm_stream.wait_stream(torch.npu.current_stream())
152
185
expand_x , dynamic_scale , expand_idx , expert_token_nums , ep_recv_counts = output [
153
186
0 :5 ]
@@ -165,10 +198,15 @@ def fused_experts_with_mc2(
165
198
w2 ,
166
199
w2_scale ,
167
200
expert_token_nums ,
168
- dynamic_scale = dynamic_scale )
201
+ dynamic_scale = dynamic_scale ,
202
+ ** kwargs )
203
+
204
+ multi_stream = isinstance (down_out_list , tuple )
205
+ if multi_stream :
206
+ down_out_list , shared_output = down_out_list
169
207
170
208
# moeCombine
171
- kwargs = {
209
+ kwargs_mc2 = {
172
210
"expand_x" : down_out_list ,
173
211
"expert_ids" : topk_ids ,
174
212
"expand_idx" : expand_idx ,
@@ -192,10 +230,12 @@ def fused_experts_with_mc2(
192
230
"tp_world_size" : tp_size ,
193
231
"tp_rank_id" : tp_rank ,
194
232
}
195
- kwargs .update (stage3_kwargs )
233
+ kwargs_mc2 .update (stage3_kwargs )
196
234
197
- hidden_states = torch_npu .npu_moe_distribute_combine (** kwargs )
235
+ hidden_states = torch_npu .npu_moe_distribute_combine (** kwargs_mc2 )
198
236
237
+ if multi_stream :
238
+ return hidden_states , shared_output
199
239
return hidden_states
200
240
201
241
@@ -633,7 +673,8 @@ def apply(
633
673
topk_ids = topk_ids ,
634
674
top_k = top_k ,
635
675
expert_map = expert_map ,
636
- moe_all_to_all_group_name = self .moe_all_to_all_group_name )
676
+ moe_all_to_all_group_name = self .moe_all_to_all_group_name ,
677
+ ** kwargs )
637
678
elif self .ep_group .world_size == 1 :
638
679
return fused_experts (hidden_states = x ,
639
680
w1 = layer .w13_weight ,
0 commit comments