From b3e71466a904f035969a3c7c06de6765b3ef05f1 Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Thu, 12 Jun 2025 16:13:01 -0700 Subject: [PATCH 01/19] tpu: Support CPU Transfer in NixlConnector Signed-off-by: Juncheng Gu --- requirements/tpu.txt | 1 + .../run_tpu_disagg_accuracy_test.sh | 162 ++++++++++ .../run_tpu_edge_case_test.sh | 128 ++++++++ .../nixl_integration/test_disagg_accuracy.py | 161 +++++++++ .../nixl_integration/test_edge_cases.py | 9 +- .../nixl_integration/toy_proxy_server.py | 6 +- .../kv_transfer/kv_connector/v1/base.py | 10 +- .../kv_connector/v1/nixl_connector.py | 306 +++++++++++++++--- vllm/envs.py | 5 + vllm/v1/executor/multiproc_executor.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 58 +--- .../worker/kv_connector_model_runner_mixin.py | 69 ++++ vllm/v1/worker/tpu_model_runner.py | 116 ++++++- vllm/v1/worker/tpu_worker.py | 10 +- 14 files changed, 933 insertions(+), 110 deletions(-) create mode 100644 tests/v1/kv_connector/nixl_integration/run_tpu_disagg_accuracy_test.sh create mode 100644 tests/v1/kv_connector/nixl_integration/run_tpu_edge_case_test.sh create mode 100644 tests/v1/kv_connector/nixl_integration/test_disagg_accuracy.py create mode 100644 vllm/v1/worker/kv_connector_model_runner_mixin.py diff --git a/requirements/tpu.txt b/requirements/tpu.txt index a26dfd460d8e..86db6aa1b1cf 100644 --- a/requirements/tpu.txt +++ b/requirements/tpu.txt @@ -10,6 +10,7 @@ jinja2>=3.1.6 ray[default] ray[data] setuptools==78.1.0 +nixl==0.3.0 # Install torch_xla --pre diff --git a/tests/v1/kv_connector/nixl_integration/run_tpu_disagg_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/run_tpu_disagg_accuracy_test.sh new file mode 100644 index 000000000000..45779d16914f --- /dev/null +++ b/tests/v1/kv_connector/nixl_integration/run_tpu_disagg_accuracy_test.sh @@ -0,0 +1,162 @@ +#!/bin/bash +set -xe + +# Hosts / ports +PREFILL_HOST=${PREFILL_HOST:-"localhost"} +PREFILL_PORT=${PREFILL_PORT:-8100} +PREFILL_NIXL_SIDE_PORT=${PREFILL_NIXL_SIDE_PORT:-5577} +DECODE_HOST=${DECODE_HOST:-"localhost"} +DECODE_PORT=${DECODE_PORT:-8200} +PROXY_HOST=${PROXY_HOST:-"localhost"} +PROXY_PORT=${PROXY_PORT:-8192} +BASELINE_HOST=${BASELINE_HOST:-"localhost"} +BASELINE_PORT=${BASELINE_PORT:-9290} + + +# Model to run. +MODEL_NAME=${MODEL_NAME:-"meta-llama/Llama-3.2-3B-Instruct"} +MAX_MODEL_LEN=${MAX_MODEL_LEN:-1024} +BLOCK_SIZE=${BLOCK_SIZE:-32} + + +# execution env +GIT_ROOT=$(git rev-parse --show-toplevel) +EXP_ROOT="${GIT_ROOT}/tests/v1/kv_connector/nixl_integration" +CONDA_PATH=${CONDA_PATH:-"/home/${USER}/anaconda3"} +CONDA_ENV_NAME=${CONDA_ENV_NAME:-"nixl"} + +OUTPUT_FILE=${OUTPUT_FILE:-"${EXP_ROOT}/.tpu_accuracy_test_outputs.txt"} + +# Trap the SIGINT signal (triggered by Ctrl+C) +trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT + + +# Waits for vLLM server to start. +wait_for_server() { + local host=$1 + local port=$2 + timeout 1200 bash -c " + until curl -s ${host}:${port}/v1/completions > /dev/null; do + sleep 1 + done" && return 0 || return 1 +} + +# Cleanup function +cleanup() { + echo "Caught Ctrl+C, cleaning up..." + # Cleanup commands + pgrep python | xargs kill -9 || true + # pkill -f python || true + echo "Cleanup complete. Exiting." +} + +launch_baseline() { + BASELINE_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME}; + VLLM_LOGGING_LEVEL=DEBUG \ + VLLM_USE_V1=1 \ + PJRT_DEVICE=TPU \ + VLLM_WORKER_MULTIPROC_METHOD=spawn \ + VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve $MODEL_NAME \ + --host ${BASELINE_HOST} \ + --port ${BASELINE_PORT} \ + --max-model-len ${MAX_MODEL_LEN}\ + --seed 42 \ + --block-size ${BLOCK_SIZE} \ + --gpu-memory-utilization 0.5 \ + --disable-log-requests \ + --enforce-eager" + echo ${BASELINE_BASE_CMD} + ssh -tt ${BASELINE_HOST} "${BASELINE_BASE_CMD}" & +} + +launch_pd() { + PREFILL_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME}; + UCX_TLS=tcp \ + VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S=200 \ + VLLM_LOGGING_LEVEL=DEBUG \ + VLLM_USE_V1=1 \ + VLLM_NIXL_SIDE_CHANNEL_HOST=${PREFILL_HOST} \ + VLLM_NIXL_SIDE_CHANNEL_PORT=${PREFILL_NIXL_SIDE_PORT} \ + PJRT_DEVICE=TPU \ + VLLM_WORKER_MULTIPROC_METHOD=spawn \ + VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve $MODEL_NAME \ + --host ${PREFILL_HOST} \ + --port ${PREFILL_PORT} \ + --max-model-len ${MAX_MODEL_LEN}\ + --seed 42 \ + --block-size ${BLOCK_SIZE} \ + --enforce-eager \ + --gpu-memory-utilization 0.5 \ + --disable-log-requests \ + --kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"cpu\"}'" + + + DECODE_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME}; + UCX_TLS=tcp \ + VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S=200 \ + VLLM_LOGGING_LEVEL=DEBUG \ + VLLM_USE_V1=1 \ + PJRT_DEVICE=TPU \ + VLLM_WORKER_MULTIPROC_METHOD=spawn \ + VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve $MODEL_NAME \ + --host ${DECODE_HOST} \ + --port ${DECODE_PORT} \ + --max-model-len ${MAX_MODEL_LEN}\ + --seed 42 \ + --block-size ${BLOCK_SIZE} \ + --enforce-eager \ + --gpu-memory-utilization 0.5 \ + --disable-log-requests \ + --kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"cpu\"}'" + + echo ${PREFILL_BASE_CMD} + echo ${DECODE_BASE_CMD} + sleep 2 + + # execute on hosts + ssh -tt ${PREFILL_HOST} "${PREFILL_BASE_CMD}" & + ssh -tt ${DECODE_HOST} "${DECODE_BASE_CMD}" & + sleep 1 + wait_for_server ${PREFILL_HOST} ${PREFILL_PORT} + sleep 1 + wait_for_server ${DECODE_HOST} ${DECODE_PORT} + sleep 1 +} + +launch_pd_proxy(){ + PROXY_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME}; + python3 ${EXP_ROOT}/toy_proxy_server.py \ + --prefiller-host ${PREFILL_HOST} --prefiller-port ${PREFILL_PORT} \ + --decoder-host ${DECODE_HOST} --decoder-port ${DECODE_PORT} \ + --host=${PROXY_HOST} --port ${PROXY_PORT}" + echo ${PROXY_BASE_CMD} + ssh -tt ${PROXY_HOST} "${PROXY_BASE_CMD}" & +} + +run_tests(){ + local service_url=$1 + local mode=$2 + python3 ${EXP_ROOT}/test_disagg_accuracy.py --service_url=${service_url} --model_name=${MODEL_NAME} --mode=${mode} --file_name=${OUTPUT_FILE} +} + + +# run non-disagg. baseline & save outputs +launch_baseline +sleep 2 +wait_for_server ${BASELINE_HOST} ${BASELINE_PORT} +run_tests "http://${BASELINE_HOST}:${BASELINE_PORT}" "baseline" +cleanup +sleep 10 + + +# run disagg. & do exact-match with the outputs from baseline +launch_pd +launch_pd_proxy +sleep 10 +run_tests "http://${PROXY_HOST}:${PROXY_PORT}" "disagg" +echo "-----P/D success----" + +rm ${OUTPUT_FILE} +cleanup + +exit 0 \ No newline at end of file diff --git a/tests/v1/kv_connector/nixl_integration/run_tpu_edge_case_test.sh b/tests/v1/kv_connector/nixl_integration/run_tpu_edge_case_test.sh new file mode 100644 index 000000000000..c37c92fdf5d3 --- /dev/null +++ b/tests/v1/kv_connector/nixl_integration/run_tpu_edge_case_test.sh @@ -0,0 +1,128 @@ +#!/bin/bash +set -xe + +# Hosts / ports +PREFILL_HOST=${PREFILL_HOST:-"localhost"} +PREFILL_PORT=${PREFILL_PORT:-8100} +PREFILL_NIXL_SIDE_PORT=${PREFILL_NIXL_SIDE_PORT:-5577} +DECODE_HOST=${DECODE_HOST:-"localhost"} +DECODE_PORT=${DECODE_PORT:-8200} +PROXY_HOST=${PROXY_HOST:-"localhost"} +PROXY_PORT=${PROXY_PORT:-8192} +BASELINE_HOST=${BASELINE_HOST:-"localhost"} +BASELINE_PORT=${BASELINE_PORT:-9290} + + +# Model to run. +MODEL_NAME=${MODEL_NAME:-"meta-llama/Llama-3.2-3B-Instruct"} +MAX_MODEL_LEN=${MAX_MODEL_LEN:-1024} +BLOCK_SIZE=${BLOCK_SIZE:-32} + + +# execution env +GIT_ROOT=$(git rev-parse --show-toplevel) +EXP_ROOT="${GIT_ROOT}/tests/v1/kv_connector/nixl_integration" +CONDA_PATH=${CONDA_PATH:-"/home/${USER}/anaconda3"} +CONDA_ENV_NAME=${CONDA_ENV_NAME:-"nixl"} + +OUTPUT_FILE=${OUTPUT_FILE:-"${EXP_ROOT}/.tpu_accuracy_test_outputs.txt"} + +# Trap the SIGINT signal (triggered by Ctrl+C) +trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT + +# Waits for vLLM server to start. +wait_for_server() { + local host=$1 + local port=$2 + timeout 1200 bash -c " + until curl -s ${host}:${port}/v1/completions > /dev/null; do + sleep 1 + done" && return 0 || return 1 +} + +# Cleanup function +cleanup() { + echo "Caught Ctrl+C, cleaning up..." + # Cleanup commands + pgrep python | xargs kill -9 || true + # pkill -f python || true + echo "Cleanup complete. Exiting." +} + + +launch_pd() { + PREFILL_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME}; + UCX_TLS=tcp \ + VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S=200 \ + VLLM_LOGGING_LEVEL=DEBUG \ + VLLM_USE_V1=1 \ + VLLM_NIXL_SIDE_CHANNEL_HOST=${PREFILL_HOST} \ + VLLM_NIXL_SIDE_CHANNEL_PORT=${PREFILL_NIXL_SIDE_PORT} \ + PJRT_DEVICE=TPU \ + VLLM_WORKER_MULTIPROC_METHOD=spawn \ + VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve $MODEL_NAME \ + --host ${PREFILL_HOST} \ + --port ${PREFILL_PORT} \ + --max-model-len ${MAX_MODEL_LEN}\ + --seed 42 \ + --block-size ${BLOCK_SIZE} \ + --enforce-eager \ + --gpu-memory-utilization 0.5 \ + --disable-log-requests \ + --kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"cpu\"}'" + + + DECODE_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME}; + UCX_TLS=tcp \ + VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S=200 \ + VLLM_LOGGING_LEVEL=DEBUG \ + VLLM_USE_V1=1 \ + PJRT_DEVICE=TPU \ + VLLM_WORKER_MULTIPROC_METHOD=spawn \ + VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve $MODEL_NAME \ + --host ${DECODE_HOST} \ + --port ${DECODE_PORT} \ + --max-model-len ${MAX_MODEL_LEN}\ + --seed 42 \ + --block-size ${BLOCK_SIZE} \ + --enforce-eager \ + --gpu-memory-utilization 0.5 \ + --disable-log-requests \ + --kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"cpu\"}'" + + echo ${PREFILL_BASE_CMD} + echo ${DECODE_BASE_CMD} + sleep 2 + + # execute on hosts + ssh -tt ${PREFILL_HOST} "${PREFILL_BASE_CMD}" & + ssh -tt ${DECODE_HOST} "${DECODE_BASE_CMD}" & + sleep 1 + wait_for_server ${PREFILL_HOST} ${PREFILL_PORT} + sleep 1 + wait_for_server ${DECODE_HOST} ${DECODE_PORT} + sleep 1 +} + +launch_pd_proxy(){ + PROXY_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME}; + python3 ${EXP_ROOT}/toy_proxy_server.py \ + --prefiller-host ${PREFILL_HOST} --prefiller-port ${PREFILL_PORT} \ + --decoder-host ${DECODE_HOST} --decoder-port ${DECODE_PORT} \ + --host=${PROXY_HOST} --port ${PROXY_PORT}" + echo ${PROXY_BASE_CMD} + ssh -tt ${PROXY_HOST} "${PROXY_BASE_CMD}" & +} + + +# run disagg. & do exact-match with the outputs from baseline +launch_pd +launch_pd_proxy +sleep 10 + +PREFILL_HOST=${PREFILL_HOST} \ +PREFILL_PORT=${PREFILL_PORT} \ +DECODE_HOST=${DECODE_HOST} \ +DECODE_PORT=${DECODE_PORT} \ +PROXY_HOST=${PROXY_HOST} \ +PROXY_PORT=${PROXY_PORT} python -m pytest -s -v ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/test_edge_cases.py \ No newline at end of file diff --git a/tests/v1/kv_connector/nixl_integration/test_disagg_accuracy.py b/tests/v1/kv_connector/nixl_integration/test_disagg_accuracy.py new file mode 100644 index 000000000000..fe30d0fbaaec --- /dev/null +++ b/tests/v1/kv_connector/nixl_integration/test_disagg_accuracy.py @@ -0,0 +1,161 @@ +# SPDX-License-Identifier: Apache-2.0 +import argparse +import json +import os +import time + +import openai +import requests + +MAX_OUTPUT_LEN = 30 + +SAMPLE_PROMPTS = ( + "Red Hat is the best company in the world to work for because it works on " + "open source software, which means that all the contributions are " + "delivered to the community. As a result, when working on projects like " + "vLLM we are able to meet many amazing people from various organizations " + "like AMD, Google, NVIDIA, ", + "We hold these truths to be self-evident, that all men are created equal, " + "that they are endowed by their Creator with certain unalienable Rights, " + "that among these are Life, Liberty and the pursuit of Happiness.--That " + "to secure these rights, Governments are instituted among Men, deriving " + "their just powers from the consent of the governed, ", +) + + +def check_vllm_server(url: str, timeout=5, retries=3) -> bool: + """ + Checks if the vLLM server is ready by sending a GET request to the + /health endpoint. + + Args: + url (str): The base URL of the vLLM server. + timeout (int): Timeout in seconds for the request. + retries (int): Number of retries if the server is not ready. + + Returns: + bool: True if the server is ready, False otherwise. + """ + for attempt in range(retries): + try: + response = requests.get(url, timeout=timeout) + if response.status_code == 200: + return True + else: + print(f"Attempt {attempt + 1}: Server returned status code " + "{response.status_code}") + except requests.exceptions.RequestException as e: + print(f"Attempt {attempt + 1}: Error connecting to server: {e}") + time.sleep(1) # Wait before retrying + return False + + +def run_simple_prompt(base_url: str, model_name: str, + input_prompt: str) -> str: + client = openai.OpenAI(api_key="EMPTY", base_url=base_url) + completion = client.completions.create(model=model_name, + prompt=input_prompt, + max_tokens=MAX_OUTPUT_LEN, + temperature=0.0, + seed=42) + + # print("-" * 50) + # print(f"Completion results for {model_name}:") + # print(completion) + # print("-" * 50) + return completion.choices[0].text + + +def main(): + """ + This script demonstrates how to accept two optional string arguments + ("service_url" and "file_name") from the command line, each with a + default value of an empty string, using the argparse module. + """ + parser = argparse.ArgumentParser(description="vLLM client script") + + parser.add_argument( + "--service_url", # Name of the first argument + type=str, + required=True, + help="The vLLM service URL.") + + parser.add_argument( + "--model_name", # Name of the first argument + type=str, + required=True, + help="model_name", + ) + + parser.add_argument( + "--mode", # Name of the second argument + type=str, + default="baseline", + help="mode: baseline==non-disagg, or disagg", + ) + + parser.add_argument( + "--file_name", # Name of the second argument + type=str, + default=".vllm_output.txt", + help="the file that saves the output tokens ", + ) + + args = parser.parse_args() + + for arg in vars(args): + print(f"{arg}: {getattr(args, arg)}") + + if args.mode == "baseline": + # non-disagg + health_check_url = f"{args.service_url}/health" + else: + # disagg proxy + health_check_url = f"{args.service_url}/healthcheck" + if not os.path.exists(args.file_name): + raise ValueError( + f"In disagg mode, the output file {args.file_name} from " + "non-disagg. baseline does not exist.") + + service_url = f"{args.service_url}/v1" + + if not check_vllm_server(health_check_url): + raise RuntimeError( + f"vllm server: {args.service_url} is not ready yet!") + + output_strs = dict() + for prompt in SAMPLE_PROMPTS: + output_str = run_simple_prompt(base_url=service_url, + model_name=args.model_name, + input_prompt=prompt) + print(f"Prompt: {prompt}, output: {output_str}") + output_strs[prompt] = output_str + + if args.mode == "baseline": + # baseline: save outputs + try: + with open(args.file_name, 'w') as json_file: + json.dump(output_strs, json_file, indent=4) + except OSError as e: + print(f"Error writing to file: {e}") + raise + else: + # disagg. verify outputs + baseline_outputs = None + try: + with open(args.file_name) as json_file: + baseline_outputs = json.load(json_file) + except OSError as e: + print(f"Error writing to file: {e}") + raise + assert isinstance(baseline_outputs, dict) + assert len(baseline_outputs) == len(output_strs) + for prompt, output in baseline_outputs.items(): + assert prompt in output_strs, f"{prompt} not included" + assert output == output_strs[prompt], ( + f"baseline_output: {output} != PD output: {output_strs[prompt]}" + ) + + +if __name__ == "__main__": + main() diff --git a/tests/v1/kv_connector/nixl_integration/test_edge_cases.py b/tests/v1/kv_connector/nixl_integration/test_edge_cases.py index 95465a25fc9d..8439e30be154 100644 --- a/tests/v1/kv_connector/nixl_integration/test_edge_cases.py +++ b/tests/v1/kv_connector/nixl_integration/test_edge_cases.py @@ -4,8 +4,11 @@ import openai +PREFILL_HOST = os.getenv("PREFILL_HOST", "localhost") PREFILL_PORT = os.getenv("PREFILL_PORT", None) +DECODE_HOST = os.getenv("DECODE_HOST", "localhost") DECODE_PORT = os.getenv("DECODE_PORT", None) +PROXY_HOST = os.getenv("PROXY_HOST", "localhost") PROXY_PORT = os.getenv("PROXY_PORT", None) if PREFILL_PORT is None or DECODE_PORT is None or PROXY_PORT is None: @@ -21,15 +24,15 @@ def test_edge_cases(): # Set the OpenAI API key and base URL decode_client = openai.OpenAI( api_key="MY_KEY", - base_url=f"http://localhost:{DECODE_PORT}/v1", + base_url=f"http://{DECODE_HOST}:{DECODE_PORT}/v1", ) prefill_client = openai.OpenAI( api_key="MY_KEY", - base_url=f"http://localhost:{PREFILL_PORT}/v1", + base_url=f"http://{PREFILL_HOST}:{PREFILL_PORT}/v1", ) proxy_client = openai.OpenAI( api_key="MY_KEY", - base_url=f"http://localhost:{PROXY_PORT}/v1", + base_url=f"http://{PROXY_HOST}:{PROXY_PORT}/v1", ) # Get the list of models diff --git a/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py b/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py index 3d720fe0cafe..a5f409e6cacc 100644 --- a/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py +++ b/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py @@ -3,6 +3,7 @@ import argparse import itertools +import logging import os import uuid from contextlib import asynccontextmanager @@ -11,9 +12,8 @@ from fastapi import FastAPI, Request from fastapi.responses import StreamingResponse -from vllm.logger import init_logger - -logger = init_logger(__name__) +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) @asynccontextmanager diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index f80b5eba235d..dc005f33536a 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -32,7 +32,7 @@ import enum from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Callable, Optional import torch @@ -124,6 +124,14 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """ return + def set_host_xfer_buffer_ops(self, d2h_copy_blocks: Callable, + h2d_copy_blocks: Callable): + """ + Set the xPU-specific ops for copying KV between host and device. + Needed when host buffer is used for kv transfer (e.g., in NixlConnector) + """ + return + @abstractmethod def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index bdab4850d4c1..9be7d4dd0d94 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import contextlib +import copy import math import threading import time @@ -8,7 +9,7 @@ from collections import defaultdict from collections.abc import Iterator from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Callable, Optional import msgspec import torch @@ -23,7 +24,7 @@ get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, get_tp_group) from vllm.logger import init_logger -from vllm.platforms import _Backend +from vllm.platforms import _Backend, current_platform from vllm.utils import make_zmq_path, make_zmq_socket, round_down from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.request import RequestStatus @@ -48,6 +49,29 @@ NixlWrapper = None +class _NIXL_SUPPORTED_XPU: + """ + xPUs and the corresponding types of kv transfer buffer + supported by NIXLConnector + """ + # {xPU: tuple of supported kv buffer types} + # TODO: "cpu" xfer buffer for cuda + _support_dict = { + "cuda": ("cuda", ), + "tpu": ("cpu", ), + } + + @classmethod + def is_supported_xpu(cls, device_type: str) -> bool: + return device_type in cls._support_dict + + @classmethod + def is_supported_kv_buffer(cls, device_type: str, + kv_buffer_type: str) -> bool: + return (device_type in cls._support_dict + and kv_buffer_type in cls._support_dict[device_type]) + + class NixlAgentMetadata( msgspec.Struct, omit_defaults=True, # type: ignore[call-arg] @@ -69,6 +93,8 @@ class ReqMeta: remote_host: str remote_port: int remote_engine_id: str + do_remote_prefill: bool = False + do_remote_decode: bool = False class NixlConnectorMetadata(KVConnectorMetadata): @@ -88,6 +114,8 @@ def add_new_req( remote_engine_id=kv_transfer_params["remote_engine_id"], remote_host=kv_transfer_params["remote_host"], remote_port=kv_transfer_params["remote_port"], + do_remote_prefill=kv_transfer_params["do_remote_prefill"], + do_remote_decode=kv_transfer_params["do_remote_decode"], ) @@ -98,7 +126,7 @@ def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): self.engine_id = vllm_config.kv_transfer_config.engine_id if role == KVConnectorRole.SCHEDULER: - self.connector_scheduler : Optional[NixlConnectorScheduler] = \ + self.connector_scheduler: Optional[NixlConnectorScheduler] = \ NixlConnectorScheduler(vllm_config, str(self.engine_id)) self.connector_worker: Optional[NixlConnectorWorker] = None elif role == KVConnectorRole.WORKER: @@ -146,6 +174,12 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): assert self.connector_worker is not None self.connector_worker.register_kv_caches(kv_caches) + def set_host_xfer_buffer_ops(self, d2h_copy_blocks: Callable, + h2d_copy_blocks: Callable): + assert self.connector_worker is not None + self.connector_worker.set_host_xfer_buffer_ops(d2h_copy_blocks, + h2d_copy_blocks) + def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]: """Get the finished recving and sending requests.""" @@ -168,8 +202,10 @@ def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, pass def wait_for_save(self): - """NixlConnector does not save explicitly.""" - pass + assert self.connector_worker is not None + assert isinstance(self._connector_metadata, NixlConnectorMetadata) + self.connector_worker.save_kv_to_host(self._connector_metadata) + return class NixlConnectorScheduler: @@ -190,6 +226,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # New requests are added by update_state_after_alloc in # the scheduler. Used to make metadata passed to Worker. self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {} + self._reqs_need_send: dict[str, tuple[Request, list[int]]] = {} def get_num_new_matched_tokens( self, request: "Request", @@ -237,7 +274,22 @@ def update_state_after_alloc(self, request: "Request", "num_external_tokens=%s, kv_transfer_params=%s", num_external_tokens, params) - if params is not None and params.get("do_remote_prefill"): + if not params: + return + if params.get("do_remote_decode"): + # NOTE: figure out full computed blocks to send / save + block_ids = blocks.get_block_ids()[0] + all_full = request.num_tokens % self.block_size == 0 + full_block_ids = (block_ids if all_full else block_ids[:-1]) + # TODO: skip the blocks that are already in the host xfer buffer. + # Currently, the host xfer buffer block is 1-to-1 mapped to device + # kv blocks, so host blocks won't be flushed as long as its device + # block is not overwritten; and it will be safe to skip saving them + # to host xfer buffer. + if full_block_ids: + self._reqs_need_send[request.request_id] = \ + (request, full_block_ids) + elif params.get("do_remote_prefill"): if params.get("remote_block_ids"): if all(p in params for p in ("remote_engine_id", "remote_host", "remote_port")): @@ -249,6 +301,7 @@ def update_state_after_alloc(self, request: "Request", # Get unhashed blocks to pull from remote. self._reqs_need_recv[request.request_id] = ( request, local_block_ids) + else: logger.warning( "Got invalid KVTransferParams: %s. This " @@ -267,14 +320,27 @@ def build_connector_meta( # Loop through scheduled reqs and convert to ReqMeta. for req_id, (req, block_ids) in self._reqs_need_recv.items(): assert req.kv_transfer_params is not None + _kv_transfer_params = copy.deepcopy(req.kv_transfer_params) + _kv_transfer_params["do_remote_prefill"] = True meta.add_new_req( request_id=req_id, local_block_ids=block_ids, - kv_transfer_params=req.kv_transfer_params, + kv_transfer_params=_kv_transfer_params, + ) + + for req_id, (req, block_ids) in self._reqs_need_send.items(): + assert req.kv_transfer_params is not None + _kv_transfer_params = copy.deepcopy(req.kv_transfer_params) + _kv_transfer_params["do_remote_decode"] = True + meta.add_new_req( + request_id=req_id, + local_block_ids=block_ids, + kv_transfer_params=_kv_transfer_params, ) # Clear the list once workers start the transfers self._reqs_need_recv.clear() + self._reqs_need_send.clear() return meta @@ -324,6 +390,12 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): logger.info("Initializing NIXL wrapper") logger.info("Initializing NIXL worker %s", engine_id) + self.device_type = current_platform.device_type + if not _NIXL_SUPPORTED_XPU.is_supported_xpu( + device_type=self.device_type): + logger.error("%s is not supported.", self.device_type) + raise RuntimeError(f"{self.device_type} is not supported.") + # Config. self.vllm_config = vllm_config self.block_size = vllm_config.cache_config.block_size @@ -347,9 +419,37 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.tp_rank = get_tensor_model_parallel_rank() self.world_size = get_tensor_model_parallel_world_size() self.tp_group = get_tp_group() + self.num_blocks = 0 # KV Caches and nixl tracking data. - self.kv_caches: dict[str, torch.Tensor] = {} + self.kv_buffer_device: str = \ + vllm_config.kv_transfer_config.kv_buffer_device.strip().lower() + if not _NIXL_SUPPORTED_XPU.is_supported_kv_buffer( + device_type=self.device_type, + kv_buffer_type=self.kv_buffer_device): + raise RuntimeError( + f"{self.device_type} with {self.kv_buffer_device} kv_buffer " + "is not supported.") + self.device_kv_caches: dict[str, torch.Tensor] = {} + self.device: torch.device = None + self.device_index: int = -1 + + # cpu kv buffer for xfer + # used when xPU memory can not be registered under nixl + self.host_xfer_buffers: dict[str, torch.Tensor] = {} + self.use_host_buffer = self.kv_buffer_device == "cpu" + if self.kv_buffer_device == "cuda": + self.nixl_memory_type = "VRAM" + elif self.kv_buffer_device == "cpu": + self.nixl_memory_type = "DRAM" + else: + raise RuntimeError( + f"{self.device_type} with {self.kv_buffer_device} kv_buffer " + "is not supported.") + + # Note: host xfer buffer ops when use_host_buffer is True + self.d2h_copy_blocks: Optional[Callable] = None + self.h2d_copy_blocks: Optional[Callable] = None # Map of engine_id -> kv_caches_base_addr. For TP case, each local # rank will still only pull from a single remote TP worker. @@ -372,6 +472,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # In progress transfers. # [req_id -> list[handle]] + self._recving_metadata: dict[str, ReqMeta] = {} self._recving_transfers = defaultdict[str, list[Transfer]](list) # Complete transfer tracker. Used by the rank 0 to track finished @@ -405,6 +506,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.backend_name = backend.get_name() attn_backend = backend_name_to_enum(self.backend_name) self._use_flashinfer = attn_backend == _Backend.FLASHINFER_VLLM_V1 + self._use_pallas_v1 = attn_backend == _Backend.PALLAS_VLLM_V1 logger.debug("Detected attention backend %s", self.backend_name) self._tp_size: dict[str, int] = {self.engine_id: self.world_size} @@ -482,53 +584,114 @@ def handshake(path: str, rank: int) -> NixlAgentMetadata: path, p_remote_rank) _ = handshake(path, p_remote_rank) + def initialize_host_xfer_buffer( + self, kv_caches: dict[str, torch.Tensor]) -> None: + """Initialize transfer buffer in CPU mem for xPUs (e.g., tpu)""" + xfer_buffers: dict[str, torch.Tensor] = {} + try: + for layer_name, kv_cache in kv_caches.items(): + kv_shape = kv_cache.shape + kv_dtype = kv_cache.dtype + xfer_buffers[layer_name] = torch.zeros(kv_shape, + dtype=kv_dtype, + device="cpu") + except MemoryError as e: + logger.error("NIXLConnectorWorker gets %s.", e) + raise + + self.host_xfer_buffers = xfer_buffers + + def set_host_xfer_buffer_ops(self, d2h_copy_blocks: Callable, + h2d_copy_blocks: Callable): + """Assign copy (d2h, h2d) operations when host buffer is used.""" + assert self.use_host_buffer + self.d2h_copy_blocks = d2h_copy_blocks + self.h2d_copy_blocks = h2d_copy_blocks + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data in nixl.""" _, first_kv_cache = next(iter(kv_caches.items())) kv_elem_size = first_kv_cache.element_size() + if self.use_host_buffer: + self.initialize_host_xfer_buffer(kv_caches=kv_caches) + assert len(self.host_xfer_buffers) == len(kv_caches), ( + f"host_buffer: {len(self.host_xfer_buffers)}, " + f"kv_caches: {len(kv_caches)}") + xfer_buffers = self.host_xfer_buffers + else: + xfer_buffers = kv_caches + assert not self.host_xfer_buffers, ( + "host_xfer_buffer should not be initialized when " + f"kv_buffer_device is {self.kv_buffer_device}") + # TODO(tms): Find a more robust way to detect and handle MLA # NOTE (NickLucche) To move blocks efficiently with NIXL, the expected # KV memory layout is HND, as opposed to the default NHD. Note that it # will only affects the strides. For MLA instead, we make require no # such thing and resort to the standard layout. use_mla = len(first_kv_cache.shape) == 3 - assert use_mla == self.use_mla - - # TODO (NickLucche) not compatible with hybrid allocator. Enforce check - # once it goes live, as a single kv layout is expected for xfers. - if use_mla: - # MLA case. + if self.device_type == "tpu": + assert not use_mla, f"{self.kv_buffer_device} does not support MLA." + assert self._use_pallas_v1, f"attn backend: {self.backend_name}" + # tpu (v1) kv shape per layer: + # (num_blocks, block_size, num_kv_heads * 2, head_size) self.num_blocks = first_kv_cache.shape[0] - block_rank = 2 # [block_size, latent_dim] + block_rank = 3 # [block_size, kv_heads, head_dim] block_shape = first_kv_cache.shape[-block_rank:] - block_size, kv_latent_dim = block_shape - self.slot_size_bytes = kv_elem_size * kv_latent_dim - else: - # [2 (k and v), num_blocks, ...] - if self._use_flashinfer: - # FlashInfer swaps 2<->num_blocks dimensions. + block_size, n_kv_heads_x_2, head_dim = block_shape + self.slot_size_bytes = kv_elem_size * n_kv_heads_x_2 * head_dim + elif self.device_type == "cuda": + # TODO (NickLucche) not compatible with hybrid allocator. + # Enforce check once it goes live, as a single kv layout + # is expected for xfers. + if use_mla: + # MLA case. self.num_blocks = first_kv_cache.shape[0] - block_rank = 4 # [2, block_size, kv_heads, head_dim] + block_rank = 2 # [block_size, latent_dim] + block_shape = first_kv_cache.shape[-block_rank:] + block_size, kv_latent_dim = block_shape + self.slot_size_bytes = kv_elem_size * kv_latent_dim else: - self.num_blocks = first_kv_cache.shape[1] - block_rank = 3 # [block_size, kv_heads, head_dim] - block_shape = first_kv_cache.shape[-block_rank:] - block_size, n_kv_heads, head_dim = block_shape[-3:] - # head size in bytes. - self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim - assert block_size == self.block_size + # [2 (k and v), num_blocks, ...] + if self._use_flashinfer: + # FlashInfer swaps 2<->num_blocks dimensions. + self.num_blocks = first_kv_cache.shape[0] + block_rank = 4 # [2, block_size, kv_heads, head_dim] + else: + self.num_blocks = first_kv_cache.shape[1] + block_rank = 3 # [block_size, kv_heads, head_dim] + block_shape = first_kv_cache.shape[-block_rank:] + block_size, n_kv_heads, head_dim = block_shape[-3:] + # head size in bytes. + self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim + assert block_size == self.block_size + else: + raise RuntimeError(f"{self.device_type} is not supported yet.") + # TODO(tms): self.block_len needs to be per-layer for sliding window, # hybrid attn, etc # block size in bytes self.block_len = kv_elem_size * math.prod(block_shape) - logger.info( - "Registering KV_Caches: use_mla: %s, num_blocks: %s, " - "block_shape: %s, per_layer_kv_cache_shape: %s", use_mla, - self.num_blocks, block_shape, first_kv_cache.shape) + + logger.debug( + "Registering KV_Caches. use_mla: %s, kv_buffer_device: %s, " + "use_host_buffer: %s, shape %s", use_mla, self.kv_buffer_device, + self.use_host_buffer, first_kv_cache.shape) + logger.debug("num_blocks: %s, block_shape: %s", self.num_blocks, + block_shape) + logger.debug("Per layer kv cache size: %s", first_kv_cache.shape) self.dst_num_blocks[self.engine_id] = self.num_blocks - self.kv_caches = kv_caches + self.device_kv_caches = kv_caches + self.device = first_kv_cache.device + # if CPU device has no index + self.device_index = 0 if not hasattr(self.device, "index") else \ + self.device.index + assert self.device + assert self.device_index >= 0, \ + f"cache device {self.device} index is invalid" + kv_caches_base_addr = [] caches_data = [] @@ -540,19 +703,19 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # (roughly 8KB vs 5KB). # Conversely for FlashInfer, K and V are transferred in the same tensor # to better exploit the memory layout (ie num_blocks is the first dim). - for cache_or_caches in kv_caches.values(): + for cache_or_caches in xfer_buffers.values(): # Normalize to always be a list of caches - cache_list = [cache_or_caches] if use_mla or self._use_flashinfer \ - else cache_or_caches + cache_list = [cache_or_caches] if use_mla \ + or self._use_pallas_v1 or self._use_flashinfer \ + else cache_or_caches for cache in cache_list: base_addr = cache.data_ptr() region_len = self.num_blocks * self.block_len - caches_data.append( - (base_addr, region_len, cache.device.index, "")) + caches_data.append((base_addr, region_len, self.tp_rank, "")) kv_caches_base_addr.append(base_addr) self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr self.num_regions = len(caches_data) - self.num_layers = len(self.kv_caches.keys()) + self.num_layers = len(xfer_buffers.keys()) # TODO(mgoin): remove this once we have hybrid memory allocator # Optimization for models with local attention (Llama 4) @@ -574,7 +737,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.block_window_per_layer) assert len(self.block_window_per_layer) == self.num_layers - descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM") + descs = self.nixl_wrapper.get_reg_descs(caches_data, + self.nixl_memory_type) logger.debug("Registering descs: %s", caches_data) self.nixl_wrapper.register_memory(descs) logger.debug("Done registering descs") @@ -596,7 +760,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): logger.debug("Created %s blocks for src engine %s and rank %s", len(blocks_data), self.engine_id, self.tp_rank) - descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, + self.nixl_memory_type) # NIXL_INIT_AGENT to be used for preparations of local descs. self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist( "NIXL_INIT_AGENT", descs) @@ -683,6 +848,9 @@ def add_remote_agent(self, "Local TP size must be divisible by remote TP size.") tp_ratio = self._tp_size[self.engine_id] // self._tp_size[engine_id] assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP" + assert self._use_pallas_v1 and tp_ratio == 1, \ + "TPU (pallas_v1) DOES NOT support heterogeneous TP yet." + if self.use_mla: # With MLA the only difference is in the number of blocks. remote_block_size = nixl_agent_meta.block_len // ( @@ -737,11 +905,53 @@ def add_remote_agent(self, self.tp_rank) # Register with NIXL. - descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, + self.nixl_memory_type) self.dst_xfer_side_handles[ engine_id] = self.nixl_wrapper.prep_xfer_dlist( self._remote_agents[engine_id][remote_tp_rank], descs) + def sync_recved_kv_to_device(self, req_id: str): + """copy recved kv from host buffer to device.""" + if not self.use_host_buffer: + return + assert self.h2d_copy_blocks is not None + + if req_id in self._recving_metadata and \ + req_id not in self._recving_transfers: + meta = self._recving_metadata[req_id] + # local decode only + if not meta.do_remote_prefill: + return + local_block_ids = meta.local_block_ids + self.h2d_copy_blocks(self.host_xfer_buffers, self.device_kv_caches, + local_block_ids, local_block_ids, self.device) + logger.debug( + "synced recved kv of request[%s] to device kv buffer," + "local_block_ids: %s. ", req_id, + ",".join(map(str, meta.local_block_ids))) + return + + def save_kv_to_host(self, metadata: NixlConnectorMetadata): + """copy kv from device to host buffer.""" + if not self.use_host_buffer: + return + assert self.d2h_copy_blocks is not None + + for req_id, meta in metadata.requests.items(): + # local prefill requests only + if not meta.do_remote_decode: + continue + # blocking + logger.debug( + "save_load_kv for request[%s] to host xfer buffer." + "local_block_ids: %s. ", req_id, + ",".join(map(str, meta.local_block_ids))) + self.d2h_copy_blocks(self.host_xfer_buffers, self.device_kv_caches, + meta.local_block_ids, meta.local_block_ids, + self.device) + return + def get_finished(self) -> tuple[set[str], set[str]]: """ Get requests that are done sending or recving. @@ -762,6 +972,13 @@ def get_finished(self) -> tuple[set[str], set[str]]: "and %s requests done recving", self.tp_rank, len(done_sending), len(done_recving)) + for req_id in done_recving: + assert req_id in self._recving_metadata, ( + f"{req_id} not found in recving_metadata list") + if self.use_host_buffer and self.h2d_copy_blocks is not None: + self.sync_recved_kv_to_device(req_id) + self._recving_metadata.pop(req_id) + if self.world_size == 1: return done_sending, done_recving @@ -855,6 +1072,8 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): We check for these trnxs to complete in each step(). """ for req_id, meta in metadata.requests.items(): + if not meta.do_remote_prefill: + continue logger.debug( "start_load_kv for request %s from remote engine %s. " "Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id, @@ -868,6 +1087,7 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): remote_host=meta.remote_host, remote_port=meta.remote_port, ) + self._recving_metadata[req_id] = copy.deepcopy(meta) def _read_blocks( self, diff --git a/vllm/envs.py b/vllm/envs.py index 921052821ee3..709627c2dec8 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -128,6 +128,7 @@ VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1 VLLM_SLEEP_WHEN_IDLE: bool = False VLLM_MQ_MAX_CHUNK_BYTES_MB: int = 16 + VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S: int = 300 def get_default_cache_root(): @@ -879,6 +880,10 @@ def get_vllm_port() -> Optional[int]: # processes via zmq. "VLLM_MQ_MAX_CHUNK_BYTES_MB": lambda: int(os.getenv("VLLM_MQ_MAX_CHUNK_BYTES_MB", "16")), + + # Timeout for calling execute_model() in multiproc_executor + "VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S": + lambda: int(os.getenv("VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S", "300")), } # --8<-- [end:env-vars-definition] diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 2148680d5f56..a72a39974f95 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -40,7 +40,7 @@ POLLING_TIMEOUT_MS = 5000 POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000 -EXECUTE_MODEL_TIMEOUT_S = 300 +EXECUTE_MODEL_TIMEOUT_S = envs.VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S class MultiprocExecutor(Executor): diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 987a24496d75..c51aa76461e8 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import copy import gc import time import weakref @@ -24,12 +23,10 @@ get_layers_from_vllm_config) from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) -from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.distributed.parallel_state import ( get_pp_group, get_tp_group, graph_capture, prepare_communication_buffer_for_model) -from vllm.forward_context import (DPMetadata, get_forward_context, - set_forward_context) +from vllm.forward_context import DPMetadata, set_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader @@ -59,6 +56,8 @@ from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch +from vllm.v1.worker.kv_connector_model_runner_mixin import ( + KVConnectorModelRunnerMixin) from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing, @@ -75,7 +74,7 @@ logger = init_logger(__name__) -class GPUModelRunner(LoRAModelRunnerMixin): +class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def __init__( self, @@ -1181,7 +1180,8 @@ def execute_model( # Return empty ModelRunnerOutput if there's no work to do. return EMPTY_MODEL_RUNNER_OUTPUT - return self.kv_connector_no_forward(scheduler_output) + return self.kv_connector_no_forward(scheduler_output, + self.vllm_config) # Prepare the decoder inputs. attn_metadata, logits_indices, spec_decode_metadata = ( @@ -1508,52 +1508,6 @@ def execute_model( finished_recving=finished_recving, ) - def kv_connector_no_forward( - self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput: - # KV send/recv even if no work to do. - with set_forward_context(None, self.vllm_config): - self.maybe_setup_kv_connector(scheduler_output) - finished_sending, finished_recving = ( - self.get_finished_kv_transfers(scheduler_output)) - - if not finished_sending and not finished_recving: - return EMPTY_MODEL_RUNNER_OUTPUT - - output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) - output.finished_sending = finished_sending - output.finished_recving = finished_recving - return output - - @staticmethod - def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"): - # Update KVConnector with the KVConnector metadata forward(). - if has_kv_transfer_group(): - kv_connector = get_kv_transfer_group() - assert isinstance(kv_connector, KVConnectorBase_V1) - assert scheduler_output.kv_connector_metadata is not None - kv_connector.bind_connector_metadata( - scheduler_output.kv_connector_metadata) - - # Background KV cache transfers happen here. - # These transfers are designed to be async and the requests - # involved may be disjoint from the running requests. - # Do this here to save a collective_rpc. - kv_connector.start_load_kv(get_forward_context()) - - @staticmethod - def maybe_wait_for_kv_save() -> None: - if has_kv_transfer_group(): - get_kv_transfer_group().wait_for_save() - - @staticmethod - def get_finished_kv_transfers( - scheduler_output: "SchedulerOutput", - ) -> tuple[Optional[set[str]], Optional[set[str]]]: - if has_kv_transfer_group(): - return get_kv_transfer_group().get_finished( - scheduler_output.finished_req_ids) - return None, None - def generate_draft_token_ids( self, sampled_token_ids: list[list[int]], diff --git a/vllm/v1/worker/kv_connector_model_runner_mixin.py b/vllm/v1/worker/kv_connector_model_runner_mixin.py new file mode 100644 index 000000000000..9da1356953ba --- /dev/null +++ b/vllm/v1/worker/kv_connector_model_runner_mixin.py @@ -0,0 +1,69 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Define KV connector functionality mixin for model runners. +""" +import copy +from typing import TYPE_CHECKING, Optional + +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer import (get_kv_transfer_group, + has_kv_transfer_group) +from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 +from vllm.forward_context import get_forward_context, set_forward_context +from vllm.logger import init_logger +from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + +logger = init_logger(__name__) + + +# Defined as a kv connector functionality mixin for ModelRunner (GPU, TPU) +class KVConnectorModelRunnerMixin: + + def kv_connector_no_forward(self, scheduler_output: "SchedulerOutput", + vllm_config: VllmConfig) -> ModelRunnerOutput: + # KV send/recv even if no work to do. + with set_forward_context(None, vllm_config): + self.maybe_setup_kv_connector(scheduler_output) + finished_sending, finished_recving = ( + self.get_finished_kv_transfers(scheduler_output)) + + if not finished_sending and not finished_recving: + return EMPTY_MODEL_RUNNER_OUTPUT + + output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) + output.finished_sending = finished_sending + output.finished_recving = finished_recving + return output + + @staticmethod + def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"): + # Update KVConnector with the KVConnector metadata forward(). + if has_kv_transfer_group(): + kv_connector = get_kv_transfer_group() + assert isinstance(kv_connector, KVConnectorBase_V1) + assert scheduler_output.kv_connector_metadata is not None + kv_connector.bind_connector_metadata( + scheduler_output.kv_connector_metadata) + + # Background KV cache transfers happen here. + # These transfers are designed to be async and the requests + # involved may be disjoint from the running requests. + # Do this here to save a collective_rpc. + kv_connector.start_load_kv(get_forward_context()) + + @staticmethod + def maybe_wait_for_kv_save() -> None: + if has_kv_transfer_group(): + get_kv_transfer_group().wait_for_save() + + @staticmethod + def get_finished_kv_transfers( + scheduler_output: "SchedulerOutput", + ) -> tuple[Optional[set[str]], Optional[set[str]]]: + if has_kv_transfer_group(): + return get_kv_transfer_group().get_finished( + scheduler_output.finished_req_ids) + return None, None diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index d5f40e4d3103..43c8e533356a 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -3,7 +3,7 @@ import bisect import gc import time -from typing import TYPE_CHECKING, Optional, cast +from typing import TYPE_CHECKING, Optional, Union, cast from unittest.mock import patch import numpy as np @@ -19,6 +19,8 @@ from vllm.attention.layer import Attention from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher from vllm.config import ParallelConfig, VllmConfig, get_layers_from_vllm_config +from vllm.distributed.kv_transfer import (get_kv_transfer_group, + has_kv_transfer_group) from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.lora.layers import BaseLayerWithLoRA @@ -43,6 +45,8 @@ from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch +from vllm.v1.worker.kv_connector_model_runner_mixin import ( + KVConnectorModelRunnerMixin) from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from .utils import (initialize_kv_cache_for_kv_sharing, @@ -96,7 +100,7 @@ # The dummy_run should be comprehensive, ensuring all potential input shapes and # branch predictions are included as subgraph inputs to facilitate # pre-compilation. -class TPUModelRunner(LoRAModelRunnerMixin): +class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def __init__( self, @@ -836,8 +840,12 @@ def execute_model( # Update cached state self._update_states(scheduler_output) if not scheduler_output.total_num_scheduled_tokens: - # Return empty ModelRunnerOutput if there's no work to do. - return EMPTY_MODEL_RUNNER_OUTPUT + if not has_kv_transfer_group(): + # Return empty ModelRunnerOutput if there's no work to do. + return EMPTY_MODEL_RUNNER_OUTPUT + + return self.kv_connector_no_forward(scheduler_output, + self.vllm_config) if self.is_multimodal_model: # Run the multimodal encoder if any. @@ -858,11 +866,17 @@ def execute_model( attn_metadata, self.vllm_config, num_tokens=scheduler_output.total_num_scheduled_tokens): + self.maybe_setup_kv_connector(scheduler_output) + hidden_states = self.model( input_ids=input_ids, positions=self.position_ids, inputs_embeds=inputs_embeds, ) + self.maybe_wait_for_kv_save() + finished_sending, finished_recving = ( + self.get_finished_kv_transfers(scheduler_output)) + hidden_states = self.select_hidden_states(hidden_states, logits_indices) logits = self.compute_logits(hidden_states) @@ -958,6 +972,8 @@ def execute_model( spec_token_ids=None, logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, + finished_sending=finished_sending, + finished_recving=finished_recving, ) # Check there are no new graphs compiled - all the graphs should be @@ -1430,6 +1446,11 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: for cache in self.kv_caches: xs.mark_sharding(cache, self.mesh, (None, 'x', None, None)) + if has_kv_transfer_group(): + get_kv_transfer_group().register_kv_caches(kv_caches) + get_kv_transfer_group().set_host_xfer_buffer_ops( + d2h_copy_blocks, h2d_copy_blocks) + def reset_dynamo_cache(self): if self.is_multimodal_model: compiled_model = self.model.get_language_model().model @@ -1644,6 +1665,93 @@ def _get_padded_token_len(paddings: list[int], x: int) -> int: return paddings[index] +def _make_src_and_dst_indices( + src_block_ids: list[int], + dst_block_ids: list[int], + src_device: Union[torch.device, str], + dst_device: Union[torch.device, str], +) -> tuple[torch.Tensor, torch.Tensor]: + src_indices = torch.tensor(src_block_ids, + device=src_device, + dtype=torch.int64) + dst_indices = torch.tensor(dst_block_ids, + device=dst_device, + dtype=torch.int64) + return src_indices, dst_indices + + +@torch.compile(backend="openxla") +def _insert_blocks_to_tpu( + src_cache: torch.Tensor, + tpu_cache: torch.Tensor, + tpu_block_indices: torch.Tensor, +) -> None: + torch.ops.xla.dynamo_set_buffer_donor_(tpu_cache, True) + tpu_cache[tpu_block_indices] = src_cache + + +@torch.compile(backend="openxla") +def _swap_out_tpu_blocks( + tpu_cache: torch.Tensor, + cpu_cache: torch.Tensor, + tpu_block_indices: torch.Tensor, + cpu_block_indices: torch.Tensor, +) -> None: + """ tpu blocks to cpu blocks""" + torch.ops.xla.dynamo_set_buffer_donor_(tpu_cache, True) + _tpu_cache = tpu_cache[tpu_block_indices] + cpu_cache[cpu_block_indices] = _tpu_cache.cpu() + + +def h2d_copy_blocks( + cpu_kv_caches: dict[str, torch.Tensor], + tpu_kv_caches: dict[str, torch.Tensor], + cpu_block_ids: list[int], + tpu_block_ids: list[int], + tpu_device: str, +) -> None: + """Copy kv blocks from host xfer buffer to device.""" + if not cpu_block_ids or not tpu_block_ids or len(cpu_block_ids) != len( + tpu_block_ids): + return + host_indices, device_indices = _make_src_and_dst_indices( + src_block_ids=cpu_block_ids, + dst_block_ids=tpu_block_ids, + src_device="cpu", + dst_device=tpu_device) + for layer_name in cpu_kv_caches: + host_tensor = cpu_kv_caches[layer_name] + device_tensor = tpu_kv_caches[layer_name] + sliced_device_tensor = host_tensor[host_indices].to(tpu_device) + _insert_blocks_to_tpu(sliced_device_tensor, device_tensor, + device_indices) + + +def d2h_copy_blocks( + cpu_kv_caches: dict[str, torch.Tensor], + tpu_kv_caches: dict[str, torch.Tensor], + cpu_block_ids: list[int], + tpu_block_ids: list[int], + tpu_device: str, +) -> None: + """Copy kv blocks from device to host xfer buffer.""" + if not cpu_block_ids or not tpu_block_ids or len(cpu_block_ids) != len( + tpu_block_ids): + return + device_indices, host_indices = _make_src_and_dst_indices( + src_block_ids=tpu_block_ids, + dst_block_ids=cpu_block_ids, + src_device=tpu_device, + dst_device="cpu") + for layer_name in cpu_kv_caches: + host_tensor = cpu_kv_caches[layer_name] + device_tensor = tpu_kv_caches[layer_name] + _swap_out_tpu_blocks(tpu_cache=device_tensor, + cpu_cache=host_tensor, + tpu_block_indices=device_indices, + cpu_block_indices=host_indices) + + def replace_set_lora(model): def _tpu_set_lora( diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 5da481baeeea..cd07e3afb182 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -12,9 +12,10 @@ import torch_xla.runtime as xr import vllm.envs as envs -from vllm.config import ParallelConfig, VllmConfig +from vllm.config import VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) +from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed @@ -110,7 +111,7 @@ def init_device(self): # Initialize the distributed environment. self._init_tpu_worker_distributed_environment( - self.parallel_config, self.rank, self.distributed_init_method, + self.vllm_config, self.rank, self.distributed_init_method, self.local_rank) # Device initialization should happen after initializing @@ -267,7 +268,7 @@ def check_health(self) -> None: def _init_tpu_worker_distributed_environment( self, - parallel_config: ParallelConfig, + vllm_config: VllmConfig, rank: int, distributed_init_method: Optional[str] = None, local_rank: int = -1, @@ -279,6 +280,7 @@ def _init_tpu_worker_distributed_environment( # the input objects on CPU. The all-reduce and all-gather ops on TPU # are invoked by `xm.all_reduce` and `xm.all_gather` which use their # own context. + parallel_config = vllm_config.parallel_config init_distributed_environment( world_size=parallel_config.world_size, rank=rank, @@ -290,6 +292,8 @@ def _init_tpu_worker_distributed_environment( parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) + ensure_kv_transfer_initialized(vllm_config) + try: from tpu_commons.worker import TPUWorker as TPUCommonsWorker From d8041e657053741c6b67e4004e65a9e54bf1cc07 Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Thu, 19 Jun 2025 05:41:48 +0000 Subject: [PATCH 02/19] fix device_index Signed-off-by: Juncheng Gu --- .../kv_transfer/kv_connector/v1/nixl_connector.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 9be7d4dd0d94..c4299ad60595 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -685,9 +685,11 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.dst_num_blocks[self.engine_id] = self.num_blocks self.device_kv_caches = kv_caches self.device = first_kv_cache.device - # if CPU device has no index - self.device_index = 0 if not hasattr(self.device, "index") else \ - self.device.index + # Note: non-CUDA devices may have a fixed device.index (0), + # use its tp_rank instead + self.device_index = (self.tp_rank if self.use_host_buffer or + self.device_type != "cuda" else self.device.index) + assert self.device assert self.device_index >= 0, \ f"cache device {self.device} index is invalid" @@ -711,7 +713,9 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): for cache in cache_list: base_addr = cache.data_ptr() region_len = self.num_blocks * self.block_len - caches_data.append((base_addr, region_len, self.tp_rank, "")) + # TODO: does device_id matter to DRAM? + caches_data.append( + (base_addr, region_len, self.device_index, "")) kv_caches_base_addr.append(base_addr) self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr self.num_regions = len(caches_data) @@ -756,7 +760,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): block_offset = block_id * self.block_len addr = base_addr + block_offset # (addr, len, device id) - blocks_data.append((addr, self.block_len, self.tp_rank)) + # TODO: does device_id matter to DRAM? + blocks_data.append((addr, self.block_len, self.device_index)) logger.debug("Created %s blocks for src engine %s and rank %s", len(blocks_data), self.engine_id, self.tp_rank) From 62b4460e0139fcc0cb162d17c94a50a939851aa9 Mon Sep 17 00:00:00 2001 From: Richard Liu Date: Mon, 23 Jun 2025 17:02:47 +0000 Subject: [PATCH 03/19] fix error Signed-off-by: Richard Liu --- vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 84e4ae9a80f6..fab53d6e9bf9 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -853,12 +853,9 @@ def add_remote_agent(self, tp_ratio = divide(self._tp_size[self.engine_id], self._tp_size[engine_id]) assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP" - assert self._use_pallas_v1 and tp_ratio == 1, \ "TPU (pallas_v1) DOES NOT support heterogeneous TP yet." - if self.use_mla: - # Handle tp_size>num_kv_heads: replicate KV cache. total_num_kv_heads = self.model_config.get_total_num_kv_heads() is_kv_replicated = self._tp_size[engine_id] // total_num_kv_heads >= 1 From b1ec96247165f6751323ef053eaec1529cec8651 Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Tue, 24 Jun 2025 20:00:08 +0000 Subject: [PATCH 04/19] fix comments Signed-off-by: Juncheng Gu --- .../kv_connector/v1/nixl_connector.py | 116 ++++++++++-------- vllm/envs.py | 5 - vllm/v1/executor/multiproc_executor.py | 2 - vllm/v1/worker/tpu_model_runner.py | 91 ++++++-------- 4 files changed, 105 insertions(+), 109 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 51c7c25346dd..dfa0128b6e55 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import contextlib -import copy import math import threading import time @@ -39,6 +38,12 @@ Transfer = tuple[int, float] # (xfer_handle, start_time) EngineId = str ReqId = str + +# s_tensor_list, d_tensor_list, s_indices, d_indices, direction +CopyBlocksOp = Callable[[ + dict[str, torch.Tensor], dict[str, torch.Tensor], list[int], list[int], str +], None] + GET_META_MSG = b"get_meta_msg" logger = init_logger(__name__) @@ -110,8 +115,9 @@ def add_new_req( request_id: str, local_block_ids: list[int], kv_transfer_params: dict[str, Any], + customize_kv_transfer_params: Optional[dict[str, Any]] = None, ): - self.requests[request_id] = ReqMeta( + _req_meta = ReqMeta( local_block_ids=local_block_ids, remote_block_ids=kv_transfer_params["remote_block_ids"], remote_engine_id=kv_transfer_params["remote_engine_id"], @@ -120,6 +126,12 @@ def add_new_req( do_remote_prefill=kv_transfer_params["do_remote_prefill"], do_remote_decode=kv_transfer_params["do_remote_decode"], ) + if customize_kv_transfer_params: + for param, value in customize_kv_transfer_params.items(): + if hasattr(_req_meta, param): + setattr(_req_meta, param, value) + + self.requests[request_id] = _req_meta class NixlConnector(KVConnectorBase_V1): @@ -178,11 +190,9 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): assert self.connector_worker is not None self.connector_worker.register_kv_caches(kv_caches) - def set_host_xfer_buffer_ops(self, d2h_copy_blocks: Callable, - h2d_copy_blocks: Callable): + def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp): assert self.connector_worker is not None - self.connector_worker.set_host_xfer_buffer_ops(d2h_copy_blocks, - h2d_copy_blocks) + self.connector_worker.set_host_xfer_buffer_ops(copy_operation) def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]: @@ -224,13 +234,15 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): envs.VLLM_NIXL_SIDE_CHANNEL_PORT + vllm_config.parallel_config.data_parallel_rank_local * vllm_config.parallel_config.tensor_parallel_size) + self.use_host_buffer = \ + vllm_config.kv_transfer_config.kv_buffer_device == "cpu" logger.info("Initializing NIXL Scheduler %s", engine_id) # Requests that need to start recv. # New requests are added by update_state_after_alloc in # the scheduler. Used to make metadata passed to Worker. self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {} - self._reqs_need_send: dict[ReqId, tuple[Request, list[int]]] = {} + self._reqs_need_save: dict[ReqId, tuple[Request, list[int]]] = {} def get_num_new_matched_tokens( self, request: "Request", @@ -291,7 +303,7 @@ def update_state_after_alloc(self, request: "Request", # block is not overwritten; and it will be safe to skip saving them # to host xfer buffer. if full_block_ids: - self._reqs_need_send[request.request_id] = \ + self._reqs_need_save[request.request_id] = \ (request, full_block_ids) elif params.get("do_remote_prefill"): if params.get("remote_block_ids"): @@ -322,29 +334,33 @@ def build_connector_meta( meta = NixlConnectorMetadata() # Loop through scheduled reqs and convert to ReqMeta. + _customize_params = { + "do_remote_prefill": True, + } if self.use_host_buffer else {} for req_id, (req, block_ids) in self._reqs_need_recv.items(): assert req.kv_transfer_params is not None - _kv_transfer_params = copy.deepcopy(req.kv_transfer_params) - _kv_transfer_params["do_remote_prefill"] = True meta.add_new_req( request_id=req_id, local_block_ids=block_ids, - kv_transfer_params=_kv_transfer_params, + kv_transfer_params=req.kv_transfer_params, + customize_kv_transfer_params=_customize_params, ) - for req_id, (req, block_ids) in self._reqs_need_send.items(): + _customize_params = { + "do_remote_decode": True, + } if self.use_host_buffer else {} + for req_id, (req, block_ids) in self._reqs_need_save.items(): assert req.kv_transfer_params is not None - _kv_transfer_params = copy.deepcopy(req.kv_transfer_params) - _kv_transfer_params["do_remote_decode"] = True meta.add_new_req( request_id=req_id, local_block_ids=block_ids, - kv_transfer_params=_kv_transfer_params, + kv_transfer_params=req.kv_transfer_params, + customize_kv_transfer_params=_customize_params, ) # Clear the list once workers start the transfers self._reqs_need_recv.clear() - self._reqs_need_send.clear() + self._reqs_need_save.clear() return meta @@ -427,7 +443,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # KV Caches and nixl tracking data. self.kv_buffer_device: str = \ - vllm_config.kv_transfer_config.kv_buffer_device.strip().lower() + vllm_config.kv_transfer_config.kv_buffer_device if not _NIXL_SUPPORTED_XPU.is_supported_kv_buffer( device_type=self.device_type, kv_buffer_type=self.kv_buffer_device): @@ -452,8 +468,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): "is not supported.") # Note: host xfer buffer ops when use_host_buffer is True - self.d2h_copy_blocks: Optional[Callable] = None - self.h2d_copy_blocks: Optional[Callable] = None + self.copy_blocks: Optional[CopyBlocksOp] = None # Map of engine_id -> kv_caches_base_addr. For TP case, each local # rank will still only pull from a single remote TP worker. @@ -590,13 +605,16 @@ def handshake(path: str, rank: int) -> NixlAgentMetadata: def initialize_host_xfer_buffer( self, kv_caches: dict[str, torch.Tensor]) -> None: - """Initialize transfer buffer in CPU mem for xPUs (e.g., tpu)""" + """ + Initialize transfer buffer in CPU mem for accelerators + NOT directly supported by NIXL (e.g., tpu) + """ xfer_buffers: dict[str, torch.Tensor] = {} try: for layer_name, kv_cache in kv_caches.items(): kv_shape = kv_cache.shape kv_dtype = kv_cache.dtype - xfer_buffers[layer_name] = torch.zeros(kv_shape, + xfer_buffers[layer_name] = torch.empty(kv_shape, dtype=kv_dtype, device="cpu") except MemoryError as e: @@ -605,12 +623,10 @@ def initialize_host_xfer_buffer( self.host_xfer_buffers = xfer_buffers - def set_host_xfer_buffer_ops(self, d2h_copy_blocks: Callable, - h2d_copy_blocks: Callable): + def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp): """Assign copy (d2h, h2d) operations when host buffer is used.""" assert self.use_host_buffer - self.d2h_copy_blocks = d2h_copy_blocks - self.h2d_copy_blocks = h2d_copy_blocks + self.copy_blocks = copy_operation def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data in nixl.""" @@ -678,14 +694,12 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # hybrid attn, etc # block size in bytes self.block_len = kv_elem_size * math.prod(block_shape) - - logger.debug( + logger.info( "Registering KV_Caches. use_mla: %s, kv_buffer_device: %s, " - "use_host_buffer: %s, shape %s", use_mla, self.kv_buffer_device, - self.use_host_buffer, first_kv_cache.shape) - logger.debug("num_blocks: %s, block_shape: %s", self.num_blocks, - block_shape) - logger.debug("Per layer kv cache size: %s", first_kv_cache.shape) + "use_host_buffer: %s, num_blocks: %s, block_shape: %s, " + "per_layer_kv_cache_shape: %s", use_mla, self.kv_buffer_device, + self.use_host_buffer, self.num_blocks, block_shape, + first_kv_cache.shape) self.dst_num_blocks[self.engine_id] = self.num_blocks self.device_kv_caches = kv_caches self.device = first_kv_cache.device @@ -856,7 +870,7 @@ def add_remote_agent(self, tp_ratio = divide(self._tp_size[self.engine_id], self._tp_size[engine_id]) assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP" - assert self._use_pallas_v1 and tp_ratio == 1, \ + assert not self._use_pallas_v1 or tp_ratio == 1, \ "TPU (pallas_v1) DOES NOT support heterogeneous TP yet." # Handle tp_size>num_kv_heads: replicate KV cache. @@ -923,21 +937,23 @@ def add_remote_agent(self, engine_id] = self.nixl_wrapper.prep_xfer_dlist( self._remote_agents[engine_id][remote_tp_rank], descs) - def sync_recved_kv_to_device(self, req_id: str): + def sync_recved_kv_to_device(self, + req_id: str, + meta: Optional[ReqMeta] = None): """copy recved kv from host buffer to device.""" if not self.use_host_buffer: return - assert self.h2d_copy_blocks is not None + assert self.copy_blocks is not None - if req_id in self._recving_metadata and \ - req_id not in self._recving_transfers: - meta = self._recving_metadata[req_id] + if meta is None: + meta = self._recving_metadata.get(req_id) + if meta and req_id not in self._recving_transfers: # local decode only if not meta.do_remote_prefill: return local_block_ids = meta.local_block_ids - self.h2d_copy_blocks(self.host_xfer_buffers, self.device_kv_caches, - local_block_ids, local_block_ids, self.device) + self.copy_blocks(self.host_xfer_buffers, self.device_kv_caches, + local_block_ids, local_block_ids, "h2d") logger.debug( "synced recved kv of request[%s] to device kv buffer," "local_block_ids: %s. ", req_id, @@ -948,7 +964,7 @@ def save_kv_to_host(self, metadata: NixlConnectorMetadata): """copy kv from device to host buffer.""" if not self.use_host_buffer: return - assert self.d2h_copy_blocks is not None + assert self.copy_blocks is not None for req_id, meta in metadata.requests.items(): # local prefill requests only @@ -959,9 +975,8 @@ def save_kv_to_host(self, metadata: NixlConnectorMetadata): "save_load_kv for request[%s] to host xfer buffer." "local_block_ids: %s. ", req_id, ",".join(map(str, meta.local_block_ids))) - self.d2h_copy_blocks(self.host_xfer_buffers, self.device_kv_caches, - meta.local_block_ids, meta.local_block_ids, - self.device) + self.copy_blocks(self.device_kv_caches, self.host_xfer_buffers, + meta.local_block_ids, meta.local_block_ids, "d2h") return def get_finished(self) -> tuple[set[str], set[str]]: @@ -984,12 +999,11 @@ def get_finished(self) -> tuple[set[str], set[str]]: "and %s requests done recving", self.tp_rank, len(done_sending), len(done_recving)) - for req_id in done_recving: - assert req_id in self._recving_metadata, ( - f"{req_id} not found in recving_metadata list") - if self.use_host_buffer and self.h2d_copy_blocks is not None: - self.sync_recved_kv_to_device(req_id) - self._recving_metadata.pop(req_id) + if self.use_host_buffer and self.copy_blocks is not None: + for req_id in done_recving: + meta = self._recving_metadata.pop(req_id) + assert meta, (f"{req_id} not found in recving_metadata list") + self.sync_recved_kv_to_device(req_id, meta) if self.world_size == 1: return done_sending, done_recving @@ -1102,7 +1116,7 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): remote_host=meta.remote_host, remote_port=meta.remote_port, ) - self._recving_metadata[req_id] = copy.deepcopy(meta) + self._recving_metadata[req_id] = meta def _read_blocks( self, diff --git a/vllm/envs.py b/vllm/envs.py index 3d31d82ea2f1..01d8d8a2d2e0 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -132,7 +132,6 @@ VLLM_MQ_MAX_CHUNK_BYTES_MB: int = 16 VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS: int = 300 VLLM_KV_CACHE_LAYOUT: Optional[str] = None - VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S: int = 300 VLLM_COMPUTE_NANS_IN_LOGITS: bool = False @@ -914,10 +913,6 @@ def get_vllm_port() -> Optional[int]: "VLLM_KV_CACHE_LAYOUT": lambda: os.getenv("VLLM_KV_CACHE_LAYOUT", None), - # Timeout for calling execute_model() in multiproc_executor - "VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S": - lambda: int(os.getenv("VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S", "300")), - # Enable checking whether the generated logits contain NaNs, # indicating corrupted output. Useful for debugging low level bugs # or bad hardware but it may add compute overhead. diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 0070eeea0c88..9da900521961 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -40,8 +40,6 @@ POLLING_TIMEOUT_MS = 5000 POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000 -EXECUTE_MODEL_TIMEOUT_S = envs.VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S - class MultiprocExecutor(Executor): diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 0cf500953e7e..85f09b81191b 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -1450,8 +1450,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: if has_kv_transfer_group(): get_kv_transfer_group().register_kv_caches(kv_caches) - get_kv_transfer_group().set_host_xfer_buffer_ops( - d2h_copy_blocks, h2d_copy_blocks) + get_kv_transfer_group().set_host_xfer_buffer_ops(copy_kv_blocks) def reset_dynamo_cache(self): if self.is_multimodal_model: @@ -1684,12 +1683,14 @@ def _make_src_and_dst_indices( @torch.compile(backend="openxla") def _insert_blocks_to_tpu( - src_cache: torch.Tensor, + cpu_cache: torch.Tensor, tpu_cache: torch.Tensor, + cpu_block_indices: torch.Tensor, tpu_block_indices: torch.Tensor, ) -> None: torch.ops.xla.dynamo_set_buffer_donor_(tpu_cache, True) - tpu_cache[tpu_block_indices] = src_cache + tpu_cache[tpu_block_indices] = cpu_cache[cpu_block_indices].to( + tpu_cache.device) @torch.compile(backend="openxla") @@ -1701,57 +1702,45 @@ def _swap_out_tpu_blocks( ) -> None: """ tpu blocks to cpu blocks""" torch.ops.xla.dynamo_set_buffer_donor_(tpu_cache, True) - _tpu_cache = tpu_cache[tpu_block_indices] - cpu_cache[cpu_block_indices] = _tpu_cache.cpu() + cpu_cache[cpu_block_indices] = tpu_cache[tpu_block_indices].cpu() -def h2d_copy_blocks( - cpu_kv_caches: dict[str, torch.Tensor], - tpu_kv_caches: dict[str, torch.Tensor], - cpu_block_ids: list[int], - tpu_block_ids: list[int], - tpu_device: str, -) -> None: - """Copy kv blocks from host xfer buffer to device.""" - if not cpu_block_ids or not tpu_block_ids or len(cpu_block_ids) != len( - tpu_block_ids): - return - host_indices, device_indices = _make_src_and_dst_indices( - src_block_ids=cpu_block_ids, - dst_block_ids=tpu_block_ids, - src_device="cpu", - dst_device=tpu_device) - for layer_name in cpu_kv_caches: - host_tensor = cpu_kv_caches[layer_name] - device_tensor = tpu_kv_caches[layer_name] - sliced_device_tensor = host_tensor[host_indices].to(tpu_device) - _insert_blocks_to_tpu(sliced_device_tensor, device_tensor, - device_indices) - - -def d2h_copy_blocks( - cpu_kv_caches: dict[str, torch.Tensor], - tpu_kv_caches: dict[str, torch.Tensor], - cpu_block_ids: list[int], - tpu_block_ids: list[int], - tpu_device: str, +def copy_kv_blocks( + src_kv_caches: dict[str, torch.Tensor], + dst_kv_caches: dict[str, torch.Tensor], + src_block_ids: list[int], + dst_block_ids: list[int], + direction: str, ) -> None: - """Copy kv blocks from device to host xfer buffer.""" - if not cpu_block_ids or not tpu_block_ids or len(cpu_block_ids) != len( - tpu_block_ids): + """Copy kv blocks between different buffers.""" + direction = direction.strip().lower() + assert direction in ("h2d", "d2h",), \ + (f"Invalid direction: {direction}") + + if not src_kv_caches or not dst_kv_caches or \ + not src_block_ids or not dst_block_ids or \ + len(src_block_ids) != len(dst_block_ids): return - device_indices, host_indices = _make_src_and_dst_indices( - src_block_ids=tpu_block_ids, - dst_block_ids=cpu_block_ids, - src_device=tpu_device, - dst_device="cpu") - for layer_name in cpu_kv_caches: - host_tensor = cpu_kv_caches[layer_name] - device_tensor = tpu_kv_caches[layer_name] - _swap_out_tpu_blocks(tpu_cache=device_tensor, - cpu_cache=host_tensor, - tpu_block_indices=device_indices, - cpu_block_indices=host_indices) + + _, src_kv_cache = next(iter(src_kv_caches.items())) + src_device = src_kv_cache.device + _, dst_kv_cache = next(iter(dst_kv_caches.items())) + dst_device = dst_kv_cache.device + + src_indices, dst_indices = _make_src_and_dst_indices( + src_block_ids=src_block_ids, + dst_block_ids=dst_block_ids, + src_device=src_device, + dst_device=dst_device) + + _copy_fn = _insert_blocks_to_tpu if direction == "h2d" else \ + _swap_out_tpu_blocks + for layer_name in src_kv_caches: + src_tensor = src_kv_caches[layer_name] + dst_tensor = dst_kv_caches[layer_name] + _copy_fn(src_tensor, dst_tensor, src_indices, dst_indices) + + return def replace_set_lora(model): From 0995bbda452a32d140fb9321fab091ce71b37081 Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Tue, 24 Jun 2025 23:59:02 +0000 Subject: [PATCH 05/19] fix recving_meta at decode Signed-off-by: Juncheng Gu --- vllm/distributed/kv_transfer/kv_connector/v1/base.py | 8 ++++++-- .../kv_transfer/kv_connector/v1/nixl_connector.py | 11 +++-------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index dc005f33536a..874c292b33cb 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -46,6 +46,11 @@ from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.request import Request +# s_tensor_list, d_tensor_list, s_indices, d_indices, direction +CopyBlocksOp = Callable[[ + dict[str, torch.Tensor], dict[str, torch.Tensor], list[int], list[int], str +], None] + logger = init_logger(__name__) @@ -124,8 +129,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """ return - def set_host_xfer_buffer_ops(self, d2h_copy_blocks: Callable, - h2d_copy_blocks: Callable): + def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp): """ Set the xPU-specific ops for copying KV between host and device. Needed when host buffer is used for kv transfer (e.g., in NixlConnector) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 9b3f81c2e15f..f423ffd76f96 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -10,7 +10,7 @@ from collections.abc import Iterator from concurrent.futures import Future, ThreadPoolExecutor from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Optional import msgspec import torch @@ -20,7 +20,7 @@ from vllm.attention.selector import backend_name_to_enum, get_attn_backend from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) + CopyBlocksOp, KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) from vllm.distributed.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, get_tp_group) @@ -41,11 +41,6 @@ EngineId = str ReqId = str -# s_tensor_list, d_tensor_list, s_indices, d_indices, direction -CopyBlocksOp = Callable[[ - dict[str, torch.Tensor], dict[str, torch.Tensor], list[int], list[int], str -], None] - GET_META_MSG = b"get_meta_msg" logger = init_logger(__name__) @@ -1132,6 +1127,7 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): "Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id, remote_engine_id, len(meta.local_block_ids), len(meta.remote_block_ids)) + self._recving_metadata[req_id] = meta if remote_engine_id not in self._remote_agents: # Being optimistic to assume engine is usually ready, apply # lock only when the optimistic check fails. @@ -1165,7 +1161,6 @@ def request_ready(_f: Future[Any], fut.add_done_callback(request_ready) continue self._read_blocks_for_req(req_id, meta) - self._recving_metadata[req_id] = meta # Start transfers for requests whose handshakes have now finished. while not self._ready_requests.empty(): From 71cf9530befdbbdd2ba91a1bb3a6d9f12850c3df Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Wed, 25 Jun 2025 04:28:43 +0000 Subject: [PATCH 06/19] tweaks Signed-off-by: Juncheng Gu --- .../kv_transfer/kv_connector/v1/nixl_connector.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index f423ffd76f96..8ea25d48afd6 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -290,7 +290,10 @@ def update_state_after_alloc(self, request: "Request", if not params: return if params.get("do_remote_decode"): - # NOTE: figure out full computed blocks to send / save + # NOTE: when kv_buffer_device (e.) is not supported by Nixl, + # prefilled blocks need to be saved to host memory before transfer. + + # figure out full computed blocks to save block_ids = blocks.get_block_ids()[0] all_full = request.num_tokens % self.block_size == 0 full_block_ids = (block_ids if all_full else block_ids[:-1]) @@ -679,6 +682,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): block_size, n_kv_heads_x_2, head_dim = block_shape self.slot_size_bytes = kv_elem_size * n_kv_heads_x_2 * head_dim elif self.device_type == "cuda": + assert use_mla == self.use_mla # TODO (NickLucche) not compatible with hybrid allocator. # Enforce check once it goes live, as a single kv layout # is expected for xfers. From 0ab79dc1bb6374a0336be583271e1533c488c3e9 Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Wed, 25 Jun 2025 07:18:37 +0000 Subject: [PATCH 07/19] tweak Signed-off-by: Juncheng Gu --- vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 8ea25d48afd6..3c4b2b60d019 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -708,7 +708,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim assert block_size == self.block_size else: - raise RuntimeError(f"{self.device_type} is not supported yet.") + raise RuntimeError( + f"{self.device_type} ({self.backend_name}) is not supported.") # TODO(tms): self.block_len needs to be per-layer for sliding window, # hybrid attn, etc From 4961a326b3304d53d152ebd6555e1de33fcbbe8c Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Thu, 26 Jun 2025 04:40:26 +0000 Subject: [PATCH 08/19] fix comments Signed-off-by: Juncheng Gu --- .../kv_transfer/kv_connector/v1/base.py | 5 +++-- .../kv_connector/v1/nixl_connector.py | 21 +++++++++++-------- vllm/v1/executor/multiproc_executor.py | 3 --- vllm/v1/worker/tpu_model_runner.py | 16 ++++---------- 4 files changed, 19 insertions(+), 26 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 874c292b33cb..9c17d75cfd06 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -32,7 +32,7 @@ import enum from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable, Literal, Optional import torch @@ -48,7 +48,8 @@ # s_tensor_list, d_tensor_list, s_indices, d_indices, direction CopyBlocksOp = Callable[[ - dict[str, torch.Tensor], dict[str, torch.Tensor], list[int], list[int], str + dict[str, torch.Tensor], dict[str, + torch.Tensor], list[int], list[int], Literal ], None] logger = init_logger(__name__) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 3c4b2b60d019..aa5aaa5522f6 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import contextlib +import logging import math import queue import threading @@ -977,10 +978,11 @@ def sync_recved_kv_to_device(self, local_block_ids = meta.local_block_ids self.copy_blocks(self.host_xfer_buffers, self.device_kv_caches, local_block_ids, local_block_ids, "h2d") - logger.debug( - "synced recved kv of request[%s] to device kv buffer," - "local_block_ids: %s. ", req_id, - ",".join(map(str, meta.local_block_ids))) + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "synced recved kv of request[%s] to device kv buffer," + "local_block_ids: %s. ", req_id, + ",".join(map(str, meta.local_block_ids))) return def save_kv_to_host(self, metadata: NixlConnectorMetadata): @@ -993,11 +995,12 @@ def save_kv_to_host(self, metadata: NixlConnectorMetadata): # local prefill requests only if not meta.do_remote_decode: continue + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "save_load_kv for request[%s] to host xfer buffer." + "local_block_ids: %s. ", req_id, + ",".join(map(str, meta.local_block_ids))) # blocking - logger.debug( - "save_load_kv for request[%s] to host xfer buffer." - "local_block_ids: %s. ", req_id, - ",".join(map(str, meta.local_block_ids))) self.copy_blocks(self.device_kv_caches, self.host_xfer_buffers, meta.local_block_ids, meta.local_block_ids, "d2h") return @@ -1022,7 +1025,7 @@ def get_finished(self) -> tuple[set[str], set[str]]: "and %s requests done recving", self.tp_rank, len(done_sending), len(done_recving)) - if self.use_host_buffer and self.copy_blocks is not None: + if self.use_host_buffer: for req_id in done_recving: meta = self._recving_metadata.pop(req_id) assert meta, (f"{req_id} not found in recving_metadata list") diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 9da900521961..b06b7cc804d5 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -37,9 +37,6 @@ logger = init_logger(__name__) -POLLING_TIMEOUT_MS = 5000 -POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000 - class MultiprocExecutor(Executor): diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 85f09b81191b..309cdefb1859 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -3,7 +3,7 @@ import bisect import gc import time -from typing import TYPE_CHECKING, Optional, Union, cast +from typing import TYPE_CHECKING, Literal, Optional, Union, cast from unittest.mock import patch import numpy as np @@ -1710,22 +1710,16 @@ def copy_kv_blocks( dst_kv_caches: dict[str, torch.Tensor], src_block_ids: list[int], dst_block_ids: list[int], - direction: str, + direction: Literal["h2d", "d2h"], ) -> None: """Copy kv blocks between different buffers.""" - direction = direction.strip().lower() - assert direction in ("h2d", "d2h",), \ - (f"Invalid direction: {direction}") - if not src_kv_caches or not dst_kv_caches or \ not src_block_ids or not dst_block_ids or \ len(src_block_ids) != len(dst_block_ids): return - _, src_kv_cache = next(iter(src_kv_caches.items())) - src_device = src_kv_cache.device - _, dst_kv_cache = next(iter(dst_kv_caches.items())) - dst_device = dst_kv_cache.device + src_device = next(iter(src_kv_caches.values())).device + dst_device = next(iter(dst_kv_caches.values())).device src_indices, dst_indices = _make_src_and_dst_indices( src_block_ids=src_block_ids, @@ -1740,8 +1734,6 @@ def copy_kv_blocks( dst_tensor = dst_kv_caches[layer_name] _copy_fn(src_tensor, dst_tensor, src_indices, dst_indices) - return - def replace_set_lora(model): From e6c61deb78a64508e543653018739bc76f35b240 Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Thu, 26 Jun 2025 17:10:24 +0000 Subject: [PATCH 09/19] fix for multi-forwards in a single input batch Signed-off-by: Juncheng Gu --- vllm/v1/worker/tpu_model_runner.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 44199c828b16..0ef6358964ac 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -913,9 +913,12 @@ def execute_model( start_index = 0 combined_selected_tokens: list[torch.Tensor] = [] combined_logprobs: list[LogprobsLists] = [] - # for kv transfer - finished_sending: set[str] = set() - finished_recving: set[str] = set() + + # NOTE: setup current batch's metadata for kv connector. + # Currently, only verified with NixlConnector + with set_forward_context(None, self.vllm_config): + self.maybe_setup_kv_connector(scheduler_output) + while start_index < self.input_batch.num_reqs: attn_metadata, logits_indices, padded_num_reqs, num_reqs,\ end_index = self._prepare_inputs(scheduler_output, start_index) @@ -927,19 +930,11 @@ def execute_model( attn_metadata, self.vllm_config, num_tokens=scheduler_output.total_num_scheduled_tokens): - self.maybe_setup_kv_connector(scheduler_output) hidden_states = self.model( input_ids=input_ids, positions=self.position_ids, inputs_embeds=inputs_embeds, ) - self.maybe_wait_for_kv_save() - _finished_sending, _finished_recving = ( - self.get_finished_kv_transfers(scheduler_output)) - if _finished_sending: - finished_sending.update(_finished_sending) - if _finished_recving: - finished_recving.update(_finished_recving) hidden_states = self.select_hidden_states(hidden_states, logits_indices) logits = self.compute_logits(hidden_states) @@ -970,6 +965,14 @@ def execute_model( start_index = end_index + # NOTE: current kv load and save get h2d/d2h copies involved. + # Those copies are blocking. Once they become async., kv_save + # should be called right after each single forward pass, + # instead of the forwards of the entire input batch. + self.maybe_wait_for_kv_save() + finished_sending, finished_recving = ( + self.get_finished_kv_transfers(scheduler_output)) + selected_token_ids = torch.cat(combined_selected_tokens, dim=0) if tpu_sampling_metadata.logprobs: From 73d4ff3a6dbfb442f6f6cf5d9bda8ed7c343fd5a Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Mon, 30 Jun 2025 18:24:27 +0000 Subject: [PATCH 10/19] fix comments Signed-off-by: Juncheng Gu --- .../kv_connector/v1/nixl_connector.py | 36 +++++++++---------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 3da44a4131b3..f362e5435eac 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -101,6 +101,9 @@ class ReqMeta: tp_size: int do_remote_prefill: bool = False do_remote_decode: bool = False + # NOTE: needed when use_host_buffer is true. + do_save_to_host: bool = False + do_load_to_device: bool = False class NixlConnectorMetadata(KVConnectorMetadata): @@ -113,9 +116,10 @@ def add_new_req( request_id: ReqId, local_block_ids: list[int], kv_transfer_params: dict[str, Any], - customize_kv_transfer_params: Optional[dict[str, Any]] = None, + do_save_to_host: bool = False, + do_load_to_device: bool = False, ): - _req_meta = ReqMeta( + self.requests[request_id] = ReqMeta( local_block_ids=local_block_ids, remote_block_ids=kv_transfer_params["remote_block_ids"], remote_engine_id=kv_transfer_params["remote_engine_id"], @@ -125,13 +129,9 @@ def add_new_req( tp_size=kv_transfer_params.get("tp_size", 1), do_remote_prefill=kv_transfer_params["do_remote_prefill"], do_remote_decode=kv_transfer_params["do_remote_decode"], + do_save_to_host=do_save_to_host, + do_load_to_device=do_load_to_device, ) - if customize_kv_transfer_params: - for param, value in customize_kv_transfer_params.items(): - if hasattr(_req_meta, param): - setattr(_req_meta, param, value) - - self.requests[request_id] = _req_meta class NixlConnector(KVConnectorBase_V1): @@ -292,7 +292,7 @@ def update_state_after_alloc(self, request: "Request", if not params: return - if params.get("do_remote_decode"): + if self.use_host_buffer and params.get("do_remote_decode"): # NOTE: when kv_buffer_device (e.) is not supported by Nixl, # prefilled blocks need to be saved to host memory before transfer. @@ -337,28 +337,22 @@ def build_connector_meta( meta = NixlConnectorMetadata() # Loop through scheduled reqs and convert to ReqMeta. - _customize_params = { - "do_remote_prefill": True, - } if self.use_host_buffer else {} for req_id, (req, block_ids) in self._reqs_need_recv.items(): assert req.kv_transfer_params is not None meta.add_new_req( request_id=req_id, local_block_ids=block_ids, kv_transfer_params=req.kv_transfer_params, - customize_kv_transfer_params=_customize_params, + do_load_to_device=self.use_host_buffer, ) - _customize_params = { - "do_remote_decode": True, - } if self.use_host_buffer else {} for req_id, (req, block_ids) in self._reqs_need_save.items(): assert req.kv_transfer_params is not None meta.add_new_req( request_id=req_id, local_block_ids=block_ids, kv_transfer_params=req.kv_transfer_params, - customize_kv_transfer_params=_customize_params, + do_save_to_host=self.use_host_buffer, ) # Clear the list once workers start the transfers @@ -979,7 +973,7 @@ def sync_recved_kv_to_device(self, meta = self._recving_metadata.get(req_id) if meta and req_id not in self._recving_transfers: # local decode only - if not meta.do_remote_prefill: + if not meta.do_load_to_device: return local_block_ids = meta.local_block_ids self.copy_blocks(self.host_xfer_buffers, self.device_kv_caches, @@ -999,7 +993,7 @@ def save_kv_to_host(self, metadata: NixlConnectorMetadata): for req_id, meta in metadata.requests.items(): # local prefill requests only - if not meta.do_remote_decode: + if not meta.do_save_to_host: continue if logger.isEnabledFor(logging.DEBUG): logger.debug( @@ -1133,7 +1127,9 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): We check for these trnxs to complete in each step(). """ for req_id, meta in metadata.requests.items(): - if not meta.do_remote_prefill: + # NOTE: when host xfer buffer is used, only load kv + # for requests with do_load_to_device = True. + if self.use_host_buffer and not meta.do_load_to_device: continue remote_engine_id = meta.remote_engine_id logger.debug( From 787e9adbbe734f25e06acf423b4ca7c9610c3abf Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Wed, 2 Jul 2025 19:30:19 +0000 Subject: [PATCH 11/19] revise h2d/d2h attributes Signed-off-by: Juncheng Gu --- .../kv_connector/v1/nixl_connector.py | 49 +++++++------------ 1 file changed, 19 insertions(+), 30 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index dea7ab6f133e..50be88feb3e6 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -99,9 +99,7 @@ class ReqMeta: remote_port: int remote_engine_id: str tp_size: int - do_remote_prefill: bool = False - do_remote_decode: bool = False - # NOTE: needed when use_host_buffer is true. + # NOTE: needed when host_xfer_buffer is used. do_save_to_host: bool = False do_load_to_device: bool = False @@ -127,8 +125,6 @@ def add_new_req( remote_port=kv_transfer_params["remote_port"], # P workers don't need to receive tp_size from proxy here. tp_size=kv_transfer_params.get("tp_size", 1), - do_remote_prefill=kv_transfer_params["do_remote_prefill"], - do_remote_decode=kv_transfer_params["do_remote_decode"], do_save_to_host=do_save_to_host, do_load_to_device=do_load_to_device, ) @@ -346,13 +342,14 @@ def build_connector_meta( do_load_to_device=self.use_host_buffer, ) + # NOTE: only needed when use_host_buffer is true. for req_id, (req, block_ids) in self._reqs_need_save.items(): assert req.kv_transfer_params is not None meta.add_new_req( request_id=req_id, local_block_ids=block_ids, kv_transfer_params=req.kv_transfer_params, - do_save_to_host=self.use_host_buffer, + do_save_to_host=True, ) # Clear the list once workers start the transfers @@ -988,34 +985,26 @@ def add_remote_agent(self, return remote_agent_name - def sync_recved_kv_to_device(self, - req_id: str, - meta: Optional[ReqMeta] = None): + def sync_recved_kv_to_device(self, req_id: str, meta: ReqMeta): """copy recved kv from host buffer to device.""" - if not self.use_host_buffer: - return + assert self.use_host_buffer assert self.copy_blocks is not None + assert req_id not in self._recving_transfers + if not meta.do_load_to_device: + return - if meta is None: - meta = self._recving_metadata.get(req_id) - if meta and req_id not in self._recving_transfers: - # local decode only - if not meta.do_load_to_device: - return - local_block_ids = meta.local_block_ids - self.copy_blocks(self.host_xfer_buffers, self.device_kv_caches, - local_block_ids, local_block_ids, "h2d") - if logger.isEnabledFor(logging.DEBUG): - logger.debug( - "synced recved kv of request[%s] to device kv buffer," - "local_block_ids: %s. ", req_id, - ",".join(map(str, meta.local_block_ids))) - return + local_block_ids = meta.local_block_ids + self.copy_blocks(self.host_xfer_buffers, self.device_kv_caches, + local_block_ids, local_block_ids, "h2d") + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "synced recved kv of request[%s] to device kv buffer," + "local_block_ids: %s. ", req_id, + ",".join(map(str, meta.local_block_ids))) def save_kv_to_host(self, metadata: NixlConnectorMetadata): """copy kv from device to host buffer.""" - if not self.use_host_buffer: - return + assert self.use_host_buffer assert self.copy_blocks is not None for req_id, meta in metadata.requests.items(): @@ -1030,7 +1019,6 @@ def save_kv_to_host(self, metadata: NixlConnectorMetadata): # blocking self.copy_blocks(self.device_kv_caches, self.host_xfer_buffers, meta.local_block_ids, meta.local_block_ids, "d2h") - return def get_finished(self) -> tuple[set[str], set[str]]: """ @@ -1164,7 +1152,8 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): "Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id, remote_engine_id, len(meta.local_block_ids), len(meta.remote_block_ids)) - self._recving_metadata[req_id] = meta + if self.use_host_buffer: + self._recving_metadata[req_id] = meta if remote_engine_id not in self._remote_agents: # Initiate handshake with remote engine to exchange metadata. with self._handshake_lock: From 6fa07b709ea4b513eb9832032bb8fbb4d11a98e6 Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Thu, 3 Jul 2025 05:02:55 +0000 Subject: [PATCH 12/19] rename ReqMeta attributes for use_host_buffer Signed-off-by: Juncheng Gu --- .../kv_connector/v1/nixl_connector.py | 33 +++++++++---------- 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 50be88feb3e6..98f97756726f 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -99,9 +99,11 @@ class ReqMeta: remote_port: int remote_engine_id: str tp_size: int - # NOTE: needed when host_xfer_buffer is used. - do_save_to_host: bool = False - do_load_to_device: bool = False + # load kv cache from remote engine / agent + load_remote_cache: bool = True + # NOTE: save kv cache from accelerator to local + # host buffer; needed when kv_cache_device is cpu. + save_to_host: bool = False class NixlConnectorMetadata(KVConnectorMetadata): @@ -114,8 +116,8 @@ def add_new_req( request_id: ReqId, local_block_ids: list[int], kv_transfer_params: dict[str, Any], - do_save_to_host: bool = False, - do_load_to_device: bool = False, + load_remote_cache: bool = True, + save_to_host: bool = False, ): self.requests[request_id] = ReqMeta( local_block_ids=local_block_ids, @@ -125,8 +127,8 @@ def add_new_req( remote_port=kv_transfer_params["remote_port"], # P workers don't need to receive tp_size from proxy here. tp_size=kv_transfer_params.get("tp_size", 1), - do_save_to_host=do_save_to_host, - do_load_to_device=do_load_to_device, + load_remote_cache=load_remote_cache, + save_to_host=save_to_host, ) @@ -289,7 +291,7 @@ def update_state_after_alloc(self, request: "Request", if not params: return if self.use_host_buffer and params.get("do_remote_decode"): - # NOTE: when kv_buffer_device (e.) is not supported by Nixl, + # NOTE: when accelerator is not directly supported by Nixl, # prefilled blocks need to be saved to host memory before transfer. # figure out full computed blocks to save @@ -339,17 +341,16 @@ def build_connector_meta( request_id=req_id, local_block_ids=block_ids, kv_transfer_params=req.kv_transfer_params, - do_load_to_device=self.use_host_buffer, ) - # NOTE: only needed when use_host_buffer is true. for req_id, (req, block_ids) in self._reqs_need_save.items(): assert req.kv_transfer_params is not None meta.add_new_req( request_id=req_id, local_block_ids=block_ids, kv_transfer_params=req.kv_transfer_params, - do_save_to_host=True, + load_remote_cache=False, + save_to_host=True, ) # Clear the list once workers start the transfers @@ -989,8 +990,7 @@ def sync_recved_kv_to_device(self, req_id: str, meta: ReqMeta): """copy recved kv from host buffer to device.""" assert self.use_host_buffer assert self.copy_blocks is not None - assert req_id not in self._recving_transfers - if not meta.do_load_to_device: + if not meta.load_remote_cache: return local_block_ids = meta.local_block_ids @@ -1008,8 +1008,7 @@ def save_kv_to_host(self, metadata: NixlConnectorMetadata): assert self.copy_blocks is not None for req_id, meta in metadata.requests.items(): - # local prefill requests only - if not meta.do_save_to_host: + if not meta.save_to_host: continue if logger.isEnabledFor(logging.DEBUG): logger.debug( @@ -1142,9 +1141,7 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): We check for these trnxs to complete in each step(). """ for req_id, meta in metadata.requests.items(): - # NOTE: when host xfer buffer is used, only load kv - # for requests with do_load_to_device = True. - if self.use_host_buffer and not meta.do_load_to_device: + if not meta.load_remote_cache: continue remote_engine_id = meta.remote_engine_id logger.debug( From 9877752f232d63052c40fdc77f5dab90126ecd3f Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Mon, 7 Jul 2025 17:41:16 +0000 Subject: [PATCH 13/19] revise _NIXL_SUPPORTED_XPUS Signed-off-by: Juncheng Gu --- .../kv_connector/v1/nixl_connector.py | 46 ++++++------------- 1 file changed, 13 insertions(+), 33 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 98f97756726f..def60734cd39 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -54,28 +54,12 @@ logger.warning("NIXL is not available") NixlWrapper = None - -class _NIXL_SUPPORTED_XPU: - """ - xPUs and the corresponding types of kv transfer buffer - supported by NIXLConnector - """ - # {xPU: tuple of supported kv buffer types} - # TODO: "cpu" xfer buffer for cuda - _support_dict = { - "cuda": ("cuda", ), - "tpu": ("cpu", ), - } - - @classmethod - def is_supported_xpu(cls, device_type: str) -> bool: - return device_type in cls._support_dict - - @classmethod - def is_supported_kv_buffer(cls, device_type: str, - kv_buffer_type: str) -> bool: - return (device_type in cls._support_dict - and kv_buffer_type in cls._support_dict[device_type]) +# Supported xPUs and types of kv transfer buffer. +# {xPU: tuple of supported kv buffer types} +_NIXL_SUPPORTED_XPUS = { + "cuda": ("cuda", ), + "tpu": ("cpu", ), +} class NixlAgentMetadata( @@ -101,8 +85,8 @@ class ReqMeta: tp_size: int # load kv cache from remote engine / agent load_remote_cache: bool = True - # NOTE: save kv cache from accelerator to local - # host buffer; needed when kv_cache_device is cpu. + # NOTE: save kv cache from device to local host buffer. + # needed when kv_cache_device is cpu. save_to_host: bool = False @@ -418,12 +402,6 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): logger.info("Initializing NIXL wrapper") logger.info("Initializing NIXL worker %s", engine_id) - self.device_type = current_platform.device_type - if not _NIXL_SUPPORTED_XPU.is_supported_xpu( - device_type=self.device_type): - logger.error("%s is not supported.", self.device_type) - raise RuntimeError(f"{self.device_type} is not supported.") - # Config. self.vllm_config = vllm_config self.block_size = vllm_config.cache_config.block_size @@ -450,11 +428,13 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.num_blocks = 0 # KV Caches and nixl tracking data. + self.device_type = current_platform.device_type self.kv_buffer_device: str = \ vllm_config.kv_transfer_config.kv_buffer_device - if not _NIXL_SUPPORTED_XPU.is_supported_kv_buffer( - device_type=self.device_type, - kv_buffer_type=self.kv_buffer_device): + if self.device_type not in _NIXL_SUPPORTED_XPUS: + raise RuntimeError(f"{self.device_type} is not supported.") + elif self.kv_buffer_device not in _NIXL_SUPPORTED_XPUS[ + self.device_type]: raise RuntimeError( f"{self.device_type} with {self.kv_buffer_device} kv_buffer " "is not supported.") From e21615176d338a8385c402a85133e2c279467650 Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Mon, 7 Jul 2025 18:10:48 +0000 Subject: [PATCH 14/19] SPDX license Signed-off-by: Juncheng Gu --- tests/v1/kv_connector/nixl_integration/test_disagg_accuracy.py | 1 + vllm/v1/worker/kv_connector_model_runner_mixin.py | 1 + 2 files changed, 2 insertions(+) diff --git a/tests/v1/kv_connector/nixl_integration/test_disagg_accuracy.py b/tests/v1/kv_connector/nixl_integration/test_disagg_accuracy.py index fe30d0fbaaec..00e62f351ce3 100644 --- a/tests/v1/kv_connector/nixl_integration/test_disagg_accuracy.py +++ b/tests/v1/kv_connector/nixl_integration/test_disagg_accuracy.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import argparse import json import os diff --git a/vllm/v1/worker/kv_connector_model_runner_mixin.py b/vllm/v1/worker/kv_connector_model_runner_mixin.py index 9da1356953ba..fc2a9e2ce44a 100644 --- a/vllm/v1/worker/kv_connector_model_runner_mixin.py +++ b/vllm/v1/worker/kv_connector_model_runner_mixin.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ Define KV connector functionality mixin for model runners. """ From 0637bdfcd0c3315259e5a96e3291c886764f0efc Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Mon, 7 Jul 2025 21:11:37 +0000 Subject: [PATCH 15/19] use tp_rank for device_id in nixl data block Signed-off-by: Juncheng Gu --- .../kv_connector/v1/nixl_connector.py | 18 ++---------------- 1 file changed, 2 insertions(+), 16 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index def60734cd39..f8f95c44ee8d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -439,8 +439,6 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): f"{self.device_type} with {self.kv_buffer_device} kv_buffer " "is not supported.") self.device_kv_caches: dict[str, torch.Tensor] = {} - self.device: torch.device = None - self.device_index: int = -1 # cpu kv buffer for xfer # used when xPU memory can not be registered under nixl @@ -731,16 +729,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): first_kv_cache.shape) self.dst_num_blocks[self.engine_id] = self.num_blocks self.device_kv_caches = kv_caches - self.device = first_kv_cache.device - # Note: non-CUDA devices may have a fixed device.index (0), - # use its tp_rank instead - self.device_index = (self.tp_rank if self.use_host_buffer or - self.device_type != "cuda" else self.device.index) - - assert self.device - assert self.device_index >= 0, \ - f"cache device {self.device} index is invalid" - kv_caches_base_addr = [] caches_data = [] @@ -760,9 +748,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): for cache in cache_list: base_addr = cache.data_ptr() region_len = self.num_blocks * self.block_len - # TODO: does device_id matter to DRAM? - caches_data.append( - (base_addr, region_len, self.device_index, "")) + caches_data.append((base_addr, region_len, self.tp_rank, "")) kv_caches_base_addr.append(base_addr) self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr self.num_regions = len(caches_data) @@ -808,7 +794,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): addr = base_addr + block_offset # (addr, len, device id) # TODO: does device_id matter to DRAM? - blocks_data.append((addr, self.block_len, self.device_index)) + blocks_data.append((addr, self.block_len, self.tp_rank)) logger.debug("Created %s blocks for src engine %s and rank %s", len(blocks_data), self.engine_id, self.tp_rank) From 0ac9a6f98a95428a573ba6712b4e9daf4c2afee7 Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Mon, 7 Jul 2025 21:16:19 +0000 Subject: [PATCH 16/19] use tp_rank for device_id in nixl data block Signed-off-by: Juncheng Gu --- vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index f8f95c44ee8d..7a930e605247 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -748,6 +748,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): for cache in cache_list: base_addr = cache.data_ptr() region_len = self.num_blocks * self.block_len + # NOTE: use tp_rank for device_id since multi-node TP + # is rarely used. caches_data.append((base_addr, region_len, self.tp_rank, "")) kv_caches_base_addr.append(base_addr) self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr From 2ba3daa7547f198d2aea50a3ce2022e9a8808c33 Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Wed, 9 Jul 2025 16:56:21 +0000 Subject: [PATCH 17/19] rm redundant code Signed-off-by: Juncheng Gu --- .../kv_connector/v1/nixl_connector.py | 18 ++---------------- 1 file changed, 2 insertions(+), 16 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index f07abf637302..1e9deb209f42 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -83,11 +83,6 @@ class ReqMeta: remote_port: int remote_engine_id: str tp_size: int - # load kv cache from remote engine / agent - load_remote_cache: bool = True - # NOTE: save kv cache from device to local host buffer. - # needed when kv_cache_device is cpu. - save_to_host: bool = False class NixlConnectorMetadata(KVConnectorMetadata): @@ -115,12 +110,10 @@ def add_new_req( remote_port=kv_transfer_params["remote_port"], # P workers don't need to receive tp_size from proxy here. tp_size=kv_transfer_params.get("tp_size", 1), - load_remote_cache=load_remote_cache, - save_to_host=save_to_host, ) if save_to_host: self.reqs_to_save[request_id] = _req - else: + if load_remote_cache: self.reqs_to_recv[request_id] = _req @@ -209,7 +202,6 @@ def wait_for_save(self): assert self.connector_worker is not None assert isinstance(self._connector_metadata, NixlConnectorMetadata) self.connector_worker.save_kv_to_host(self._connector_metadata) - return class NixlConnectorScheduler: @@ -978,8 +970,6 @@ def sync_recved_kv_to_device(self, req_id: str, meta: ReqMeta): """copy recved kv from host buffer to device.""" assert self.use_host_buffer assert self.copy_blocks is not None - if not meta.load_remote_cache: - return local_block_ids = meta.local_block_ids self.copy_blocks(self.host_xfer_buffers, self.device_kv_caches, @@ -996,8 +986,6 @@ def save_kv_to_host(self, metadata: NixlConnectorMetadata): assert self.copy_blocks is not None for req_id, meta in metadata.reqs_to_save.items(): - if not meta.save_to_host: - continue if logger.isEnabledFor(logging.DEBUG): logger.debug( "save_load_kv for request[%s] to host xfer buffer." @@ -1030,7 +1018,7 @@ def get_finished(self) -> tuple[set[str], set[str]]: if self.use_host_buffer: for req_id in done_recving: meta = self._recving_metadata.pop(req_id) - assert meta, (f"{req_id} not found in recving_metadata list") + assert meta, f"{req_id} not found in recving_metadata list" self.sync_recved_kv_to_device(req_id, meta) # Handle timeout to avoid stranding blocks on remote. @@ -1140,8 +1128,6 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): We check for these trnxs to complete in each step(). """ for req_id, meta in metadata.reqs_to_recv.items(): - if not meta.load_remote_cache: - continue remote_engine_id = meta.remote_engine_id logger.debug( "start_load_kv for request %s from remote engine %s. " From 4e72df9696aca423d2adc35fb27cd70438565bfb Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Wed, 9 Jul 2025 19:30:17 +0000 Subject: [PATCH 18/19] fix assertation Signed-off-by: Juncheng Gu --- .../distributed/kv_transfer/kv_connector/v1/nixl_connector.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 1e9deb209f42..cf4e0fb590d9 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -201,7 +201,9 @@ def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, def wait_for_save(self): assert self.connector_worker is not None assert isinstance(self._connector_metadata, NixlConnectorMetadata) - self.connector_worker.save_kv_to_host(self._connector_metadata) + if self.connector_worker.use_host_buffer and \ + self.connector_worker.copy_blocks: + self.connector_worker.save_kv_to_host(self._connector_metadata) class NixlConnectorScheduler: From aca42465bf0414180701696b3593db19b0404339 Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Fri, 11 Jul 2025 04:47:43 +0000 Subject: [PATCH 19/19] update tpu worker & model_runner Signed-off-by: Juncheng Gu --- vllm/v1/worker/tpu_model_runner.py | 14 +++++--------- vllm/v1/worker/tpu_worker.py | 25 +++++++++++++++++++++++-- 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 8e58f620effd..3a6e3a608864 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -952,12 +952,12 @@ def execute_model( # Update cached state self._update_states(scheduler_output) if not scheduler_output.total_num_scheduled_tokens: - if not has_kv_transfer_group(): - # Return empty ModelRunnerOutput if there's no work to do. - return EMPTY_MODEL_RUNNER_OUTPUT + if has_kv_transfer_group(): + with set_forward_context(None, self.vllm_config): + self.maybe_setup_kv_connector(scheduler_output) - return self.kv_connector_no_forward(scheduler_output, - self.vllm_config) + # Return empty ModelRunnerOutput if there's no work to do. + return EMPTY_MODEL_RUNNER_OUTPUT if self.is_multimodal_model: # Run the multimodal encoder if any. @@ -1028,8 +1028,6 @@ def execute_model( # should be called right after each single forward pass, # instead of the forwards of the entire input batch. self.maybe_wait_for_kv_save() - finished_sending, finished_recving = ( - self.get_finished_kv_transfers(scheduler_output)) selected_token_ids = torch.cat(combined_selected_tokens, dim=0) if tpu_sampling_metadata.logprobs: @@ -1125,8 +1123,6 @@ def concat_lists(input_lists): logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, pooler_output=[], - finished_sending=finished_sending, - finished_recving=finished_recving, ) # Check there are no new graphs compiled - all the graphs should be diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 6d8e1afcc971..0f923965801e 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """A TPU worker class.""" +import copy import os from typing import Optional @@ -15,7 +16,9 @@ from vllm.config import VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) -from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized +from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized, + get_kv_transfer_group, + has_kv_transfer_group) from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed @@ -25,7 +28,7 @@ from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheConfig, KVCacheSpec) -from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput from vllm.v1.utils import bind_kv_cache, report_usage_stats from vllm.v1.worker.tpu_model_runner import TPUModelRunner @@ -242,6 +245,24 @@ def execute_model( scheduler_output: "SchedulerOutput", ) -> Optional[ModelRunnerOutput]: output = self.model_runner.execute_model(scheduler_output) + assert isinstance(output, ModelRunnerOutput) + if has_kv_transfer_group(): + finished_sending, finished_recving = ( + get_kv_transfer_group().get_finished( + scheduler_output.finished_req_ids)) + if finished_sending or finished_recving: + if output is EMPTY_MODEL_RUNNER_OUTPUT: + output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) + output.finished_sending = finished_sending + output.finished_recving = finished_recving + + # Clear KVConnector state for this step. + get_kv_transfer_group().clear_connector_metadata() + + # with a connector, the scheduler expects output from all workers + return output + + # return output only from the driver worker return output if self.is_driver_worker else None def profile(self, is_start: bool = True):