Skip to content

[BugFix] Fix KVConnector TP worker aggregation #21473

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 24, 2025

Conversation

njhill
Copy link
Member

@njhill njhill commented Jul 23, 2025

The ray PD compatibility fix #21072 broke non-ray TP PD.

Discovered by @robertgshaw2-redhat

Signed-off-by: Nick Hill <nhill@redhat.com>
@njhill njhill added the bug Something isn't working label Jul 23, 2025
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Jul 23, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request fixes an issue with KVConnector TP worker aggregation for non-Ray pipeline-decoupled setups. The changes primarily involve adjusting the return logic in gpu_worker.py for non-last pipeline parallel ranks to correctly handle KV cache transfer status. A potential issue was identified where the code could modify a shared mutable global constant. A code suggestion has been provided to address this and make the implementation more robust.

empty_output.finished_sending = output.finished_sending
empty_output.finished_recving = output.finished_recving
output = empty_output
new_output = copy.copy(new_output)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this needs to be a deepcopy right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we updated the value of EMPTY_MODEL_RUNNER_OUTPUT?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't completely follow what you mean but I'm pretty sure it doesn't need to be a deep copy. We are just copying because we don't want to modify the shared EMPTY_MODEL_RUNNER_OUTPUT itself.

@robertgshaw2-redhat robertgshaw2-redhat enabled auto-merge (squash) July 23, 2025 18:06
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 23, 2025
@simon-mo simon-mo added this to the v0.10.0 milestone Jul 24, 2025
@simon-mo simon-mo disabled auto-merge July 24, 2025 03:56
@simon-mo simon-mo merged commit eec6942 into vllm-project:main Jul 24, 2025
75 of 76 checks passed
Comment on lines +347 to +350
output = new_output

assert isinstance(output, ModelRunnerOutput)
# return output only from the driver worker
return output if self.is_driver_worker else None
return output
Copy link
Contributor

@sdavidbd sdavidbd Jul 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change modifies the original behavior: if no KV connector is present, non-driver workers in the last PP rank now return output. I'm not certain this has any practical impact, though—under MultiprocExecutor, only the worker with output_rank sends its output back via WorkerProc.
Suggested fix:

            return new_output

        assert isinstance(output, ModelRunnerOutput)
        return_output = self.is_driver_worker or has_kv_transfer_group()
        return output if return_output else None 

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants