-
Notifications
You must be signed in to change notification settings - Fork 238
[Feature] Optimize forward metadata collection across dp ranks #1593
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@NeverRaR PTAL |
lgtm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR optimizes how forward-pass metadata is collected and communicated across data-parallel ranks by removing the previous all-reduce and introducing an HCCL-based all-gather approach.
- Enforce that dummy batch execution only runs under data parallelism and refactor
execute_dummy_batch
to use per-rank metadata. - Replace
dist.all_reduce
with HCCLall_gather
in_get_forward_metadata_across_dp
and update callers to handle the Tensor of per-rank token counts. - Propagate
num_tokens_across_dp
through dummy runs and forward contexts, masking sentinel values before the pass.
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
File | Description |
---|---|
vllm_ascend/worker/worker_v1.py | Added assertion for dp_size > 1 , refactored dummy-run logic to use HCCL per-rank metadata. |
vllm_ascend/worker/model_runner_v1.py | Swapped all_reduce for all_gather under get_dp_group() , changed method signature and updated callers to handle a Tensor of metadata. |
Comments suppressed due to low confidence (1)
vllm_ascend/worker/model_runner_v1.py:622
- Add unit or integration tests covering the
dp_size > 1
aggregation path to verify thatall_gather
produces the correct combined metadata and that themasked_fill_
logic correctly replaces sentinel values.
local_forward_metadata)
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #1593 +/- ##
===========================================
+ Coverage 27.39% 52.36% +24.96%
===========================================
Files 56 78 +22
Lines 6191 9631 +3440
===========================================
+ Hits 1696 5043 +3347
- Misses 4495 4588 +93
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
@Yikun @wangxiyuan @ApsarasX @ganyi1996ppo ready to merge. |
local_forward_metadata = torch.tensor([num_tokens, with_prefill], | ||
device="npu", | ||
dtype=torch.int32) | ||
global_forward_metadata = get_dp_group().all_gather( | ||
local_forward_metadata) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello, I tried your PR and run offline inference with DP > 1, and I found that line 630 would throw a "dim out of range" error. It seems that you are using all_gather
to get forward metadata across all dp ranks, but local_forward_metadata
is actually an 1D tensor. Therefore, global_forward_metadata
is also an 1D tensor. Perhaps codes should be changed like this:
local_forward_metadata = torch.tensor([num_tokens, with_prefill], | |
device="npu", | |
dtype=torch.int32) | |
global_forward_metadata = get_dp_group().all_gather( | |
local_forward_metadata) | |
local_forward_metadata = torch.tensor([num_tokens, with_prefill], | |
device="npu", | |
dtype=torch.int32).unsqueeze(0) | |
global_forward_metadata = get_dp_group().all_gather( | |
local_forward_metadata, dim=0) |
max_num_tokens) | ||
runner._dummy_run(max_num_tokens, | ||
else: | ||
num_tokens = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is it 1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If graph mode is off, a dummy run only needs to be executed; computational requirements are not a factor.
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
…rker Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
What this PR does / why we need it?
This PR introduces two optimizations for cases where data parallel size > 1:
set_forward_context
Does this PR introduce any user-facing change?
no
How was this patch tested?
CI passed.