Skip to content

[Feature] Optimize forward metadata collection across dp ranks #1593

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from

Conversation

jianzs
Copy link
Collaborator

@jianzs jianzs commented Jul 2, 2025

What this PR does / why we need it?

This PR introduces two optimizations for cases where data parallel size > 1:

  1. Eliminates DP communication in set_forward_context
  2. Implements HCCL for DP metadata communication, resulting in significant performance improvements for large DP configurations
    • Achieves ~20ms latency reduction with DP size of 64

Does this PR introduce any user-facing change?

no

How was this patch tested?

CI passed.

@jianzs
Copy link
Collaborator Author

jianzs commented Jul 2, 2025

@NeverRaR PTAL

@NeverRaR
Copy link
Contributor

NeverRaR commented Jul 2, 2025

lgtm

@jianzs jianzs added ready read for review and removed ready read for review labels Jul 2, 2025
@jianzs jianzs requested a review from Copilot July 2, 2025 14:51
Copilot

This comment was marked as outdated.

@jianzs jianzs requested a review from Copilot July 2, 2025 14:57
Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR optimizes how forward-pass metadata is collected and communicated across data-parallel ranks by removing the previous all-reduce and introducing an HCCL-based all-gather approach.

  • Enforce that dummy batch execution only runs under data parallelism and refactor execute_dummy_batch to use per-rank metadata.
  • Replace dist.all_reduce with HCCL all_gather in _get_forward_metadata_across_dp and update callers to handle the Tensor of per-rank token counts.
  • Propagate num_tokens_across_dp through dummy runs and forward contexts, masking sentinel values before the pass.

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
vllm_ascend/worker/worker_v1.py Added assertion for dp_size > 1, refactored dummy-run logic to use HCCL per-rank metadata.
vllm_ascend/worker/model_runner_v1.py Swapped all_reduce for all_gather under get_dp_group(), changed method signature and updated callers to handle a Tensor of metadata.
Comments suppressed due to low confidence (1)

vllm_ascend/worker/model_runner_v1.py:622

  • Add unit or integration tests covering the dp_size > 1 aggregation path to verify that all_gather produces the correct combined metadata and that the masked_fill_ logic correctly replaces sentinel values.
            local_forward_metadata)

@jianzs jianzs requested a review from NeverRaR July 3, 2025 02:37
@jianzs jianzs force-pushed the feat/dp-comm-opt branch from 5d90031 to f1ddce2 Compare July 3, 2025 11:54
Copy link

codecov bot commented Jul 3, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 52.36%. Comparing base (c30ddb8) to head (6100e0d).
Report is 75 commits behind head on main.

Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1593       +/-   ##
===========================================
+ Coverage   27.39%   52.36%   +24.96%     
===========================================
  Files          56       78       +22     
  Lines        6191     9631     +3440     
===========================================
+ Hits         1696     5043     +3347     
- Misses       4495     4588       +93     
Flag Coverage Δ
unittests 52.36% <ø> (+24.96%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@jianzs jianzs added performance-test enable performance test for PR ready-for-test start test by label for PR labels Jul 3, 2025
@jianzs
Copy link
Collaborator Author

jianzs commented Jul 4, 2025

@Yikun @wangxiyuan @ApsarasX @ganyi1996ppo ready to merge.

Comment on lines +625 to +630
local_forward_metadata = torch.tensor([num_tokens, with_prefill],
device="npu",
dtype=torch.int32)
global_forward_metadata = get_dp_group().all_gather(
local_forward_metadata)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello, I tried your PR and run offline inference with DP > 1, and I found that line 630 would throw a "dim out of range" error. It seems that you are using all_gather to get forward metadata across all dp ranks, but local_forward_metadata is actually an 1D tensor. Therefore, global_forward_metadata is also an 1D tensor. Perhaps codes should be changed like this:

Suggested change
local_forward_metadata = torch.tensor([num_tokens, with_prefill],
device="npu",
dtype=torch.int32)
global_forward_metadata = get_dp_group().all_gather(
local_forward_metadata)
local_forward_metadata = torch.tensor([num_tokens, with_prefill],
device="npu",
dtype=torch.int32).unsqueeze(0)
global_forward_metadata = get_dp_group().all_gather(
local_forward_metadata, dim=0)

@jianzs jianzs added the ready read for review label Jul 4, 2025
max_num_tokens)
runner._dummy_run(max_num_tokens,
else:
num_tokens = 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is it 1?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If graph mode is off, a dummy run only needs to be executed; computational requirements are not a factor.

jianzs added 7 commits July 4, 2025 23:50
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
jianzs added 3 commits July 4, 2025 23:50
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
…rker

Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
@jianzs jianzs force-pushed the feat/dp-comm-opt branch from ed8b7b5 to 6100e0d Compare July 4, 2025 15:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance-test enable performance test for PR ready read for review ready-for-test start test by label for PR
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants