Skip to content

Commit 150a577

Browse files
author
yangcheng (AJ)
committed
fix mc2 bug
1 parent f4219a3 commit 150a577

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

vllm_ascend/ops/fused_moe.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,10 @@ def fused_experts_with_mc2(
120120
moe_all_to_all_group_name: Optional[str] = None,
121121
shared_experts: Optional[Any] = None
122122
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
123-
global_bs = 0
123+
vllm_config = get_current_vllm_config()
124+
ep_group = get_ep_group().device_group
125+
all_to_all_group_size = torch.distributed.get_world_size(ep_group)
126+
global_bs = vllm_config.scheduler_config.max_num_seqs * all_to_all_group_size
124127
moe_expert_num = len(expert_map)
125128
kwargs_mc2 = {
126129
"x": hidden_states,
@@ -132,11 +135,8 @@ def fused_experts_with_mc2(
132135
}
133136

134137
rank = torch.distributed.get_rank()
135-
136138
quant_mode = 0
137-
ep_group = get_ep_group().device_group
138139
local_rank = torch.distributed.get_rank(group=ep_group)
139-
all_to_all_group_size = torch.distributed.get_world_size(ep_group)
140140

141141
tp_size = get_etp_group().world_size
142142
tp_rank = rank % tp_size
@@ -204,7 +204,7 @@ def fused_experts_with_mc2(
204204
"expert_shard_type": 0,
205205
"shared_expert_rank_num": 0,
206206
"moe_expert_num": moe_expert_num,
207-
"global_bs": 0,
207+
"global_bs": global_bs,
208208
}
209209
tp_recv_counts = output[5]
210210
stage3_kwargs = {

0 commit comments

Comments
 (0)