diff --git a/requirements/tpu.txt b/requirements/tpu.txt index db58b37c2b1..98d74cdeb39 100644 --- a/requirements/tpu.txt +++ b/requirements/tpu.txt @@ -10,6 +10,7 @@ jinja2>=3.1.6 ray[default] ray[data] setuptools==78.1.0 +nixl==0.3.0 # Install torch_xla --pre diff --git a/tests/v1/kv_connector/nixl_integration/run_tpu_disagg_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/run_tpu_disagg_accuracy_test.sh new file mode 100644 index 00000000000..45779d16914 --- /dev/null +++ b/tests/v1/kv_connector/nixl_integration/run_tpu_disagg_accuracy_test.sh @@ -0,0 +1,162 @@ +#!/bin/bash +set -xe + +# Hosts / ports +PREFILL_HOST=${PREFILL_HOST:-"localhost"} +PREFILL_PORT=${PREFILL_PORT:-8100} +PREFILL_NIXL_SIDE_PORT=${PREFILL_NIXL_SIDE_PORT:-5577} +DECODE_HOST=${DECODE_HOST:-"localhost"} +DECODE_PORT=${DECODE_PORT:-8200} +PROXY_HOST=${PROXY_HOST:-"localhost"} +PROXY_PORT=${PROXY_PORT:-8192} +BASELINE_HOST=${BASELINE_HOST:-"localhost"} +BASELINE_PORT=${BASELINE_PORT:-9290} + + +# Model to run. +MODEL_NAME=${MODEL_NAME:-"meta-llama/Llama-3.2-3B-Instruct"} +MAX_MODEL_LEN=${MAX_MODEL_LEN:-1024} +BLOCK_SIZE=${BLOCK_SIZE:-32} + + +# execution env +GIT_ROOT=$(git rev-parse --show-toplevel) +EXP_ROOT="${GIT_ROOT}/tests/v1/kv_connector/nixl_integration" +CONDA_PATH=${CONDA_PATH:-"/home/${USER}/anaconda3"} +CONDA_ENV_NAME=${CONDA_ENV_NAME:-"nixl"} + +OUTPUT_FILE=${OUTPUT_FILE:-"${EXP_ROOT}/.tpu_accuracy_test_outputs.txt"} + +# Trap the SIGINT signal (triggered by Ctrl+C) +trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT + + +# Waits for vLLM server to start. +wait_for_server() { + local host=$1 + local port=$2 + timeout 1200 bash -c " + until curl -s ${host}:${port}/v1/completions > /dev/null; do + sleep 1 + done" && return 0 || return 1 +} + +# Cleanup function +cleanup() { + echo "Caught Ctrl+C, cleaning up..." + # Cleanup commands + pgrep python | xargs kill -9 || true + # pkill -f python || true + echo "Cleanup complete. Exiting." +} + +launch_baseline() { + BASELINE_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME}; + VLLM_LOGGING_LEVEL=DEBUG \ + VLLM_USE_V1=1 \ + PJRT_DEVICE=TPU \ + VLLM_WORKER_MULTIPROC_METHOD=spawn \ + VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve $MODEL_NAME \ + --host ${BASELINE_HOST} \ + --port ${BASELINE_PORT} \ + --max-model-len ${MAX_MODEL_LEN}\ + --seed 42 \ + --block-size ${BLOCK_SIZE} \ + --gpu-memory-utilization 0.5 \ + --disable-log-requests \ + --enforce-eager" + echo ${BASELINE_BASE_CMD} + ssh -tt ${BASELINE_HOST} "${BASELINE_BASE_CMD}" & +} + +launch_pd() { + PREFILL_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME}; + UCX_TLS=tcp \ + VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S=200 \ + VLLM_LOGGING_LEVEL=DEBUG \ + VLLM_USE_V1=1 \ + VLLM_NIXL_SIDE_CHANNEL_HOST=${PREFILL_HOST} \ + VLLM_NIXL_SIDE_CHANNEL_PORT=${PREFILL_NIXL_SIDE_PORT} \ + PJRT_DEVICE=TPU \ + VLLM_WORKER_MULTIPROC_METHOD=spawn \ + VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve $MODEL_NAME \ + --host ${PREFILL_HOST} \ + --port ${PREFILL_PORT} \ + --max-model-len ${MAX_MODEL_LEN}\ + --seed 42 \ + --block-size ${BLOCK_SIZE} \ + --enforce-eager \ + --gpu-memory-utilization 0.5 \ + --disable-log-requests \ + --kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"cpu\"}'" + + + DECODE_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME}; + UCX_TLS=tcp \ + VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S=200 \ + VLLM_LOGGING_LEVEL=DEBUG \ + VLLM_USE_V1=1 \ + PJRT_DEVICE=TPU \ + VLLM_WORKER_MULTIPROC_METHOD=spawn \ + VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve $MODEL_NAME \ + --host ${DECODE_HOST} \ + --port ${DECODE_PORT} \ + --max-model-len ${MAX_MODEL_LEN}\ + --seed 42 \ + --block-size ${BLOCK_SIZE} \ + --enforce-eager \ + --gpu-memory-utilization 0.5 \ + --disable-log-requests \ + --kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"cpu\"}'" + + echo ${PREFILL_BASE_CMD} + echo ${DECODE_BASE_CMD} + sleep 2 + + # execute on hosts + ssh -tt ${PREFILL_HOST} "${PREFILL_BASE_CMD}" & + ssh -tt ${DECODE_HOST} "${DECODE_BASE_CMD}" & + sleep 1 + wait_for_server ${PREFILL_HOST} ${PREFILL_PORT} + sleep 1 + wait_for_server ${DECODE_HOST} ${DECODE_PORT} + sleep 1 +} + +launch_pd_proxy(){ + PROXY_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME}; + python3 ${EXP_ROOT}/toy_proxy_server.py \ + --prefiller-host ${PREFILL_HOST} --prefiller-port ${PREFILL_PORT} \ + --decoder-host ${DECODE_HOST} --decoder-port ${DECODE_PORT} \ + --host=${PROXY_HOST} --port ${PROXY_PORT}" + echo ${PROXY_BASE_CMD} + ssh -tt ${PROXY_HOST} "${PROXY_BASE_CMD}" & +} + +run_tests(){ + local service_url=$1 + local mode=$2 + python3 ${EXP_ROOT}/test_disagg_accuracy.py --service_url=${service_url} --model_name=${MODEL_NAME} --mode=${mode} --file_name=${OUTPUT_FILE} +} + + +# run non-disagg. baseline & save outputs +launch_baseline +sleep 2 +wait_for_server ${BASELINE_HOST} ${BASELINE_PORT} +run_tests "http://${BASELINE_HOST}:${BASELINE_PORT}" "baseline" +cleanup +sleep 10 + + +# run disagg. & do exact-match with the outputs from baseline +launch_pd +launch_pd_proxy +sleep 10 +run_tests "http://${PROXY_HOST}:${PROXY_PORT}" "disagg" +echo "-----P/D success----" + +rm ${OUTPUT_FILE} +cleanup + +exit 0 \ No newline at end of file diff --git a/tests/v1/kv_connector/nixl_integration/run_tpu_edge_case_test.sh b/tests/v1/kv_connector/nixl_integration/run_tpu_edge_case_test.sh new file mode 100644 index 00000000000..c37c92fdf5d --- /dev/null +++ b/tests/v1/kv_connector/nixl_integration/run_tpu_edge_case_test.sh @@ -0,0 +1,128 @@ +#!/bin/bash +set -xe + +# Hosts / ports +PREFILL_HOST=${PREFILL_HOST:-"localhost"} +PREFILL_PORT=${PREFILL_PORT:-8100} +PREFILL_NIXL_SIDE_PORT=${PREFILL_NIXL_SIDE_PORT:-5577} +DECODE_HOST=${DECODE_HOST:-"localhost"} +DECODE_PORT=${DECODE_PORT:-8200} +PROXY_HOST=${PROXY_HOST:-"localhost"} +PROXY_PORT=${PROXY_PORT:-8192} +BASELINE_HOST=${BASELINE_HOST:-"localhost"} +BASELINE_PORT=${BASELINE_PORT:-9290} + + +# Model to run. +MODEL_NAME=${MODEL_NAME:-"meta-llama/Llama-3.2-3B-Instruct"} +MAX_MODEL_LEN=${MAX_MODEL_LEN:-1024} +BLOCK_SIZE=${BLOCK_SIZE:-32} + + +# execution env +GIT_ROOT=$(git rev-parse --show-toplevel) +EXP_ROOT="${GIT_ROOT}/tests/v1/kv_connector/nixl_integration" +CONDA_PATH=${CONDA_PATH:-"/home/${USER}/anaconda3"} +CONDA_ENV_NAME=${CONDA_ENV_NAME:-"nixl"} + +OUTPUT_FILE=${OUTPUT_FILE:-"${EXP_ROOT}/.tpu_accuracy_test_outputs.txt"} + +# Trap the SIGINT signal (triggered by Ctrl+C) +trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT + +# Waits for vLLM server to start. +wait_for_server() { + local host=$1 + local port=$2 + timeout 1200 bash -c " + until curl -s ${host}:${port}/v1/completions > /dev/null; do + sleep 1 + done" && return 0 || return 1 +} + +# Cleanup function +cleanup() { + echo "Caught Ctrl+C, cleaning up..." + # Cleanup commands + pgrep python | xargs kill -9 || true + # pkill -f python || true + echo "Cleanup complete. Exiting." +} + + +launch_pd() { + PREFILL_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME}; + UCX_TLS=tcp \ + VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S=200 \ + VLLM_LOGGING_LEVEL=DEBUG \ + VLLM_USE_V1=1 \ + VLLM_NIXL_SIDE_CHANNEL_HOST=${PREFILL_HOST} \ + VLLM_NIXL_SIDE_CHANNEL_PORT=${PREFILL_NIXL_SIDE_PORT} \ + PJRT_DEVICE=TPU \ + VLLM_WORKER_MULTIPROC_METHOD=spawn \ + VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve $MODEL_NAME \ + --host ${PREFILL_HOST} \ + --port ${PREFILL_PORT} \ + --max-model-len ${MAX_MODEL_LEN}\ + --seed 42 \ + --block-size ${BLOCK_SIZE} \ + --enforce-eager \ + --gpu-memory-utilization 0.5 \ + --disable-log-requests \ + --kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"cpu\"}'" + + + DECODE_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME}; + UCX_TLS=tcp \ + VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S=200 \ + VLLM_LOGGING_LEVEL=DEBUG \ + VLLM_USE_V1=1 \ + PJRT_DEVICE=TPU \ + VLLM_WORKER_MULTIPROC_METHOD=spawn \ + VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve $MODEL_NAME \ + --host ${DECODE_HOST} \ + --port ${DECODE_PORT} \ + --max-model-len ${MAX_MODEL_LEN}\ + --seed 42 \ + --block-size ${BLOCK_SIZE} \ + --enforce-eager \ + --gpu-memory-utilization 0.5 \ + --disable-log-requests \ + --kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"cpu\"}'" + + echo ${PREFILL_BASE_CMD} + echo ${DECODE_BASE_CMD} + sleep 2 + + # execute on hosts + ssh -tt ${PREFILL_HOST} "${PREFILL_BASE_CMD}" & + ssh -tt ${DECODE_HOST} "${DECODE_BASE_CMD}" & + sleep 1 + wait_for_server ${PREFILL_HOST} ${PREFILL_PORT} + sleep 1 + wait_for_server ${DECODE_HOST} ${DECODE_PORT} + sleep 1 +} + +launch_pd_proxy(){ + PROXY_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME}; + python3 ${EXP_ROOT}/toy_proxy_server.py \ + --prefiller-host ${PREFILL_HOST} --prefiller-port ${PREFILL_PORT} \ + --decoder-host ${DECODE_HOST} --decoder-port ${DECODE_PORT} \ + --host=${PROXY_HOST} --port ${PROXY_PORT}" + echo ${PROXY_BASE_CMD} + ssh -tt ${PROXY_HOST} "${PROXY_BASE_CMD}" & +} + + +# run disagg. & do exact-match with the outputs from baseline +launch_pd +launch_pd_proxy +sleep 10 + +PREFILL_HOST=${PREFILL_HOST} \ +PREFILL_PORT=${PREFILL_PORT} \ +DECODE_HOST=${DECODE_HOST} \ +DECODE_PORT=${DECODE_PORT} \ +PROXY_HOST=${PROXY_HOST} \ +PROXY_PORT=${PROXY_PORT} python -m pytest -s -v ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/test_edge_cases.py \ No newline at end of file diff --git a/tests/v1/kv_connector/nixl_integration/test_disagg_accuracy.py b/tests/v1/kv_connector/nixl_integration/test_disagg_accuracy.py new file mode 100644 index 00000000000..00e62f351ce --- /dev/null +++ b/tests/v1/kv_connector/nixl_integration/test_disagg_accuracy.py @@ -0,0 +1,162 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import argparse +import json +import os +import time + +import openai +import requests + +MAX_OUTPUT_LEN = 30 + +SAMPLE_PROMPTS = ( + "Red Hat is the best company in the world to work for because it works on " + "open source software, which means that all the contributions are " + "delivered to the community. As a result, when working on projects like " + "vLLM we are able to meet many amazing people from various organizations " + "like AMD, Google, NVIDIA, ", + "We hold these truths to be self-evident, that all men are created equal, " + "that they are endowed by their Creator with certain unalienable Rights, " + "that among these are Life, Liberty and the pursuit of Happiness.--That " + "to secure these rights, Governments are instituted among Men, deriving " + "their just powers from the consent of the governed, ", +) + + +def check_vllm_server(url: str, timeout=5, retries=3) -> bool: + """ + Checks if the vLLM server is ready by sending a GET request to the + /health endpoint. + + Args: + url (str): The base URL of the vLLM server. + timeout (int): Timeout in seconds for the request. + retries (int): Number of retries if the server is not ready. + + Returns: + bool: True if the server is ready, False otherwise. + """ + for attempt in range(retries): + try: + response = requests.get(url, timeout=timeout) + if response.status_code == 200: + return True + else: + print(f"Attempt {attempt + 1}: Server returned status code " + "{response.status_code}") + except requests.exceptions.RequestException as e: + print(f"Attempt {attempt + 1}: Error connecting to server: {e}") + time.sleep(1) # Wait before retrying + return False + + +def run_simple_prompt(base_url: str, model_name: str, + input_prompt: str) -> str: + client = openai.OpenAI(api_key="EMPTY", base_url=base_url) + completion = client.completions.create(model=model_name, + prompt=input_prompt, + max_tokens=MAX_OUTPUT_LEN, + temperature=0.0, + seed=42) + + # print("-" * 50) + # print(f"Completion results for {model_name}:") + # print(completion) + # print("-" * 50) + return completion.choices[0].text + + +def main(): + """ + This script demonstrates how to accept two optional string arguments + ("service_url" and "file_name") from the command line, each with a + default value of an empty string, using the argparse module. + """ + parser = argparse.ArgumentParser(description="vLLM client script") + + parser.add_argument( + "--service_url", # Name of the first argument + type=str, + required=True, + help="The vLLM service URL.") + + parser.add_argument( + "--model_name", # Name of the first argument + type=str, + required=True, + help="model_name", + ) + + parser.add_argument( + "--mode", # Name of the second argument + type=str, + default="baseline", + help="mode: baseline==non-disagg, or disagg", + ) + + parser.add_argument( + "--file_name", # Name of the second argument + type=str, + default=".vllm_output.txt", + help="the file that saves the output tokens ", + ) + + args = parser.parse_args() + + for arg in vars(args): + print(f"{arg}: {getattr(args, arg)}") + + if args.mode == "baseline": + # non-disagg + health_check_url = f"{args.service_url}/health" + else: + # disagg proxy + health_check_url = f"{args.service_url}/healthcheck" + if not os.path.exists(args.file_name): + raise ValueError( + f"In disagg mode, the output file {args.file_name} from " + "non-disagg. baseline does not exist.") + + service_url = f"{args.service_url}/v1" + + if not check_vllm_server(health_check_url): + raise RuntimeError( + f"vllm server: {args.service_url} is not ready yet!") + + output_strs = dict() + for prompt in SAMPLE_PROMPTS: + output_str = run_simple_prompt(base_url=service_url, + model_name=args.model_name, + input_prompt=prompt) + print(f"Prompt: {prompt}, output: {output_str}") + output_strs[prompt] = output_str + + if args.mode == "baseline": + # baseline: save outputs + try: + with open(args.file_name, 'w') as json_file: + json.dump(output_strs, json_file, indent=4) + except OSError as e: + print(f"Error writing to file: {e}") + raise + else: + # disagg. verify outputs + baseline_outputs = None + try: + with open(args.file_name) as json_file: + baseline_outputs = json.load(json_file) + except OSError as e: + print(f"Error writing to file: {e}") + raise + assert isinstance(baseline_outputs, dict) + assert len(baseline_outputs) == len(output_strs) + for prompt, output in baseline_outputs.items(): + assert prompt in output_strs, f"{prompt} not included" + assert output == output_strs[prompt], ( + f"baseline_output: {output} != PD output: {output_strs[prompt]}" + ) + + +if __name__ == "__main__": + main() diff --git a/tests/v1/kv_connector/nixl_integration/test_edge_cases.py b/tests/v1/kv_connector/nixl_integration/test_edge_cases.py index 95465a25fc9..8439e30be15 100644 --- a/tests/v1/kv_connector/nixl_integration/test_edge_cases.py +++ b/tests/v1/kv_connector/nixl_integration/test_edge_cases.py @@ -4,8 +4,11 @@ import openai +PREFILL_HOST = os.getenv("PREFILL_HOST", "localhost") PREFILL_PORT = os.getenv("PREFILL_PORT", None) +DECODE_HOST = os.getenv("DECODE_HOST", "localhost") DECODE_PORT = os.getenv("DECODE_PORT", None) +PROXY_HOST = os.getenv("PROXY_HOST", "localhost") PROXY_PORT = os.getenv("PROXY_PORT", None) if PREFILL_PORT is None or DECODE_PORT is None or PROXY_PORT is None: @@ -21,15 +24,15 @@ def test_edge_cases(): # Set the OpenAI API key and base URL decode_client = openai.OpenAI( api_key="MY_KEY", - base_url=f"http://localhost:{DECODE_PORT}/v1", + base_url=f"http://{DECODE_HOST}:{DECODE_PORT}/v1", ) prefill_client = openai.OpenAI( api_key="MY_KEY", - base_url=f"http://localhost:{PREFILL_PORT}/v1", + base_url=f"http://{PREFILL_HOST}:{PREFILL_PORT}/v1", ) proxy_client = openai.OpenAI( api_key="MY_KEY", - base_url=f"http://localhost:{PROXY_PORT}/v1", + base_url=f"http://{PROXY_HOST}:{PROXY_PORT}/v1", ) # Get the list of models diff --git a/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py b/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py index c58cb0286f1..66e237da0f8 100644 --- a/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py +++ b/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py @@ -3,6 +3,7 @@ import argparse import itertools +import logging import os import uuid from contextlib import asynccontextmanager @@ -11,9 +12,8 @@ from fastapi import FastAPI, Request from fastapi.responses import StreamingResponse -from vllm.logger import init_logger - -logger = init_logger(__name__) +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) @asynccontextmanager diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 9459ab27aba..a9bea1246e5 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -32,7 +32,7 @@ import enum from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Callable, Literal, Optional import torch @@ -46,6 +46,12 @@ from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.request import Request +# s_tensor_list, d_tensor_list, s_indices, d_indices, direction +CopyBlocksOp = Callable[[ + dict[str, torch.Tensor], dict[ + str, torch.Tensor], list[int], list[int], Literal["h2d", "d2h"] +], None] + logger = init_logger(__name__) @@ -127,6 +133,13 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """ return + def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp): + """ + Set the xPU-specific ops for copying KV between host and device. + Needed when host buffer is used for kv transfer (e.g., in NixlConnector) + """ + return + @abstractmethod def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 0c5986bfafa..c06cda356f5 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import contextlib +import logging import math import queue import threading @@ -20,14 +21,14 @@ from vllm.attention.selector import backend_name_to_enum, get_attn_backend from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) + CopyBlocksOp, KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) from vllm.distributed.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, get_tp_group) from vllm.distributed.utils import divide from vllm.forward_context import ForwardContext from vllm.logger import init_logger -from vllm.platforms import _Backend +from vllm.platforms import _Backend, current_platform from vllm.utils import make_zmq_path, make_zmq_socket, round_down from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.request import RequestStatus @@ -40,6 +41,7 @@ Transfer = tuple[int, float] # (xfer_handle, start_time) EngineId = str ReqId = str + GET_META_MSG = b"get_meta_msg" logger = init_logger(__name__) @@ -52,6 +54,13 @@ logger.warning("NIXL is not available") NixlWrapper = None +# Supported xPUs and types of kv transfer buffer. +# {xPU: tuple of supported kv buffer types} +_NIXL_SUPPORTED_XPUS = { + "cuda": ("cuda", ), + "tpu": ("cpu", ), +} + class NixlAgentMetadata( msgspec.Struct, @@ -80,6 +89,7 @@ class NixlConnectorMetadata(KVConnectorMetadata): def __init__(self): self.reqs_to_recv: dict[ReqId, ReqMeta] = {} + self.reqs_to_save: dict[ReqId, ReqMeta] = {} self.reqs_to_send: dict[ReqId, float] = {} def add_new_req( @@ -87,8 +97,12 @@ def add_new_req( request_id: ReqId, local_block_ids: list[int], kv_transfer_params: dict[str, Any], + load_remote_cache: bool = True, + save_to_host: bool = False, ): - self.reqs_to_recv[request_id] = ReqMeta( + # save and load are mutually exclusive + assert load_remote_cache ^ save_to_host + _req = ReqMeta( local_block_ids=local_block_ids, remote_block_ids=kv_transfer_params["remote_block_ids"], remote_engine_id=kv_transfer_params["remote_engine_id"], @@ -97,6 +111,10 @@ def add_new_req( # P workers don't need to receive tp_size from proxy here. tp_size=kv_transfer_params.get("tp_size", 1), ) + if save_to_host: + self.reqs_to_save[request_id] = _req + if load_remote_cache: + self.reqs_to_recv[request_id] = _req class NixlConnector(KVConnectorBase_V1): @@ -155,6 +173,10 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): assert self.connector_worker is not None self.connector_worker.register_kv_caches(kv_caches) + def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp): + assert self.connector_worker is not None + self.connector_worker.set_host_xfer_buffer_ops(copy_operation) + def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]: """Get the finished recving and sending requests.""" @@ -177,8 +199,11 @@ def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, pass def wait_for_save(self): - """NixlConnector does not save explicitly.""" - pass + assert self.connector_worker is not None + assert isinstance(self._connector_metadata, NixlConnectorMetadata) + if self.connector_worker.use_host_buffer and \ + self.connector_worker.copy_blocks: + self.connector_worker.save_kv_to_host(self._connector_metadata) class NixlConnectorScheduler: @@ -193,12 +218,15 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): envs.VLLM_NIXL_SIDE_CHANNEL_PORT + vllm_config.parallel_config.data_parallel_rank * vllm_config.parallel_config.tensor_parallel_size) + self.use_host_buffer = \ + vllm_config.kv_transfer_config.kv_buffer_device == "cpu" logger.info("Initializing NIXL Scheduler %s", engine_id) # Requests that need to start recv/send. # New requests are added by update_state_after_alloc in # the scheduler. Used to make metadata passed to Worker. self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {} + self._reqs_need_save: dict[ReqId, tuple[Request, list[int]]] = {} # Reqs to send and their expiration time self._reqs_need_send: dict[ReqId, float] = {} @@ -248,7 +276,25 @@ def update_state_after_alloc(self, request: "Request", "num_external_tokens=%s, kv_transfer_params=%s", num_external_tokens, params) - if params is not None and params.get("do_remote_prefill"): + if not params: + return + if self.use_host_buffer and params.get("do_remote_decode"): + # NOTE: when accelerator is not directly supported by Nixl, + # prefilled blocks need to be saved to host memory before transfer. + + # figure out full computed blocks to save + block_ids = blocks.get_block_ids()[0] + all_full = request.num_tokens % self.block_size == 0 + full_block_ids = (block_ids if all_full else block_ids[:-1]) + # TODO: skip the blocks that are already in the host xfer buffer. + # Currently, the host xfer buffer block is 1-to-1 mapped to device + # kv blocks, so host blocks won't be flushed as long as its device + # block is not overwritten; and it will be safe to skip saving them + # to host xfer buffer. + if full_block_ids: + self._reqs_need_save[request.request_id] = \ + (request, full_block_ids) + elif params.get("do_remote_prefill"): if params.get("remote_block_ids"): if all(p in params for p in ("remote_engine_id", "remote_host", "remote_port")): @@ -260,6 +306,7 @@ def update_state_after_alloc(self, request: "Request", # Get unhashed blocks to pull from remote. self._reqs_need_recv[request.request_id] = ( request, local_block_ids) + else: logger.warning( "Got invalid KVTransferParams: %s. This " @@ -284,10 +331,21 @@ def build_connector_meta( kv_transfer_params=req.kv_transfer_params, ) - # Clear the list once workers start the transfers - self._reqs_need_recv.clear() + for req_id, (req, block_ids) in self._reqs_need_save.items(): + assert req.kv_transfer_params is not None + meta.add_new_req( + request_id=req_id, + local_block_ids=block_ids, + kv_transfer_params=req.kv_transfer_params, + load_remote_cache=False, + save_to_host=True, + ) meta.reqs_to_send = self._reqs_need_send + + # Clear the list once workers start the transfers + self._reqs_need_recv.clear() + self._reqs_need_save.clear() self._reqs_need_send = {} return meta @@ -379,9 +437,36 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.tp_rank = get_tensor_model_parallel_rank() self.world_size = get_tensor_model_parallel_world_size() self.tp_group = get_tp_group() + self.num_blocks = 0 # KV Caches and nixl tracking data. - self.kv_caches: dict[str, torch.Tensor] = {} + self.device_type = current_platform.device_type + self.kv_buffer_device: str = \ + vllm_config.kv_transfer_config.kv_buffer_device + if self.device_type not in _NIXL_SUPPORTED_XPUS: + raise RuntimeError(f"{self.device_type} is not supported.") + elif self.kv_buffer_device not in _NIXL_SUPPORTED_XPUS[ + self.device_type]: + raise RuntimeError( + f"{self.device_type} with {self.kv_buffer_device} kv_buffer " + "is not supported.") + self.device_kv_caches: dict[str, torch.Tensor] = {} + + # cpu kv buffer for xfer + # used when xPU memory can not be registered under nixl + self.host_xfer_buffers: dict[str, torch.Tensor] = {} + self.use_host_buffer = self.kv_buffer_device == "cpu" + if self.kv_buffer_device == "cuda": + self.nixl_memory_type = "VRAM" + elif self.kv_buffer_device == "cpu": + self.nixl_memory_type = "DRAM" + else: + raise RuntimeError( + f"{self.device_type} with {self.kv_buffer_device} kv_buffer " + "is not supported.") + + # Note: host xfer buffer ops when use_host_buffer is True + self.copy_blocks: Optional[CopyBlocksOp] = None # Map of engine_id -> kv_caches_base_addr. For TP case, each local # rank will still only pull from a single remote TP worker. @@ -404,6 +489,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # In progress transfers. # [req_id -> list[handle]] + self._recving_metadata: dict[ReqId, ReqMeta] = {} self._recving_transfers = defaultdict[ReqId, list[Transfer]](list) # Track the expiration time of requests that are waiting to be sent. self._reqs_to_send: dict[ReqId, float] = {} @@ -440,6 +526,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.backend_name = backend.get_name() attn_backend = backend_name_to_enum(self.backend_name) self._use_flashinfer = attn_backend == _Backend.FLASHINFER_VLLM_V1 + self._use_pallas_v1 = attn_backend == _Backend.PALLAS_VLLM_V1 logger.debug("Detected attention backend %s", self.backend_name) self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size} @@ -529,6 +616,31 @@ def _nixl_handshake( # Remote rank -> agent name. return {p_remote_rank: remote_agent_name} + def initialize_host_xfer_buffer( + self, kv_caches: dict[str, torch.Tensor]) -> None: + """ + Initialize transfer buffer in CPU mem for accelerators + NOT directly supported by NIXL (e.g., tpu) + """ + xfer_buffers: dict[str, torch.Tensor] = {} + try: + for layer_name, kv_cache in kv_caches.items(): + kv_shape = kv_cache.shape + kv_dtype = kv_cache.dtype + xfer_buffers[layer_name] = torch.empty(kv_shape, + dtype=kv_dtype, + device="cpu") + except MemoryError as e: + logger.error("NIXLConnectorWorker gets %s.", e) + raise + + self.host_xfer_buffers = xfer_buffers + + def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp): + """Assign copy (d2h, h2d) operations when host buffer is used.""" + assert self.use_host_buffer + self.copy_blocks = copy_operation + def _background_nixl_handshake(self, req_id: str, remote_engine_id: EngineId, meta: ReqMeta): # Do NIXL handshake in background and add to _ready_requests when done. @@ -562,47 +674,76 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): _, first_kv_cache = next(iter(kv_caches.items())) kv_elem_size = first_kv_cache.element_size() + if self.use_host_buffer: + self.initialize_host_xfer_buffer(kv_caches=kv_caches) + assert len(self.host_xfer_buffers) == len(kv_caches), ( + f"host_buffer: {len(self.host_xfer_buffers)}, " + f"kv_caches: {len(kv_caches)}") + xfer_buffers = self.host_xfer_buffers + else: + xfer_buffers = kv_caches + assert not self.host_xfer_buffers, ( + "host_xfer_buffer should not be initialized when " + f"kv_buffer_device is {self.kv_buffer_device}") + # TODO(tms): Find a more robust way to detect and handle MLA # NOTE (NickLucche) To move blocks efficiently with NIXL, the expected # KV memory layout is HND, as opposed to the default NHD. Note that it # will only affects the strides. For MLA instead, we make require no # such thing and resort to the standard layout. use_mla = len(first_kv_cache.shape) == 3 - assert use_mla == self.use_mla - - # TODO (NickLucche) not compatible with hybrid allocator. Enforce check - # once it goes live, as a single kv layout is expected for xfers. - if use_mla: - # MLA case. + if self.device_type == "tpu": + assert not use_mla, f"{self.kv_buffer_device} does not support MLA." + assert self._use_pallas_v1, f"attn backend: {self.backend_name}" + # tpu (v1) kv shape per layer: + # (num_blocks, block_size, num_kv_heads * 2, head_size) self.num_blocks = first_kv_cache.shape[0] - block_rank = 2 # [block_size, latent_dim] + block_rank = 3 # [block_size, kv_heads, head_dim] block_shape = first_kv_cache.shape[-block_rank:] - block_size, kv_latent_dim = block_shape - self.slot_size_bytes = kv_elem_size * kv_latent_dim - else: - # [2 (k and v), num_blocks, ...] - if self._use_flashinfer: - # FlashInfer swaps 2<->num_blocks dimensions. + block_size, n_kv_heads_x_2, head_dim = block_shape + self.slot_size_bytes = kv_elem_size * n_kv_heads_x_2 * head_dim + elif self.device_type == "cuda": + assert use_mla == self.use_mla + # TODO (NickLucche) not compatible with hybrid allocator. + # Enforce check once it goes live, as a single kv layout + # is expected for xfers. + if use_mla: + # MLA case. self.num_blocks = first_kv_cache.shape[0] - block_rank = 4 # [2, block_size, kv_heads, head_dim] + block_rank = 2 # [block_size, latent_dim] + block_shape = first_kv_cache.shape[-block_rank:] + block_size, kv_latent_dim = block_shape + self.slot_size_bytes = kv_elem_size * kv_latent_dim else: - self.num_blocks = first_kv_cache.shape[1] - block_rank = 3 # [block_size, kv_heads, head_dim] - block_shape = first_kv_cache.shape[-block_rank:] - block_size, n_kv_heads, head_dim = block_shape[-3:] - # head size in bytes. - self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim - assert block_size == self.block_size + # [2 (k and v), num_blocks, ...] + if self._use_flashinfer: + # FlashInfer swaps 2<->num_blocks dimensions. + self.num_blocks = first_kv_cache.shape[0] + block_rank = 4 # [2, block_size, kv_heads, head_dim] + else: + self.num_blocks = first_kv_cache.shape[1] + block_rank = 3 # [block_size, kv_heads, head_dim] + block_shape = first_kv_cache.shape[-block_rank:] + block_size, n_kv_heads, head_dim = block_shape[-3:] + # head size in bytes. + self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim + assert block_size == self.block_size + else: + raise RuntimeError( + f"{self.device_type} ({self.backend_name}) is not supported.") + # TODO(tms): self.block_len needs to be per-layer for sliding window, # hybrid attn, etc # block size in bytes self.block_len = kv_elem_size * math.prod(block_shape) logger.info( - "Registering KV_Caches: use_mla: %s, num_blocks: %s, " - "block_shape: %s, per_layer_kv_cache_shape: %s", use_mla, - self.num_blocks, block_shape, first_kv_cache.shape) + "Registering KV_Caches. use_mla: %s, kv_buffer_device: %s, " + "use_host_buffer: %s, num_blocks: %s, block_shape: %s, " + "per_layer_kv_cache_shape: %s", use_mla, self.kv_buffer_device, + self.use_host_buffer, self.num_blocks, block_shape, + first_kv_cache.shape) self.dst_num_blocks[self.engine_id] = self.num_blocks - self.kv_caches = kv_caches + self.device_kv_caches = kv_caches kv_caches_base_addr = [] caches_data = [] @@ -614,19 +755,21 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # (roughly 8KB vs 5KB). # Conversely for FlashInfer, K and V are transferred in the same tensor # to better exploit the memory layout (ie num_blocks is the first dim). - for cache_or_caches in kv_caches.values(): + for cache_or_caches in xfer_buffers.values(): # Normalize to always be a list of caches - cache_list = [cache_or_caches] if use_mla or self._use_flashinfer \ - else cache_or_caches + cache_list = [cache_or_caches] if use_mla \ + or self._use_pallas_v1 or self._use_flashinfer \ + else cache_or_caches for cache in cache_list: base_addr = cache.data_ptr() region_len = self.num_blocks * self.block_len - caches_data.append( - (base_addr, region_len, cache.device.index, "")) + # NOTE: use tp_rank for device_id since multi-node TP + # is rarely used. + caches_data.append((base_addr, region_len, self.tp_rank, "")) kv_caches_base_addr.append(base_addr) self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr self.num_regions = len(caches_data) - self.num_layers = len(self.kv_caches.keys()) + self.num_layers = len(xfer_buffers.keys()) # TODO(mgoin): remove this once we have hybrid memory allocator # Optimization for models with local attention (Llama 4) @@ -648,7 +791,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.block_window_per_layer) assert len(self.block_window_per_layer) == self.num_layers - descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM") + descs = self.nixl_wrapper.get_reg_descs(caches_data, + self.nixl_memory_type) logger.debug("Registering descs: %s", caches_data) self.nixl_wrapper.register_memory(descs) logger.debug("Done registering descs") @@ -666,11 +810,13 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): block_offset = block_id * self.block_len addr = base_addr + block_offset # (addr, len, device id) + # TODO: does device_id matter to DRAM? blocks_data.append((addr, self.block_len, self.tp_rank)) logger.debug("Created %s blocks for src engine %s and rank %s", len(blocks_data), self.engine_id, self.tp_rank) - descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, + self.nixl_memory_type) # NIXL_INIT_AGENT to be used for preparations of local descs. self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist( "NIXL_INIT_AGENT", descs) @@ -755,6 +901,8 @@ def add_remote_agent(self, tp_ratio = divide(self._tp_size[self.engine_id], self._tp_size[engine_id]) assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP" + assert not self._use_pallas_v1 or tp_ratio == 1, \ + "TPU (pallas_v1) DOES NOT support heterogeneous TP yet." # Handle tp_size>num_kv_heads: replicate KV cache. total_num_kv_heads = self.model_config.get_total_num_kv_heads() @@ -813,13 +961,43 @@ def add_remote_agent(self, self.tp_rank) # Register with NIXL. - descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, + self.nixl_memory_type) self.dst_xfer_side_handles[ engine_id] = self.nixl_wrapper.prep_xfer_dlist( remote_agent_name, descs) return remote_agent_name + def sync_recved_kv_to_device(self, req_id: str, meta: ReqMeta): + """copy recved kv from host buffer to device.""" + assert self.use_host_buffer + assert self.copy_blocks is not None + + local_block_ids = meta.local_block_ids + self.copy_blocks(self.host_xfer_buffers, self.device_kv_caches, + local_block_ids, local_block_ids, "h2d") + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "synced recved kv of request[%s] to device kv buffer," + "local_block_ids: %s. ", req_id, + ",".join(map(str, meta.local_block_ids))) + + def save_kv_to_host(self, metadata: NixlConnectorMetadata): + """copy kv from device to host buffer.""" + assert self.use_host_buffer + assert self.copy_blocks is not None + + for req_id, meta in metadata.reqs_to_save.items(): + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "save_load_kv for request[%s] to host xfer buffer." + "local_block_ids: %s. ", req_id, + ",".join(map(str, meta.local_block_ids))) + # blocking + self.copy_blocks(self.device_kv_caches, self.host_xfer_buffers, + meta.local_block_ids, meta.local_block_ids, "d2h") + def get_finished(self) -> tuple[set[str], set[str]]: """ Get requests that are done sending or recving on this specific worker. @@ -834,6 +1012,12 @@ def get_finished(self) -> tuple[set[str], set[str]]: "and %s requests done recving", self.tp_rank, len(done_sending), len(done_recving)) + if self.use_host_buffer: + for req_id in done_recving: + meta = self._recving_metadata.pop(req_id) + assert meta, f"{req_id} not found in recving_metadata list" + self.sync_recved_kv_to_device(req_id, meta) + # Handle timeout to avoid stranding blocks on remote. now = time.perf_counter() while self._reqs_to_send: @@ -904,6 +1088,8 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): "Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id, remote_engine_id, len(meta.local_block_ids), len(meta.remote_block_ids)) + if self.use_host_buffer: + self._recving_metadata[req_id] = meta if remote_engine_id not in self._remote_agents: # Initiate handshake with remote engine to exchange metadata. with self._handshake_lock: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index af216539c90..fb3f3625492 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -23,12 +23,10 @@ from vllm.distributed.eplb.eplb_state import EplbState from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) -from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.distributed.parallel_state import ( get_pp_group, get_tp_group, graph_capture, is_global_first_rank, prepare_communication_buffer_for_model) -from vllm.forward_context import (DPMetadata, get_forward_context, - set_forward_context) +from vllm.forward_context import DPMetadata, set_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding @@ -64,6 +62,8 @@ from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch +from vllm.v1.worker.kv_connector_model_runner_mixin import ( + KVConnectorModelRunnerMixin) from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from ..sample.logits_processor import LogitsProcessorManager @@ -86,7 +86,7 @@ logger = init_logger(__name__) -class GPUModelRunner(LoRAModelRunnerMixin): +class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def __init__( self, @@ -1673,27 +1673,6 @@ def propose_draft_token_ids( spec_token_ids = draft_token_ids.tolist() return spec_token_ids - @staticmethod - def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"): - # Update KVConnector with the KVConnector metadata forward(). - if has_kv_transfer_group(): - kv_connector = get_kv_transfer_group() - assert isinstance(kv_connector, KVConnectorBase_V1) - assert scheduler_output.kv_connector_metadata is not None - kv_connector.bind_connector_metadata( - scheduler_output.kv_connector_metadata) - - # Background KV cache transfers happen here. - # These transfers are designed to be async and the requests - # involved may be disjoint from the running requests. - # Do this here to save a collective_rpc. - kv_connector.start_load_kv(get_forward_context()) - - @staticmethod - def maybe_wait_for_kv_save() -> None: - if has_kv_transfer_group(): - get_kv_transfer_group().wait_for_save() - def propose_ngram_draft_token_ids( self, sampled_token_ids: list[list[int]], diff --git a/vllm/v1/worker/kv_connector_model_runner_mixin.py b/vllm/v1/worker/kv_connector_model_runner_mixin.py new file mode 100644 index 00000000000..850c8a97d72 --- /dev/null +++ b/vllm/v1/worker/kv_connector_model_runner_mixin.py @@ -0,0 +1,42 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Define KV connector functionality mixin for model runners. +""" +from typing import TYPE_CHECKING + +from vllm.distributed.kv_transfer import (get_kv_transfer_group, + has_kv_transfer_group) +from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 +from vllm.forward_context import get_forward_context +from vllm.logger import init_logger + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + +logger = init_logger(__name__) + + +# Defined as a kv connector functionality mixin for ModelRunner (GPU, TPU) +class KVConnectorModelRunnerMixin: + + @staticmethod + def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"): + # Update KVConnector with the KVConnector metadata forward(). + if has_kv_transfer_group(): + kv_connector = get_kv_transfer_group() + assert isinstance(kv_connector, KVConnectorBase_V1) + assert scheduler_output.kv_connector_metadata is not None + kv_connector.bind_connector_metadata( + scheduler_output.kv_connector_metadata) + + # Background KV cache transfers happen here. + # These transfers are designed to be async and the requests + # involved may be disjoint from the running requests. + # Do this here to save a collective_rpc. + kv_connector.start_load_kv(get_forward_context()) + + @staticmethod + def maybe_wait_for_kv_save() -> None: + if has_kv_transfer_group(): + get_kv_transfer_group().wait_for_save() \ No newline at end of file diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index ad62d204381..84775658139 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -3,7 +3,7 @@ import bisect import gc import time -from typing import TYPE_CHECKING, Any, Optional, cast +from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast from unittest.mock import patch import numpy as np @@ -20,6 +20,8 @@ from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher from vllm.config import (ParallelConfig, VllmConfig, get_layers_from_vllm_config, update_config) +from vllm.distributed.kv_transfer import (get_kv_transfer_group, + has_kv_transfer_group) from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.lora.layers import BaseLayerWithLoRA @@ -43,6 +45,8 @@ LogprobsTensors, ModelRunnerOutput) from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler +from vllm.v1.worker.kv_connector_model_runner_mixin import ( + KVConnectorModelRunnerMixin) from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from vllm.v1.worker.tpu_input_batch import CachedRequestState, InputBatch @@ -94,7 +98,7 @@ # The dummy_run should be comprehensive, ensuring all potential input shapes and # branch predictions are included as subgraph inputs to facilitate # pre-compilation. -class TPUModelRunner(LoRAModelRunnerMixin): +class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def __init__( self, @@ -959,6 +963,10 @@ def execute_model( # Update cached state self._update_states(scheduler_output) if not scheduler_output.total_num_scheduled_tokens: + if has_kv_transfer_group(): + with set_forward_context(None, self.vllm_config): + self.maybe_setup_kv_connector(scheduler_output) + # Return empty ModelRunnerOutput if there's no work to do. return EMPTY_MODEL_RUNNER_OUTPUT @@ -974,6 +982,12 @@ def execute_model( start_index = 0 combined_selected_tokens: list[torch.Tensor] = [] combined_logprobs: list[LogprobsLists] = [] + + # NOTE: setup current batch's metadata for kv connector. + # Currently, only verified with NixlConnector + with set_forward_context(None, self.vllm_config): + self.maybe_setup_kv_connector(scheduler_output) + while start_index < self.input_batch.num_reqs: attn_metadata, logits_indices, padded_num_reqs, num_reqs,\ end_index = self._prepare_inputs(scheduler_output, start_index) @@ -1020,6 +1034,12 @@ def execute_model( start_index = end_index + # NOTE: current kv load and save get h2d/d2h copies involved. + # Those copies are blocking. Once they become async., kv_save + # should be called right after each single forward pass, + # instead of the forwards of the entire input batch. + self.maybe_wait_for_kv_save() + selected_token_ids = torch.cat(combined_selected_tokens, dim=0) if tpu_sampling_metadata.logprobs: @@ -1624,6 +1644,10 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: for cache in self.kv_caches: xs.mark_sharding(cache, self.mesh, (None, 'x', None, None)) + if has_kv_transfer_group(): + get_kv_transfer_group().register_kv_caches(kv_caches) + get_kv_transfer_group().set_host_xfer_buffer_ops(copy_kv_blocks) + def reset_dynamo_cache(self): if self.is_multimodal_model: compiled_model = self.model.get_language_model().model @@ -1838,6 +1862,75 @@ def _get_padded_token_len(paddings: list[int], x: int) -> int: return paddings[index] +def _make_src_and_dst_indices( + src_block_ids: list[int], + dst_block_ids: list[int], + src_device: Union[torch.device, str], + dst_device: Union[torch.device, str], +) -> tuple[torch.Tensor, torch.Tensor]: + src_indices = torch.tensor(src_block_ids, + device=src_device, + dtype=torch.int64) + dst_indices = torch.tensor(dst_block_ids, + device=dst_device, + dtype=torch.int64) + return src_indices, dst_indices + + +@torch.compile(backend="openxla") +def _insert_blocks_to_tpu( + cpu_cache: torch.Tensor, + tpu_cache: torch.Tensor, + cpu_block_indices: torch.Tensor, + tpu_block_indices: torch.Tensor, +) -> None: + torch.ops.xla.dynamo_set_buffer_donor_(tpu_cache, True) + tpu_cache[tpu_block_indices] = cpu_cache[cpu_block_indices].to( + tpu_cache.device) + + +@torch.compile(backend="openxla") +def _swap_out_tpu_blocks( + tpu_cache: torch.Tensor, + cpu_cache: torch.Tensor, + tpu_block_indices: torch.Tensor, + cpu_block_indices: torch.Tensor, +) -> None: + """ tpu blocks to cpu blocks""" + torch.ops.xla.dynamo_set_buffer_donor_(tpu_cache, True) + cpu_cache[cpu_block_indices] = tpu_cache[tpu_block_indices].cpu() + + +def copy_kv_blocks( + src_kv_caches: dict[str, torch.Tensor], + dst_kv_caches: dict[str, torch.Tensor], + src_block_ids: list[int], + dst_block_ids: list[int], + direction: Literal["h2d", "d2h"], +) -> None: + """Copy kv blocks between different buffers.""" + if not src_kv_caches or not dst_kv_caches or \ + not src_block_ids or not dst_block_ids or \ + len(src_block_ids) != len(dst_block_ids): + return + + src_device = next(iter(src_kv_caches.values())).device + dst_device = next(iter(dst_kv_caches.values())).device + + src_indices, dst_indices = _make_src_and_dst_indices( + src_block_ids=src_block_ids, + dst_block_ids=dst_block_ids, + src_device=src_device, + dst_device=dst_device) + + _copy_fn = _insert_blocks_to_tpu if direction == "h2d" else \ + _swap_out_tpu_blocks + for layer_name in src_kv_caches: + src_tensor = src_kv_caches[layer_name] + dst_tensor = dst_kv_caches[layer_name] + _copy_fn(src_tensor, dst_tensor, src_indices, dst_indices) + + def _get_padded_num_kv_cache_update_slices( num_tokens: int, max_num_reqs: int, page_size: int, num_slices_per_kv_cache_update_block: int) -> int: diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index c4bf40d6654..6f6c7773273 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """A TPU worker class.""" +import copy import os from typing import Any, Optional @@ -12,9 +13,12 @@ import torch_xla.runtime as xr import vllm.envs as envs -from vllm.config import ParallelConfig, VllmConfig +from vllm.config import VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) +from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized, + get_kv_transfer_group, + has_kv_transfer_group) from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed @@ -24,7 +28,7 @@ from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheConfig, KVCacheSpec) -from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput from vllm.v1.utils import report_usage_stats from vllm.v1.worker.tpu_model_runner import TPUModelRunner from vllm.v1.worker.utils import bind_kv_cache @@ -118,7 +122,7 @@ def init_device(self): # Initialize the distributed environment. self._init_tpu_worker_distributed_environment( - self.parallel_config, self.rank, self.distributed_init_method, + self.vllm_config, self.rank, self.distributed_init_method, self.local_rank) # Device initialization should happen after initializing @@ -242,6 +246,24 @@ def execute_model( scheduler_output: "SchedulerOutput", ) -> Optional[ModelRunnerOutput]: output = self.model_runner.execute_model(scheduler_output) + assert isinstance(output, ModelRunnerOutput) + if has_kv_transfer_group(): + finished_sending, finished_recving = ( + get_kv_transfer_group().get_finished( + scheduler_output.finished_req_ids)) + if finished_sending or finished_recving: + if output is EMPTY_MODEL_RUNNER_OUTPUT: + output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) + output.finished_sending = finished_sending + output.finished_recving = finished_recving + + # Clear KVConnector state for this step. + get_kv_transfer_group().clear_connector_metadata() + + # with a connector, the scheduler expects output from all workers + return output + + # return output only from the driver worker return output if self.is_driver_worker else None def profile(self, is_start: bool = True): @@ -288,7 +310,7 @@ def check_health(self) -> None: def _init_tpu_worker_distributed_environment( self, - parallel_config: ParallelConfig, + vllm_config: VllmConfig, rank: int, distributed_init_method: Optional[str] = None, local_rank: int = -1, @@ -300,6 +322,7 @@ def _init_tpu_worker_distributed_environment( # the input objects on CPU. The all-reduce and all-gather ops on TPU # are invoked by `xm.all_reduce` and `xm.all_gather` which use their # own context. + parallel_config = vllm_config.parallel_config init_distributed_environment( world_size=parallel_config.world_size, rank=rank, @@ -311,6 +334,8 @@ def _init_tpu_worker_distributed_environment( parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) + ensure_kv_transfer_initialized(vllm_config) + try: from tpu_commons.worker import TPUWorker as TPUCommonsWorker