Skip to content

[P/D] Support CPU Transfer in NixlConnector #18293

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 35 commits into
base: main
Choose a base branch
from

Conversation

juncgu
Copy link

@juncgu juncgu commented May 17, 2025

This PR adds TPU support in NixlConnector (#17751) for P/D disaggregated serving.
The high-level idea is to use a buffer in host memory as the kv transfer buffer. The kv transfer buffer is registered under nixl agent (as the type of "DRAM"). The computed KV cache (full blocks) at the prefill instance will be saved to the transfer buffer. One the decode side, the remote KV data will be read into the transfer buffer and then load into the device memory.

Currently, we supports the same P/D disaggregated serving scenarios as #17751:

  • support xPyD
  • support homogeneous TP > 1
  • support P->D request flow

We will follow up the updates in NixlConnector and support the incoming features mentioned in #17751.

How to config NixlConnector for TPU?

We need to set kv_buffer_device to cpu in kv_transfer_config.
For example:

# launch a prefill instance
VLLM_LOGGING_LEVEL=DEBUG \
VLLM_USE_V1=1 \
VLLM_NIXL_SIDE_CHANNEL_HOST=${PREFILL_HOST} \
VLLM_NIXL_SIDE_CHANNEL_PORT=${NIXL_SIDE_PORT} \
PJRT_DEVICE=TPU \
VLLM_WORKER_MULTIPROC_METHOD=spawn \
VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve ${MODEL_NAME} \
    --host ${PREFILL_HOST} \
    --port 8100 \
    --tensor-parallel-size 8 \
    --enforce-eager \
    --gpu-memory-utilization 0.5 \
    --disable-log-requests \
    --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both","kv_buffer_device":"cpu"}'

A simple 1p1d disaggregation example can be found at tests/v1/kv_connector/nixl_integration/run_tpu_disagg_accuracy_test.sh.

Notes:

  1. The current transfer buffer in host memory has the same num of blocks as the kv cache on device. This simplifies the impl. Given that the host memory capacity is often much large than the total HBM capacity within a node (e.g., TPU v6e spec).
  2. By design, this impl. should be able to support any xPUs which are incompatible with the nixl library and need to use host memory as transfer buffer. Currently, only TPU is tested.
  3. The extra time overhead from nixl-agent handshake and xla compilation may hit the time limit of execute_model in the multiproc_executor. Therefore, we'd suggest to relax the timeout value by setting VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S env var.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added v1 tpu Related to Google TPUs labels May 17, 2025
Copy link

mergify bot commented May 17, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @juncgu.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@juncgu
Copy link
Author

juncgu commented May 17, 2025

Copy link

mergify bot commented May 17, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @juncgu.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@robertgshaw2-redhat robertgshaw2-redhat changed the title [P/D] Support TPU in NixlConnector [P/D] Support CPU Transfer in NixlConnector May 17, 2025
@robertgshaw2-redhat robertgshaw2-redhat self-assigned this May 17, 2025
# cpu kv buffer for xfer
# used when xPU memory can not be registered under nixl
self.host_xfer_buffers: dict[str, torch.Tensor] = {}
self.use_host_buffer = True if _NIXL_SUPPORTED_XPU_TYPE.support(self.kv_buffer_device) else False
Copy link
Collaborator

@robertgshaw2-redhat robertgshaw2-redhat May 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any reason why we cannot allows the GPU to use the host buffer?

Also, it is confusing that the tpu uses kv_buffer_device: "tpu" when it is using the host buffer.

I think we should allow the user to specify kv_buffer_device, then have:

_NIXL_SUPPORTED_XPU_TYPE = {
    "cuda": ["cuda", "cpu"],
    "tpu": ["cpu"]
}

if self.kv_buffer_device not in _NIXL_SUPPORTED_XPU_TYPE[current_platform.platform()]:
   raise
   
if self.kv_buffer_device == "cuda":
   self.nixl_memory_type = "VRAM"
else:
   assert self.kv_buffer_device == "cpu":
   self.nixl_memory_type = "DRAM"

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cuda: cpu is not supported yet.

Copy link
Collaborator

@yaochengji yaochengji left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for supporting such an important feature, left a few comments.

Also, could you add some tests? IMO, the minimum requirements are a unit test for nixl connector on tpu and an e2e accuracy test.

if params.get("do_remote_decode"):
# NOTE: only need to save / send full computed blocks
block_ids = blocks.get_block_ids()[0]
all_full = request.num_tokens % self.block_size == 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there multiple requests in the block_ids? If so, request.num_tokens % self.block_size == 0 cannot lead to all_full.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, actually even for one request, e.g. num_tokens = 2 * block_size, but it can use 3 blocks when it's not aligned.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The block_ids is for a single request.

In my understanding, when e.g. prompt_len == 2 * block_size, 2 blocks will be allocated to the request at prefill stage, and the third block will be allocated when its first decode step gets scheduled.

Can you give more information about the case of when it's not aligned?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's say the page_size is 4 and the prompt_len is also 4. there're two pages:
0, 1, 2, 3 | 4, 5, 6, 7
is it possible for the prompt to use the indices of (2, 3, 4, 5)?

@@ -253,14 +307,27 @@ def build_connector_meta(
# Loop through scheduled reqs and convert to ReqMeta.
for req_id, (req, block_ids) in self._reqs_need_recv.items():
assert req.kv_transfer_params is not None
_kv_transfer_params = copy.deepcopy(req.kv_transfer_params)
_kv_transfer_params["do_remote_prefill"] = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May I ask why do we need to make it a deepcopy and set do_remote_prefill to True, but previously we don't need to?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We rely on the attributes of do_remote_prefill/decode in ReqMeta to determine the direction of data-copy (i.e., D2H, or H2D).
The two lines here are just for aligning with add_new_req. Otherwise, we need to re-assign the attribute after calling add_new_req.

# e.g.,
            meta.add_new_req(
                request_id=req_id,
                local_block_ids=block_ids,
                kv_transfer_params=req.kv_transfer_params,
            )
            meta.requests[req_id].do_remote_prefill = True

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's avoid the deepcopys, it adds some unnecessary overhead. No harm in adding new args (with default vals) to add_new_req

Copy link

mergify bot commented May 19, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @juncgu.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Copy link

mergify bot commented May 21, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @juncgu.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label May 21, 2025
tpu_cache: torch.Tensor,
tpu_block_indices: torch.Tensor,
) -> None:
torch.ops.xla.dynamo_set_buffer_donor_(tpu_cache, True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder why we do torch.ops.xla.dynamo_set_buffer_donor_(tpu_cache, True) here

meta = self._recving_metadata[req_id]
# local decode only
if not meta.do_remote_prefill:
return
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should here be continue or return?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since only one request is handled by this function, "continue" or "return" will be the same.

@mergify mergify bot removed the needs-rebase label Jul 8, 2025
@juncgu
Copy link
Author

juncgu commented Jul 8, 2025

Thanks @juncgu for all the work and patience!

Could you merge in the latest main branch? This should hopefully also address the CI failure.

Thanks, @njhill.
I merged with the main, please take a look and advise.

Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @juncgu. It looks good now! but I saw that there's still some unnecessary redundancy which should hopefully be very quick to simplify. Then we can get this merged.

Comment on lines 981 to 982
if not meta.load_remote_cache:
return
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check should not be needed here, it should never reach here with this as False (and per other comment we can eliminate the load_remote_cache field.

Signed-off-by: Juncheng Gu <juncgu@gmail.com>
@juncgu
Copy link
Author

juncgu commented Jul 9, 2025

Thanks @juncgu. It looks good now! but I saw that there's still some unnecessary redundancy which should hopefully be very quick to simplify. Then we can get this merged.

Thanks for your continuous help, @njhill. Those redundancies have been removed.

Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @juncgu for all of your hard work and patience!

@njhill njhill added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 9, 2025
juncgu added 2 commits July 9, 2025 19:30
Signed-off-by: Juncheng Gu <juncgu@gmail.com>
Signed-off-by: Juncheng Gu <juncgu@gmail.com>
Copy link

mergify bot commented Jul 10, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @juncgu.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jul 10, 2025
@njhill
Copy link
Member

njhill commented Jul 10, 2025

@juncgu apologies we merged another PR and this will need some minor modification, I can help with that. Just want to get #20756 merged first since it's a small follow-on fix to the other one.

@juncgu
Copy link
Author

juncgu commented Jul 10, 2025

@juncgu apologies we merged another PR and this will need some minor modification, I can help with that. Just want to get #20756 merged first since it's a small follow-on fix to the other one.

Thanks for the notice, please let me know when others are clear.

@njhill
Copy link
Member

njhill commented Jul 10, 2025

@juncgu ok that other PR has been merged now. The changes eliminate a couple of the "mixin" methods but some logic has to be added into the tpu_worker.py execute_model method. See #19555 and #20756.

Signed-off-by: Juncheng Gu <juncgu@gmail.com>
@mergify mergify bot removed the needs-rebase label Jul 11, 2025
Signed-off-by: Juncheng Gu <juncgu@gmail.com>
@juncgu
Copy link
Author

juncgu commented Jul 11, 2025

@juncgu ok that other PR has been merged now. The changes eliminate a couple of the "mixin" methods but some logic has to be added into the tpu_worker.py execute_model method. See #19555 and #20756.

Hi @njhill, please review the latest version.

@njhill
Copy link
Member

njhill commented Jul 11, 2025

Thanks @juncgu, looks good to me. Could you merge in latest main again, hopefully that will fix the CI failures which look unrelated.

Signed-off-by: Juncheng Gu <juncgu@gmail.com>
@juncgu
Copy link
Author

juncgu commented Jul 11, 2025

Thanks @juncgu, looks good to me. Could you merge in latest main again, hopefully that will fix the CI failures which look unrelated.

Thanks, @njhill. It's clear now after merging the latest main.

Copy link

mergify bot commented Jul 14, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @juncgu.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jul 14, 2025
Signed-off-by: Juncheng Gu <juncgu@gmail.com>
@mergify mergify bot removed the needs-rebase label Jul 14, 2025
Copy link

mergify bot commented Jul 16, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @juncgu.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jul 16, 2025
Signed-off-by: Juncheng Gu <juncgu@gmail.com>
@mergify mergify bot removed the needs-rebase label Jul 16, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build ready ONLY add when PR is ready to merge/full CI is needed tpu Related to Google TPUs v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants