Skip to content

Commit d945298

Browse files
author
yangcheng (AJ)
committed
fix mc2 bug
1 parent f4219a3 commit d945298

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

vllm_ascend/ops/fused_moe.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,13 @@ def fused_experts_with_mc2(
118118
top_k: int,
119119
expert_map: torch.Tensor = None,
120120
moe_all_to_all_group_name: Optional[str] = None,
121-
shared_experts: Optional[Any] = None
121+
shared_experts: Optional[Any] = None,
122+
global_batch_size: int = 256,
122123
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
123-
global_bs = 0
124+
125+
ep_group = get_ep_group().device_group
126+
all_to_all_group_size = torch.distributed.get_world_size(ep_group)
127+
global_bs = global_batch_size * all_to_all_group_size
124128
moe_expert_num = len(expert_map)
125129
kwargs_mc2 = {
126130
"x": hidden_states,
@@ -132,11 +136,8 @@ def fused_experts_with_mc2(
132136
}
133137

134138
rank = torch.distributed.get_rank()
135-
136139
quant_mode = 0
137-
ep_group = get_ep_group().device_group
138140
local_rank = torch.distributed.get_rank(group=ep_group)
139-
all_to_all_group_size = torch.distributed.get_world_size(ep_group)
140141

141142
tp_size = get_etp_group().world_size
142143
tp_rank = rank % tp_size
@@ -204,7 +205,7 @@ def fused_experts_with_mc2(
204205
"expert_shard_type": 0,
205206
"shared_expert_rank_num": 0,
206207
"moe_expert_num": moe_expert_num,
207-
"global_bs": 0,
208+
"global_bs": global_bs,
208209
}
209210
tp_recv_counts = output[5]
210211
stage3_kwargs = {
@@ -1037,7 +1038,8 @@ def apply(
10371038
top_k=top_k,
10381039
expert_map=expert_map,
10391040
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
1040-
shared_experts=shared_experts)
1041+
shared_experts=shared_experts,
1042+
global_batch_size=self.global_batch_size)
10411043
elif fused_moe_state == FusedMoEState.AllGather:
10421044
return fused_experts(hidden_states=x,
10431045
w1=layer.w13_weight,

0 commit comments

Comments
 (0)