@@ -147,7 +147,7 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor,
147
147
global_bs = 0
148
148
moe_expert_num = len (expert_map )
149
149
# hidden_states = hidden_states.bfloat16()
150
- kwargs1 = {
150
+ kwargs_mc2 = {
151
151
"x" : hidden_states ,
152
152
"expert_ids" : topk_ids ,
153
153
"expert_shard_type" : 0 ,
@@ -178,9 +178,9 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor,
178
178
"tp_world_size" : tp_size ,
179
179
"tp_rank_id" : tp_rank ,
180
180
}
181
- kwargs1 .update (stage1_kwargs )
181
+ kwargs_mc2 .update (stage1_kwargs )
182
182
183
- output = torch_npu .npu_moe_distribute_dispatch (** kwargs1 )
183
+ output = torch_npu .npu_moe_distribute_dispatch (** kwargs_mc2 )
184
184
# comm_stream.wait_stream(torch.npu.current_stream())
185
185
expand_x , dynamic_scale , expand_idx , expert_token_nums , ep_recv_counts = output [
186
186
0 :5 ]
@@ -206,7 +206,7 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor,
206
206
down_out_list , shared_output = down_out_list
207
207
208
208
# moeCombine
209
- kwargs2 = {
209
+ kwargs_mc2 = {
210
210
"expand_x" : down_out_list ,
211
211
"expert_ids" : topk_ids ,
212
212
"expand_idx" : expand_idx ,
@@ -230,9 +230,9 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor,
230
230
"tp_world_size" : tp_size ,
231
231
"tp_rank_id" : tp_rank ,
232
232
}
233
- kwargs2 .update (stage3_kwargs )
233
+ kwargs_mc2 .update (stage3_kwargs )
234
234
235
- hidden_states = torch_npu .npu_moe_distribute_combine (** kwargs2 )
235
+ hidden_states = torch_npu .npu_moe_distribute_combine (** kwargs_mc2 )
236
236
237
237
if multi_stream :
238
238
return hidden_states , shared_output
0 commit comments