Skip to content

Commit a990949

Browse files
committed
fix etp rank related accuracy problem
Signed-off-by: whx-sjtu <2952154980@qq.com>
1 parent 5cf9ff1 commit a990949

File tree

1 file changed

+12
-25
lines changed

1 file changed

+12
-25
lines changed

vllm_ascend/ops/fused_moe.py

Lines changed: 12 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -615,32 +615,19 @@ def __init__(
615615
self.expert_map = None
616616
self.activation = activation
617617

618-
if self.ep_size > 1:
619-
# Create a tensor of size num_experts filled with -1
620-
self.local_num_experts, self.expert_map = determine_expert_map(
621-
self.ep_size,
622-
get_ep_group().rank_in_group, self.global_num_experts)
623-
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
624-
self.tp_rank = get_etp_group().rank_in_group
625-
self.ep_rank = get_ep_group().rank_in_group
626-
else:
627-
self.moe_parallel_config.tp_rank = get_etp_group(
628-
).rank_in_group
629-
self.moe_parallel_config.ep_rank = get_ep_group().rank_in_group
630-
618+
# Create a tensor of size num_experts filled with -1
619+
self.local_num_experts, self.expert_map = determine_expert_map(
620+
self.ep_size,
621+
get_ep_group().rank_in_group, self.global_num_experts)
622+
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
623+
self.tp_rank = get_etp_group().rank_in_group
624+
self.ep_rank = get_ep_group().rank_in_group
631625
else:
632-
# Adjust TP size for DP attention
633-
# haven't test its functionality yet, may remove in the future
634-
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
635-
self.tp_rank = self.tp_size * self.dp_rank
636-
self.ep_rank = 0
637-
self.tp_size = self.tp_size * self.dp_size
638-
self.ep_size = 1
639-
else:
640-
self.moe_parallel_config.tp_rank = self.tp_size * self.dp_rank
641-
self.moe_parallel_config.ep_rank = 0
642-
self.moe_parallel_config.tp_size = self.tp_size * self.dp_size
643-
self.moe_parallel_config.ep_size = 1
626+
self.moe_parallel_config.tp_rank = get_etp_group(
627+
).rank_in_group
628+
self.moe_parallel_config.ep_rank = get_ep_group().rank_in_group
629+
630+
644631

645632
self.local_num_experts, self.expert_map = (self.global_num_experts,
646633
None)

0 commit comments

Comments
 (0)