Skip to content

[BugFix] Fix KVConnector TP worker aggregation #21473

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 24, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions vllm/v1/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment,
set_custom_all_reduce)
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
has_kv_transfer_group)
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
Expand Down Expand Up @@ -333,19 +334,20 @@ def execute_model(
assert isinstance(output, IntermediateTensors)
get_pp_group().send_tensor_dict(output.tensors,
all_gather_group=get_tp_group())
if not has_kv_transfer_group():
return None

# In case of PP with kv transfer, we need to pass through the
# finished_sending and finished_recving buffers.
empty_output = EMPTY_MODEL_RUNNER_OUTPUT
new_output = EMPTY_MODEL_RUNNER_OUTPUT
if output.finished_sending or output.finished_recving:
empty_output = copy.copy(empty_output)
empty_output.finished_sending = output.finished_sending
empty_output.finished_recving = output.finished_recving
output = empty_output
new_output = copy.copy(new_output)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this needs to be a deepcopy right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we updated the value of EMPTY_MODEL_RUNNER_OUTPUT?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't completely follow what you mean but I'm pretty sure it doesn't need to be a deep copy. We are just copying because we don't want to modify the shared EMPTY_MODEL_RUNNER_OUTPUT itself.

new_output.finished_sending = output.finished_sending
new_output.finished_recving = output.finished_recving
output = new_output

assert isinstance(output, ModelRunnerOutput)
# return output only from the driver worker
return output if self.is_driver_worker else None
return output
Comment on lines +347 to +350
Copy link
Contributor

@sdavidbd sdavidbd Jul 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change modifies the original behavior: if no KV connector is present, non-driver workers in the last PP rank now return output. I'm not certain this has any practical impact, though—under MultiprocExecutor, only the worker with output_rank sends its output back via WorkerProc.
Suggested fix:

            return new_output

        assert isinstance(output, ModelRunnerOutput)
        return_output = self.is_driver_worker or has_kv_transfer_group()
        return output if return_output else None 


def profile(self, is_start: bool = True):
if self.profiler is None:
Expand Down