Skip to content

rm router logits Improve TTOP 3ms #1407

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Jul 11, 2025

Conversation

ttanzhiqiang
Copy link
Contributor

@ttanzhiqiang ttanzhiqiang commented Jun 24, 2025

What this PR does / why we need it?

The previous code is
router_logits, _ = self.gate(hidden_states)
hidden_states = get_dp_group().all_gather(hidden_states, 0)
router_logits = get_dp_group().all_gather(router_logits, 0)
I want to change the two all_gathers to one, reduce one all_gather communication, and make it
hidden_states = get_dp_group().all_gather(hidden_states, 0)
router_logits, _ = self.gate(hidden_states)

Does this PR introduce any user-facing change?

How was this patch tested?

bash examples/run_dp_attention_etp16.sh
bash examples/run_dp_attention_etp16_benmark.sh

gsm8k accuracy verification
截屏2025-06-24 21 53 24

Signed-off-by: ttanzhiqiang <389825161@qq.com>
Signed-off-by: ttanzhiqiang <389825161@qq.com>
Signed-off-by: ttanzhiqiang <389825161@qq.com>
Copy link

codecov bot commented Jun 25, 2025

Codecov Report

Attention: Patch coverage is 5.00000% with 19 lines in your changes missing coverage. Please review.

Project coverage is 54.49%. Comparing base (c30ddb8) to head (af900cc).
Report is 109 commits behind head on main.

Files with missing lines Patch % Lines
vllm_ascend/ops/fused_moe.py 0.00% 9 Missing ⚠️
vllm_ascend/utils.py 14.28% 6 Missing ⚠️
vllm_ascend/models/deepseek_v2.py 0.00% 4 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1407       +/-   ##
===========================================
+ Coverage   27.39%   54.49%   +27.10%     
===========================================
  Files          56       80       +24     
  Lines        6191     9984     +3793     
===========================================
+ Hits         1696     5441     +3745     
- Misses       4495     4543       +48     
Flag Coverage Δ
unittests 54.49% <5.00%> (+27.10%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@ApsarasX ApsarasX added the ready read for review label Jul 1, 2025
@@ -121,6 +121,9 @@
# value to False to disable the optimized model.
"USE_OPTIMIZED_MODEL":
lambda: bool(int(os.getenv('USE_OPTIMIZED_MODEL', '1'))),
# Remove the two communications of get_dp_group().all_gather and change it to one, and do gate after the communication
"VLLM_ASCEND_RM_ROUTER_LOGITS":
lambda: int(os.getenv("VLLM_ASCEND_RM_ROUTER_LOGITS", 0)),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from Q3, we'll be careful to add more configuration. please remove it to enable rm_router_logits by default.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only valid in the FusedMoEState.AllGather solution. If other models use gate externally and rm_router_logits internally, an error will be reported, such as deepseek_dbo/qwen3/qwen2

Copy link
Collaborator

@Yikun Yikun Jul 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, could we enable this in some cases automatically, because it's difficult to let users know which models should enable this env or not.

Otherwise, LGTM

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In theory, this solution is only applicable to AllGather and AllGatherEP, because in the dp scenario, the previous operation was gate + two communications, and now it is changed to one communication + gate operation, which can save some communication time. In theory, all moe AllGather and AllGatherEP solutions can follow this logic, but now other moe models (qwen3-235b) dp solutions are not adjusted, so use the switch to control it to prevent code errors.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's not common, I prefer not to merge, we can wait more.

Or, if we can add more logic check instead of env var, i'm fine as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, currently m is enabled by default in the AllGather, AllGatherEP and NaiveMulticast scenarios of the deepseek model. rm_router_logits is not enabled in other scenarios and models. You can add it later if necessary.

@github-actions github-actions bot added merge-conflicts and removed ready read for review labels Jul 7, 2025
Copy link

github-actions bot commented Jul 7, 2025

This pull request has conflicts, please resolve those before we can evaluate the pull request.

…iveMulticast]

Signed-off-by: ttanzhiqiang <389825161@qq.com>
…iveMulticast]

Signed-off-by: ttanzhiqiang <389825161@qq.com>
…iveMulticast]

Signed-off-by: ttanzhiqiang <389825161@qq.com>
…iveMulticast]

Signed-off-by: ttanzhiqiang <389825161@qq.com>
…iveMulticast]

Signed-off-by: ttanzhiqiang <389825161@qq.com>
…iveMulticast]

Signed-off-by: ttanzhiqiang <389825161@qq.com>
…iveMulticast]

Signed-off-by: ttanzhiqiang <389825161@qq.com>
…iveMulticast]

Signed-off-by: ttanzhiqiang <389825161@qq.com>
@ttanzhiqiang
Copy link
Contributor Author

update @wangxiyuan @Yikun

Copy link

github-actions bot commented Jul 9, 2025

This pull request has conflicts, please resolve those before we can evaluate the pull request.

…iveMulticast]

Signed-off-by: ttanzhiqiang <389825161@qq.com>
Signed-off-by: ttanzhiqiang <389825161@qq.com>
Signed-off-by: ttanzhiqiang <389825161@qq.com>
Copy link

This pull request has conflicts, please resolve those before we can evaluate the pull request.

Signed-off-by: ttanzhiqiang <389825161@qq.com>
@wangxiyuan wangxiyuan merged commit 9d16c99 into vllm-project:main Jul 11, 2025
22 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants