Skip to content

rm router logits Improve TTOP 3ms #1407

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Jul 11, 2025
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
72aeb69
rm router logits Improve TTOP 3ms
ttanzhiqiang Jun 24, 2025
04ad4c2
update
ttanzhiqiang Jun 24, 2025
f13442e
update
ttanzhiqiang Jun 25, 2025
db520cd
Merge branch 'main' into rm_router_logits
ttanzhiqiang Jun 25, 2025
4c8954a
Merge branch 'main' into rm_router_logits
ttanzhiqiang Jul 8, 2025
86df0a2
deepseekv3/r1 support rm_router_logits in [AllGatherEP, AllGather, Na…
ttanzhiqiang Jul 8, 2025
2f77bc9
deepseekv3/r1 support rm_router_logits in [AllGatherEP, AllGather, Na…
ttanzhiqiang Jul 8, 2025
6f18307
deepseekv3/r1 support rm_router_logits in [AllGatherEP, AllGather, Na…
ttanzhiqiang Jul 8, 2025
cb15e05
deepseekv3/r1 support rm_router_logits in [AllGatherEP, AllGather, Na…
ttanzhiqiang Jul 8, 2025
d8755c9
deepseekv3/r1 support rm_router_logits in [AllGatherEP, AllGather, Na…
ttanzhiqiang Jul 8, 2025
e0c36a8
deepseekv3/r1 support rm_router_logits in [AllGatherEP, AllGather, Na…
ttanzhiqiang Jul 8, 2025
9e15f42
deepseekv3/r1 support rm_router_logits in [AllGatherEP, AllGather, Na…
ttanzhiqiang Jul 8, 2025
a595a67
deepseekv3/r1 support rm_router_logits in [AllGatherEP, AllGather, Na…
ttanzhiqiang Jul 8, 2025
eedcd05
Merge branch 'main' into rm_router_logits
ttanzhiqiang Jul 9, 2025
fa50f6a
deepseekv3/r1 support rm_router_logits in [AllGatherEP, AllGather, Na…
ttanzhiqiang Jul 9, 2025
e4fc29f
Empty submission
ttanzhiqiang Jul 9, 2025
a0be155
Empty submission
ttanzhiqiang Jul 10, 2025
89458f0
Merge branch 'main' into rm_router_logits
ttanzhiqiang Jul 10, 2025
af900cc
update
ttanzhiqiang Jul 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion vllm_ascend/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@
self.ep_group = get_ep_group()

self.params_dtype = torch.get_default_dtype()
self.rm_router_logits = self.experts.rm_router_logits

Check warning on line 370 in vllm_ascend/models/deepseek_v2.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/models/deepseek_v2.py#L370

Added line #L370 was not covered by tests

def forward(self,
hidden_states: torch.Tensor,
Expand All @@ -389,7 +390,9 @@
is_prefill = is_prefill or attn_metadata.with_prefill_across_dp

# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
router_logits = None
if not self.rm_router_logits:
router_logits, _ = self.gate(hidden_states)

Check warning on line 395 in vllm_ascend/models/deepseek_v2.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/models/deepseek_v2.py#L393-L395

Added lines #L393 - L395 were not covered by tests

experts_hidden_states = self.experts(
hidden_states=hidden_states,
Expand All @@ -398,6 +401,7 @@
top_k=CustomDeepseekV2MoE.top_k,
enable_force_load_balance=enable_force_load_balance,
shared_experts=self.shared_experts,
gate=self.gate,
replace_allreduce=replace_allreduce)

hidden_states = (
Expand Down
33 changes: 24 additions & 9 deletions vllm_ascend/ops/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@
from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
from vllm_ascend.utils import (FusedMoEState, dispose_tensor,
get_fused_moe_state, is_310p, npu_stream_switch,
npu_wait_tensor)
get_fused_moe_state, get_rm_router_logits_state,
is_310p, npu_stream_switch, npu_wait_tensor)

MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER

Expand Down Expand Up @@ -1144,6 +1144,10 @@
self.log2phy = None
self.global_redundant_expert_num = 0

is_deepseek_v3_r1 = self.global_num_experts == 256
self.rm_router_logits = get_rm_router_logits_state(

Check warning on line 1148 in vllm_ascend/ops/fused_moe.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/ops/fused_moe.py#L1147-L1148

Added lines #L1147 - L1148 were not covered by tests
self.moe_parallel_config.ep_size, self.dp_size, is_deepseek_v3_r1)

ascend_config = get_ascend_config()
expert_map_path = ascend_config.expert_map_path
if expert_map_path and os.path.exists(expert_map_path):
Expand Down Expand Up @@ -1233,7 +1237,9 @@
enable_force_load_balance: bool = False,
top_k: Optional[int] = None,
shared_experts: Optional[Any] = None,
gate=None,
replace_allreduce: bool = False):

assert self.quant_method is not None

if top_k:
Expand Down Expand Up @@ -1269,6 +1275,7 @@
tp_rank = get_tensor_model_parallel_rank()
hidden_states = chunk_hidden_states[tp_rank]
router_logits = chunk_router_logits[tp_rank]

if self.dp_size > 1:
if fused_moe_state == FusedMoEState.AllGather:
# NOTE: When in torchair graph, it has been padded in model_runner_v1
Expand All @@ -1281,19 +1288,27 @@
hidden_states,
(0, 0, 0,
max_num_tokens_across_dp - num_tokens))
router_logits = nn.functional.pad(
router_logits,
(0, 0, 0,
max_num_tokens_across_dp - num_tokens))
if not self.rm_router_logits:
router_logits = nn.functional.pad(

Check warning on line 1292 in vllm_ascend/ops/fused_moe.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/ops/fused_moe.py#L1291-L1292

Added lines #L1291 - L1292 were not covered by tests
router_logits,
(0, 0, 0,
max_num_tokens_across_dp - num_tokens))
hidden_states = get_dp_group().all_gather(hidden_states, 0)
router_logits = get_dp_group().all_gather(router_logits, 0)
if self.rm_router_logits:
router_logits, _ = gate(hidden_states)

Check warning on line 1298 in vllm_ascend/ops/fused_moe.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/ops/fused_moe.py#L1297-L1298

Added lines #L1297 - L1298 were not covered by tests
else:
router_logits = get_dp_group().all_gather(router_logits, 0)

Check warning on line 1300 in vllm_ascend/ops/fused_moe.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/ops/fused_moe.py#L1300

Added line #L1300 was not covered by tests

elif fused_moe_state == FusedMoEState.NaiveMulticast:
cu_tokens_across_dp_cpu = get_forward_context(
).dp_metadata.cu_tokens_across_dp_cpu
hidden_states = self.naive_multicast(hidden_states,
cu_tokens_across_dp_cpu)
router_logits = self.naive_multicast(router_logits,
cu_tokens_across_dp_cpu)
if self.rm_router_logits:
router_logits, _ = gate(hidden_states)

Check warning on line 1308 in vllm_ascend/ops/fused_moe.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/ops/fused_moe.py#L1307-L1308

Added lines #L1307 - L1308 were not covered by tests
else:
router_logits = self.naive_multicast(

Check warning on line 1310 in vllm_ascend/ops/fused_moe.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/ops/fused_moe.py#L1310

Added line #L1310 was not covered by tests
router_logits, cu_tokens_across_dp_cpu)

# Matrix multiply.
e_hidden_states = self.quant_method.apply(
Expand Down
16 changes: 16 additions & 0 deletions vllm_ascend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,22 @@
NaiveMulticast = 4


# TODO(ttanzhiqiang): rm_router_logits
# dp>1 will trigger
# In theory, this solution is only applicable to AllGather and AllGatherEP, because in the dp scenario, the previous operation was gate + two communications, and now it is changed to one communication + gate operation, which can save some communication time. In theory, all moe AllGather and AllGatherEP solutions can follow this logic, but now other moe models (qwen3-235b) dp solutions are not adjusted, so use the switch to control it to prevent code errors.
def get_rm_router_logits_state(ep_size: int, dp_size: int,
is_deepseek_v3_r1: bool):
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
# only supports deepseek v3/r1
if dp_size > 1:
if (envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1

Check warning on line 436 in vllm_ascend/utils.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/utils.py#L435-L436

Added lines #L435 - L436 were not covered by tests
and is_deepseek_v3_r1):
return True
elif ep_size == 1 and is_deepseek_v3_r1:
return True
return False

Check warning on line 441 in vllm_ascend/utils.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/utils.py#L438-L441

Added lines #L438 - L441 were not covered by tests


# TODO(zzzzwwjj): add soc_version to choose branch
def get_fused_moe_state(ep_size: int, with_prefill: bool,
is_deepseek_v3_r1: bool):
Expand Down
Loading