Skip to content

rm router logits Improve TTOP 3ms #1407

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Jul 11, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
72aeb69
rm router logits Improve TTOP 3ms
ttanzhiqiang Jun 24, 2025
04ad4c2
update
ttanzhiqiang Jun 24, 2025
f13442e
update
ttanzhiqiang Jun 25, 2025
db520cd
Merge branch 'main' into rm_router_logits
ttanzhiqiang Jun 25, 2025
4c8954a
Merge branch 'main' into rm_router_logits
ttanzhiqiang Jul 8, 2025
86df0a2
deepseekv3/r1 support rm_router_logits in [AllGatherEP, AllGather, Na…
ttanzhiqiang Jul 8, 2025
2f77bc9
deepseekv3/r1 support rm_router_logits in [AllGatherEP, AllGather, Na…
ttanzhiqiang Jul 8, 2025
6f18307
deepseekv3/r1 support rm_router_logits in [AllGatherEP, AllGather, Na…
ttanzhiqiang Jul 8, 2025
cb15e05
deepseekv3/r1 support rm_router_logits in [AllGatherEP, AllGather, Na…
ttanzhiqiang Jul 8, 2025
d8755c9
deepseekv3/r1 support rm_router_logits in [AllGatherEP, AllGather, Na…
ttanzhiqiang Jul 8, 2025
e0c36a8
deepseekv3/r1 support rm_router_logits in [AllGatherEP, AllGather, Na…
ttanzhiqiang Jul 8, 2025
9e15f42
deepseekv3/r1 support rm_router_logits in [AllGatherEP, AllGather, Na…
ttanzhiqiang Jul 8, 2025
a595a67
deepseekv3/r1 support rm_router_logits in [AllGatherEP, AllGather, Na…
ttanzhiqiang Jul 8, 2025
eedcd05
Merge branch 'main' into rm_router_logits
ttanzhiqiang Jul 9, 2025
fa50f6a
deepseekv3/r1 support rm_router_logits in [AllGatherEP, AllGather, Na…
ttanzhiqiang Jul 9, 2025
e4fc29f
Empty submission
ttanzhiqiang Jul 9, 2025
a0be155
Empty submission
ttanzhiqiang Jul 10, 2025
89458f0
Merge branch 'main' into rm_router_logits
ttanzhiqiang Jul 10, 2025
af900cc
update
ttanzhiqiang Jul 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/run_dp_attention_etp16.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
source /usr/local/Ascend/nnal/atb/set_env.sh
export ASCEND_LAUNCH_BLOCKING=0
export VLLM_VERSION=0.9.0
export VLLM_ASCEND_RM_ROUTER_LOGITS=1

nohup python -m vllm.entrypoints.openai.api_server --model=/mnt/deepseek/DeepSeek-R1-W8A8-VLLM \
--quantization ascend \
Expand Down
3 changes: 3 additions & 0 deletions vllm_ascend/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@
# value to False to disable the optimized model.
"USE_OPTIMIZED_MODEL":
lambda: bool(int(os.getenv('USE_OPTIMIZED_MODEL', '1'))),
# Remove the two communications of get_dp_group().all_gather and change it to one, and do gate after the communication
"VLLM_ASCEND_RM_ROUTER_LOGITS":
lambda: int(os.getenv("VLLM_ASCEND_RM_ROUTER_LOGITS", 0)),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from Q3, we'll be careful to add more configuration. please remove it to enable rm_router_logits by default.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only valid in the FusedMoEState.AllGather solution. If other models use gate externally and rm_router_logits internally, an error will be reported, such as deepseek_dbo/qwen3/qwen2

Copy link
Collaborator

@Yikun Yikun Jul 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, could we enable this in some cases automatically, because it's difficult to let users know which models should enable this env or not.

Otherwise, LGTM

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In theory, this solution is only applicable to AllGather and AllGatherEP, because in the dp scenario, the previous operation was gate + two communications, and now it is changed to one communication + gate operation, which can save some communication time. In theory, all moe AllGather and AllGatherEP solutions can follow this logic, but now other moe models (qwen3-235b) dp solutions are not adjusted, so use the switch to control it to prevent code errors.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's not common, I prefer not to merge, we can wait more.

Or, if we can add more logic check instead of env var, i'm fine as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, currently m is enabled by default in the AllGather, AllGatherEP and NaiveMulticast scenarios of the deepseek model. rm_router_logits is not enabled in other scenarios and models. You can add it later if necessary.

}

# end-env-vars-definition
Expand Down
7 changes: 6 additions & 1 deletion vllm_ascend/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
maybe_prefix)
from vllm.sequence import IntermediateTensors

import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import get_ep_group
from vllm_ascend.ops.fused_moe import AscendFusedMoE
Expand Down Expand Up @@ -365,6 +366,7 @@ def __init__(
self.ep_group = get_ep_group()

self.params_dtype = torch.get_default_dtype()
self.rm_router_logits = envs_ascend.VLLM_ASCEND_RM_ROUTER_LOGITS

def forward(self,
hidden_states: torch.Tensor,
Expand All @@ -387,7 +389,9 @@ def forward(self,
is_prefill = is_prefill or attn_metadata.with_prefill_across_dp

# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
router_logits = None
if not self.rm_router_logits:
router_logits, _ = self.gate(hidden_states)

experts_hidden_states = self.experts(
hidden_states=hidden_states,
Expand All @@ -396,6 +400,7 @@ def forward(self,
top_k=CustomDeepseekV2MoE.top_k,
enable_force_load_balance=enable_force_load_balance,
shared_experts=self.shared_experts,
gate=self.gate,
replace_allreduce=replace_allreduce)

hidden_states = (
Expand Down
16 changes: 12 additions & 4 deletions vllm_ascend/ops/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1137,6 +1137,7 @@ def __init__(
self.activation = activation
self.log2phy = None
self.global_redundant_expert_num = 0
self.rm_router_logits = envs_ascend.VLLM_ASCEND_RM_ROUTER_LOGITS

ascend_config = get_ascend_config()
expert_map_path = ascend_config.expert_map_path
Expand Down Expand Up @@ -1212,7 +1213,9 @@ def forward(self,
enable_force_load_balance: bool = False,
top_k: Optional[int] = None,
shared_experts: Optional[Any] = None,
gate=None,
replace_allreduce: bool = False):

assert self.quant_method is not None

if top_k:
Expand Down Expand Up @@ -1257,11 +1260,16 @@ def forward(self,
hidden_states = nn.functional.pad(
hidden_states,
(0, 0, 0, max_num_tokens_across_dp - num_tokens))
router_logits = nn.functional.pad(
router_logits,
(0, 0, 0, max_num_tokens_across_dp - num_tokens))
if not self.rm_router_logits:
router_logits = nn.functional.pad(
router_logits,
(0, 0, 0,
max_num_tokens_across_dp - num_tokens))
hidden_states = get_dp_group().all_gather(hidden_states, 0)
router_logits = get_dp_group().all_gather(router_logits, 0)
if self.rm_router_logits:
router_logits, _ = gate(hidden_states)
else:
router_logits = get_dp_group().all_gather(router_logits, 0)

# Matrix multiply.
e_hidden_states = self.quant_method.apply(
Expand Down
Loading