From efbd79129ec4e52ce7038f4448395ee09649df04 Mon Sep 17 00:00:00 2001 From: Will Eaton Date: Mon, 2 Jun 2025 13:56:44 -0400 Subject: [PATCH 01/33] initial noodling Signed-off-by: Will Eaton --- .gitignore | 1 + vllm/config.py | 10 +- .../kv_transfer/kv_connector/v1/base.py | 75 ++++++- .../kv_connector/v1/nixl_connector.py | 203 ++++++++---------- vllm/engine/llm_engine.py | 1 + vllm/entrypoints/openai/api_server.py | 5 + vllm/executor/uniproc_executor.py | 4 + vllm/utils.py | 86 +++++++- vllm/v1/core/sched/scheduler.py | 2 + vllm/v1/engine/core.py | 42 +++- vllm/v1/executor/abstract.py | 6 +- vllm/v1/utils.py | 10 + vllm/v1/worker/gpu_worker.py | 38 +++- 13 files changed, 342 insertions(+), 141 deletions(-) diff --git a/.gitignore b/.gitignore index e49d1d6ba61..45ad4584f4b 100644 --- a/.gitignore +++ b/.gitignore @@ -202,3 +202,4 @@ shellcheck*/ # Ingore moe/marlin_moe gen code csrc/moe/marlin_moe_wna16/kernel_* +uv.lock diff --git a/vllm/config.py b/vllm/config.py index 32ef83a1866..6536afa007d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -10,7 +10,7 @@ import textwrap import uuid import warnings -from collections import Counter +from collections import Counter, defaultdict from contextlib import contextmanager from dataclasses import (MISSING, Field, asdict, field, fields, is_dataclass, replace) @@ -33,6 +33,8 @@ import vllm.envs as envs from vllm import version from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorHandshakeMetadata) from vllm.logger import init_logger from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS, QuantizationMethods, @@ -1511,6 +1513,12 @@ class CacheConfig: num_cpu_blocks: Optional[int] = field(default=None, init=False) """The number of blocks to allocate for CPU memory.""" + transfer_handshake_metadata: dict[int, dict[int, + KVConnectorHandshakeMetadata]] = field( + default_factory=lambda: defaultdict(dict), + init=False) + """Metadata for the KV connector handshake.""" + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index f80b5eba235..7b9b2ca1036 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -32,9 +32,11 @@ import enum from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Callable, Optional import torch +import msgspec +from pydantic_core import core_schema from vllm.logger import init_logger from vllm.v1.core.sched.output import SchedulerOutput @@ -62,18 +64,58 @@ class KVConnectorMetadata: Abstract Metadata used to communicate between the Scheduler KVConnector and Worker KVConnector. """ - pass + + def __init__(self): + pass + +class KVConnectorHandshakeMetadata( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + # required for @cached_property. + dict=True): + """ + Metadata optionally used for out of band connector handshake between P/D workers. + """ + connector_type: str = "base" + + @classmethod + def __get_pydantic_core_schema__( + cls, + _source_type: Any, + _handler: Callable[[Any], core_schema.CoreSchema] + ) -> core_schema.CoreSchema: + """bridge msgspec.Struct with pydantic for schema generation""" + return core_schema.no_info_after_validator_function( + cls, + core_schema.dict_schema() + ) + +class KVConnectorTransferMetadata( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + dict=True): + """ + Wrapper for transfer handshake metadata sent between engine and utils. + """ + tensor_parallel_rank: int + data_parallel_rank: int + content: Optional[dict] + class KVConnectorBase_V1(ABC): - def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + def __init__(self, + vllm_config: "VllmConfig", + role: KVConnectorRole): logger.warning( "Initializing KVConnectorBase_V1. This API is experimental and " "subject to change in the future as we iterate the design.") self._connector_metadata = KVConnectorMetadata() self._vllm_config = vllm_config self._role = role + self._handshake_metadata: Optional[KVConnectorHandshakeMetadata] = None + @property def role(self) -> KVConnectorRole: @@ -104,7 +146,7 @@ def clear_connector_metadata(self) -> None: """ self._connector_metadata = KVConnectorMetadata() - def _get_connector_metadata(self) -> KVConnectorMetadata: + def get_connector_metadata(self) -> KVConnectorMetadata: """Get the connector metadata. This function should only be called inside the connector. @@ -201,6 +243,31 @@ def get_finished( """ return None, None + def set_handshake_metadata( + self, handshake_metadata: KVConnectorHandshakeMetadata) -> None: + """ + Set the handshake metadata for the connector. + + This metadata is used for out-of-band connector handshake + between P/D workers. + + Args: + handshake_metadata (KVConnectorHandshakeMetadata): the handshake + metadata. + """ + self._handshake_metadata = handshake_metadata + + + def get_handshake_metadata( + self) -> Optional[KVConnectorHandshakeMetadata]: + """ + Get the handshake metadata for the connector. + + Returns: + KVConnectorHandshakeMetadata: the handshake metadata. + """ + return self._handshake_metadata + # ============================== # Scheduler-side methods # ============================== diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 7552fc889f2..c1211ab03ca 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -1,30 +1,32 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import contextlib +import asyncio import math import threading import time import uuid from collections import defaultdict -from collections.abc import Iterator from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Optional - +import json +import base64 +import aiohttp import msgspec import torch -import zmq from vllm import envs from vllm.attention.selector import backend_name_to_enum, get_attn_backend from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) + KVConnectorBase_V1, KVConnectorHandshakeMetadata, KVConnectorMetadata, + KVConnectorRole) from vllm.distributed.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, get_tp_group) from vllm.logger import init_logger from vllm.platforms import _Backend -from vllm.utils import make_zmq_path, make_zmq_socket, round_down +from vllm.utils import build_uri, round_down + from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.request import RequestStatus @@ -48,11 +50,8 @@ NixlWrapper = None -class NixlAgentMetadata( - msgspec.Struct, - omit_defaults=True, # type: ignore[call-arg] - # required for @cached_property. - dict=True): +class NixlAgentMetadata(KVConnectorHandshakeMetadata, kw_only=True): + connector_type: str = "nixl" engine_id: str agent_metadata: bytes kv_caches_base_addr: list[int] @@ -94,6 +93,8 @@ def add_new_req( class NixlConnector(KVConnectorBase_V1): def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): + super().__init__(vllm_config, role) + assert vllm_config.kv_transfer_config is not None self.engine_id = vllm_config.kv_transfer_config.engine_id @@ -146,6 +147,10 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): assert self.connector_worker is not None self.connector_worker.register_kv_caches(kv_caches) + # Set handshake metadata using the base class method + if hasattr(self.connector_worker, 'xfer_metadata'): + self.set_handshake_metadata(self.connector_worker.xfer_metadata) + def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]: """Get the finished recving and sending requests.""" @@ -171,6 +176,12 @@ def wait_for_save(self): """NixlConnector does not save explicitly.""" pass + def set_handshake_metadata(self, handshake_metadata): + logger.debug("Setting handshake metadata for NIXL connector: %s", + handshake_metadata) + assert self.connector_worker is not None + return super().set_handshake_metadata(handshake_metadata) + class NixlConnectorScheduler: """Implementation of Scheduler side methods""" @@ -179,11 +190,14 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.vllm_config = vllm_config self.block_size = vllm_config.cache_config.block_size self.engine_id = engine_id + + logger.debug("Block size for NIXL connector: %s", self.block_size) + + # FIXME: This is a temporary fix to get the side channel host and port. self.side_channel_host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST - self.side_channel_port = ( - envs.VLLM_NIXL_SIDE_CHANNEL_PORT + - vllm_config.parallel_config.data_parallel_rank_local * - vllm_config.parallel_config.tensor_parallel_size) + # This needs to be the same port that the VLLM webserver is running on. + self.side_channel_port = envs.VLLM_NIXL_SIDE_CHANNEL_PORT + logger.info("Initializing NIXL Scheduler %s", engine_id) # Requests that need to start recv. @@ -333,15 +347,6 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}. self._remote_agents: dict[str, dict[int, str]] = defaultdict(dict) - # NIXL handshake port. - # NOTE(rob): Within a DP group, each DP rank gets its own - # base port (which is sent in the KVTransferParams). - # Each TP rank listens/queries on the base_port + tp_rank. - self.side_channel_port = ( - envs.VLLM_NIXL_SIDE_CHANNEL_PORT + - vllm_config.parallel_config.data_parallel_rank_local * - vllm_config.parallel_config.tensor_parallel_size) - # Metadata. self.engine_id = engine_id self.tp_rank = get_tensor_model_parallel_rank() @@ -412,75 +417,63 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # finish reading before safely freeing the blocks. self.consumer_notification_counts_by_req = defaultdict[str, int](int) - @staticmethod - def _nixl_handshake_listener(metadata: NixlAgentMetadata, - ready_event: threading.Event, base_port: int, - tp_rank: int): - """Background thread for getting new NIXL handshakes.""" - # NOTE(rob): this is a simple implementation. We will move - # to a better approach via HTTP endpoint soon. - - encoder = msgspec.msgpack.Encoder() - encoded_data = encoder.encode(metadata) - size_in_bytes = len(encoded_data) - logger.debug("Size of encoded NixlAgentMetadata: %s bytes", - str(size_in_bytes)) - - # Listen for new requests for metadata. - host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST - path = make_zmq_path("tcp", host, base_port + tp_rank) - logger.debug("Starting listening on path: %s", path) - with zmq_ctx(zmq.ROUTER, path) as sock: - ready_event.set() - while True: - identity, _, msg = sock.recv_multipart() - if msg != GET_META_MSG: - logger.warning( - "Connection listener got unexpected message %s", msg) - sock.send_multipart((identity, b"", encoded_data)) - - def _nixl_handshake(self, host: str, port: int): + async def _nixl_handshake(self, host: str, port: int): """Do a NIXL handshake with a remote instance.""" start_time = time.perf_counter() - # NOTE(rob): we need each rank to have a unique port. This is - # a hack to keep us moving. We will switch when moving to etcd - # or where we have a single ZMQ socket in the scheduler. - - def handshake(path: str, rank: int) -> NixlAgentMetadata: - # Send query for the request. - with zmq_ctx(zmq.REQ, path) as sock: - sock.send(GET_META_MSG) - metadata_bytes = sock.recv() - decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) - metadata = decoder.decode(metadata_bytes) - got_metadata_time = time.perf_counter() - - # Register Remote agent. - self.add_remote_agent(metadata, rank) - setup_agent_time = time.perf_counter() - - logger.debug("NIXL handshake: get metadata took: %s", - got_metadata_time - start_time) - logger.debug("NIXL handshake: add agent took: %s", - setup_agent_time - got_metadata_time) - return metadata - - # Handshake with remote agent-rank0 first to get the tp_size of remote - path = make_zmq_path("tcp", host, port) - logger.debug("Querying master rank metadata on path: %s", path) - metadata = handshake(path, 0) + url = build_uri(host, port, path="get_kv_connector_metadata") + logger.debug("Querying metadata on path: %s", url) + + timeout = aiohttp.ClientTimeout(total=30.0) + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.get(url) as response: + res = await response.json() + logger.debug("NIXL handshake response: %s", res) + + + remote_tp_size = len(res.keys()) + # Default case is that the remote TP size is 1, so we can + # directly access the metadata. + tp_data = res.get(str(self.tp_rank), {}).get("0", {}) + metadata_bytes = tp_data.get("agent_metadata", None) # Handshake only with the other TP remote the current local rank will # pull from. With homogeneous TP it happens to be the same rank_i. - tp_ratio = self._tp_size[self.engine_id] // metadata.tp_size + tp_ratio = self._tp_size[self.engine_id] // remote_tp_size p_remote_rank = self.tp_rank // tp_ratio if p_remote_rank > 0: - path = make_zmq_path("tcp", host, port + p_remote_rank) - logger.debug("Querying metadata on path: %s at remote rank %s", - path, p_remote_rank) - _ = handshake(path, p_remote_rank) + metadata_bytes = res.get(str(p_remote_rank), {}).get("0", {}).get( + "agent_metadata", None) + + if metadata_bytes is not None: + # Reconstruct NixlAgentMetadata from JSON response + # agent_metadata is base64-encoded binary data, not msgpack + metadata = NixlAgentMetadata( + engine_id=tp_data["engine_id"], + agent_metadata=base64.b64decode(metadata_bytes), + kv_caches_base_addr=tp_data["kv_caches_base_addr"], + num_blocks=tp_data["num_blocks"], + tp_size=tp_data["tp_size"], + block_len=tp_data["block_len"], + attn_backend_name=tp_data["attn_backend_name"], + ) + + # Register Remote agent. + self.add_remote_agent(metadata, p_remote_rank) + setup_agent_time = time.perf_counter() + + logger.debug("NIXL handshake: get metadata took: %s", + time.perf_counter() - start_time) + logger.debug("NIXL handshake: add agent took: %s", + setup_agent_time - (time.perf_counter() - start_time)) + else: + # If metadata_bytes is None, it means the remote agent + # is not using NIXL, so we can skip the handshake. + logger.warning( + "Received None metadata from %s:%s, skipping NIXL handshake", + host, port) + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data in nixl.""" @@ -558,6 +551,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # Optimization for models with local attention (Llama 4) if self.vllm_config.model_config.hf_config.model_type == "llama4": from transformers import Llama4TextConfig + assert isinstance(self.vllm_config.model_config.hf_text_config, Llama4TextConfig) llama4_config = self.vllm_config.model_config.hf_text_config @@ -601,23 +595,18 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist( "NIXL_INIT_AGENT", descs) - # After KV Caches registered, listen for new connections. - metadata = NixlAgentMetadata( + # Store metadata on worker instance for main connector access + self.xfer_metadata = NixlAgentMetadata( engine_id=self.engine_id, agent_metadata=self.nixl_wrapper.get_agent_metadata(), kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id], num_blocks=self.num_blocks, tp_size=self.world_size, block_len=self.block_len, - attn_backend_name=self.backend_name) - ready_event = threading.Event() - self._nixl_handshake_listener_t = threading.Thread( - target=self._nixl_handshake_listener, - args=(metadata, ready_event, self.side_channel_port, self.tp_rank), - daemon=True, - name="nixl_handshake_listener") - self._nixl_handshake_listener_t.start() - ready_event.wait() + attn_backend_name=self.backend_name + ) + + def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata, @@ -756,6 +745,9 @@ def get_finished(self) -> tuple[set[str], set[str]]: to Rank 0 once their transaction is done + Rank 0 returns finished sets to Scheduler only once all ranks are done. """ + # Process any completed handshakes first + self._process_completed_handshakes() + done_sending = self._get_new_notifs() done_recving = self._pop_done_transfers(self._recving_transfers) if len(done_sending) > 0 or len(done_recving) > 0: @@ -882,7 +874,7 @@ def _read_blocks( ): # NOTE(rob): this takes ~2s. We need to get this off the hotpath. if dst_engine_id not in self._remote_agents: - self._nixl_handshake(remote_host, remote_port) + asyncio.run(self._nixl_handshake(remote_host, remote_port)) # NOTE(rob): having the staging blocks be on the READER side is # not going to work well (since we will have to call rearrange tensors). @@ -1009,22 +1001,3 @@ def _get_block_descs_ids(self, for block_id in block_ids: descs_ids.append(reg_id * num_blocks + block_id) return descs_ids - - -@contextlib.contextmanager -def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]: - """Context manager for a ZMQ socket""" - - if socket_type not in (zmq.ROUTER, zmq.REQ): - raise ValueError(f"Unexpected socket type: {socket_type}") - - ctx: Optional[zmq.Context] = None - try: - ctx = zmq.Context() # type: ignore[attr-defined] - yield make_zmq_socket(ctx=ctx, - path=addr, - socket_type=socket_type, - bind=socket_type == zmq.ROUTER) - finally: - if ctx is not None: - ctx.destroy(linger=0) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 8fccf9bd2aa..dced747fae9 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -16,6 +16,7 @@ import torch from typing_extensions import TypeVar +from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorHandshakeMetadata import vllm.envs as envs from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, ObservabilityConfig, ParallelConfig, SchedulerConfig, diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 62f1c6a7c12..7a20eb68468 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -868,6 +868,11 @@ async def show_server_info(raw_request: Request): server_info = {"vllm_config": str(raw_request.app.state.vllm_config)} return JSONResponse(content=server_info) + @router.get("/get_kv_connector_metadata") + async def get_kv_connector_metadata(raw_request: Request): + kv_connector_metadata = raw_request.app.state.vllm_config.cache_config.transfer_handshake_metadata + return JSONResponse(content=kv_connector_metadata) + @router.post("/reset_prefix_cache") async def reset_prefix_cache(raw_request: Request): """ diff --git a/vllm/executor/uniproc_executor.py b/vllm/executor/uniproc_executor.py index 7ebeb4a2255..94d1f597bb2 100644 --- a/vllm/executor/uniproc_executor.py +++ b/vllm/executor/uniproc_executor.py @@ -57,6 +57,10 @@ def collective_rpc(self, answer = run_method(self.driver_worker, method, args, kwargs) return [answer] + def get_kv_connector_handshake_metadata(self) -> List[Optional[Dict]]: + """Get KV connector handshake metadata from all workers.""" + return self.collective_rpc("get_kv_connector_handshake_metadata") + def check_health(self) -> None: # UniProcExecutor will always be healthy as long as # it's running. diff --git a/vllm/utils.py b/vllm/utils.py index 342241d0dd8..1474bfc3a43 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -33,21 +33,48 @@ import uuid import warnings import weakref -from argparse import (Action, ArgumentDefaultsHelpFormatter, ArgumentParser, - ArgumentTypeError, RawDescriptionHelpFormatter, - _ArgumentGroup) +from argparse import ( + Action, + ArgumentDefaultsHelpFormatter, + ArgumentParser, + ArgumentTypeError, + RawDescriptionHelpFormatter, + _ArgumentGroup, +) from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task from collections import UserDict, defaultdict -from collections.abc import (AsyncGenerator, Awaitable, Collection, Generator, - Hashable, Iterable, Iterator, KeysView, Mapping) +from collections.abc import ( + AsyncGenerator, + Awaitable, + Collection, + Generator, + Hashable, + Iterable, + Iterator, + KeysView, + Mapping, +) from concurrent.futures.process import ProcessPoolExecutor from dataclasses import dataclass, field from functools import cache, lru_cache, partial, wraps from types import MappingProxyType -from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple, - Optional, Sequence, Tuple, Type, TypeVar, Union, cast, - overload) -from urllib.parse import urlparse +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generic, + Literal, + NamedTuple, + Optional, + Sequence, + Tuple, + Type, + TypeVar, + Union, + cast, + overload, +) +from urllib.parse import urlparse, urlunparse from uuid import uuid4 import cachetools @@ -67,6 +94,7 @@ from typing_extensions import Never, ParamSpec, TypeIs, assert_never import vllm.envs as envs + # NOTE: import triton_utils to make TritonPlaceholderModule work # if triton is unavailable import vllm.triton_utils # noqa: F401 @@ -2917,3 +2945,43 @@ def is_torch_equal_or_newer(target: str) -> bool: except Exception: # Fallback to PKG-INFO to load the package info, needed by the doc gen. return Version(importlib.metadata.version('torch')) >= Version(target) + + +def build_uri( + scheme: str, + host: str, + port: Optional[int] = None, + path: str = "", + params: str = "", + query: str = "", + fragment: str = "" +) -> str: + """ + Robustly build a URI that properly handles IPv6 addresses. + + Args: + scheme: URI scheme (e.g., 'http', 'https') + host: hostname or IP address + port: port number (optional) + path: path component + params: parameters component + query: query string + fragment: fragment identifier + + Returns: + Complete URI string + """ + # Handle IPv6 addresses + if host: + try: + # Check if it's an IPv6 address + ip = ipaddress.ip_address(host) + # Ensure IPv6 addresses are bracketed + if (isinstance(ip, ipaddress.IPv6Address) and + not (host.startswith('[') and host.endswith(']'))): + host = f'[{host}]' + except ValueError: + pass + + netloc = f"{host}:{port}" if port else host + return urlunparse((scheme, netloc, path, params, query, fragment)) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 3d7bbe7e0e3..4611c7bfe17 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -14,6 +14,8 @@ KVConnectorFactory) from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1, KVConnectorRole) +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorMetadata) from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index f36a491a197..5610d1fed2b 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -79,12 +79,15 @@ def __init__(self, executor_fail_callback) # Setup KV Caches and update CacheConfig after profiling. - num_gpu_blocks, num_cpu_blocks, kv_cache_config = \ - self._initialize_kv_caches(vllm_config) + num_gpu_blocks, num_cpu_blocks, kv_cache_config, \ + transfer_handshake_metadata = self._initialize_kv_caches(vllm_config) vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks + # Store KV connector metadata for handshake + self.transfer_handshake_metadata = transfer_handshake_metadata + self.structured_output_manager = StructuredOutputManager(vllm_config) # Setup scheduler. @@ -130,7 +133,8 @@ def __init__(self, self.batch_queue = queue.Queue(self.batch_queue_size) def _initialize_kv_caches( - self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]: + self, vllm_config: VllmConfig + ) -> tuple[int, int, KVCacheConfig, Optional[list[Optional[dict]]]]: start = time.time() # Get all kv cache needed by the model @@ -167,10 +171,15 @@ def _initialize_kv_caches( # Initialize kv cache and warmup the execution self.model_executor.initialize_from_config(kv_cache_configs) + # Collect KV connector xfer metadata from workers (after KV cache registration) + transfer_handshake_metadata = ( + self.model_executor.get_kv_connector_handshake_metadata()) + elapsed = time.time() - start logger.info(("init engine (profile, create kv cache, " "warmup model) took %.2f seconds"), elapsed) - return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config + return (num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config, + transfer_handshake_metadata) def add_request(self, request: EngineCoreRequest): """Add request to the scheduler.""" @@ -432,12 +441,25 @@ def _perform_handshake( # Send ready message. num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks - handshake_socket.send( - msgspec.msgpack.encode({ - "status": "READY", - "local": on_head_node, - "num_gpu_blocks": num_gpu_blocks, - })) + handshake_message = { + "status": "READY", + "local": on_head_node, + "num_gpu_blocks": num_gpu_blocks, + } + + # Include KV connector metadata if available + if hasattr(self, + 'transfer_handshake_metadata') and self.transfer_handshake_metadata: + # self.transfer_handshake_metadata is a list of dicts from workers + # Each dict already has structure {tp_rank: {dp_rank: metadata}} + # Merge all worker dicts into a single dict + content = {} + for worker_dict in self.transfer_handshake_metadata: + if worker_dict is not None: + content.update(worker_dict) + handshake_message["transfer_handshake_metadata"] = content + + handshake_socket.send(msgspec.msgpack.encode(handshake_message)) @staticmethod def startup_handshake( diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index 50b9634a49e..8df5dbec580 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from concurrent.futures import Future -from typing import Callable, Union +from typing import Callable, Optional, Union import torch import torch.distributed as dist @@ -80,6 +80,10 @@ def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]: output = self.collective_rpc("get_kv_cache_spec") return output + def get_kv_connector_handshake_metadata(self) -> list[Optional[dict[int, dict[int, dict]]]]: + output = self.collective_rpc("get_kv_connector_handshake_metadata") + return output + def execute_model( self, scheduler_output, diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 192c9067740..705d62404bf 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -536,6 +536,16 @@ def wait_for_engine_startup( num_gpu_blocks = cache_config.num_gpu_blocks or 0 num_gpu_blocks += msg["num_gpu_blocks"] cache_config.num_gpu_blocks = num_gpu_blocks + # stash KV connector metadata in vllm_config if passed in. + if "transfer_handshake_metadata" in msg and msg["transfer_handshake_metadata"]: + logger.debug( + "Received transfer handshake metadata from engine %s: %s", + eng_index, msg["transfer_handshake_metadata"]) + # Merge the received metadata with existing cache config + for tp_rank, dp_dict in msg["transfer_handshake_metadata"].items(): + for dp_rank, metadata in dp_dict.items(): + cache_config.transfer_handshake_metadata[tp_rank][ + dp_rank] = metadata start_pending[0 if local else 1] -= 1 engine.state = CoreEngineState.READY diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index b7d244f2704..805d4bf1483 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -5,17 +5,22 @@ import os from typing import TYPE_CHECKING, Optional +import msgspec import torch import torch.distributed import torch.nn as nn +from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole +from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorBase_V1 import vllm.envs as envs from vllm.config import VllmConfig from vllm.device_allocator.cumem import CuMemAllocator from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce) -from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized +from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized, + is_v1_kv_transfer_group, + get_kv_transfer_group) from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -230,6 +235,37 @@ def determine_available_memory(self) -> int: return int(available_kv_cache_memory) + def get_kv_connector_handshake_metadata(self) -> Optional[dict]: + """Get KV connector metadata from this worker if available.""" + + connector = get_kv_transfer_group() + if not is_v1_kv_transfer_group(connector): + logger.warning( + "The KV connector is not a v1 connector. " + "This method is only supported for v1 connectors.") + return None + + # Only return metadata if this is a worker role + if connector.role == KVConnectorRole.WORKER: + metadata = connector.get_handshake_metadata() + if metadata is None: + logger.warning( + "KV connector metadata is not available. " + "This may happen if the KV connector is not initialized " + "or the worker is not part of a disaggregated KV cache setup." + ) + return None + + tp_rank = get_tp_group().rank_in_group + dp_rank = self.vllm_config.parallel_config.data_parallel_rank_local + return { + tp_rank: { + dp_rank: msgspec.to_builtins(metadata) + } + } + + return None + def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: return self.model_runner.get_kv_cache_spec() From 9dd2c1c0f4d06ace6fa762905ba1132e885b8175 Mon Sep 17 00:00:00 2001 From: Will Eaton Date: Tue, 10 Jun 2025 16:21:02 -0400 Subject: [PATCH 02/33] attempt to background agent registration Signed-off-by: Will Eaton --- .../kv_connector/v1/nixl_connector.py | 188 +++++++++++++++--- vllm/v1/core/sched/scheduler.py | 3 - 2 files changed, 161 insertions(+), 30 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index c1211ab03ca..08554f03777 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -1,6 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 +<<<<<<< HEAD # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio +======= +import base64 +import json +>>>>>>> a1eaf5a5e (attempt to background agent registration) import math import threading import time @@ -8,10 +13,9 @@ from collections import defaultdict from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Optional -import json -import base64 -import aiohttp -import msgspec +from urllib.request import Request, urlopen +from urllib.error import URLError, HTTPError +from urllib.parse import urljoin import torch from vllm import envs @@ -379,6 +383,12 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # [req_id -> list[handle]] self._recving_transfers = defaultdict[str, list[Transfer]](list) + + # Pending requests waiting for handshake completion + # [engine_id -> list[(req_id, meta)]] + self._pending_requests: dict[str, list[tuple[str, ReqMeta]]] = {} + + # Complete transfer tracker. Used by the rank 0 to track finished # transactions on ranks 1 to N-1. # [req_id -> count] @@ -387,8 +397,11 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self._done_sending_count: defaultdict[str, int] = defaultdict(lambda: 0) - # Background thread for establishing new connections. - self._nixl_handshake_listener_t: Optional[threading.Thread] = None + # Background handshake threads for remote engines + self._handshake_threads: dict[str, threading.Thread] = {} + + # Thread results for handshake completion tracking + self._handshake_results: dict[str, bool] = {} self.vllm_config = vllm_config self.block_size = vllm_config.cache_config.block_size @@ -417,19 +430,45 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # finish reading before safely freeing the blocks. self.consumer_notification_counts_by_req = defaultdict[str, int](int) - async def _nixl_handshake(self, host: str, port: int): + def _run_handshake_in_thread(self, engine_id: str, host: str, port: int): + """Run handshake in background thread.""" + + def handshake_worker(): + logger.debug("Starting handshake worker for engine %s", engine_id) + try: + self._nixl_handshake(host, port) + self._handshake_results[engine_id] = True + logger.debug("Handshake succeeded for engine %s", engine_id) + except Exception as e: + self._handshake_results[engine_id] = False + logger.warning("Handshake failed for engine %s: %s", engine_id, e) + finally: + logger.debug("Handshake worker finished for engine %s", engine_id) + + thread = threading.Thread(target=handshake_worker, daemon=True) + thread._start_time = time.time() # track when thread started + self._handshake_threads[engine_id] = thread + thread.start() + return thread + + def _nixl_handshake(self, host: str, port: int): """Do a NIXL handshake with a remote instance.""" start_time = time.perf_counter() - url = build_uri(host, port, path="get_kv_connector_metadata") + # TODO: make the scheme dynamic, and/or implement https on both sides. + url = build_uri("http", host, port, path="get_kv_connector_metadata") logger.debug("Querying metadata on path: %s", url) - timeout = aiohttp.ClientTimeout(total=30.0) - async with aiohttp.ClientSession(timeout=timeout) as session: - async with session.get(url) as response: - res = await response.json() + try: + req = Request(url) + with urlopen(req, timeout=5.0) as response: + response_data = response.read().decode('utf-8') + res = json.loads(response_data) logger.debug("NIXL handshake response: %s", res) + except (URLError, HTTPError) as e: + logger.error("Failed to fetch metadata from %s: %s", url, e) + raise remote_tp_size = len(res.keys()) @@ -460,13 +499,18 @@ async def _nixl_handshake(self, host: str, port: int): ) # Register Remote agent. - self.add_remote_agent(metadata, p_remote_rank) - setup_agent_time = time.perf_counter() - - logger.debug("NIXL handshake: get metadata took: %s", - time.perf_counter() - start_time) - logger.debug("NIXL handshake: add agent took: %s", - setup_agent_time - (time.perf_counter() - start_time)) + logger.debug("About to register remote agent for engine %s", + metadata.engine_id) + pre_register = time.perf_counter() + self.add_remote_agent(metadata, remote_tp_rank=p_remote_rank) + agent_time = time.perf_counter() + logger.debug("Finished registering remote agent for engine %s", + metadata.engine_id) + + logger.debug("NIXL handshake: get metadata took: %s", + pre_register - start_time) + logger.debug("NIXL handshake: add agent took: %s", + agent_time - pre_register) else: # If metadata_bytes is None, it means the remote agent # is not using NIXL, so we can skip the handshake. @@ -755,6 +799,7 @@ def get_finished(self) -> tuple[set[str], set[str]]: "Rank %s, get_finished: %s requests done sending " "and %s requests done recving", self.tp_rank, len(done_sending), len(done_recving)) + if self.world_size == 1: return done_sending, done_recving @@ -843,39 +888,128 @@ def _pop_done_transfers( xfer_state) return done_req_ids + def _process_completed_handshakes(self): + """Process completed handshakes and mark remote agents as ready.""" + + # debug: log current state + if self._handshake_threads: + logger.debug("Processing handshakes: %d active threads, %d pending", + len(self._handshake_threads), + sum(len(reqs) for reqs in self._pending_requests.values())) + + completed_engines = [] + for engine_id, thread in list(self._handshake_threads.items()): + logger.debug("Checking handshake thread for engine %s: alive=%s", + engine_id, thread.is_alive()) + + # check for timeout (threads running > 30 seconds) + thread_age = time.time() - getattr(thread, '_start_time', time.time()) + if thread.is_alive() and thread_age > 30.0: + logger.warning("Handshake thread for %s running %.1fs (hung?)", + engine_id, thread_age) + + if not thread.is_alive(): + logger.debug("Handshake completed for engine %s", engine_id) + completed_engines.append(engine_id) + + success = self._handshake_results.get(engine_id, False) + logger.debug("Handshake result for engine %s: success=%s", + engine_id, success) + if not success: + logger.warning("Handshake failed for engine %s", engine_id) + continue + + logger.debug("Handshake succeeded for engine %s", engine_id) + if engine_id in self._pending_requests: + pending_reqs = self._pending_requests[engine_id] + logger.debug( + "Handshake completed for %s, clearing %d pending requests " + "(will retry naturally on next start_load_kv)", + engine_id, len(pending_reqs)) + + # clear pending requests - they'll be retried naturally + # by the event loop on the next start_load_kv() call + del self._pending_requests[engine_id] + + for engine_id in completed_engines: + logger.debug("Cleaning up handshake thread for engine %s", + engine_id) + del self._handshake_threads[engine_id] + if engine_id in self._handshake_results: + del self._handshake_results[engine_id] + + def _is_request_pending_handshake(self, req_id: str) -> bool: + """Check if request is still pending handshake completion.""" + for engine_requests in self._pending_requests.values(): + for pending_req_id, _ in engine_requests: + if pending_req_id == req_id: + return True + return False + def start_load_kv(self, metadata: NixlConnectorMetadata): """ Start loading by triggering non-blocking nixl_xfer. We check for these trnxs to complete in each step(). """ + logger.debug("start_load_kv called with %d requests", len(metadata.requests)) for req_id, meta in metadata.requests.items(): + if (req_id in self._recving_transfers or + self._is_request_pending_handshake(req_id)): + logger.debug( + "Request %s already being processed, skipping", req_id) + continue + logger.debug( "start_load_kv for request %s from remote engine %s. " "Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id, meta.remote_engine_id, len(meta.local_block_ids), len(meta.remote_block_ids)) + + + if meta.remote_engine_id not in self._remote_agents: + logger.debug( + "Remote engine %s not registered for request %s, " + "starting handshake and deferring transfer", + meta.remote_engine_id, req_id) + + if meta.remote_engine_id not in self._handshake_threads: + logger.debug( + "Starting handshake thread for remote engine %s", + meta.remote_engine_id) + self._run_handshake_in_thread( + meta.remote_engine_id, meta.remote_host, + meta.remote_port) + else: + logger.debug( + "Handshake thread already exists for remote engine %s", + meta.remote_engine_id) + + if meta.remote_engine_id not in self._pending_requests: + self._pending_requests[meta.remote_engine_id] = [] + self._pending_requests[meta.remote_engine_id].append( + (req_id, meta)) + + logger.debug( + "Request %s marked as pending handshake for engine %s", + req_id, meta.remote_engine_id) + continue + + logger.debug("Remote agent available for %s, calling _read_blocks", + meta.remote_engine_id) self._read_blocks( request_id=req_id, dst_engine_id=meta.remote_engine_id, local_block_ids=meta.local_block_ids, remote_block_ids=meta.remote_block_ids, - remote_host=meta.remote_host, - remote_port=meta.remote_port, ) def _read_blocks( self, local_block_ids: list[int], remote_block_ids: list[int], - remote_host: str, - remote_port: int, dst_engine_id: str, request_id: str, ): - # NOTE(rob): this takes ~2s. We need to get this off the hotpath. - if dst_engine_id not in self._remote_agents: - asyncio.run(self._nixl_handshake(remote_host, remote_port)) - # NOTE(rob): having the staging blocks be on the READER side is # not going to work well (since we will have to call rearrange tensors). # after we detect the txn is complete (which means we cannot make the diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 4611c7bfe17..f7e232fe03c 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -331,9 +331,6 @@ def schedule(self) -> SchedulerOutput: if is_ready: request.status = RequestStatus.WAITING else: - logger.debug( - "%s is still in WAITING_FOR_REMOTE_KVS state.", - request.request_id) self.waiting.popleft() skipped_waiting_requests.appendleft(request) continue From a0c03cd2da161be989697a2c7d2f6fd61989526b Mon Sep 17 00:00:00 2001 From: Will Eaton Date: Wed, 11 Jun 2025 09:48:12 -0400 Subject: [PATCH 03/33] more simple immediate retry Signed-off-by: Will Eaton --- .../kv_connector/v1/nixl_connector.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 08554f03777..00dbbb24b6f 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -923,12 +923,22 @@ def _process_completed_handshakes(self): if engine_id in self._pending_requests: pending_reqs = self._pending_requests[engine_id] logger.debug( - "Handshake completed for %s, clearing %d pending requests " - "(will retry naturally on next start_load_kv)", + "Handshake completed for %s, immediately retrying %d pending requests", engine_id, len(pending_reqs)) - # clear pending requests - they'll be retried naturally - # by the event loop on the next start_load_kv() call + for req_id, meta in pending_reqs: + logger.debug("Immediately retrying request %s for engine %s", + req_id, engine_id) + try: + self._read_blocks( + request_id=req_id, + dst_engine_id=meta.remote_engine_id, + local_block_ids=meta.local_block_ids, + remote_block_ids=meta.remote_block_ids, + ) + except Exception as e: + logger.error("Failed to retry request %s: %s", req_id, e) + del self._pending_requests[engine_id] for engine_id in completed_engines: From cf616b51f279605e62eb0ec788410e51d902ab89 Mon Sep 17 00:00:00 2001 From: Will Eaton Date: Wed, 11 Jun 2025 14:56:52 -0400 Subject: [PATCH 04/33] fix bad merge Signed-off-by: Will Eaton --- .../kv_transfer/kv_connector/v1/nixl_connector.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 00dbbb24b6f..6ad4ad6dfe8 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -1,11 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 -<<<<<<< HEAD # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import asyncio -======= import base64 import json ->>>>>>> a1eaf5a5e (attempt to background agent registration) import math import threading import time @@ -15,7 +11,6 @@ from typing import TYPE_CHECKING, Any, Optional from urllib.request import Request, urlopen from urllib.error import URLError, HTTPError -from urllib.parse import urljoin import torch from vllm import envs From 505f586ec96287160da64bef4e837faea3c21401 Mon Sep 17 00:00:00 2001 From: Will Eaton Date: Mon, 16 Jun 2025 14:12:57 -0400 Subject: [PATCH 05/33] implement nicks suggestions Signed-off-by: Will Eaton --- vllm/config.py | 8 +- .../kv_transfer/kv_connector/v1/base.py | 72 ++------- .../kv_connector/v1/nixl_connector.py | 139 ++++++++---------- vllm/v1/engine/core.py | 18 ++- vllm/v1/utils.py | 6 +- vllm/v1/worker/gpu_worker.py | 42 ++---- 6 files changed, 110 insertions(+), 175 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 6536afa007d..30bc7fc734e 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -10,7 +10,7 @@ import textwrap import uuid import warnings -from collections import Counter, defaultdict +from collections import Counter from contextlib import contextmanager from dataclasses import (MISSING, Field, asdict, field, fields, is_dataclass, replace) @@ -1513,10 +1513,8 @@ class CacheConfig: num_cpu_blocks: Optional[int] = field(default=None, init=False) """The number of blocks to allocate for CPU memory.""" - transfer_handshake_metadata: dict[int, dict[int, - KVConnectorHandshakeMetadata]] = field( - default_factory=lambda: defaultdict(dict), - init=False) + transfer_handshake_metadata: Optional[dict[int, dict[ + int, KVConnectorHandshakeMetadata]]] = field(default=None, init=False) """Metadata for the KV connector handshake.""" def compute_hash(self) -> str: diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 7b9b2ca1036..64ec33cda35 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -1,5 +1,4 @@ # SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ KVConnectorBase_V1 Class for Distributed KV Cache & Hidden State communication in vLLM v1 @@ -8,15 +7,9 @@ Scheduler-side: runs in the scheduler, binds metadata, which is used by the worker-side to load/save KV cache. get_num_new_matched_tokens() - get number of new tokens - that exist in the remote KV cache. Might be called multiple - times for a given request and should be side-effect free. + that exist in the remote KV cache update_state_after_alloc() - update KVConnector state after temporary buffer alloc by the CacheManager. - request_finished() - called when a request is finished, with - the computed kv cache blocks for the request. - Returns whether KV cache should be freed now or will be - freed asynchronously and optionally returns KV transfer - params. Worker-side: runs in each worker, loads/saves KV cache to/from the Connector based on the metadata. @@ -25,17 +18,14 @@ save_kv_layer() - starts saving KV for layer i (maybe async) wait_for_save() - blocks until all saves are done - - get_finished() - called with ids of finished requests, returns - ids of requests that have completed async sending/recving. """ import enum from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Callable, Optional -import torch import msgspec +import torch from pydantic_core import core_schema from vllm.logger import init_logger @@ -64,7 +54,7 @@ class KVConnectorMetadata: Abstract Metadata used to communicate between the Scheduler KVConnector and Worker KVConnector. """ - + def __init__(self): pass @@ -75,21 +65,20 @@ class KVConnectorHandshakeMetadata( # required for @cached_property. dict=True): """ - Metadata optionally used for out of band connector handshake between P/D workers. + Metadata optionally used for out of band connector handshake between + P/D workers. """ connector_type: str = "base" @classmethod def __get_pydantic_core_schema__( - cls, - _source_type: Any, - _handler: Callable[[Any], core_schema.CoreSchema] + cls, _source_type: Any, _handler: Callable[[Any], + core_schema.CoreSchema] ) -> core_schema.CoreSchema: """bridge msgspec.Struct with pydantic for schema generation""" return core_schema.no_info_after_validator_function( - cls, - core_schema.dict_schema() - ) + cls, core_schema.dict_schema()) + class KVConnectorTransferMetadata( msgspec.Struct, @@ -101,13 +90,11 @@ class KVConnectorTransferMetadata( tensor_parallel_rank: int data_parallel_rank: int content: Optional[dict] - + class KVConnectorBase_V1(ABC): - def __init__(self, - vllm_config: "VllmConfig", - role: KVConnectorRole): + def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): logger.warning( "Initializing KVConnectorBase_V1. This API is experimental and " "subject to change in the future as we iterate the design.") @@ -116,7 +103,6 @@ def __init__(self, self._role = role self._handshake_metadata: Optional[KVConnectorHandshakeMetadata] = None - @property def role(self) -> KVConnectorRole: return self._role @@ -235,31 +221,14 @@ def get_finished( finished generating tokens. Returns: - ids of requests that have finished asynchronous transfer - (requests that previously returned True from request_finished()), + ids of requests that have finished asynchronous transfer, tuple of (sending/saving ids, recving/loading ids). The finished saves/sends req ids must belong to a set provided in a call to this method (this call or a prior one). """ return None, None - def set_handshake_metadata( - self, handshake_metadata: KVConnectorHandshakeMetadata) -> None: - """ - Set the handshake metadata for the connector. - - This metadata is used for out-of-band connector handshake - between P/D workers. - - Args: - handshake_metadata (KVConnectorHandshakeMetadata): the handshake - metadata. - """ - self._handshake_metadata = handshake_metadata - - - def get_handshake_metadata( - self) -> Optional[KVConnectorHandshakeMetadata]: + def get_handshake_metadata(self) -> Optional[KVConnectorHandshakeMetadata]: """ Get the handshake metadata for the connector. @@ -292,8 +261,7 @@ def get_num_new_matched_tokens( - The number of tokens that can be loaded from the external KV cache beyond what is already computed. - `True` if external KV cache tokens will be loaded - asynchronously (between scheduler steps). Must be - 'False' if the first element is 0. + asynchronously (between scheduler steps). """ pass @@ -303,18 +271,6 @@ def update_state_after_alloc(self, request: "Request", num_external_tokens: int): """ Update KVConnector state after block allocation. - - If get_num_new_matched_tokens previously returned True for a - request, this function may be called twice for that same request - - first when blocks are allocated for the connector tokens to be - asynchronously loaded into, and second when any additional blocks - are allocated, after the load/transfer is complete. - - Args: - request (Request): the request object. - blocks (KVCacheBlocks): the blocks allocated for the request. - num_external_tokens (int): the number of tokens that will be - loaded from the external KV cache. """ pass diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 6ad4ad6dfe8..b3d9961c493 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -9,8 +9,9 @@ from collections import defaultdict from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Optional +from urllib.error import HTTPError, URLError from urllib.request import Request, urlopen -from urllib.error import URLError, HTTPError + import torch from vllm import envs @@ -25,7 +26,6 @@ from vllm.logger import init_logger from vllm.platforms import _Backend from vllm.utils import build_uri, round_down - from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.request import RequestStatus @@ -378,12 +378,10 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # [req_id -> list[handle]] self._recving_transfers = defaultdict[str, list[Transfer]](list) - # Pending requests waiting for handshake completion # [engine_id -> list[(req_id, meta)]] self._pending_requests: dict[str, list[tuple[str, ReqMeta]]] = {} - # Complete transfer tracker. Used by the rank 0 to track finished # transactions on ranks 1 to N-1. # [req_id -> count] @@ -394,7 +392,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # Background handshake threads for remote engines self._handshake_threads: dict[str, threading.Thread] = {} - + # Thread results for handshake completion tracking self._handshake_results: dict[str, bool] = {} @@ -427,7 +425,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): def _run_handshake_in_thread(self, engine_id: str, host: str, port: int): """Run handshake in background thread.""" - + def handshake_worker(): logger.debug("Starting handshake worker for engine %s", engine_id) try: @@ -436,10 +434,12 @@ def handshake_worker(): logger.debug("Handshake succeeded for engine %s", engine_id) except Exception as e: self._handshake_results[engine_id] = False - logger.warning("Handshake failed for engine %s: %s", engine_id, e) + logger.warning("Handshake failed for engine %s: %s", engine_id, + e) finally: - logger.debug("Handshake worker finished for engine %s", engine_id) - + logger.debug("Handshake worker finished for engine %s", + engine_id) + thread = threading.Thread(target=handshake_worker, daemon=True) thread._start_time = time.time() # track when thread started self._handshake_threads[engine_id] = thread @@ -465,7 +465,6 @@ def _nixl_handshake(self, host: str, port: int): logger.error("Failed to fetch metadata from %s: %s", url, e) raise - remote_tp_size = len(res.keys()) # Default case is that the remote TP size is 1, so we can # directly access the metadata. @@ -477,35 +476,29 @@ def _nixl_handshake(self, host: str, port: int): tp_ratio = self._tp_size[self.engine_id] // remote_tp_size p_remote_rank = self.tp_rank // tp_ratio if p_remote_rank > 0: - metadata_bytes = res.get(str(p_remote_rank), {}).get("0", {}).get( - "agent_metadata", None) + metadata_bytes = res.get(str(p_remote_rank), + {}).get("0", + {}).get("agent_metadata", None) if metadata_bytes is not None: # Reconstruct NixlAgentMetadata from JSON response # agent_metadata is base64-encoded binary data, not msgpack metadata = NixlAgentMetadata( - engine_id=tp_data["engine_id"], - agent_metadata=base64.b64decode(metadata_bytes), - kv_caches_base_addr=tp_data["kv_caches_base_addr"], - num_blocks=tp_data["num_blocks"], - tp_size=tp_data["tp_size"], - block_len=tp_data["block_len"], - attn_backend_name=tp_data["attn_backend_name"], - ) + agent_metadata=base64.b64decode(metadata_bytes), **tp_data) # Register Remote agent. - logger.debug("About to register remote agent for engine %s", - metadata.engine_id) + logger.debug("About to register remote agent for engine %s", + metadata.engine_id) pre_register = time.perf_counter() self.add_remote_agent(metadata, remote_tp_rank=p_remote_rank) agent_time = time.perf_counter() - logger.debug("Finished registering remote agent for engine %s", - metadata.engine_id) + logger.debug("Finished registering remote agent for engine %s", + metadata.engine_id) - logger.debug("NIXL handshake: get metadata took: %s", - pre_register - start_time) - logger.debug("NIXL handshake: add agent took: %s", - agent_time - pre_register) + logger.debug("NIXL handshake: get metadata took: %s", + pre_register - start_time) + logger.debug("NIXL handshake: add agent took: %s", + agent_time - pre_register) else: # If metadata_bytes is None, it means the remote agent # is not using NIXL, so we can skip the handshake. @@ -513,7 +506,6 @@ def _nixl_handshake(self, host: str, port: int): "Received None metadata from %s:%s, skipping NIXL handshake", host, port) - def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data in nixl.""" @@ -642,10 +634,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): num_blocks=self.num_blocks, tp_size=self.world_size, block_len=self.block_len, - attn_backend_name=self.backend_name - ) - - + attn_backend_name=self.backend_name) def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata, @@ -786,7 +775,7 @@ def get_finished(self) -> tuple[set[str], set[str]]: """ # Process any completed handshakes first self._process_completed_handshakes() - + done_sending = self._get_new_notifs() done_recving = self._pop_done_transfers(self._recving_transfers) if len(done_sending) > 0 or len(done_recving) > 0: @@ -794,7 +783,6 @@ def get_finished(self) -> tuple[set[str], set[str]]: "Rank %s, get_finished: %s requests done sending " "and %s requests done recving", self.tp_rank, len(done_sending), len(done_recving)) - if self.world_size == 1: return done_sending, done_recving @@ -885,45 +873,49 @@ def _pop_done_transfers( def _process_completed_handshakes(self): """Process completed handshakes and mark remote agents as ready.""" - + # debug: log current state if self._handshake_threads: - logger.debug("Processing handshakes: %d active threads, %d pending", - len(self._handshake_threads), - sum(len(reqs) for reqs in self._pending_requests.values())) - + logger.debug( + "Processing handshakes: %d active threads, %d pending", + len(self._handshake_threads), + sum(len(reqs) for reqs in self._pending_requests.values())) + completed_engines = [] for engine_id, thread in list(self._handshake_threads.items()): - logger.debug("Checking handshake thread for engine %s: alive=%s", - engine_id, thread.is_alive()) - + logger.debug("Checking handshake thread for engine %s: alive=%s", + engine_id, thread.is_alive()) + # check for timeout (threads running > 30 seconds) - thread_age = time.time() - getattr(thread, '_start_time', time.time()) + thread_age = time.time() - getattr(thread, '_start_time', + time.time()) if thread.is_alive() and thread_age > 30.0: - logger.warning("Handshake thread for %s running %.1fs (hung?)", - engine_id, thread_age) - + logger.warning("Handshake thread for %s running %.1fs (hung?)", + engine_id, thread_age) + if not thread.is_alive(): logger.debug("Handshake completed for engine %s", engine_id) completed_engines.append(engine_id) - + success = self._handshake_results.get(engine_id, False) - logger.debug("Handshake result for engine %s: success=%s", - engine_id, success) + logger.debug("Handshake result for engine %s: success=%s", + engine_id, success) if not success: logger.warning("Handshake failed for engine %s", engine_id) continue - + logger.debug("Handshake succeeded for engine %s", engine_id) if engine_id in self._pending_requests: pending_reqs = self._pending_requests[engine_id] logger.debug( - "Handshake completed for %s, immediately retrying %d pending requests", + "Handshake completed for %s, immediately retrying %d " \ + "pending requests", engine_id, len(pending_reqs)) - + for req_id, meta in pending_reqs: - logger.debug("Immediately retrying request %s for engine %s", - req_id, engine_id) + logger.debug( + "Immediately retrying request %s for engine %s", + req_id, engine_id) try: self._read_blocks( request_id=req_id, @@ -932,12 +924,13 @@ def _process_completed_handshakes(self): remote_block_ids=meta.remote_block_ids, ) except Exception as e: - logger.error("Failed to retry request %s: %s", req_id, e) - + logger.error("Failed to retry request %s: %s", + req_id, e) + del self._pending_requests[engine_id] - + for engine_id in completed_engines: - logger.debug("Cleaning up handshake thread for engine %s", + logger.debug("Cleaning up handshake thread for engine %s", engine_id) del self._handshake_threads[engine_id] if engine_id in self._handshake_results: @@ -956,51 +949,45 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): Start loading by triggering non-blocking nixl_xfer. We check for these trnxs to complete in each step(). """ - logger.debug("start_load_kv called with %d requests", len(metadata.requests)) + logger.debug("start_load_kv called with %d requests", + len(metadata.requests)) for req_id, meta in metadata.requests.items(): - if (req_id in self._recving_transfers or - self._is_request_pending_handshake(req_id)): - logger.debug( - "Request %s already being processed, skipping", req_id) - continue - logger.debug( "start_load_kv for request %s from remote engine %s. " "Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id, meta.remote_engine_id, len(meta.local_block_ids), len(meta.remote_block_ids)) - if meta.remote_engine_id not in self._remote_agents: logger.debug( "Remote engine %s not registered for request %s, " "starting handshake and deferring transfer", meta.remote_engine_id, req_id) - + if meta.remote_engine_id not in self._handshake_threads: logger.debug( - "Starting handshake thread for remote engine %s", + "Starting handshake thread for remote engine %s", meta.remote_engine_id) - self._run_handshake_in_thread( - meta.remote_engine_id, meta.remote_host, - meta.remote_port) + self._run_handshake_in_thread(meta.remote_engine_id, + meta.remote_host, + meta.remote_port) else: logger.debug( - "Handshake thread already exists for remote engine %s", + "Handshake thread already exists for remote engine %s", meta.remote_engine_id) - + if meta.remote_engine_id not in self._pending_requests: self._pending_requests[meta.remote_engine_id] = [] self._pending_requests[meta.remote_engine_id].append( (req_id, meta)) - + logger.debug( "Request %s marked as pending handshake for engine %s", req_id, meta.remote_engine_id) continue logger.debug("Remote agent available for %s, calling _read_blocks", - meta.remote_engine_id) + meta.remote_engine_id) self._read_blocks( request_id=req_id, dst_engine_id=meta.remote_engine_id, diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 5610d1fed2b..b93a1dc6861 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -80,7 +80,8 @@ def __init__(self, # Setup KV Caches and update CacheConfig after profiling. num_gpu_blocks, num_cpu_blocks, kv_cache_config, \ - transfer_handshake_metadata = self._initialize_kv_caches(vllm_config) + transfer_handshake_metadata = self._initialize_kv_caches( + vllm_config) vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks @@ -171,9 +172,12 @@ def _initialize_kv_caches( # Initialize kv cache and warmup the execution self.model_executor.initialize_from_config(kv_cache_configs) - # Collect KV connector xfer metadata from workers (after KV cache registration) + # Collect KV connector xfer metadata from workers + # (after KV cache registration) transfer_handshake_metadata = ( - self.model_executor.get_kv_connector_handshake_metadata()) + self.model_executor.get_kv_connector_handshake_metadata() + if self.vllm_config.cache_config.transfer_handshake_metadata else + None) elapsed = time.time() - start logger.info(("init engine (profile, create kv cache, " @@ -448,9 +452,9 @@ def _perform_handshake( } # Include KV connector metadata if available - if hasattr(self, - 'transfer_handshake_metadata') and self.transfer_handshake_metadata: - # self.transfer_handshake_metadata is a list of dicts from workers + if hasattr(self, 'transfer_handshake_metadata' + ) and self.transfer_handshake_metadata: + # self.transfer_handshake_metadata is list of dicts from workers # Each dict already has structure {tp_rank: {dp_rank: metadata}} # Merge all worker dicts into a single dict content = {} @@ -458,7 +462,7 @@ def _perform_handshake( if worker_dict is not None: content.update(worker_dict) handshake_message["transfer_handshake_metadata"] = content - + handshake_socket.send(msgspec.msgpack.encode(handshake_message)) @staticmethod diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 705d62404bf..0b8d35b23b5 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -537,12 +537,12 @@ def wait_for_engine_startup( num_gpu_blocks += msg["num_gpu_blocks"] cache_config.num_gpu_blocks = num_gpu_blocks # stash KV connector metadata in vllm_config if passed in. - if "transfer_handshake_metadata" in msg and msg["transfer_handshake_metadata"]: + if txfer_metadata := msg.get("transfer_handshake_metadata"): logger.debug( "Received transfer handshake metadata from engine %s: %s", - eng_index, msg["transfer_handshake_metadata"]) + eng_index, txfer_metadata) # Merge the received metadata with existing cache config - for tp_rank, dp_dict in msg["transfer_handshake_metadata"].items(): + for tp_rank, dp_dict in txfer_metadata.items(): for dp_rank, metadata in dp_dict.items(): cache_config.transfer_handshake_metadata[tp_rank][ dp_rank] = metadata diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 805d4bf1483..4aa6234043f 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -10,8 +10,6 @@ import torch.distributed import torch.nn as nn -from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole -from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorBase_V1 import vllm.envs as envs from vllm.config import VllmConfig from vllm.device_allocator.cumem import CuMemAllocator @@ -19,8 +17,8 @@ init_distributed_environment, set_custom_all_reduce) from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized, - is_v1_kv_transfer_group, - get_kv_transfer_group) + get_kv_transfer_group, + is_v1_kv_transfer_group) from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -240,30 +238,22 @@ def get_kv_connector_handshake_metadata(self) -> Optional[dict]: connector = get_kv_transfer_group() if not is_v1_kv_transfer_group(connector): + logger.warning("The KV connector is not a v1 connector. " + "This method is only supported for v1 connectors.") + return None + + metadata = connector.get_handshake_metadata() + if metadata is None: logger.warning( - "The KV connector is not a v1 connector. " - "This method is only supported for v1 connectors.") + "KV connector metadata is not available. " + "This may happen if the KV connector is not initialized " + "or the worker is not part of a disaggregated KV cache setup.") return None - - # Only return metadata if this is a worker role - if connector.role == KVConnectorRole.WORKER: - metadata = connector.get_handshake_metadata() - if metadata is None: - logger.warning( - "KV connector metadata is not available. " - "This may happen if the KV connector is not initialized " - "or the worker is not part of a disaggregated KV cache setup." - ) - return None - - tp_rank = get_tp_group().rank_in_group - dp_rank = self.vllm_config.parallel_config.data_parallel_rank_local - return { - tp_rank: { - dp_rank: msgspec.to_builtins(metadata) - } - } - + + tp_rank = get_tp_group().rank_in_group + dp_rank = self.vllm_config.parallel_config.data_parallel_rank_local + return {tp_rank: {dp_rank: msgspec.to_builtins(metadata)}} + return None def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: From a1a09188a4a94aa6559de6cda3d9f6bedd734442 Mon Sep 17 00:00:00 2001 From: Will Eaton Date: Mon, 16 Jun 2025 15:27:45 -0400 Subject: [PATCH 06/33] change retry logic and move to scheduler; simply background handshake processing Signed-off-by: Will Eaton --- .../kv_transfer/kv_connector/v1/base.py | 68 ++++ .../kv_connector/v1/multi_connector.py | 9 + .../kv_connector/v1/nixl_connector.py | 338 +++++++++--------- vllm/v1/core/sched/scheduler.py | 23 +- vllm/v1/outputs.py | 4 +- vllm/v1/worker/gpu_model_runner.py | 13 +- 6 files changed, 277 insertions(+), 178 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 64ec33cda35..be085ed484b 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -22,6 +22,7 @@ import enum from abc import ABC, abstractmethod +from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Callable, Optional import msgspec @@ -41,6 +42,64 @@ logger = init_logger(__name__) +@dataclass +class KVTransferFinishedResult: + """Result of KV transfer get_finished operation.""" + + finished_sending: set[str] + finished_recving: set[str] + pending_handshake: set[str] + + def has_any_finished(self) -> bool: + """Check if any requests finished or are pending.""" + return bool(self.finished_sending or self.finished_recving + or self.pending_handshake) + + def is_empty(self) -> bool: + """Check if all sets are empty.""" + return not self.has_any_finished() + + def get_all_finished_req_ids(self) -> set[str]: + """Get all request IDs that have finished (sending or receiving).""" + return self.finished_sending.union(self.finished_recving) + + def merge(self, + other: 'KVTransferFinishedResult') -> 'KVTransferFinishedResult': + """Merge with another result, combining all sets.""" + return KVTransferFinishedResult( + finished_sending=self.finished_sending.union( + other.finished_sending), + finished_recving=self.finished_recving.union( + other.finished_recving), + pending_handshake=self.pending_handshake.union( + other.pending_handshake)) + + @classmethod + def empty(cls) -> 'KVTransferFinishedResult': + """Create an empty result.""" + return cls(finished_sending=set(), + finished_recving=set(), + pending_handshake=set()) + + @classmethod + def from_tuple( + cls, result_tuple: tuple[set[str], set[str], set[str]] + ) -> 'KVTransferFinishedResult': + """Create from the old tuple format for backward compatibility.""" + finished_sending, finished_recving, pending_handshake = result_tuple + return cls(finished_sending=finished_sending, + finished_recving=finished_recving, + pending_handshake=pending_handshake) + + def to_tuple(self) -> tuple[set[str], set[str], set[str]]: + """Convert to the old tuple format for backward compatibility.""" + return ( + self.finished_sending, + self.finished_recving, + self.pending_handshake, + ) + + class KVConnectorRole(enum.Enum): # Connector running in the scheduler process SCHEDULER = 0 @@ -228,6 +287,15 @@ def get_finished( """ return None, None + def get_pending_handshake_req_ids(self) -> Optional[set[str]]: + """ + Get request IDs that are currently pending handshake completion. + + Returns: + Set of request IDs waiting for handshake, or None if not applicable. + """ + return None + def get_handshake_metadata(self) -> Optional[KVConnectorHandshakeMetadata]: """ Get the handshake metadata for the connector. diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index be3c2339941..58ba71306dc 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -129,6 +129,15 @@ def get_finished( return finished_sending or None, finished_recving or None + def get_pending_handshake_req_ids(self) -> Optional[set[str]]: + """Get request IDs that are currently pending handshake completion.""" + pending_handshake: set[str] = set() + for c in self._connectors: + connector_pending = c.get_pending_handshake_req_ids() + if connector_pending: + pending_handshake.update(connector_pending) + return pending_handshake or None + # ============================== # Scheduler-side methods # ============================== diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index b3d9961c493..bdb851752bd 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -3,10 +3,12 @@ import base64 import json import math +import queue import threading import time import uuid from collections import defaultdict +from concurrent.futures import Future, ThreadPoolExecutor from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Optional from urllib.error import HTTPError, URLError @@ -19,7 +21,7 @@ from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorBase_V1, KVConnectorHandshakeMetadata, KVConnectorMetadata, - KVConnectorRole) + KVConnectorRole, KVTransferFinishedResult) from vllm.distributed.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, get_tp_group) @@ -154,7 +156,14 @@ def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]: """Get the finished recving and sending requests.""" assert self.connector_worker is not None - return self.connector_worker.get_finished() + result = self.connector_worker.get_finished() + # Store pending handshake for the new method to retrieve + self._pending_handshake_req_ids = result.pending_handshake + return result.finished_sending, result.finished_recving + + def get_pending_handshake_req_ids(self) -> Optional[set[str]]: + """Get request IDs that are currently pending handshake completion.""" + return getattr(self, '_pending_handshake_req_ids', None) def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: @@ -378,10 +387,6 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # [req_id -> list[handle]] self._recving_transfers = defaultdict[str, list[Transfer]](list) - # Pending requests waiting for handshake completion - # [engine_id -> list[(req_id, meta)]] - self._pending_requests: dict[str, list[tuple[str, ReqMeta]]] = {} - # Complete transfer tracker. Used by the rank 0 to track finished # transactions on ranks 1 to N-1. # [req_id -> count] @@ -391,10 +396,13 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): int] = defaultdict(lambda: 0) # Background handshake threads for remote engines - self._handshake_threads: dict[str, threading.Thread] = {} - + self._executor = ThreadPoolExecutor( + max_workers=4, thread_name_prefix="nixl-handshake") # Thread results for handshake completion tracking - self._handshake_results: dict[str, bool] = {} + self._handshake_futures: dict[str, Future] = {} + self._pending_requests: dict[str, list[tuple[str, ReqMeta]]] = {} + self._ready_requests: queue.Queue[tuple[str, ReqMeta]] = queue.Queue() + self._lock = threading.Lock() self.vllm_config = vllm_config self.block_size = vllm_config.cache_config.block_size @@ -423,28 +431,49 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # finish reading before safely freeing the blocks. self.consumer_notification_counts_by_req = defaultdict[str, int](int) - def _run_handshake_in_thread(self, engine_id: str, host: str, port: int): - """Run handshake in background thread.""" + def __del__(self): + """Cleanup ThreadPoolExecutor on destruction.""" + if hasattr(self, '_executor'): + self._executor.shutdown(wait=False) - def handshake_worker(): - logger.debug("Starting handshake worker for engine %s", engine_id) + def _start_handshake(self, engine_id: str, host: str, port: int): + """Start handshake using ThreadPoolExecutor.""" + if engine_id in self._handshake_futures: + return + + logger.debug("Starting handshake for engine %s", engine_id) + future = self._executor.submit(self._nixl_handshake, host, port) + self._handshake_futures[engine_id] = future + + # Set up callback to handle completion + def on_handshake_complete(fut: Future): try: - self._nixl_handshake(host, port) - self._handshake_results[engine_id] = True + fut.result() # This will raise if the handshake failed logger.debug("Handshake succeeded for engine %s", engine_id) + with self._lock: + # Remove from futures dict - requests will remain pending + # and be handled by scheduler retry logic + if engine_id in self._handshake_futures: + del self._handshake_futures[engine_id] + logger.debug("Handshake completed for engine %s. " + "Pending requests will be retried by" \ + "scheduler.", engine_id) except Exception as e: - self._handshake_results[engine_id] = False logger.warning("Handshake failed for engine %s: %s", engine_id, e) - finally: - logger.debug("Handshake worker finished for engine %s", - engine_id) - - thread = threading.Thread(target=handshake_worker, daemon=True) - thread._start_time = time.time() # track when thread started - self._handshake_threads[engine_id] = thread - thread.start() - return thread + with self._lock: + # clean up failed handshake - requests will remain pending + # and be reported to scheduler for retry + if engine_id in self._handshake_futures: + del self._handshake_futures[engine_id] + if engine_id in self._pending_requests: + failed_reqs = self._pending_requests[engine_id] + logger.warning( + "Handshake failed for engine %s, leaving" + "%d requests pending for scheduler retry", + engine_id, len(failed_reqs)) + + future.add_done_callback(on_handshake_complete) def _nixl_handshake(self, host: str, port: int): """Do a NIXL handshake with a remote instance.""" @@ -761,9 +790,9 @@ def add_remote_agent(self, engine_id] = self.nixl_wrapper.prep_xfer_dlist( self._remote_agents[engine_id][remote_tp_rank], descs) - def get_finished(self) -> tuple[set[str], set[str]]: + def get_finished(self) -> KVTransferFinishedResult: """ - Get requests that are done sending or recving. + Get requests that are done sending, done recving, and pending handshake. In TP>1 setup, each rank exchanges KVs with its counterpart ranks independently. get_finished() runs in a worker creates @@ -773,61 +802,84 @@ def get_finished(self) -> tuple[set[str], set[str]]: to Rank 0 once their transaction is done + Rank 0 returns finished sets to Scheduler only once all ranks are done. """ - # Process any completed handshakes first - self._process_completed_handshakes() + # process requests that became ready after handshake completion + self._process_ready_requests() done_sending = self._get_new_notifs() done_recving = self._pop_done_transfers(self._recving_transfers) - if len(done_sending) > 0 or len(done_recving) > 0: + + with self._lock: + pending_handshake = set() + for pending_reqs in self._pending_requests.values(): + for req_id, _ in pending_reqs: + pending_handshake.add(req_id) + + local_result = KVTransferFinishedResult( + finished_sending=done_sending, + finished_recving=done_recving, + pending_handshake=pending_handshake) + + if not local_result.is_empty(): logger.debug( - "Rank %s, get_finished: %s requests done sending " - "and %s requests done recving", self.tp_rank, - len(done_sending), len(done_recving)) + "Rank %s, get_finished: %s requests done sending, " + "%s requests done recving, %s pending handshake", self.tp_rank, + len(done_sending), len(done_recving), len(pending_handshake)) if self.world_size == 1: - return done_sending, done_recving + return local_result + + return self._coordinate_multi_rank_results(local_result) + + def _coordinate_multi_rank_results( + self, local_result: KVTransferFinishedResult + ) -> KVTransferFinishedResult: + """Coordinate results across multiple TP ranks.""" - # Rank 0: get finished from all other ranks. if self.tp_rank == 0: - for req_id in done_sending: + # Rank 0 collects results from all other ranks. + for req_id in local_result.finished_sending: self._done_sending_count[req_id] += 1 - for req_id in done_recving: + for req_id in local_result.finished_recving: self._done_recving_count[req_id] += 1 - # Keep track of how many other ranks have finished. - other_ranks_finished_ids: list[str] = [] + all_pending_handshake = local_result.pending_handshake.copy() for i in range(1, self.world_size): - other_ranks_finished_ids.extend( - self.tp_group.recv_object(src=i)) - for req_id in other_ranks_finished_ids: - if (req_id in self._done_recving_count - or req_id in self._recving_transfers): - self._done_recving_count[req_id] += 1 - else: - self._done_sending_count[req_id] += 1 - - # Return ids that finished on all ranks to the scheduler. - all_done_recving: set[str] = set() - for req_id in list(self._done_recving_count.keys()): - if self._done_recving_count[req_id] == self.world_size: - del self._done_recving_count[req_id] - all_done_recving.add(req_id) - - all_done_sending: set[str] = set() - for req_id in list(self._done_sending_count.keys()): - if self._done_sending_count[req_id] == self.world_size: - del self._done_sending_count[req_id] - all_done_sending.add(req_id) - - return all_done_sending, all_done_recving - - # Ranks 1 to N-1: send finished ids to Rank 0. + rank_data = self.tp_group.recv_object(src=i) + other_rank_result = KVTransferFinishedResult.from_tuple( + rank_data) + + for req_id in other_rank_result.get_all_finished_req_ids(): + if (req_id in self._done_recving_count + or req_id in self._recving_transfers): + self._done_recving_count[req_id] += 1 + else: + self._done_sending_count[req_id] += 1 + + all_pending_handshake.update( + other_rank_result.pending_handshake) + + all_done_recving = self._get_globally_finished_requests( + self._done_recving_count) + all_done_sending = self._get_globally_finished_requests( + self._done_sending_count) + + return KVTransferFinishedResult( + finished_sending=all_done_sending, + finished_recving=all_done_recving, + pending_handshake=all_pending_handshake) else: - finished_req_ids = list(done_recving.union(done_sending)) - self.tp_group.send_object(finished_req_ids, dst=0) - - # Unused as only Rank 0 results are sent to scheduler. - return done_sending, done_recving + self.tp_group.send_object(local_result.to_tuple(), dst=0) + return local_result + + def _get_globally_finished_requests( + self, counter_dict: dict[str, int]) -> set[str]: + """Get request IDs that have finished on all ranks.""" + finished_req_ids = set() + for req_id in list(counter_dict.keys()): + if counter_dict[req_id] == self.world_size: + del counter_dict[req_id] + finished_req_ids.add(req_id) + return finished_req_ids def _get_new_notifs(self) -> set[str]: """ @@ -871,78 +923,18 @@ def _pop_done_transfers( xfer_state) return done_req_ids - def _process_completed_handshakes(self): - """Process completed handshakes and mark remote agents as ready.""" - - # debug: log current state - if self._handshake_threads: - logger.debug( - "Processing handshakes: %d active threads, %d pending", - len(self._handshake_threads), - sum(len(reqs) for reqs in self._pending_requests.values())) - - completed_engines = [] - for engine_id, thread in list(self._handshake_threads.items()): - logger.debug("Checking handshake thread for engine %s: alive=%s", - engine_id, thread.is_alive()) - - # check for timeout (threads running > 30 seconds) - thread_age = time.time() - getattr(thread, '_start_time', - time.time()) - if thread.is_alive() and thread_age > 30.0: - logger.warning("Handshake thread for %s running %.1fs (hung?)", - engine_id, thread_age) - - if not thread.is_alive(): - logger.debug("Handshake completed for engine %s", engine_id) - completed_engines.append(engine_id) - - success = self._handshake_results.get(engine_id, False) - logger.debug("Handshake result for engine %s: success=%s", - engine_id, success) - if not success: - logger.warning("Handshake failed for engine %s", engine_id) - continue - - logger.debug("Handshake succeeded for engine %s", engine_id) - if engine_id in self._pending_requests: - pending_reqs = self._pending_requests[engine_id] - logger.debug( - "Handshake completed for %s, immediately retrying %d " \ - "pending requests", - engine_id, len(pending_reqs)) - - for req_id, meta in pending_reqs: - logger.debug( - "Immediately retrying request %s for engine %s", - req_id, engine_id) - try: - self._read_blocks( - request_id=req_id, - dst_engine_id=meta.remote_engine_id, - local_block_ids=meta.local_block_ids, - remote_block_ids=meta.remote_block_ids, - ) - except Exception as e: - logger.error("Failed to retry request %s: %s", - req_id, e) - - del self._pending_requests[engine_id] - - for engine_id in completed_engines: - logger.debug("Cleaning up handshake thread for engine %s", - engine_id) - del self._handshake_threads[engine_id] - if engine_id in self._handshake_results: - del self._handshake_results[engine_id] - - def _is_request_pending_handshake(self, req_id: str) -> bool: - """Check if request is still pending handshake completion.""" - for engine_requests in self._pending_requests.values(): - for pending_req_id, _ in engine_requests: - if pending_req_id == req_id: - return True - return False + def _process_ready_requests(self): + """Process requests that are ready after handshake completion. + + Note: With scheduler-based retry logic, this method is simplified + as automatic retries are handled by the scheduler. + """ + # Clear any remaining items in the ready queue to prevent memory leaks + while True: + try: + self._ready_requests.get_nowait() + except queue.Empty: + break def start_load_kv(self, metadata: NixlConnectorMetadata): """ @@ -951,49 +943,47 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): """ logger.debug("start_load_kv called with %d requests", len(metadata.requests)) + for req_id, meta in metadata.requests.items(): logger.debug( "start_load_kv for request %s from remote engine %s. " - "Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id, + "Num local_block_ids: %s. Num remote_block_ids: %s.", req_id, meta.remote_engine_id, len(meta.local_block_ids), len(meta.remote_block_ids)) - if meta.remote_engine_id not in self._remote_agents: - logger.debug( - "Remote engine %s not registered for request %s, " - "starting handshake and deferring transfer", - meta.remote_engine_id, req_id) - - if meta.remote_engine_id not in self._handshake_threads: + with self._lock: + if meta.remote_engine_id in self._remote_agents: logger.debug( - "Starting handshake thread for remote engine %s", + "Remote agent available for %s, calling _read_blocks", meta.remote_engine_id) - self._run_handshake_in_thread(meta.remote_engine_id, - meta.remote_host, - meta.remote_port) + self._read_blocks( + request_id=req_id, + dst_engine_id=meta.remote_engine_id, + local_block_ids=meta.local_block_ids, + remote_block_ids=meta.remote_block_ids, + ) + elif meta.remote_engine_id in self._handshake_futures: + logger.debug( + "Handshake in progress for engine %s, adding" + " request %s to pending", meta.remote_engine_id, + req_id) + if meta.remote_engine_id not in self._pending_requests: + self._pending_requests[meta.remote_engine_id] = [] + self._pending_requests[meta.remote_engine_id].append( + (req_id, meta)) else: logger.debug( - "Handshake thread already exists for remote engine %s", - meta.remote_engine_id) - - if meta.remote_engine_id not in self._pending_requests: - self._pending_requests[meta.remote_engine_id] = [] - self._pending_requests[meta.remote_engine_id].append( - (req_id, meta)) - - logger.debug( - "Request %s marked as pending handshake for engine %s", - req_id, meta.remote_engine_id) - continue - - logger.debug("Remote agent available for %s, calling _read_blocks", - meta.remote_engine_id) - self._read_blocks( - request_id=req_id, - dst_engine_id=meta.remote_engine_id, - local_block_ids=meta.local_block_ids, - remote_block_ids=meta.remote_block_ids, - ) + "Starting handshake for engine %s and adding" + " request %s to pending", meta.remote_engine_id, + req_id) + if meta.remote_engine_id not in self._pending_requests: + self._pending_requests[meta.remote_engine_id] = [] + self._pending_requests[meta.remote_engine_id].append( + (req_id, meta)) + self._start_handshake(meta.remote_engine_id, + meta.remote_host, meta.remote_port) + + self._process_ready_requests() def _read_blocks( self, diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index f7e232fe03c..9fe2557b673 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -14,8 +14,6 @@ KVConnectorFactory) from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1, KVConnectorRole) -from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - KVConnectorMetadata) from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, @@ -706,10 +704,14 @@ def update_from_output( logprobs = model_runner_output.logprobs prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict num_scheduled_tokens = scheduler_output.num_scheduled_tokens + pending_handshake_req_ids = ( + model_runner_output.pending_handshake_req_ids) new_running: list[Request] = [] outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) spec_decoding_stats: Optional[SpecDecodingStats] = None + num_requests_to_reschedule = 0 + num_tokens_to_reschedule = 0 # NOTE(woosuk): As len(self.running) can be up to 1K or more, the below # loop can be a performance bottleneck. We should do our best to avoid @@ -722,6 +724,17 @@ def update_from_output( new_running.append(request) continue + # Check if this request is pending handshake and needs to reschedule + if (pending_handshake_req_ids + and req_id in pending_handshake_req_ids): + num_requests_to_reschedule += 1 + num_tokens_to_reschedule += request.num_computed_tokens + # Reset computed tokens to force rescheduling from beginning + request.num_computed_tokens = 0 + num_tokens_to_reschedule -= request.num_computed_tokens + new_running.append(request) + continue + req_index = model_runner_output.req_id_to_index[req_id] generated_token_ids = sampled_token_ids[req_index] @@ -824,6 +837,12 @@ def update_from_output( if not stopped: new_running.append(request) + if num_requests_to_reschedule: + logger.info( + "Recovered from handshake failure: " + "%d request(s) rescheduled (%d tokens affected).", + num_requests_to_reschedule, num_tokens_to_reschedule) + # KV Connector: update state for finished KV Transfers. self._update_from_kv_xfer_finished(model_runner_output) diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 17a299d57cb..791ca730027 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -104,6 +104,7 @@ class ModelRunnerOutput: # [req_ids] finished_sending: Optional[set[str]] = None finished_recving: Optional[set[str]] = None + pending_handshake_req_ids: Optional[set[str]] = None EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[], @@ -113,4 +114,5 @@ class ModelRunnerOutput: logprobs=None, prompt_logprobs_dict={}, finished_sending=None, - finished_recving=None) + finished_recving=None, + pending_handshake_req_ids=None) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index b1bc727e1e8..bea7d17beaf 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1267,6 +1267,7 @@ def execute_model( self.maybe_wait_for_kv_save() finished_sending, finished_recving = ( self.get_finished_kv_transfers(scheduler_output)) + pending_handshake_req_ids = self.get_pending_handshake_req_ids() if self.use_aux_hidden_state_outputs: hidden_states, aux_hidden_states = model_output @@ -1505,6 +1506,7 @@ def execute_model( prompt_logprobs_dict=prompt_logprobs_dict, finished_sending=finished_sending, finished_recving=finished_recving, + pending_handshake_req_ids=pending_handshake_req_ids, ) def kv_connector_no_forward( @@ -1514,13 +1516,16 @@ def kv_connector_no_forward( self.maybe_setup_kv_connector(scheduler_output) finished_sending, finished_recving = ( self.get_finished_kv_transfers(scheduler_output)) + pending_handshake_req_ids = self.get_pending_handshake_req_ids() - if not finished_sending and not finished_recving: + if (not finished_sending and not finished_recving + and not pending_handshake_req_ids): return EMPTY_MODEL_RUNNER_OUTPUT output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) output.finished_sending = finished_sending output.finished_recving = finished_recving + output.pending_handshake_req_ids = pending_handshake_req_ids return output @staticmethod @@ -1553,6 +1558,12 @@ def get_finished_kv_transfers( scheduler_output.finished_req_ids) return None, None + @staticmethod + def get_pending_handshake_req_ids() -> Optional[set[str]]: + if has_kv_transfer_group(): + return get_kv_transfer_group().get_pending_handshake_req_ids() + return None + def generate_draft_token_ids( self, sampled_token_ids: list[list[int]], From 518a59f411b3e980d20115e3b54caf85c2111cfb Mon Sep 17 00:00:00 2001 From: Will Eaton Date: Mon, 16 Jun 2025 15:34:25 -0400 Subject: [PATCH 07/33] fixup by changing types at callsite Signed-off-by: Will Eaton --- .../kv_transfer/kv_connector/v1/base.py | 11 ++++---- .../kv_connector/v1/multi_connector.py | 27 ++++++++++++------- .../kv_connector/v1/nixl_connector.py | 15 ++++++----- vllm/v1/worker/gpu_model_runner.py | 5 +++- 4 files changed, 35 insertions(+), 23 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index be085ed484b..f1c1bf7edd9 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -272,20 +272,19 @@ def wait_for_save(self): """ pass - def get_finished( - self, finished_req_ids: set[str] - ) -> tuple[Optional[set[str]], Optional[set[str]]]: + def get_finished(self, + finished_req_ids: set[str]) -> KVTransferFinishedResult: """ Notifies worker-side connector ids of requests that have finished generating tokens. Returns: - ids of requests that have finished asynchronous transfer, - tuple of (sending/saving ids, recving/loading ids). + KVTransferFinishedResult containing sets of finished sending, + finished receiving, and pending handshake request IDs. The finished saves/sends req ids must belong to a set provided in a call to this method (this call or a prior one). """ - return None, None + return KVTransferFinishedResult.empty() def get_pending_handshake_req_ids(self) -> Optional[set[str]]: """ diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index 58ba71306dc..5ce2a6e1523 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -10,7 +10,8 @@ from vllm.distributed.kv_transfer.kv_connector.factory import ( KVConnectorFactory) from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole, + KVTransferFinishedResult) from vllm.logger import init_logger from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.sched.output import SchedulerOutput @@ -102,21 +103,27 @@ def wait_for_save(self): for c in self._connectors: c.wait_for_save() - def get_finished( - self, finished_req_ids: set[str] - ) -> tuple[Optional[set[str]], Optional[set[str]]]: + def get_finished(self, + finished_req_ids: set[str]) -> KVTransferFinishedResult: finished_sending: set[str] = set() finished_recving: set[str] = set() + pending_handshake: set[str] = set() + for c in self._connectors: - sending, recving = c.get_finished(finished_req_ids) - if not recving and not sending: + result = c.get_finished(finished_req_ids) + if result.is_empty(): continue + # Aggregate finished recving request ids. - finished_recving.update(recving or ()) + finished_recving.update(result.finished_recving) + + # Aggregate pending handshake request ids. + pending_handshake.update(result.pending_handshake) + # Aggregate finished sending request ids - only include # once we've drained the "extra" count (for cases where # more than one connector is async-saving the same request). - for req_id in sending or (): + for req_id in result.finished_sending: extra_pending = self._extra_async_saves.get(req_id) if extra_pending is None: finished_sending.add(req_id) @@ -127,7 +134,9 @@ def get_finished( else: self._extra_async_saves[req_id] = extra_pending - 1 - return finished_sending or None, finished_recving or None + return KVTransferFinishedResult(finished_sending=finished_sending, + finished_recving=finished_recving, + pending_handshake=pending_handshake) def get_pending_handshake_req_ids(self) -> Optional[set[str]]: """Get request IDs that are currently pending handshake completion.""" diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index bdb851752bd..0a3ea2a6112 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -153,17 +153,18 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.set_handshake_metadata(self.connector_worker.xfer_metadata) def get_finished(self, - finished_req_ids: set[str]) -> tuple[set[str], set[str]]: + finished_req_ids: set[str]) -> KVTransferFinishedResult: """Get the finished recving and sending requests.""" assert self.connector_worker is not None - result = self.connector_worker.get_finished() - # Store pending handshake for the new method to retrieve - self._pending_handshake_req_ids = result.pending_handshake - return result.finished_sending, result.finished_recving + return self.connector_worker.get_finished() def get_pending_handshake_req_ids(self) -> Optional[set[str]]: """Get request IDs that are currently pending handshake completion.""" - return getattr(self, '_pending_handshake_req_ids', None) + if self.connector_worker is not None: + result = self.connector_worker.get_finished() + return (result.pending_handshake + if result.pending_handshake else None) + return None def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: @@ -188,7 +189,7 @@ def set_handshake_metadata(self, handshake_metadata): logger.debug("Setting handshake metadata for NIXL connector: %s", handshake_metadata) assert self.connector_worker is not None - return super().set_handshake_metadata(handshake_metadata) + self._handshake_metadata = handshake_metadata class NixlConnectorScheduler: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index bea7d17beaf..7713b388298 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1554,8 +1554,11 @@ def get_finished_kv_transfers( scheduler_output: "SchedulerOutput", ) -> tuple[Optional[set[str]], Optional[set[str]]]: if has_kv_transfer_group(): - return get_kv_transfer_group().get_finished( + result = get_kv_transfer_group().get_finished( scheduler_output.finished_req_ids) + return ( + result.finished_sending if result.finished_sending else None, + result.finished_recving if result.finished_recving else None) return None, None @staticmethod From 6426e3ffa71cf4a4eb91230d5b6e47ccbe332713 Mon Sep 17 00:00:00 2001 From: Will Eaton Date: Mon, 16 Jun 2025 15:51:28 -0400 Subject: [PATCH 08/33] fixup popping requests Signed-off-by: Will Eaton --- .../kv_transfer/kv_connector/v1/nixl_connector.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 0a3ea2a6112..5b5cdd1ccce 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -452,13 +452,18 @@ def on_handshake_complete(fut: Future): fut.result() # This will raise if the handshake failed logger.debug("Handshake succeeded for engine %s", engine_id) with self._lock: - # Remove from futures dict - requests will remain pending - # and be handled by scheduler retry logic + # Remove from futures dict if engine_id in self._handshake_futures: del self._handshake_futures[engine_id] - logger.debug("Handshake completed for engine %s. " - "Pending requests will be retried by" \ - "scheduler.", engine_id) + # Clear pending requests - they are no longer pending + # handshake and will be processed normally by the scheduler + if engine_id in self._pending_requests: + completed_reqs = self._pending_requests[engine_id] + del self._pending_requests[engine_id] + logger.debug( + "Handshake completed for engine %s. " + "Cleared %d requests from pending state.", + engine_id, len(completed_reqs)) except Exception as e: logger.warning("Handshake failed for engine %s: %s", engine_id, e) From 29885af2d9dbe4df72c86db7e85be8bb3b552ba1 Mon Sep 17 00:00:00 2001 From: Will Eaton Date: Mon, 16 Jun 2025 16:34:07 -0400 Subject: [PATCH 09/33] debug logging Signed-off-by: Will Eaton --- .../kv_transfer/kv_connector/v1/nixl_connector.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 5b5cdd1ccce..9cd6f01eb98 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -448,6 +448,8 @@ def _start_handshake(self, engine_id: str, host: str, port: int): # Set up callback to handle completion def on_handshake_complete(fut: Future): + logger.debug("Handshake callback triggered for engine %s", + engine_id) try: fut.result() # This will raise if the handshake failed logger.debug("Handshake succeeded for engine %s", engine_id) @@ -455,14 +457,16 @@ def on_handshake_complete(fut: Future): # Remove from futures dict if engine_id in self._handshake_futures: del self._handshake_futures[engine_id] - # Clear pending requests - they are no longer pending - # handshake and will be processed normally by the scheduler + # The scheduler will retry them on the next cycle and + # they'll be processed normally since the remote agent + # is now registered. if engine_id in self._pending_requests: completed_reqs = self._pending_requests[engine_id] del self._pending_requests[engine_id] logger.debug( "Handshake completed for engine %s. " - "Cleared %d requests from pending state.", + "Cleared %d requests from pending - " \ + "scheduler to retry", engine_id, len(completed_reqs)) except Exception as e: logger.warning("Handshake failed for engine %s: %s", engine_id, @@ -485,6 +489,7 @@ def _nixl_handshake(self, host: str, port: int): """Do a NIXL handshake with a remote instance.""" start_time = time.perf_counter() + logger.debug("Starting NIXL handshake with %s:%s", host, port) # TODO: make the scheme dynamic, and/or implement https on both sides. url = build_uri("http", host, port, path="get_kv_connector_metadata") @@ -492,7 +497,9 @@ def _nixl_handshake(self, host: str, port: int): try: req = Request(url) + logger.debug("About to send HTTP request to %s", url) with urlopen(req, timeout=5.0) as response: + logger.debug("Received HTTP response from %s", url) response_data = response.read().decode('utf-8') res = json.loads(response_data) logger.debug("NIXL handshake response: %s", res) @@ -541,6 +548,8 @@ def _nixl_handshake(self, host: str, port: int): "Received None metadata from %s:%s, skipping NIXL handshake", host, port) + logger.debug("NIXL handshake method completed for %s:%s", host, port) + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data in nixl.""" From 504bba72111d428e20ae74ba8447e01ca89b1d8d Mon Sep 17 00:00:00 2001 From: Will Eaton Date: Mon, 16 Jun 2025 21:33:51 -0400 Subject: [PATCH 10/33] working checkpoint Signed-off-by: Will Eaton --- vllm/config.py | 5 +++ .../kv_connector/v1/nixl_connector.py | 45 ++++++++++--------- vllm/v1/engine/core.py | 4 +- vllm/v1/utils.py | 3 +- 4 files changed, 32 insertions(+), 25 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 30bc7fc734e..3d389d4581f 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4510,6 +4510,11 @@ def __post_init__(self): if self.kv_events_config is not None: # Hybrid KV cache manager is not compatible with KV events. self.scheduler_config.disable_hybrid_kv_cache_manager = True + + if (self.kv_transfer_config is not None + and self.kv_transfer_config.is_kv_transfer_instance): + from collections import defaultdict + self.cache_config.transfer_handshake_metadata = defaultdict(dict) def update_sizes_for_sequence_parallelism(self, possible_sizes: list) -> list: diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 9cd6f01eb98..c092c9fbc6d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -457,16 +457,14 @@ def on_handshake_complete(fut: Future): # Remove from futures dict if engine_id in self._handshake_futures: del self._handshake_futures[engine_id] - # The scheduler will retry them on the next cycle and - # they'll be processed normally since the remote agent - # is now registered. if engine_id in self._pending_requests: completed_reqs = self._pending_requests[engine_id] del self._pending_requests[engine_id] + for req_id, meta in completed_reqs: + self._ready_requests.put((req_id, meta)) logger.debug( "Handshake completed for engine %s. " - "Cleared %d requests from pending - " \ - "scheduler to retry", + "Moved %d requests to ready queue for processing", engine_id, len(completed_reqs)) except Exception as e: logger.warning("Handshake failed for engine %s: %s", engine_id, @@ -507,6 +505,10 @@ def _nixl_handshake(self, host: str, port: int): logger.error("Failed to fetch metadata from %s: %s", url, e) raise + if res is None: + logger.warning("Remote server returned None metadata, skipping handshake") + raise RuntimeError("Remote server returned None metadata") + remote_tp_size = len(res.keys()) # Default case is that the remote TP size is 1, so we can # directly access the metadata. @@ -525,6 +527,7 @@ def _nixl_handshake(self, host: str, port: int): if metadata_bytes is not None: # Reconstruct NixlAgentMetadata from JSON response # agent_metadata is base64-encoded binary data, not msgpack + tp_data.pop("agent_metadata", None) metadata = NixlAgentMetadata( agent_metadata=base64.b64decode(metadata_bytes), **tp_data) @@ -547,6 +550,7 @@ def _nixl_handshake(self, host: str, port: int): logger.warning( "Received None metadata from %s:%s, skipping NIXL handshake", host, port) + raise RuntimeError("Remote server does not support NIXL") logger.debug("NIXL handshake method completed for %s:%s", host, port) @@ -834,12 +838,6 @@ def get_finished(self) -> KVTransferFinishedResult: finished_recving=done_recving, pending_handshake=pending_handshake) - if not local_result.is_empty(): - logger.debug( - "Rank %s, get_finished: %s requests done sending, " - "%s requests done recving, %s pending handshake", self.tp_rank, - len(done_sending), len(done_recving), len(pending_handshake)) - if self.world_size == 1: return local_result @@ -939,26 +937,31 @@ def _pop_done_transfers( return done_req_ids def _process_ready_requests(self): - """Process requests that are ready after handshake completion. - - Note: With scheduler-based retry logic, this method is simplified - as automatic retries are handled by the scheduler. - """ - # Clear any remaining items in the ready queue to prevent memory leaks + """Process requests that are ready after handshake completion.""" + processed_count = 0 while True: try: - self._ready_requests.get_nowait() + req_id, meta = self._ready_requests.get_nowait() + logger.debug("Processing ready request %s for engine %s", + req_id, meta.remote_engine_id) + self._read_blocks( + request_id=req_id, + dst_engine_id=meta.remote_engine_id, + local_block_ids=meta.local_block_ids, + remote_block_ids=meta.remote_block_ids, + ) + processed_count += 1 except queue.Empty: break + + if processed_count > 0: + logger.debug("Processed %d ready requests", processed_count) def start_load_kv(self, metadata: NixlConnectorMetadata): """ Start loading by triggering non-blocking nixl_xfer. We check for these trnxs to complete in each step(). """ - logger.debug("start_load_kv called with %d requests", - len(metadata.requests)) - for req_id, meta in metadata.requests.items(): logger.debug( "start_load_kv for request %s from remote engine %s. " diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index b93a1dc6861..ac697441354 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -175,9 +175,7 @@ def _initialize_kv_caches( # Collect KV connector xfer metadata from workers # (after KV cache registration) transfer_handshake_metadata = ( - self.model_executor.get_kv_connector_handshake_metadata() - if self.vllm_config.cache_config.transfer_handshake_metadata else - None) + self.model_executor.get_kv_connector_handshake_metadata()) elapsed = time.time() - start logger.info(("init engine (profile, create kv cache, " diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 0b8d35b23b5..540e056f0b0 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -541,7 +541,8 @@ def wait_for_engine_startup( logger.debug( "Received transfer handshake metadata from engine %s: %s", eng_index, txfer_metadata) - # Merge the received metadata with existing cache config + if cache_config.transfer_handshake_metadata is None: + cache_config.transfer_handshake_metadata = defaultdict(dict) for tp_rank, dp_dict in txfer_metadata.items(): for dp_rank, metadata in dp_dict.items(): cache_config.transfer_handshake_metadata[tp_rank][ From dd49c96ecc13fb2c936d31cb3edb59c0d572ec1a Mon Sep 17 00:00:00 2001 From: Will Eaton Date: Mon, 16 Jun 2025 21:41:23 -0400 Subject: [PATCH 11/33] fix unreachable Signed-off-by: Will Eaton --- vllm/distributed/kv_transfer/kv_connector/v1/base.py | 4 +--- vllm/v1/worker/gpu_worker.py | 2 -- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index f1c1bf7edd9..9adbdc0de07 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -113,9 +113,7 @@ class KVConnectorMetadata: Abstract Metadata used to communicate between the Scheduler KVConnector and Worker KVConnector. """ - - def __init__(self): - pass + pass class KVConnectorHandshakeMetadata( diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 4aa6234043f..bbe38edee0a 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -254,8 +254,6 @@ def get_kv_connector_handshake_metadata(self) -> Optional[dict]: dp_rank = self.vllm_config.parallel_config.data_parallel_rank_local return {tp_rank: {dp_rank: msgspec.to_builtins(metadata)}} - return None - def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: return self.model_runner.get_kv_cache_spec() From f65c1b3d2a25308bcba5ca55be9ce663fef7f9ed Mon Sep 17 00:00:00 2001 From: Will Eaton Date: Tue, 17 Jun 2025 10:09:41 -0400 Subject: [PATCH 12/33] remove uv.lock Signed-off-by: Will Eaton --- .gitignore | 1 - 1 file changed, 1 deletion(-) diff --git a/.gitignore b/.gitignore index 45ad4584f4b..e49d1d6ba61 100644 --- a/.gitignore +++ b/.gitignore @@ -202,4 +202,3 @@ shellcheck*/ # Ingore moe/marlin_moe gen code csrc/moe/marlin_moe_wna16/kernel_* -uv.lock From efd655b40e0d7d631ec6f8da51d3ad0ad9c826bd Mon Sep 17 00:00:00 2001 From: Will Eaton Date: Tue, 17 Jun 2025 10:10:12 -0400 Subject: [PATCH 13/33] remove unused Signed-off-by: Will Eaton --- vllm/engine/llm_engine.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index dced747fae9..8fccf9bd2aa 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -16,7 +16,6 @@ import torch from typing_extensions import TypeVar -from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorHandshakeMetadata import vllm.envs as envs from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, ObservabilityConfig, ParallelConfig, SchedulerConfig, From dba38354f4b57512bada3366f2ab937fe3593665 Mon Sep 17 00:00:00 2001 From: Will Eaton Date: Tue, 17 Jun 2025 22:16:27 -0400 Subject: [PATCH 14/33] flip protocol; fix scheduling order bug Signed-off-by: Will Eaton --- vllm/config.py | 2 +- .../kv_connector/v1/nixl_connector.py | 43 ++++++++++++------- vllm/entrypoints/openai/api_server.py | 23 +++++++++- vllm/v1/core/sched/scheduler.py | 11 ++--- vllm/v1/engine/core.py | 6 ++- vllm/v1/utils.py | 8 ++-- vllm/v1/worker/gpu_worker.py | 2 +- 7 files changed, 66 insertions(+), 29 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 3d389d4581f..7e71cf6029b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1515,7 +1515,7 @@ class CacheConfig: transfer_handshake_metadata: Optional[dict[int, dict[ int, KVConnectorHandshakeMetadata]]] = field(default=None, init=False) - """Metadata for the KV connector handshake.""" + """Metadata for the KV connector handshake. Structure: dp_rank -> tp_rank -> metadata""" def compute_hash(self) -> str: """ diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index c092c9fbc6d..8e836f04c63 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -489,8 +489,9 @@ def _nixl_handshake(self, host: str, port: int): start_time = time.perf_counter() logger.debug("Starting NIXL handshake with %s:%s", host, port) - # TODO: make the scheme dynamic, and/or implement https on both sides. - url = build_uri("http", host, port, path="get_kv_connector_metadata") + # Use the new endpoint scheme to filter by dp_rank and tp_rank + # Default to dp_rank 0 and use current tp_rank for optimal filtering + url = build_uri("http", host, port, path=f"get_kv_connector_metadata/0/{self.tp_rank}") logger.debug("Querying metadata on path: %s", url) try: @@ -509,20 +510,29 @@ def _nixl_handshake(self, host: str, port: int): logger.warning("Remote server returned None metadata, skipping handshake") raise RuntimeError("Remote server returned None metadata") - remote_tp_size = len(res.keys()) - # Default case is that the remote TP size is 1, so we can - # directly access the metadata. - tp_data = res.get(str(self.tp_rank), {}).get("0", {}) - metadata_bytes = tp_data.get("agent_metadata", None) - - # Handshake only with the other TP remote the current local rank will - # pull from. With homogeneous TP it happens to be the same rank_i. - tp_ratio = self._tp_size[self.engine_id] // remote_tp_size - p_remote_rank = self.tp_rank // tp_ratio - if p_remote_rank > 0: - metadata_bytes = res.get(str(p_remote_rank), - {}).get("0", - {}).get("agent_metadata", None) + # With filtered response from new endpoint, we get: {dp_rank: {tp_rank: metadata}} + # Since we filtered by dp_rank=0 and tp_rank=self.tp_rank, extract directly + if "0" in res and str(self.tp_rank) in res["0"]: + tp_data = res["0"][str(self.tp_rank)] + metadata_bytes = tp_data.get("agent_metadata", None) + p_remote_rank = self.tp_rank # Use current tp_rank for filtered response + else: + # Fallback to unfiltered endpoint for heterogeneous TP cases + url_fallback = build_uri("http", host, port, path="get_kv_connector_metadata") + logger.debug("Using fallback unfiltered endpoint: %s", url_fallback) + req = Request(url_fallback) + with urlopen(req, timeout=5.0) as response: + response_data = response.read().decode('utf-8') + res = json.loads(response_data) + + dp_data = res.get("0", {}) + remote_tp_size = len(dp_data.keys()) if dp_data else 1 + + # Handle heterogeneous TP mapping + tp_ratio = self._tp_size[self.engine_id] // remote_tp_size + p_remote_rank = self.tp_rank // tp_ratio + tp_data = dp_data.get(str(p_remote_rank), {}) + metadata_bytes = tp_data.get("agent_metadata", None) if metadata_bytes is not None: # Reconstruct NixlAgentMetadata from JSON response @@ -962,6 +972,7 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): Start loading by triggering non-blocking nixl_xfer. We check for these trnxs to complete in each step(). """ + for req_id, meta in metadata.requests.items(): logger.debug( "start_load_kv for request %s from remote engine %s. " diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 7a20eb68468..13c46a4538d 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -869,8 +869,29 @@ async def show_server_info(raw_request: Request): return JSONResponse(content=server_info) @router.get("/get_kv_connector_metadata") - async def get_kv_connector_metadata(raw_request: Request): + @router.get("/get_kv_connector_metadata/{dp_rank}") + @router.get("/get_kv_connector_metadata/{dp_rank}/{tp_rank}") + async def get_kv_connector_metadata(raw_request: Request, dp_rank: int = None, tp_rank: int = None): kv_connector_metadata = raw_request.app.state.vllm_config.cache_config.transfer_handshake_metadata + + if kv_connector_metadata is None: + return JSONResponse(content=None) + + # Filter by dp_rank if specified + if dp_rank is not None: + if dp_rank not in kv_connector_metadata: + return JSONResponse(content={}) + dp_data = kv_connector_metadata[dp_rank] + + # Filter by tp_rank if also specified + if tp_rank is not None: + if tp_rank not in dp_data: + return JSONResponse(content={}) + return JSONResponse(content={dp_rank: {tp_rank: dp_data[tp_rank]}}) + else: + return JSONResponse(content={dp_rank: dp_data}) + + # Return all metadata if no filtering return JSONResponse(content=kv_connector_metadata) @router.post("/reset_prefix_cache") diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 9fe2557b673..2b53297b80e 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -719,11 +719,7 @@ def update_from_output( for request in self.running: req_id = request.request_id num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0) - if num_tokens_scheduled == 0: - # The request was not scheduled in this step. - new_running.append(request) - continue - + # Check if this request is pending handshake and needs to reschedule if (pending_handshake_req_ids and req_id in pending_handshake_req_ids): @@ -734,6 +730,11 @@ def update_from_output( num_tokens_to_reschedule -= request.num_computed_tokens new_running.append(request) continue + + if num_tokens_scheduled == 0: + # The request was not scheduled in this step. + new_running.append(request) + continue req_index = model_runner_output.req_id_to_index[req_id] generated_token_ids = sampled_token_ids[req_index] diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index ac697441354..020d8491f07 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -458,7 +458,11 @@ def _perform_handshake( content = {} for worker_dict in self.transfer_handshake_metadata: if worker_dict is not None: - content.update(worker_dict) + # Deep merge the nested dictionaries instead of overwriting + for dp_rank, tp_dict in worker_dict.items(): + if dp_rank not in content: + content[dp_rank] = {} + content[dp_rank].update(tp_dict) handshake_message["transfer_handshake_metadata"] = content handshake_socket.send(msgspec.msgpack.encode(handshake_message)) diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 540e056f0b0..d0803209bf6 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -543,10 +543,10 @@ def wait_for_engine_startup( eng_index, txfer_metadata) if cache_config.transfer_handshake_metadata is None: cache_config.transfer_handshake_metadata = defaultdict(dict) - for tp_rank, dp_dict in txfer_metadata.items(): - for dp_rank, metadata in dp_dict.items(): - cache_config.transfer_handshake_metadata[tp_rank][ - dp_rank] = metadata + for dp_rank, tp_dict in txfer_metadata.items(): + for tp_rank, metadata in tp_dict.items(): + cache_config.transfer_handshake_metadata[dp_rank][ + tp_rank] = metadata start_pending[0 if local else 1] -= 1 engine.state = CoreEngineState.READY diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index bbe38edee0a..8b3b7a5f6b2 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -252,7 +252,7 @@ def get_kv_connector_handshake_metadata(self) -> Optional[dict]: tp_rank = get_tp_group().rank_in_group dp_rank = self.vllm_config.parallel_config.data_parallel_rank_local - return {tp_rank: {dp_rank: msgspec.to_builtins(metadata)}} + return {dp_rank: {tp_rank: msgspec.to_builtins(metadata)}} def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: return self.model_runner.get_kv_cache_spec() From 85855d17da88fb22344248f962f78bcbde6471d4 Mon Sep 17 00:00:00 2001 From: Will Eaton Date: Wed, 18 Jun 2025 13:54:07 -0400 Subject: [PATCH 15/33] fix bug in case of no kvconnectorgroup Signed-off-by: Will Eaton --- vllm/v1/worker/gpu_worker.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 8b3b7a5f6b2..26c61093765 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -18,6 +18,7 @@ set_custom_all_reduce) from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized, get_kv_transfer_group, + has_kv_transfer_group, is_v1_kv_transfer_group) from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.logger import init_logger @@ -236,6 +237,9 @@ def determine_available_memory(self) -> int: def get_kv_connector_handshake_metadata(self) -> Optional[dict]: """Get KV connector metadata from this worker if available.""" + if not has_kv_transfer_group(): + return None + connector = get_kv_transfer_group() if not is_v1_kv_transfer_group(connector): logger.warning("The KV connector is not a v1 connector. " From 1fc1af4909bda8532ea387721885d39c83f4a383 Mon Sep 17 00:00:00 2001 From: Will Eaton Date: Wed, 18 Jun 2025 15:49:54 -0400 Subject: [PATCH 16/33] actually use handshake timeout; simplify route Signed-off-by: Will Eaton --- .../kv_connector/v1/nixl_connector.py | 52 ++++++++++++------- vllm/entrypoints/openai/api_server.py | 37 ++++++------- vllm/envs.py | 6 +++ vllm/v1/engine/core.py | 6 +-- 4 files changed, 61 insertions(+), 40 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 8e836f04c63..6a2c21d1cf0 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -12,7 +12,8 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Optional from urllib.error import HTTPError, URLError -from urllib.request import Request, urlopen +from urllib.request import Request as URLRequest +from urllib.request import urlopen import torch @@ -398,7 +399,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # Background handshake threads for remote engines self._executor = ThreadPoolExecutor( - max_workers=4, thread_name_prefix="nixl-handshake") + max_workers=1, thread_name_prefix="nixl-handshake") # Thread results for handshake completion tracking self._handshake_futures: dict[str, Future] = {} self._pending_requests: dict[str, list[tuple[str, ReqMeta]]] = {} @@ -491,13 +492,17 @@ def _nixl_handshake(self, host: str, port: int): # Use the new endpoint scheme to filter by dp_rank and tp_rank # Default to dp_rank 0 and use current tp_rank for optimal filtering - url = build_uri("http", host, port, path=f"get_kv_connector_metadata/0/{self.tp_rank}") + url = build_uri("http", + host, + port, + path=f"get_kv_connector_metadata/0/{self.tp_rank}") logger.debug("Querying metadata on path: %s", url) try: - req = Request(url) + req = URLRequest(url) logger.debug("About to send HTTP request to %s", url) - with urlopen(req, timeout=5.0) as response: + with urlopen(req, + timeout=envs.VLLM_NIXL_HANDSHAKE_TIMEOUT) as response: logger.debug("Received HTTP response from %s", url) response_data = response.read().decode('utf-8') res = json.loads(response_data) @@ -507,27 +512,36 @@ def _nixl_handshake(self, host: str, port: int): raise if res is None: - logger.warning("Remote server returned None metadata, skipping handshake") + logger.warning( + "Remote server returned None metadata, skipping handshake") raise RuntimeError("Remote server returned None metadata") - # With filtered response from new endpoint, we get: {dp_rank: {tp_rank: metadata}} - # Since we filtered by dp_rank=0 and tp_rank=self.tp_rank, extract directly + # With filtered response from new endpoint, we get: + # {dp_rank: {tp_rank: metadata}} + # Since we filtered by dp_rank=0 and tp_rank=self.tp_rank, + # extract directly. if "0" in res and str(self.tp_rank) in res["0"]: tp_data = res["0"][str(self.tp_rank)] metadata_bytes = tp_data.get("agent_metadata", None) - p_remote_rank = self.tp_rank # Use current tp_rank for filtered response + # use current tp_rank for filtered response + p_remote_rank = self.tp_rank else: # Fallback to unfiltered endpoint for heterogeneous TP cases - url_fallback = build_uri("http", host, port, path="get_kv_connector_metadata") - logger.debug("Using fallback unfiltered endpoint: %s", url_fallback) - req = Request(url_fallback) - with urlopen(req, timeout=5.0) as response: + url_fallback = build_uri("http", + host, + port, + path="get_kv_connector_metadata") + logger.debug("Using fallback unfiltered endpoint: %s", + url_fallback) + req = URLRequest(url_fallback) + with urlopen(req, + timeout=envs.VLLM_NIXL_HANDSHAKE_TIMEOUT) as response: response_data = response.read().decode('utf-8') res = json.loads(response_data) - + dp_data = res.get("0", {}) remote_tp_size = len(dp_data.keys()) if dp_data else 1 - + # Handle heterogeneous TP mapping tp_ratio = self._tp_size[self.engine_id] // remote_tp_size p_remote_rank = self.tp_rank // tp_ratio @@ -952,8 +966,8 @@ def _process_ready_requests(self): while True: try: req_id, meta = self._ready_requests.get_nowait() - logger.debug("Processing ready request %s for engine %s", - req_id, meta.remote_engine_id) + logger.debug("Processing ready request %s for engine %s", + req_id, meta.remote_engine_id) self._read_blocks( request_id=req_id, dst_engine_id=meta.remote_engine_id, @@ -963,7 +977,7 @@ def _process_ready_requests(self): processed_count += 1 except queue.Empty: break - + if processed_count > 0: logger.debug("Processed %d ready requests", processed_count) @@ -972,7 +986,7 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): Start loading by triggering non-blocking nixl_xfer. We check for these trnxs to complete in each step(). """ - + for req_id, meta in metadata.requests.items(): logger.debug( "start_load_kv for request %s from remote engine %s. " diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 13c46a4538d..6707aeb8dcf 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -871,28 +871,29 @@ async def show_server_info(raw_request: Request): @router.get("/get_kv_connector_metadata") @router.get("/get_kv_connector_metadata/{dp_rank}") @router.get("/get_kv_connector_metadata/{dp_rank}/{tp_rank}") - async def get_kv_connector_metadata(raw_request: Request, dp_rank: int = None, tp_rank: int = None): - kv_connector_metadata = raw_request.app.state.vllm_config.cache_config.transfer_handshake_metadata - - if kv_connector_metadata is None: - return JSONResponse(content=None) - - # Filter by dp_rank if specified + async def get_kv_connector_metadata(raw_request: Request, + dp_rank: Optional[int] = None, + tp_rank: Optional[int] = None): + kv_meta: Optional[dict[str, dict[str, dict[str, Any]]]] = ( + raw_request.app.state.vllm_config.cache_config. + transfer_handshake_metadata) + + if kv_meta is None: + return None + if dp_rank is not None: - if dp_rank not in kv_connector_metadata: - return JSONResponse(content={}) - dp_data = kv_connector_metadata[dp_rank] - - # Filter by tp_rank if also specified + if dp_rank not in kv_meta: + return {} + dp_data = kv_meta[dp_rank] + if tp_rank is not None: if tp_rank not in dp_data: - return JSONResponse(content={}) - return JSONResponse(content={dp_rank: {tp_rank: dp_data[tp_rank]}}) + return {} + return {dp_rank: {tp_rank: dp_data[tp_rank]}} else: - return JSONResponse(content={dp_rank: dp_data}) - - # Return all metadata if no filtering - return JSONResponse(content=kv_connector_metadata) + return {dp_rank: dp_data} + + return kv_meta @router.post("/reset_prefix_cache") async def reset_prefix_cache(raw_request: Request): diff --git a/vllm/envs.py b/vllm/envs.py index 80c5f289bba..1bff67b8212 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -122,6 +122,7 @@ VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost" VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557 + VLLM_NIXL_HANDSHAKE_TIMEOUT: float = 2.0 VLLM_ALL2ALL_BACKEND: str = "naive" VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840 VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1 @@ -840,6 +841,11 @@ def get_vllm_port() -> Optional[int]: "VLLM_NIXL_SIDE_CHANNEL_PORT": lambda: int(os.getenv("VLLM_NIXL_SIDE_CHANNEL_PORT", "5557")), + # Timeout in seconds for NIXL HTTP handshake requests. + # Default is 2 seconds + "VLLM_NIXL_HANDSHAKE_TIMEOUT": + lambda: float(os.getenv("VLLM_NIXL_HANDSHAKE_TIMEOUT", "2.0")), + # all2all backend for vllm's expert parallel communication # Available options: # - "naive": naive all2all implementation using all-reduce diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 020d8491f07..9c8c2dfa8f1 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -453,12 +453,12 @@ def _perform_handshake( if hasattr(self, 'transfer_handshake_metadata' ) and self.transfer_handshake_metadata: # self.transfer_handshake_metadata is list of dicts from workers - # Each dict already has structure {tp_rank: {dp_rank: metadata}} + # Each dict already has structure {dp_rank: {tp_rank: metadata}} # Merge all worker dicts into a single dict - content = {} + content: dict[str, dict[str, dict[str, Any]]] = {} for worker_dict in self.transfer_handshake_metadata: if worker_dict is not None: - # Deep merge the nested dictionaries instead of overwriting + # Deep merge nested dictionaries instead of overwrite for dp_rank, tp_dict in worker_dict.items(): if dp_rank not in content: content[dp_rank] = {} From 9d8c15c2c241f9ca92960006272c4fda6ec6fc31 Mon Sep 17 00:00:00 2001 From: Will Eaton Date: Thu, 19 Jun 2025 08:49:24 -0400 Subject: [PATCH 17/33] revert back to working het TP logic Signed-off-by: Will Eaton --- .../kv_connector/v1/nixl_connector.py | 114 +++++++----------- 1 file changed, 45 insertions(+), 69 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 6a2c21d1cf0..890ff8108c0 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -490,23 +490,14 @@ def _nixl_handshake(self, host: str, port: int): start_time = time.perf_counter() logger.debug("Starting NIXL handshake with %s:%s", host, port) - # Use the new endpoint scheme to filter by dp_rank and tp_rank - # Default to dp_rank 0 and use current tp_rank for optimal filtering - url = build_uri("http", - host, - port, - path=f"get_kv_connector_metadata/0/{self.tp_rank}") - logger.debug("Querying metadata on path: %s", url) + url = build_uri("http", host, port, path="get_kv_connector_metadata") try: req = URLRequest(url) - logger.debug("About to send HTTP request to %s", url) with urlopen(req, timeout=envs.VLLM_NIXL_HANDSHAKE_TIMEOUT) as response: - logger.debug("Received HTTP response from %s", url) response_data = response.read().decode('utf-8') res = json.loads(response_data) - logger.debug("NIXL handshake response: %s", res) except (URLError, HTTPError) as e: logger.error("Failed to fetch metadata from %s: %s", url, e) raise @@ -516,65 +507,50 @@ def _nixl_handshake(self, host: str, port: int): "Remote server returned None metadata, skipping handshake") raise RuntimeError("Remote server returned None metadata") - # With filtered response from new endpoint, we get: - # {dp_rank: {tp_rank: metadata}} - # Since we filtered by dp_rank=0 and tp_rank=self.tp_rank, - # extract directly. - if "0" in res and str(self.tp_rank) in res["0"]: - tp_data = res["0"][str(self.tp_rank)] - metadata_bytes = tp_data.get("agent_metadata", None) - # use current tp_rank for filtered response - p_remote_rank = self.tp_rank - else: - # Fallback to unfiltered endpoint for heterogeneous TP cases - url_fallback = build_uri("http", - host, - port, - path="get_kv_connector_metadata") - logger.debug("Using fallback unfiltered endpoint: %s", - url_fallback) - req = URLRequest(url_fallback) - with urlopen(req, - timeout=envs.VLLM_NIXL_HANDSHAKE_TIMEOUT) as response: - response_data = response.read().decode('utf-8') - res = json.loads(response_data) - - dp_data = res.get("0", {}) - remote_tp_size = len(dp_data.keys()) if dp_data else 1 - - # Handle heterogeneous TP mapping - tp_ratio = self._tp_size[self.engine_id] // remote_tp_size - p_remote_rank = self.tp_rank // tp_ratio - tp_data = dp_data.get(str(p_remote_rank), {}) - metadata_bytes = tp_data.get("agent_metadata", None) - - if metadata_bytes is not None: - # Reconstruct NixlAgentMetadata from JSON response - # agent_metadata is base64-encoded binary data, not msgpack - tp_data.pop("agent_metadata", None) - metadata = NixlAgentMetadata( - agent_metadata=base64.b64decode(metadata_bytes), **tp_data) - - # Register Remote agent. - logger.debug("About to register remote agent for engine %s", - metadata.engine_id) - pre_register = time.perf_counter() - self.add_remote_agent(metadata, remote_tp_rank=p_remote_rank) - agent_time = time.perf_counter() - logger.debug("Finished registering remote agent for engine %s", - metadata.engine_id) - - logger.debug("NIXL handshake: get metadata took: %s", - pre_register - start_time) - logger.debug("NIXL handshake: add agent took: %s", - agent_time - pre_register) - else: - # If metadata_bytes is None, it means the remote agent - # is not using NIXL, so we can skip the handshake. - logger.warning( - "Received None metadata from %s:%s, skipping NIXL handshake", - host, port) - raise RuntimeError("Remote server does not support NIXL") + # Get dp_rank 0 data (standard for disaggregated prefill-decode) + dp_data = res.get("0", {}) + if not dp_data: + raise RuntimeError("No metadata found for dp_rank 0") + + remote_tp_size = len(dp_data.keys()) + rank0_data = dp_data.get("0", {}) + if not rank0_data: + raise RuntimeError("No metadata found for remote rank 0") + + metadata_bytes = rank0_data.get("agent_metadata", None) + if metadata_bytes is None: + raise RuntimeError("No agent metadata found for remote rank 0") + + rank0_data_copy = rank0_data.copy() + rank0_data_copy.pop("agent_metadata", None) + rank0_metadata = NixlAgentMetadata( + agent_metadata=base64.b64decode(metadata_bytes), **rank0_data_copy) + + pre_register = time.perf_counter() + self.add_remote_agent(rank0_metadata, remote_tp_rank=0) + tp_ratio = self._tp_size[self.engine_id] // remote_tp_size + p_remote_rank = self.tp_rank // tp_ratio + + if p_remote_rank > 0: + p_rank_data = dp_data.get(str(p_remote_rank), {}) + if p_rank_data: + p_metadata_bytes = p_rank_data.get("agent_metadata", None) + if p_metadata_bytes: + p_rank_data_copy = p_rank_data.copy() + p_rank_data_copy.pop("agent_metadata", None) + p_metadata = NixlAgentMetadata( + agent_metadata=base64.b64decode(p_metadata_bytes), + **p_rank_data_copy) + self.add_remote_agent(p_metadata, remote_tp_rank=p_remote_rank) + + agent_time = time.perf_counter() + + logger.debug("Finished registering remote agent for engine %s", + rank0_metadata.engine_id) + logger.debug("NIXL handshake: get metadata took: %s", + pre_register - start_time) + logger.debug("NIXL handshake: add agent took: %s", + agent_time - pre_register) logger.debug("NIXL handshake method completed for %s:%s", host, port) From 83ec83a9a627ed6f51c157a0dd5a4eb03a4df5fe Mon Sep 17 00:00:00 2001 From: Will Eaton Date: Fri, 27 Jun 2025 15:52:21 -0400 Subject: [PATCH 18/33] move nixl sidechannel to own entrypoint Signed-off-by: Will Eaton push missing side channel server Signed-off-by: Will Eaton --- vllm/entrypoints/nixl_side_channel_server.py | 112 +++++++++++++++++++ vllm/entrypoints/openai/api_server.py | 39 +++---- 2 files changed, 125 insertions(+), 26 deletions(-) create mode 100644 vllm/entrypoints/nixl_side_channel_server.py diff --git a/vllm/entrypoints/nixl_side_channel_server.py b/vllm/entrypoints/nixl_side_channel_server.py new file mode 100644 index 00000000000..5004f5cde3c --- /dev/null +++ b/vllm/entrypoints/nixl_side_channel_server.py @@ -0,0 +1,112 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import asyncio +import threading +from typing import Any, Optional + +import uvicorn +from fastapi import FastAPI + +from vllm import envs +from vllm.config import VllmConfig +from vllm.entrypoints.launcher import serve_http +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class NixlSideChannelServer: + + def __init__(self, vllm_config: VllmConfig, host: str, port: int): + self.vllm_config = vllm_config + self.host = host + self.port = port + self.app = FastAPI(title="vLLM NIXL Side Channel Server") + self.server = None + self.server_thread = None + self._setup_routes() + + def _setup_routes(self): + + @self.app.get("/get_kv_connector_metadata") + @self.app.get("/get_kv_connector_metadata/{dp_rank}") + @self.app.get("/get_kv_connector_metadata/{dp_rank}/{tp_rank}") + async def get_kv_connector_metadata(dp_rank: Optional[int] = None, + tp_rank: Optional[int] = None): + kv_meta: Optional[dict[str, dict[str, dict[str, Any]]]] = ( + self.vllm_config.cache_config.transfer_handshake_metadata) + + if kv_meta is None: + return None + + if dp_rank is not None: + if dp_rank not in kv_meta: + return {} + dp_data = kv_meta[dp_rank] + + if tp_rank is not None: + if tp_rank not in dp_data: + return {} + return {dp_rank: {tp_rank: dp_data[tp_rank]}} + else: + return {dp_rank: dp_data} + + return kv_meta + + async def start_async(self): + if self.server is not None: + logger.warning("Side channel server is already running") + return + + logger.info("Starting NIXL side channel server on %s:%s", + self.host, self.port) + + # use uvicorn directly to avoid dependency on engine_client + config = uvicorn.Config( + app=self.app, + host=self.host, + port=self.port, + log_level="info", + access_log=True, + ) + self.server = uvicorn.Server(config) + + # start the server in a background task + asyncio.create_task(self.server.serve()) + logger.info("NIXL side channel server started successfully") + + async def stop_async(self): + if self.server is not None: + logger.info("Stopping NIXL side channel server") + try: + self.server.should_exit = True + await asyncio.sleep(1) # give it time to shutdown + except Exception as e: + logger.warning("Error during side channel server shutdown: %s", e) + self.server = None + logger.info("NIXL side channel server stopped") + + +def should_start_nixl_side_channel_server(vllm_config: VllmConfig) -> bool: + if vllm_config.kv_transfer_config is None: + return False + + return vllm_config.kv_transfer_config.kv_connector == "NixlConnector" + + +async def start_nixl_side_channel_server_if_needed( + vllm_config: VllmConfig) -> Optional[NixlSideChannelServer]: + if not should_start_nixl_side_channel_server(vllm_config): + return None + + side_channel_host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST + side_channel_port = envs.VLLM_NIXL_SIDE_CHANNEL_PORT + + logger.info("Starting NIXL side channel metadata server on %s:%d", + side_channel_host, side_channel_port) + + server = NixlSideChannelServer( + vllm_config, side_channel_host, side_channel_port) + await server.start_async() + return server diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 3a4ddd82c08..21a20ac5936 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -93,6 +93,8 @@ from vllm.entrypoints.openai.serving_transcription import ( OpenAIServingTranscription, OpenAIServingTranslation) from vllm.entrypoints.openai.tool_parsers import ToolParserManager +from vllm.entrypoints.nixl_side_channel_server import ( + start_nixl_side_channel_server_if_needed) from vllm.entrypoints.utils import (cli_env_setup, load_aware_call, with_cancellation) from vllm.logger import init_logger @@ -916,32 +918,6 @@ async def show_server_info(raw_request: Request): server_info = {"vllm_config": str(raw_request.app.state.vllm_config)} return JSONResponse(content=server_info) - @router.get("/get_kv_connector_metadata") - @router.get("/get_kv_connector_metadata/{dp_rank}") - @router.get("/get_kv_connector_metadata/{dp_rank}/{tp_rank}") - async def get_kv_connector_metadata(raw_request: Request, - dp_rank: Optional[int] = None, - tp_rank: Optional[int] = None): - kv_meta: Optional[dict[str, dict[str, dict[str, Any]]]] = ( - raw_request.app.state.vllm_config.cache_config. - transfer_handshake_metadata) - - if kv_meta is None: - return None - - if dp_rank is not None: - if dp_rank not in kv_meta: - return {} - dp_data = kv_meta[dp_rank] - - if tp_rank is not None: - if tp_rank not in dp_data: - return {} - return {dp_rank: {tp_rank: dp_data[tp_rank]}} - else: - return {dp_rank: dp_data} - - return kv_meta @router.post("/reset_prefix_cache") async def reset_prefix_cache(raw_request: Request): @@ -1474,6 +1450,12 @@ async def run_server_worker(listen_address, vllm_config = await engine_client.get_vllm_config() await init_app_state(engine_client, vllm_config, app.state, args) + nixl_side_channel_server = None + try: + nixl_side_channel_server = await start_nixl_side_channel_server_if_needed(vllm_config) + except Exception as e: + logger.warning("Failed to start NIXL side channel server: %s", e) + logger.info("Starting vLLM API server %d on %s", server_index, listen_address) shutdown_task = await serve_http( @@ -1498,6 +1480,11 @@ async def run_server_worker(listen_address, try: await shutdown_task finally: + if nixl_side_channel_server is not None: + try: + await nixl_side_channel_server.stop_async() + except Exception as e: + logger.warning("Error stopping NIXL side channel server: %s", e) sock.close() From 9a86b3753d765ca671d82f9482cb67b68edc808a Mon Sep 17 00:00:00 2001 From: Will Eaton Date: Mon, 30 Jun 2025 10:25:47 -0400 Subject: [PATCH 19/33] allow configurable handshake strategy --- .../unit/test_nixl_handshake_strategies.py | 329 ++++++++++++++++++ .../kv_connector/v1/nixl_connector.py | 298 ++++++++++++---- vllm/entrypoints/nixl_side_channel_server.py | 11 +- vllm/envs.py | 5 + 4 files changed, 575 insertions(+), 68 deletions(-) create mode 100644 tests/v1/kv_connector/unit/test_nixl_handshake_strategies.py diff --git a/tests/v1/kv_connector/unit/test_nixl_handshake_strategies.py b/tests/v1/kv_connector/unit/test_nixl_handshake_strategies.py new file mode 100644 index 00000000000..8db1c8bff75 --- /dev/null +++ b/tests/v1/kv_connector/unit/test_nixl_handshake_strategies.py @@ -0,0 +1,329 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import base64 +import json +from unittest.mock import MagicMock, patch +from urllib.error import URLError + +import pytest + +from vllm import envs +from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( + HandshakeStrategy, HttpHandshakeStrategy, NixlAgentMetadata, + ZmqHandshakeStrategy) + + +class TestHandshakeStrategyAbstraction: + + def test_abstract_base_class(self): + with pytest.raises(TypeError): + HandshakeStrategy(None, 0, 1, 8080, "test-engine") + + def test_strategy_interface(self): + mock_nixl = MagicMock() + mock_add_agent = MagicMock() + + zmq_strategy = ZmqHandshakeStrategy( + mock_nixl, 0, 1, 8080, "test-engine", mock_add_agent) + assert hasattr(zmq_strategy, 'initiate_handshake') + assert hasattr(zmq_strategy, 'setup_listener') + assert hasattr(zmq_strategy, 'cleanup') + + http_strategy = HttpHandshakeStrategy( + mock_nixl, 0, 1, 8080, "test-engine", mock_add_agent) + assert hasattr(http_strategy, 'initiate_handshake') + assert hasattr(http_strategy, 'setup_listener') + assert hasattr(http_strategy, 'cleanup') + + +class TestZmqHandshakeStrategy: + + def create_test_metadata(self) -> NixlAgentMetadata: + return NixlAgentMetadata( + engine_id="test-engine", + agent_metadata=b"test-agent-data", + kv_caches_base_addr=[12345], + num_blocks=100, + block_len=16, + attn_backend_name="FLASH_ATTN_VLLM_V1" + ) + + @patch('vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.ZmqHandshakeStrategy._zmq_ctx') + @patch('vllm.utils.make_zmq_path') + def test_zmq_handshake_success(self, mock_make_path, mock_zmq_ctx): + mock_nixl = MagicMock() + mock_add_agent = MagicMock(return_value="agent-name-0") + + strategy = ZmqHandshakeStrategy( + mock_nixl, 0, 1, 8080, "test-engine", mock_add_agent) + + mock_socket = MagicMock() + mock_zmq_ctx.return_value.__enter__.return_value = mock_socket + mock_make_path.return_value = "tcp://localhost:8080" + + test_metadata = self.create_test_metadata() + with patch('msgspec.msgpack.Decoder') as mock_decoder_class: + mock_decoder = MagicMock() + mock_decoder_class.return_value = mock_decoder + mock_decoder.decode.return_value = test_metadata + + result = strategy.initiate_handshake("localhost", 8080, 1) + + assert result == {0: "agent-name-0"} + mock_add_agent.assert_called_once() + mock_socket.send.assert_called() + mock_socket.recv.assert_called() + + @patch('vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.ZmqHandshakeStrategy._zmq_ctx') + @patch('vllm.utils.make_zmq_path') + def test_zmq_handshake_multi_rank(self, mock_make_path, mock_zmq_ctx): + mock_nixl = MagicMock() + mock_add_agent = MagicMock(side_effect=["agent-0", "agent-1"]) + + strategy = ZmqHandshakeStrategy( + mock_nixl, 1, 2, 8080, "test-engine", mock_add_agent) + + mock_socket = MagicMock() + mock_zmq_ctx.return_value.__enter__.return_value = mock_socket + mock_make_path.side_effect = ["tcp://localhost:8080", "tcp://localhost:8081"] + + test_metadata = self.create_test_metadata() + with patch('msgspec.msgpack.Decoder') as mock_decoder_class: + mock_decoder = MagicMock() + mock_decoder_class.return_value = mock_decoder + mock_decoder.decode.return_value = test_metadata + + result = strategy.initiate_handshake("localhost", 8080, 2) + + assert result == {0: "agent-0", 1: "agent-1"} + assert mock_add_agent.call_count == 2 + + @patch('threading.Thread') + def test_setup_listener(self, mock_thread): + mock_nixl = MagicMock() + mock_add_agent = MagicMock() + + strategy = ZmqHandshakeStrategy( + mock_nixl, 0, 1, 8080, "test-engine", mock_add_agent) + + mock_thread_instance = MagicMock() + mock_thread.return_value = mock_thread_instance + + test_metadata = self.create_test_metadata() + + with patch('threading.Event') as mock_event_class: + mock_event = MagicMock() + mock_event_class.return_value = mock_event + + strategy.setup_listener(test_metadata) + + mock_thread.assert_called_once() + mock_thread_instance.start.assert_called_once() + mock_event.wait.assert_called_once() + + def test_cleanup(self): + mock_nixl = MagicMock() + mock_add_agent = MagicMock() + + strategy = ZmqHandshakeStrategy( + mock_nixl, 0, 1, 8080, "test-engine", mock_add_agent) + + mock_thread = MagicMock() + strategy._listener_thread = mock_thread + + strategy.cleanup() + + mock_thread.join.assert_called_once_with(timeout=0) + + +class TestHttpHandshakeStrategy: + + def create_test_metadata_response(self) -> dict: + return { + "0": { + "0": { + "engine_id": "3871ab24-6b5a-4ea5-a614-5381594bcdde", + "agent_metadata": base64.b64encode(b"nixl-prefill-agent-data").decode(), + "kv_caches_base_addr": [0x7f8b2c000000], + "num_blocks": 1000, + "block_len": 128, + "attn_backend_name": "FLASH_ATTN_VLLM_V1" + } + } + } + + @patch('vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.urlopen') + def test_http_handshake_success(self, mock_urlopen): + mock_nixl = MagicMock() + mock_add_agent = MagicMock(return_value="remote-agent-0") + + strategy = HttpHandshakeStrategy( + mock_nixl, 0, 1, 8080, "test-engine", mock_add_agent) + + mock_response = MagicMock() + mock_response.read.return_value = json.dumps( + self.create_test_metadata_response()).encode() + mock_urlopen.return_value.__enter__.return_value = mock_response + + result = strategy.initiate_handshake("localhost", 8080, 1) + + assert result == {0: "remote-agent-0"} + mock_add_agent.assert_called_once() + + call_args = mock_add_agent.call_args + metadata = call_args[0][0] + assert isinstance(metadata, NixlAgentMetadata) + assert metadata.engine_id == "3871ab24-6b5a-4ea5-a614-5381594bcdde" + assert metadata.agent_metadata == b"nixl-prefill-agent-data" + + @patch('vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.urlopen') + def test_http_handshake_multi_rank(self, mock_urlopen): + mock_nixl = MagicMock() + mock_add_agent = MagicMock(return_value="remote-agent-1") + + strategy = HttpHandshakeStrategy( + mock_nixl, 1, 2, 8080, "test-engine", mock_add_agent) + + response_data = { + "0": { + "0": { + "engine_id": "339a1bdd-e9ad-4c6e-a3e3-e0e7cca2238d", + "agent_metadata": base64.b64encode(b"decode-agent-0-data").decode(), + "kv_caches_base_addr": [0x7f8b2c000000], + "num_blocks": 800, + "block_len": 128, + "attn_backend_name": "FLASH_ATTN_VLLM_V1" + }, + "1": { + "engine_id": "339a1bdd-e9ad-4c6e-a3e3-e0e7cca2238d", + "agent_metadata": base64.b64encode(b"decode-agent-1-data").decode(), + "kv_caches_base_addr": [0x7f8b2d000000], + "num_blocks": 800, + "block_len": 128, + "attn_backend_name": "FLASH_ATTN_VLLM_V1" + } + } + } + + mock_response = MagicMock() + mock_response.read.return_value = json.dumps(response_data).encode() + mock_urlopen.return_value.__enter__.return_value = mock_response + + result = strategy.initiate_handshake("localhost", 8080, 2) + + assert result == {1: "remote-agent-1"} + + call_args = mock_add_agent.call_args + metadata = call_args[0][0] + assert metadata.agent_metadata == b"decode-agent-1-data" + + @patch('vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.urlopen') + def test_http_handshake_url_error(self, mock_urlopen): + mock_nixl = MagicMock() + mock_add_agent = MagicMock() + + strategy = HttpHandshakeStrategy( + mock_nixl, 0, 1, 8080, "test-engine", mock_add_agent) + + mock_urlopen.side_effect = URLError("Connection failed") + + with pytest.raises(URLError): + strategy.initiate_handshake("localhost", 8080, 1) + + @patch('vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.urlopen') + def test_http_handshake_none_response(self, mock_urlopen): + mock_nixl = MagicMock() + mock_add_agent = MagicMock() + + strategy = HttpHandshakeStrategy( + mock_nixl, 0, 1, 8080, "test-engine", mock_add_agent) + + mock_response = MagicMock() + mock_response.read.return_value = json.dumps(None).encode() + mock_urlopen.return_value.__enter__.return_value = mock_response + + with pytest.raises(RuntimeError, match="Remote server returned None metadata"): + strategy.initiate_handshake("localhost", 8080, 1) + + @patch('vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.urlopen') + def test_http_handshake_missing_rank(self, mock_urlopen): + mock_nixl = MagicMock() + mock_add_agent = MagicMock() + + strategy = HttpHandshakeStrategy( + mock_nixl, 1, 2, 8080, "decode-engine", mock_add_agent) + + mock_response = MagicMock() + empty_response = {"0": {}} + mock_response.read.return_value = json.dumps(empty_response).encode() + mock_urlopen.return_value.__enter__.return_value = mock_response + + with pytest.raises(RuntimeError, match="No metadata found for dp_rank 0"): + strategy.initiate_handshake("localhost", 8080, 1) + + def test_setup_listener_noop(self): + mock_nixl = MagicMock() + mock_add_agent = MagicMock() + + strategy = HttpHandshakeStrategy( + mock_nixl, 0, 1, 8080, "test-engine", mock_add_agent) + + test_metadata = NixlAgentMetadata( + engine_id="test-engine", + agent_metadata=b"test-data", + kv_caches_base_addr=[12345], + num_blocks=100, + block_len=16, + attn_backend_name="FLASH_ATTN_VLLM_V1" + ) + + strategy.setup_listener(test_metadata) + + def test_cleanup_noop(self): + mock_nixl = MagicMock() + mock_add_agent = MagicMock() + + strategy = HttpHandshakeStrategy( + mock_nixl, 0, 1, 8080, "test-engine", mock_add_agent) + + strategy.cleanup() + + +class TestHandshakeStrategyIntegration: + + @patch.dict('os.environ', {'VLLM_NIXL_HANDSHAKE_METHOD': 'zmq'}) + @patch('vllm.envs.VLLM_NIXL_HANDSHAKE_METHOD', 'zmq') + def test_zmq_strategy_selection(self): + assert envs.VLLM_NIXL_HANDSHAKE_METHOD.lower() == 'zmq' + + @patch.dict('os.environ', {'VLLM_NIXL_HANDSHAKE_METHOD': 'http'}) + @patch('vllm.envs.VLLM_NIXL_HANDSHAKE_METHOD', 'http') + def test_http_strategy_selection(self): + assert envs.VLLM_NIXL_HANDSHAKE_METHOD.lower() == 'http' + + def test_strategy_polymorphism(self): + mock_nixl = MagicMock() + mock_add_agent = MagicMock() + + strategies = [ + ZmqHandshakeStrategy(mock_nixl, 0, 1, 8080, "test", mock_add_agent), + HttpHandshakeStrategy(mock_nixl, 0, 1, 8080, "test", mock_add_agent) + ] + + test_metadata = NixlAgentMetadata( + engine_id="test-engine", + agent_metadata=b"test-data", + kv_caches_base_addr=[12345], + num_blocks=100, + block_len=16, + attn_backend_name="FLASH_ATTN_VLLM_V1" + ) + + for strategy in strategies: + assert callable(strategy.initiate_handshake) + assert callable(strategy.setup_listener) + assert callable(strategy.cleanup) + + strategy.setup_listener(test_metadata) + strategy.cleanup() \ No newline at end of file diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index df7693d370b..95548e2e22a 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -1,22 +1,26 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import base64 +import contextlib import json import math import queue import threading import time import uuid +from abc import ABC, abstractmethod from collections import defaultdict from collections.abc import Iterator from concurrent.futures import Future, ThreadPoolExecutor from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Dict, Optional from urllib.error import HTTPError, URLError from urllib.request import Request as URLRequest from urllib.request import urlopen +import msgspec import torch +import zmq from vllm import envs from vllm.attention.selector import backend_name_to_enum, get_attn_backend @@ -31,7 +35,7 @@ from vllm.forward_context import ForwardContext from vllm.logger import init_logger from vllm.platforms import _Backend -from vllm.utils import build_uri, round_down +from vllm.utils import build_uri, make_zmq_path, make_zmq_socket, round_down from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.request import RequestStatus @@ -76,6 +80,209 @@ class ReqMeta: tp_size: int +class HandshakeStrategy(ABC): + + def __init__(self, nixl_wrapper, tp_rank: int, tp_size: int, + side_channel_port: int, engine_id: str): + self.nixl_wrapper = nixl_wrapper + self.tp_rank = tp_rank + self.tp_size = tp_size + self.side_channel_port = side_channel_port + self.engine_id = engine_id + + @abstractmethod + def initiate_handshake(self, host: str, port: int, + remote_tp_size: int) -> Dict[int, str]: + pass + + @abstractmethod + def setup_listener(self, metadata: NixlAgentMetadata) -> None: + pass + + @abstractmethod + def cleanup(self) -> None: + pass + + +class ZmqHandshakeStrategy(HandshakeStrategy): + + def __init__(self, nixl_wrapper, tp_rank: int, tp_size: int, + side_channel_port: int, engine_id: str, + add_remote_agent_func): + super().__init__(nixl_wrapper, tp_rank, tp_size, side_channel_port, engine_id) + self.add_remote_agent_func = add_remote_agent_func + self._listener_thread: Optional[threading.Thread] = None + self._tp_size_mapping: Dict[str, int] = {engine_id: tp_size} + + def initiate_handshake(self, host: str, port: int, + remote_tp_size: int) -> Dict[int, str]: + start_time = time.perf_counter() + + def handshake(path: str, rank: int) -> tuple[NixlAgentMetadata, str]: + with self._zmq_ctx(zmq.REQ, path) as sock: + sock.send(GET_META_MSG) + metadata_bytes = sock.recv() + decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) + metadata = decoder.decode(metadata_bytes) + got_metadata_time = time.perf_counter() + + # Register Remote agent + agent_name = self.add_remote_agent_func(metadata, rank, remote_tp_size) + setup_agent_time = time.perf_counter() + + logger.debug("NIXL handshake: get metadata took: %s", + got_metadata_time - start_time) + logger.debug("NIXL handshake: add agent took: %s", + setup_agent_time - got_metadata_time) + return metadata, agent_name + + # Handshake with remote agent-rank0 first to get the tp_size of remote + path = make_zmq_path("tcp", host, port) + logger.debug("Querying master rank metadata on path: %s", path) + metadata, agent_name_0 = handshake(path, 0) + + agents = {0: agent_name_0} + + # Handshake only with the other TP remote the current local rank will + # pull from. With homogeneous TP it happens to be the same rank_i. + tp_ratio = self._tp_size_mapping[self.engine_id] // remote_tp_size + p_remote_rank = self.tp_rank // tp_ratio + if p_remote_rank > 0: + path = make_zmq_path("tcp", host, port + p_remote_rank) + logger.debug("Querying metadata on path: %s at remote rank %s", + path, p_remote_rank) + _, agent_name = handshake(path, p_remote_rank) + agents[p_remote_rank] = agent_name + + return agents + + def setup_listener(self, metadata: NixlAgentMetadata) -> None: + ready_event = threading.Event() + self._listener_thread = threading.Thread( + target=self._nixl_handshake_listener, + args=(metadata, ready_event, self.side_channel_port, self.tp_rank), + daemon=True, + name="nixl_handshake_listener") + self._listener_thread.start() + ready_event.wait() + + def cleanup(self) -> None: + if self._listener_thread: + self._listener_thread.join(timeout=0) + + @staticmethod + def _nixl_handshake_listener(metadata: NixlAgentMetadata, + ready_event: threading.Event, base_port: int, + tp_rank: int): + encoder = msgspec.msgpack.Encoder() + encoded_data = encoder.encode(metadata) + size_in_bytes = len(encoded_data) + logger.debug("Size of encoded NixlAgentMetadata: %s bytes", size_in_bytes) + + # Listen for new requests for metadata + host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST + path = make_zmq_path("tcp", host, base_port + tp_rank) + logger.debug("Starting listening on path: %s", path) + with ZmqHandshakeStrategy._zmq_ctx(zmq.ROUTER, path) as sock: + ready_event.set() + while True: + identity, _, msg = sock.recv_multipart() + if msg != GET_META_MSG: + logger.warning("Connection listener got unexpected message %s", msg) + sock.send_multipart((identity, b"", encoded_data)) + + @staticmethod + @contextlib.contextmanager + def _zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]: + if socket_type not in (zmq.ROUTER, zmq.REQ): + raise ValueError(f"Unexpected socket type: {socket_type}") + + ctx: Optional[zmq.Context] = None + try: + ctx = zmq.Context() + yield make_zmq_socket(ctx=ctx, path=addr, socket_type=socket_type, + bind=socket_type == zmq.ROUTER) + finally: + if ctx is not None: + ctx.destroy(linger=0) + + +class HttpHandshakeStrategy(HandshakeStrategy): + + def __init__(self, nixl_wrapper, tp_rank: int, tp_size: int, + side_channel_port: int, engine_id: str, + add_remote_agent_func): + super().__init__(nixl_wrapper, tp_rank, tp_size, side_channel_port, engine_id) + self.add_remote_agent_func = add_remote_agent_func + self._tp_size_mapping: Dict[str, int] = {engine_id: tp_size} + + def initiate_handshake(self, host: str, port: int, + remote_tp_size: int) -> Dict[int, str]: + start_time = time.perf_counter() + logger.debug("Starting NIXL handshake with %s:%s", host, port) + + url = build_uri("http", host, port, path="get_kv_connector_metadata") + + try: + req = URLRequest(url) + with urlopen(req, timeout=envs.VLLM_NIXL_HANDSHAKE_TIMEOUT) as response: + response_data = response.read().decode('utf-8') + res = json.loads(response_data) + except (URLError, HTTPError) as e: + logger.error("Failed to fetch metadata from %s: %s", url, e) + raise + + if res is None: + logger.warning("Remote server returned None metadata, skipping handshake") + raise RuntimeError("Remote server returned None metadata") + + # Get dp_rank 0 data (standard for disaggregated prefill-decode) + dp_data = res.get("0", {}) + if not dp_data: + raise RuntimeError("No metadata found for dp_rank 0") + + remote_tp_size = len(dp_data.keys()) + + # Handshake only with the remote TP rank that current local rank will + # pull from. With homogeneous TP it happens to be the same rank_i. + tp_ratio = self._tp_size_mapping[self.engine_id] // remote_tp_size + p_remote_rank = self.tp_rank // tp_ratio + + # Get data for the specific rank we need to connect to + rank_data = dp_data.get(str(p_remote_rank), {}) + if not rank_data: + raise RuntimeError(f"No metadata found for remote rank {p_remote_rank}") + + metadata_bytes = rank_data.get("agent_metadata", None) + if metadata_bytes is None: + raise RuntimeError(f"No agent metadata found for remote rank {p_remote_rank}") + + rank_data_copy = rank_data.copy() + rank_data_copy.pop("agent_metadata", None) + metadata = NixlAgentMetadata( + agent_metadata=base64.b64decode(metadata_bytes), **rank_data_copy) + + pre_register = time.perf_counter() + # Register Remote agent + remote_agent_name = self.add_remote_agent_func(metadata, p_remote_rank, remote_tp_size) + agent_time = time.perf_counter() + + logger.debug("Finished registering remote agent for engine %s", metadata.engine_id) + logger.debug("NIXL handshake: get metadata took: %s", pre_register - start_time) + logger.debug("NIXL handshake: add agent took: %s", agent_time - pre_register) + + logger.debug("NIXL handshake method completed for %s:%s", host, port) + + # Return remote rank -> agent name mapping + return {p_remote_rank: remote_agent_name} + + def setup_listener(self, metadata: NixlAgentMetadata) -> None: + pass + + def cleanup(self) -> None: + pass + + class NixlConnectorMetadata(KVConnectorMetadata): def __init__(self): @@ -461,78 +668,32 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # finish reading before safely freeing the blocks. self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int) + # Initialize handshake strategy + handshake_method = envs.VLLM_NIXL_HANDSHAKE_METHOD.lower() + if handshake_method == "zmq": + self._handshake_strategy = ZmqHandshakeStrategy( + self.nixl_wrapper, self.tp_rank, self.world_size, + self.side_channel_port, self.engine_id, self.add_remote_agent) + elif handshake_method == "http": + self._handshake_strategy = HttpHandshakeStrategy( + self.nixl_wrapper, self.tp_rank, self.world_size, + self.side_channel_port, self.engine_id, self.add_remote_agent) + else: + raise ValueError(f"Unknown handshake method: {handshake_method}. " + "Supported methods: 'zmq', 'http'") + + logger.info("Using %s handshake strategy", handshake_method) + def __del__(self): - """Cleanup background threads on destruction.""" self._handshake_initiation_executor.shutdown(wait=False) + if hasattr(self, '_handshake_strategy'): + self._handshake_strategy.cleanup() if self._nixl_handshake_listener_t: self._nixl_handshake_listener_t.join(timeout=0) def _nixl_handshake(self, host: str, port: int, remote_tp_size: int) -> dict[int, str]: - """Do a NIXL handshake with a remote instance.""" - - start_time = time.perf_counter() - logger.debug("Starting NIXL handshake with %s:%s", host, port) - - url = build_uri("http", host, port, path="get_kv_connector_metadata") - - try: - req = URLRequest(url) - with urlopen(req, - timeout=envs.VLLM_NIXL_HANDSHAKE_TIMEOUT) as response: - response_data = response.read().decode('utf-8') - res = json.loads(response_data) - except (URLError, HTTPError) as e: - logger.error("Failed to fetch metadata from %s: %s", url, e) - raise - - if res is None: - logger.warning( - "Remote server returned None metadata, skipping handshake") - raise RuntimeError("Remote server returned None metadata") - - # Get dp_rank 0 data (standard for disaggregated prefill-decode) - dp_data = res.get("0", {}) - if not dp_data: - raise RuntimeError("No metadata found for dp_rank 0") - - remote_tp_size = len(dp_data.keys()) - - # Handshake only with the remote TP rank that current local rank will - # pull from. With homogeneous TP it happens to be the same rank_i. - tp_ratio = self._tp_size[self.engine_id] // remote_tp_size - p_remote_rank = self.tp_rank // tp_ratio - - # Get data for the specific rank we need to connect to - rank_data = dp_data.get(str(p_remote_rank), {}) - if not rank_data: - raise RuntimeError(f"No metadata found for remote rank {p_remote_rank}") - - metadata_bytes = rank_data.get("agent_metadata", None) - if metadata_bytes is None: - raise RuntimeError(f"No agent metadata found for remote rank {p_remote_rank}") - - rank_data_copy = rank_data.copy() - rank_data_copy.pop("agent_metadata", None) - metadata = NixlAgentMetadata( - agent_metadata=base64.b64decode(metadata_bytes), **rank_data_copy) - - pre_register = time.perf_counter() - # Register Remote agent. - remote_agent_name = self.add_remote_agent(metadata, p_remote_rank, remote_tp_size) - agent_time = time.perf_counter() - - logger.debug("Finished registering remote agent for engine %s", - metadata.engine_id) - logger.debug("NIXL handshake: get metadata took: %s", - pre_register - start_time) - logger.debug("NIXL handshake: add agent took: %s", - agent_time - pre_register) - - logger.debug("NIXL handshake method completed for %s:%s", host, port) - - # Return remote rank -> agent name mapping - return {p_remote_rank: remote_agent_name} + return self._handshake_strategy.initiate_handshake(host, port, remote_tp_size) def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data in nixl.""" @@ -663,6 +824,9 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): block_len=self.block_len, attn_backend_name=self.backend_name) + # Setup handshake strategy listener + self._handshake_strategy.setup_listener(self.xfer_metadata) + def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata, remote_tp_rank: int = 0, diff --git a/vllm/entrypoints/nixl_side_channel_server.py b/vllm/entrypoints/nixl_side_channel_server.py index 5004f5cde3c..13daf80090f 100644 --- a/vllm/entrypoints/nixl_side_channel_server.py +++ b/vllm/entrypoints/nixl_side_channel_server.py @@ -92,12 +92,21 @@ def should_start_nixl_side_channel_server(vllm_config: VllmConfig) -> bool: if vllm_config.kv_transfer_config is None: return False - return vllm_config.kv_transfer_config.kv_connector == "NixlConnector" + if vllm_config.kv_transfer_config.kv_connector != "NixlConnector": + return False + + handshake_method = envs.VLLM_NIXL_HANDSHAKE_METHOD.lower() + return handshake_method == "http" async def start_nixl_side_channel_server_if_needed( vllm_config: VllmConfig) -> Optional[NixlSideChannelServer]: if not should_start_nixl_side_channel_server(vllm_config): + if (vllm_config.kv_transfer_config is not None and + vllm_config.kv_transfer_config.kv_connector == "NixlConnector"): + handshake_method = envs.VLLM_NIXL_HANDSHAKE_METHOD.lower() + logger.info("Skipping NIXL HTTP side channel server (handshake method: %s)", + handshake_method) return None side_channel_host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST diff --git a/vllm/envs.py b/vllm/envs.py index b01f091869c..d6e9b9bee96 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -127,6 +127,7 @@ VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost" VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557 VLLM_NIXL_HANDSHAKE_TIMEOUT: float = 2.0 + VLLM_NIXL_HANDSHAKE_METHOD: str = "zmq" VLLM_ALL2ALL_BACKEND: str = "naive" VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840 VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1 @@ -905,6 +906,10 @@ def get_vllm_port() -> Optional[int]: "VLLM_NIXL_HANDSHAKE_TIMEOUT": lambda: float(os.getenv("VLLM_NIXL_HANDSHAKE_TIMEOUT", "2.0")), + # NIXL handshake method ("zmq" or "http") + "VLLM_NIXL_HANDSHAKE_METHOD": + lambda: os.getenv("VLLM_NIXL_HANDSHAKE_METHOD", "zmq"), + # all2all backend for vllm's expert parallel communication # Available options: # - "naive": naive all2all implementation using all-reduce From cc887f5e07e61f13e98fe4b3991ca3ce47178cdc Mon Sep 17 00:00:00 2001 From: Will Eaton Date: Mon, 30 Jun 2025 14:22:00 -0400 Subject: [PATCH 20/33] pre-commit fixes Signed-off-by: Will Eaton --- .../unit/test_nixl_handshake_strategies.py | 260 ++++++++++-------- vllm/config.py | 6 +- .../kv_connector/v1/nixl_connector.py | 158 ++++++----- vllm/entrypoints/nixl_side_channel_server.py | 45 +-- vllm/entrypoints/openai/api_server.py | 11 +- vllm/utils.py | 65 ++--- vllm/v1/core/sched/scheduler.py | 2 +- vllm/v1/executor/abstract.py | 3 +- vllm/v1/utils.py | 3 +- 9 files changed, 285 insertions(+), 268 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_handshake_strategies.py b/tests/v1/kv_connector/unit/test_nixl_handshake_strategies.py index 8db1c8bff75..8ad39e4ad7e 100644 --- a/tests/v1/kv_connector/unit/test_nixl_handshake_strategies.py +++ b/tests/v1/kv_connector/unit/test_nixl_handshake_strategies.py @@ -23,15 +23,15 @@ def test_abstract_base_class(self): def test_strategy_interface(self): mock_nixl = MagicMock() mock_add_agent = MagicMock() - - zmq_strategy = ZmqHandshakeStrategy( - mock_nixl, 0, 1, 8080, "test-engine", mock_add_agent) + + zmq_strategy = ZmqHandshakeStrategy(mock_nixl, 0, 1, 8080, + "test-engine", mock_add_agent) assert hasattr(zmq_strategy, 'initiate_handshake') assert hasattr(zmq_strategy, 'setup_listener') assert hasattr(zmq_strategy, 'cleanup') - - http_strategy = HttpHandshakeStrategy( - mock_nixl, 0, 1, 8080, "test-engine", mock_add_agent) + + http_strategy = HttpHandshakeStrategy(mock_nixl, 0, 1, 8080, + "test-engine", mock_add_agent) assert hasattr(http_strategy, 'initiate_handshake') assert hasattr(http_strategy, 'setup_listener') assert hasattr(http_strategy, 'cleanup') @@ -40,62 +40,66 @@ def test_strategy_interface(self): class TestZmqHandshakeStrategy: def create_test_metadata(self) -> NixlAgentMetadata: - return NixlAgentMetadata( - engine_id="test-engine", - agent_metadata=b"test-agent-data", - kv_caches_base_addr=[12345], - num_blocks=100, - block_len=16, - attn_backend_name="FLASH_ATTN_VLLM_V1" - ) + return NixlAgentMetadata(engine_id="test-engine", + agent_metadata=b"test-agent-data", + kv_caches_base_addr=[12345], + num_blocks=100, + block_len=16, + attn_backend_name="FLASH_ATTN_VLLM_V1") - @patch('vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.ZmqHandshakeStrategy._zmq_ctx') + @patch( + 'vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.ZmqHandshakeStrategy._zmq_ctx' + ) @patch('vllm.utils.make_zmq_path') def test_zmq_handshake_success(self, mock_make_path, mock_zmq_ctx): mock_nixl = MagicMock() mock_add_agent = MagicMock(return_value="agent-name-0") - - strategy = ZmqHandshakeStrategy( - mock_nixl, 0, 1, 8080, "test-engine", mock_add_agent) - + + strategy = ZmqHandshakeStrategy(mock_nixl, 0, 1, 8080, "test-engine", + mock_add_agent) + mock_socket = MagicMock() mock_zmq_ctx.return_value.__enter__.return_value = mock_socket mock_make_path.return_value = "tcp://localhost:8080" - + test_metadata = self.create_test_metadata() with patch('msgspec.msgpack.Decoder') as mock_decoder_class: mock_decoder = MagicMock() mock_decoder_class.return_value = mock_decoder mock_decoder.decode.return_value = test_metadata - + result = strategy.initiate_handshake("localhost", 8080, 1) - + assert result == {0: "agent-name-0"} mock_add_agent.assert_called_once() mock_socket.send.assert_called() mock_socket.recv.assert_called() - @patch('vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.ZmqHandshakeStrategy._zmq_ctx') + @patch( + 'vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.ZmqHandshakeStrategy._zmq_ctx' + ) @patch('vllm.utils.make_zmq_path') def test_zmq_handshake_multi_rank(self, mock_make_path, mock_zmq_ctx): mock_nixl = MagicMock() mock_add_agent = MagicMock(side_effect=["agent-0", "agent-1"]) - - strategy = ZmqHandshakeStrategy( - mock_nixl, 1, 2, 8080, "test-engine", mock_add_agent) - + + strategy = ZmqHandshakeStrategy(mock_nixl, 1, 2, 8080, "test-engine", + mock_add_agent) + mock_socket = MagicMock() mock_zmq_ctx.return_value.__enter__.return_value = mock_socket - mock_make_path.side_effect = ["tcp://localhost:8080", "tcp://localhost:8081"] - + mock_make_path.side_effect = [ + "tcp://localhost:8080", "tcp://localhost:8081" + ] + test_metadata = self.create_test_metadata() with patch('msgspec.msgpack.Decoder') as mock_decoder_class: mock_decoder = MagicMock() mock_decoder_class.return_value = mock_decoder mock_decoder.decode.return_value = test_metadata - + result = strategy.initiate_handshake("localhost", 8080, 2) - + assert result == {0: "agent-0", 1: "agent-1"} assert mock_add_agent.call_count == 2 @@ -103,21 +107,21 @@ def test_zmq_handshake_multi_rank(self, mock_make_path, mock_zmq_ctx): def test_setup_listener(self, mock_thread): mock_nixl = MagicMock() mock_add_agent = MagicMock() - - strategy = ZmqHandshakeStrategy( - mock_nixl, 0, 1, 8080, "test-engine", mock_add_agent) - + + strategy = ZmqHandshakeStrategy(mock_nixl, 0, 1, 8080, "test-engine", + mock_add_agent) + mock_thread_instance = MagicMock() mock_thread.return_value = mock_thread_instance - + test_metadata = self.create_test_metadata() - + with patch('threading.Event') as mock_event_class: mock_event = MagicMock() mock_event_class.return_value = mock_event - + strategy.setup_listener(test_metadata) - + mock_thread.assert_called_once() mock_thread_instance.start.assert_called_once() mock_event.wait.assert_called_once() @@ -125,15 +129,15 @@ def test_setup_listener(self, mock_thread): def test_cleanup(self): mock_nixl = MagicMock() mock_add_agent = MagicMock() - - strategy = ZmqHandshakeStrategy( - mock_nixl, 0, 1, 8080, "test-engine", mock_add_agent) - + + strategy = ZmqHandshakeStrategy(mock_nixl, 0, 1, 8080, "test-engine", + mock_add_agent) + mock_thread = MagicMock() strategy._listener_thread = mock_thread - + strategy.cleanup() - + mock_thread.join.assert_called_once_with(timeout=0) @@ -143,150 +147,171 @@ def create_test_metadata_response(self) -> dict: return { "0": { "0": { - "engine_id": "3871ab24-6b5a-4ea5-a614-5381594bcdde", - "agent_metadata": base64.b64encode(b"nixl-prefill-agent-data").decode(), + "engine_id": + "3871ab24-6b5a-4ea5-a614-5381594bcdde", + "agent_metadata": + base64.b64encode(b"nixl-prefill-agent-data").decode(), "kv_caches_base_addr": [0x7f8b2c000000], - "num_blocks": 1000, - "block_len": 128, - "attn_backend_name": "FLASH_ATTN_VLLM_V1" + "num_blocks": + 1000, + "block_len": + 128, + "attn_backend_name": + "FLASH_ATTN_VLLM_V1" } } } - @patch('vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.urlopen') + @patch( + 'vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.urlopen') def test_http_handshake_success(self, mock_urlopen): mock_nixl = MagicMock() mock_add_agent = MagicMock(return_value="remote-agent-0") - - strategy = HttpHandshakeStrategy( - mock_nixl, 0, 1, 8080, "test-engine", mock_add_agent) - + + strategy = HttpHandshakeStrategy(mock_nixl, 0, 1, 8080, "test-engine", + mock_add_agent) + mock_response = MagicMock() mock_response.read.return_value = json.dumps( self.create_test_metadata_response()).encode() mock_urlopen.return_value.__enter__.return_value = mock_response - + result = strategy.initiate_handshake("localhost", 8080, 1) - + assert result == {0: "remote-agent-0"} mock_add_agent.assert_called_once() - + call_args = mock_add_agent.call_args metadata = call_args[0][0] assert isinstance(metadata, NixlAgentMetadata) assert metadata.engine_id == "3871ab24-6b5a-4ea5-a614-5381594bcdde" assert metadata.agent_metadata == b"nixl-prefill-agent-data" - @patch('vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.urlopen') + @patch( + 'vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.urlopen') def test_http_handshake_multi_rank(self, mock_urlopen): mock_nixl = MagicMock() mock_add_agent = MagicMock(return_value="remote-agent-1") - - strategy = HttpHandshakeStrategy( - mock_nixl, 1, 2, 8080, "test-engine", mock_add_agent) - + + strategy = HttpHandshakeStrategy(mock_nixl, 1, 2, 8080, "test-engine", + mock_add_agent) + response_data = { "0": { "0": { - "engine_id": "339a1bdd-e9ad-4c6e-a3e3-e0e7cca2238d", - "agent_metadata": base64.b64encode(b"decode-agent-0-data").decode(), + "engine_id": + "339a1bdd-e9ad-4c6e-a3e3-e0e7cca2238d", + "agent_metadata": + base64.b64encode(b"decode-agent-0-data").decode(), "kv_caches_base_addr": [0x7f8b2c000000], - "num_blocks": 800, - "block_len": 128, - "attn_backend_name": "FLASH_ATTN_VLLM_V1" + "num_blocks": + 800, + "block_len": + 128, + "attn_backend_name": + "FLASH_ATTN_VLLM_V1" }, "1": { - "engine_id": "339a1bdd-e9ad-4c6e-a3e3-e0e7cca2238d", - "agent_metadata": base64.b64encode(b"decode-agent-1-data").decode(), + "engine_id": + "339a1bdd-e9ad-4c6e-a3e3-e0e7cca2238d", + "agent_metadata": + base64.b64encode(b"decode-agent-1-data").decode(), "kv_caches_base_addr": [0x7f8b2d000000], - "num_blocks": 800, - "block_len": 128, - "attn_backend_name": "FLASH_ATTN_VLLM_V1" + "num_blocks": + 800, + "block_len": + 128, + "attn_backend_name": + "FLASH_ATTN_VLLM_V1" } } } - + mock_response = MagicMock() mock_response.read.return_value = json.dumps(response_data).encode() mock_urlopen.return_value.__enter__.return_value = mock_response - + result = strategy.initiate_handshake("localhost", 8080, 2) - + assert result == {1: "remote-agent-1"} - + call_args = mock_add_agent.call_args metadata = call_args[0][0] assert metadata.agent_metadata == b"decode-agent-1-data" - @patch('vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.urlopen') + @patch( + 'vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.urlopen') def test_http_handshake_url_error(self, mock_urlopen): mock_nixl = MagicMock() mock_add_agent = MagicMock() - - strategy = HttpHandshakeStrategy( - mock_nixl, 0, 1, 8080, "test-engine", mock_add_agent) - + + strategy = HttpHandshakeStrategy(mock_nixl, 0, 1, 8080, "test-engine", + mock_add_agent) + mock_urlopen.side_effect = URLError("Connection failed") - + with pytest.raises(URLError): strategy.initiate_handshake("localhost", 8080, 1) - @patch('vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.urlopen') + @patch( + 'vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.urlopen') def test_http_handshake_none_response(self, mock_urlopen): mock_nixl = MagicMock() mock_add_agent = MagicMock() - - strategy = HttpHandshakeStrategy( - mock_nixl, 0, 1, 8080, "test-engine", mock_add_agent) - + + strategy = HttpHandshakeStrategy(mock_nixl, 0, 1, 8080, "test-engine", + mock_add_agent) + mock_response = MagicMock() mock_response.read.return_value = json.dumps(None).encode() mock_urlopen.return_value.__enter__.return_value = mock_response - - with pytest.raises(RuntimeError, match="Remote server returned None metadata"): + + with pytest.raises(RuntimeError, + match="Remote server returned None metadata"): strategy.initiate_handshake("localhost", 8080, 1) - @patch('vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.urlopen') + @patch( + 'vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.urlopen') def test_http_handshake_missing_rank(self, mock_urlopen): mock_nixl = MagicMock() mock_add_agent = MagicMock() - - strategy = HttpHandshakeStrategy( - mock_nixl, 1, 2, 8080, "decode-engine", mock_add_agent) - + + strategy = HttpHandshakeStrategy(mock_nixl, 1, 2, 8080, + "decode-engine", mock_add_agent) + mock_response = MagicMock() empty_response = {"0": {}} mock_response.read.return_value = json.dumps(empty_response).encode() mock_urlopen.return_value.__enter__.return_value = mock_response - - with pytest.raises(RuntimeError, match="No metadata found for dp_rank 0"): + + with pytest.raises(RuntimeError, + match="No metadata found for dp_rank 0"): strategy.initiate_handshake("localhost", 8080, 1) def test_setup_listener_noop(self): mock_nixl = MagicMock() mock_add_agent = MagicMock() - - strategy = HttpHandshakeStrategy( - mock_nixl, 0, 1, 8080, "test-engine", mock_add_agent) - + + strategy = HttpHandshakeStrategy(mock_nixl, 0, 1, 8080, "test-engine", + mock_add_agent) + test_metadata = NixlAgentMetadata( engine_id="test-engine", agent_metadata=b"test-data", kv_caches_base_addr=[12345], num_blocks=100, block_len=16, - attn_backend_name="FLASH_ATTN_VLLM_V1" - ) - + attn_backend_name="FLASH_ATTN_VLLM_V1") + strategy.setup_listener(test_metadata) def test_cleanup_noop(self): mock_nixl = MagicMock() mock_add_agent = MagicMock() - - strategy = HttpHandshakeStrategy( - mock_nixl, 0, 1, 8080, "test-engine", mock_add_agent) - + + strategy = HttpHandshakeStrategy(mock_nixl, 0, 1, 8080, "test-engine", + mock_add_agent) + strategy.cleanup() @@ -305,25 +330,26 @@ def test_http_strategy_selection(self): def test_strategy_polymorphism(self): mock_nixl = MagicMock() mock_add_agent = MagicMock() - + strategies = [ - ZmqHandshakeStrategy(mock_nixl, 0, 1, 8080, "test", mock_add_agent), - HttpHandshakeStrategy(mock_nixl, 0, 1, 8080, "test", mock_add_agent) + ZmqHandshakeStrategy(mock_nixl, 0, 1, 8080, "test", + mock_add_agent), + HttpHandshakeStrategy(mock_nixl, 0, 1, 8080, "test", + mock_add_agent) ] - + test_metadata = NixlAgentMetadata( engine_id="test-engine", agent_metadata=b"test-data", kv_caches_base_addr=[12345], num_blocks=100, block_len=16, - attn_backend_name="FLASH_ATTN_VLLM_V1" - ) - + attn_backend_name="FLASH_ATTN_VLLM_V1") + for strategy in strategies: assert callable(strategy.initiate_handshake) assert callable(strategy.setup_listener) assert callable(strategy.cleanup) - + strategy.setup_listener(test_metadata) - strategy.cleanup() \ No newline at end of file + strategy.cleanup() diff --git a/vllm/config.py b/vllm/config.py index 0f984953336..cafc7d930f2 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1542,7 +1542,7 @@ class CacheConfig: transfer_handshake_metadata: Optional[dict[int, dict[ int, KVConnectorHandshakeMetadata]]] = field(default=None, init=False) - """Metadata for the KV connector handshake. Structure: dp_rank -> tp_rank -> metadata""" + """Metadata for KV connector handshake. Structure: dp_rank -> tp_rank""" def compute_hash(self) -> str: """ @@ -4633,8 +4633,8 @@ def __post_init__(self): if self.kv_events_config is not None: # Hybrid KV cache manager is not compatible with KV events. self.scheduler_config.disable_hybrid_kv_cache_manager = True - - if (self.kv_transfer_config is not None + + if (self.kv_transfer_config is not None and self.kv_transfer_config.is_kv_transfer_instance): from collections import defaultdict self.cache_config.transfer_handshake_metadata = defaultdict(dict) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 95548e2e22a..fe2c256d993 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -13,7 +13,7 @@ from collections.abc import Iterator from concurrent.futures import Future, ThreadPoolExecutor from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Optional from urllib.error import HTTPError, URLError from urllib.request import Request as URLRequest from urllib.request import urlopen @@ -81,43 +81,44 @@ class ReqMeta: class HandshakeStrategy(ABC): - - def __init__(self, nixl_wrapper, tp_rank: int, tp_size: int, + + def __init__(self, nixl_wrapper, tp_rank: int, tp_size: int, side_channel_port: int, engine_id: str): self.nixl_wrapper = nixl_wrapper self.tp_rank = tp_rank self.tp_size = tp_size self.side_channel_port = side_channel_port self.engine_id = engine_id - + @abstractmethod - def initiate_handshake(self, host: str, port: int, - remote_tp_size: int) -> Dict[int, str]: + def initiate_handshake(self, host: str, port: int, + remote_tp_size: int) -> dict[int, str]: pass - + @abstractmethod def setup_listener(self, metadata: NixlAgentMetadata) -> None: pass - + @abstractmethod def cleanup(self) -> None: pass class ZmqHandshakeStrategy(HandshakeStrategy): - + def __init__(self, nixl_wrapper, tp_rank: int, tp_size: int, - side_channel_port: int, engine_id: str, + side_channel_port: int, engine_id: str, add_remote_agent_func): - super().__init__(nixl_wrapper, tp_rank, tp_size, side_channel_port, engine_id) + super().__init__(nixl_wrapper, tp_rank, tp_size, side_channel_port, + engine_id) self.add_remote_agent_func = add_remote_agent_func self._listener_thread: Optional[threading.Thread] = None - self._tp_size_mapping: Dict[str, int] = {engine_id: tp_size} - - def initiate_handshake(self, host: str, port: int, - remote_tp_size: int) -> Dict[int, str]: + self._tp_size_mapping: dict[str, int] = {engine_id: tp_size} + + def initiate_handshake(self, host: str, port: int, + remote_tp_size: int) -> dict[int, str]: start_time = time.perf_counter() - + def handshake(path: str, rank: int) -> tuple[NixlAgentMetadata, str]: with self._zmq_ctx(zmq.REQ, path) as sock: sock.send(GET_META_MSG) @@ -125,24 +126,25 @@ def handshake(path: str, rank: int) -> tuple[NixlAgentMetadata, str]: decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) metadata = decoder.decode(metadata_bytes) got_metadata_time = time.perf_counter() - + # Register Remote agent - agent_name = self.add_remote_agent_func(metadata, rank, remote_tp_size) + agent_name = self.add_remote_agent_func( + metadata, rank, remote_tp_size) setup_agent_time = time.perf_counter() - + logger.debug("NIXL handshake: get metadata took: %s", - got_metadata_time - start_time) - logger.debug("NIXL handshake: add agent took: %s", - setup_agent_time - got_metadata_time) + got_metadata_time - start_time) + logger.debug("NIXL handshake: add agent took: %s", + setup_agent_time - got_metadata_time) return metadata, agent_name - + # Handshake with remote agent-rank0 first to get the tp_size of remote path = make_zmq_path("tcp", host, port) logger.debug("Querying master rank metadata on path: %s", path) metadata, agent_name_0 = handshake(path, 0) - + agents = {0: agent_name_0} - + # Handshake only with the other TP remote the current local rank will # pull from. With homogeneous TP it happens to be the same rank_i. tp_ratio = self._tp_size_mapping[self.engine_id] // remote_tp_size @@ -150,12 +152,12 @@ def handshake(path: str, rank: int) -> tuple[NixlAgentMetadata, str]: if p_remote_rank > 0: path = make_zmq_path("tcp", host, port + p_remote_rank) logger.debug("Querying metadata on path: %s at remote rank %s", - path, p_remote_rank) + path, p_remote_rank) _, agent_name = handshake(path, p_remote_rank) agents[p_remote_rank] = agent_name - + return agents - + def setup_listener(self, metadata: NixlAgentMetadata) -> None: ready_event = threading.Event() self._listener_thread = threading.Thread( @@ -165,20 +167,21 @@ def setup_listener(self, metadata: NixlAgentMetadata) -> None: name="nixl_handshake_listener") self._listener_thread.start() ready_event.wait() - + def cleanup(self) -> None: if self._listener_thread: self._listener_thread.join(timeout=0) - + @staticmethod def _nixl_handshake_listener(metadata: NixlAgentMetadata, - ready_event: threading.Event, base_port: int, - tp_rank: int): + ready_event: threading.Event, base_port: int, + tp_rank: int): encoder = msgspec.msgpack.Encoder() encoded_data = encoder.encode(metadata) size_in_bytes = len(encoded_data) - logger.debug("Size of encoded NixlAgentMetadata: %s bytes", size_in_bytes) - + logger.debug("Size of encoded NixlAgentMetadata: %s bytes", + size_in_bytes) + # Listen for new requests for metadata host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST path = make_zmq_path("tcp", host, base_port + tp_rank) @@ -188,97 +191,109 @@ def _nixl_handshake_listener(metadata: NixlAgentMetadata, while True: identity, _, msg = sock.recv_multipart() if msg != GET_META_MSG: - logger.warning("Connection listener got unexpected message %s", msg) + logger.warning( + "Connection listener got unexpected message %s", msg) sock.send_multipart((identity, b"", encoded_data)) - + @staticmethod @contextlib.contextmanager def _zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]: if socket_type not in (zmq.ROUTER, zmq.REQ): raise ValueError(f"Unexpected socket type: {socket_type}") - + ctx: Optional[zmq.Context] = None try: ctx = zmq.Context() - yield make_zmq_socket(ctx=ctx, path=addr, socket_type=socket_type, - bind=socket_type == zmq.ROUTER) + yield make_zmq_socket(ctx=ctx, + path=addr, + socket_type=socket_type, + bind=socket_type == zmq.ROUTER) finally: if ctx is not None: ctx.destroy(linger=0) class HttpHandshakeStrategy(HandshakeStrategy): - + def __init__(self, nixl_wrapper, tp_rank: int, tp_size: int, side_channel_port: int, engine_id: str, add_remote_agent_func): - super().__init__(nixl_wrapper, tp_rank, tp_size, side_channel_port, engine_id) + super().__init__(nixl_wrapper, tp_rank, tp_size, side_channel_port, + engine_id) self.add_remote_agent_func = add_remote_agent_func - self._tp_size_mapping: Dict[str, int] = {engine_id: tp_size} - - def initiate_handshake(self, host: str, port: int, - remote_tp_size: int) -> Dict[int, str]: + self._tp_size_mapping: dict[str, int] = {engine_id: tp_size} + + def initiate_handshake(self, host: str, port: int, + remote_tp_size: int) -> dict[int, str]: start_time = time.perf_counter() logger.debug("Starting NIXL handshake with %s:%s", host, port) - + url = build_uri("http", host, port, path="get_kv_connector_metadata") - + try: req = URLRequest(url) - with urlopen(req, timeout=envs.VLLM_NIXL_HANDSHAKE_TIMEOUT) as response: + with urlopen(req, + timeout=envs.VLLM_NIXL_HANDSHAKE_TIMEOUT) as response: response_data = response.read().decode('utf-8') res = json.loads(response_data) except (URLError, HTTPError) as e: logger.error("Failed to fetch metadata from %s: %s", url, e) raise - + if res is None: - logger.warning("Remote server returned None metadata, skipping handshake") + logger.warning( + "Remote server returned None metadata, skipping handshake") raise RuntimeError("Remote server returned None metadata") - + # Get dp_rank 0 data (standard for disaggregated prefill-decode) dp_data = res.get("0", {}) if not dp_data: raise RuntimeError("No metadata found for dp_rank 0") - + remote_tp_size = len(dp_data.keys()) - + # Handshake only with the remote TP rank that current local rank will # pull from. With homogeneous TP it happens to be the same rank_i. tp_ratio = self._tp_size_mapping[self.engine_id] // remote_tp_size p_remote_rank = self.tp_rank // tp_ratio - + # Get data for the specific rank we need to connect to rank_data = dp_data.get(str(p_remote_rank), {}) if not rank_data: - raise RuntimeError(f"No metadata found for remote rank {p_remote_rank}") - + raise RuntimeError( + f"No metadata found for remote rank {p_remote_rank}") + metadata_bytes = rank_data.get("agent_metadata", None) if metadata_bytes is None: - raise RuntimeError(f"No agent metadata found for remote rank {p_remote_rank}") - + raise RuntimeError( + f"No agent metadata found for remote rank {p_remote_rank}") + rank_data_copy = rank_data.copy() rank_data_copy.pop("agent_metadata", None) metadata = NixlAgentMetadata( agent_metadata=base64.b64decode(metadata_bytes), **rank_data_copy) - + pre_register = time.perf_counter() # Register Remote agent - remote_agent_name = self.add_remote_agent_func(metadata, p_remote_rank, remote_tp_size) + remote_agent_name = self.add_remote_agent_func(metadata, p_remote_rank, + remote_tp_size) agent_time = time.perf_counter() - - logger.debug("Finished registering remote agent for engine %s", metadata.engine_id) - logger.debug("NIXL handshake: get metadata took: %s", pre_register - start_time) - logger.debug("NIXL handshake: add agent took: %s", agent_time - pre_register) - + + logger.debug("Finished registering remote agent for engine %s", + metadata.engine_id) + logger.debug("NIXL handshake: get metadata took: %s", + pre_register - start_time) + logger.debug("NIXL handshake: add agent took: %s", + agent_time - pre_register) + logger.debug("NIXL handshake method completed for %s:%s", host, port) - + # Return remote rank -> agent name mapping return {p_remote_rank: remote_agent_name} - + def setup_listener(self, metadata: NixlAgentMetadata) -> None: pass - + def cleanup(self) -> None: pass @@ -680,8 +695,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.side_channel_port, self.engine_id, self.add_remote_agent) else: raise ValueError(f"Unknown handshake method: {handshake_method}. " - "Supported methods: 'zmq', 'http'") - + "Supported methods: 'zmq', 'http'") + logger.info("Using %s handshake strategy", handshake_method) def __del__(self): @@ -693,7 +708,8 @@ def __del__(self): def _nixl_handshake(self, host: str, port: int, remote_tp_size: int) -> dict[int, str]: - return self._handshake_strategy.initiate_handshake(host, port, remote_tp_size) + return self._handshake_strategy.initiate_handshake( + host, port, remote_tp_size) def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data in nixl.""" diff --git a/vllm/entrypoints/nixl_side_channel_server.py b/vllm/entrypoints/nixl_side_channel_server.py index 13daf80090f..9e9f6119008 100644 --- a/vllm/entrypoints/nixl_side_channel_server.py +++ b/vllm/entrypoints/nixl_side_channel_server.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio -import threading from typing import Any, Optional import uvicorn @@ -10,7 +9,6 @@ from vllm import envs from vllm.config import VllmConfig -from vllm.entrypoints.launcher import serve_http from vllm.logger import init_logger logger = init_logger(__name__) @@ -26,9 +24,9 @@ def __init__(self, vllm_config: VllmConfig, host: str, port: int): self.server = None self.server_thread = None self._setup_routes() - + def _setup_routes(self): - + @self.app.get("/get_kv_connector_metadata") @self.app.get("/get_kv_connector_metadata/{dp_rank}") @self.app.get("/get_kv_connector_metadata/{dp_rank}/{tp_rank}") @@ -58,10 +56,10 @@ async def start_async(self): if self.server is not None: logger.warning("Side channel server is already running") return - - logger.info("Starting NIXL side channel server on %s:%s", - self.host, self.port) - + + logger.info("Starting NIXL side channel server on %s:%s", self.host, + self.port) + # use uvicorn directly to avoid dependency on engine_client config = uvicorn.Config( app=self.app, @@ -71,7 +69,7 @@ async def start_async(self): access_log=True, ) self.server = uvicorn.Server(config) - + # start the server in a background task asyncio.create_task(self.server.serve()) logger.info("NIXL side channel server started successfully") @@ -83,7 +81,8 @@ async def stop_async(self): self.server.should_exit = True await asyncio.sleep(1) # give it time to shutdown except Exception as e: - logger.warning("Error during side channel server shutdown: %s", e) + logger.warning("Error during side channel server shutdown: %s", + e) self.server = None logger.info("NIXL side channel server stopped") @@ -91,10 +90,10 @@ async def stop_async(self): def should_start_nixl_side_channel_server(vllm_config: VllmConfig) -> bool: if vllm_config.kv_transfer_config is None: return False - + if vllm_config.kv_transfer_config.kv_connector != "NixlConnector": return False - + handshake_method = envs.VLLM_NIXL_HANDSHAKE_METHOD.lower() return handshake_method == "http" @@ -102,20 +101,22 @@ def should_start_nixl_side_channel_server(vllm_config: VllmConfig) -> bool: async def start_nixl_side_channel_server_if_needed( vllm_config: VllmConfig) -> Optional[NixlSideChannelServer]: if not should_start_nixl_side_channel_server(vllm_config): - if (vllm_config.kv_transfer_config is not None and - vllm_config.kv_transfer_config.kv_connector == "NixlConnector"): + if (vllm_config.kv_transfer_config is not None + and vllm_config.kv_transfer_config.kv_connector + == "NixlConnector"): handshake_method = envs.VLLM_NIXL_HANDSHAKE_METHOD.lower() - logger.info("Skipping NIXL HTTP side channel server (handshake method: %s)", - handshake_method) + logger.info( + "Skipping NIXL HTTP side channel server (handshake method: %s)", + handshake_method) return None - + side_channel_host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST side_channel_port = envs.VLLM_NIXL_SIDE_CHANNEL_PORT - + logger.info("Starting NIXL side channel metadata server on %s:%d", - side_channel_host, side_channel_port) - - server = NixlSideChannelServer( - vllm_config, side_channel_host, side_channel_port) + side_channel_host, side_channel_port) + + server = NixlSideChannelServer(vllm_config, side_channel_host, + side_channel_port) await server.start_async() return server diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 21a20ac5936..cc40e1f727d 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -47,6 +47,8 @@ resolve_mistral_chat_template) from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.nixl_side_channel_server import ( + start_nixl_side_channel_server_if_needed) from vllm.entrypoints.openai.cli_args import (log_non_default_args, make_arg_parser, validate_parsed_serve_args) @@ -93,8 +95,6 @@ from vllm.entrypoints.openai.serving_transcription import ( OpenAIServingTranscription, OpenAIServingTranslation) from vllm.entrypoints.openai.tool_parsers import ToolParserManager -from vllm.entrypoints.nixl_side_channel_server import ( - start_nixl_side_channel_server_if_needed) from vllm.entrypoints.utils import (cli_env_setup, load_aware_call, with_cancellation) from vllm.logger import init_logger @@ -918,7 +918,6 @@ async def show_server_info(raw_request: Request): server_info = {"vllm_config": str(raw_request.app.state.vllm_config)} return JSONResponse(content=server_info) - @router.post("/reset_prefix_cache") async def reset_prefix_cache(raw_request: Request): """ @@ -1452,7 +1451,8 @@ async def run_server_worker(listen_address, nixl_side_channel_server = None try: - nixl_side_channel_server = await start_nixl_side_channel_server_if_needed(vllm_config) + nixl_side_channel_server = await \ + start_nixl_side_channel_server_if_needed(vllm_config) except Exception as e: logger.warning("Failed to start NIXL side channel server: %s", e) @@ -1484,7 +1484,8 @@ async def run_server_worker(listen_address, try: await nixl_side_channel_server.stop_async() except Exception as e: - logger.warning("Error stopping NIXL side channel server: %s", e) + logger.warning("Error stopping NIXL side channel server: %s", + e) sock.close() diff --git a/vllm/utils.py b/vllm/utils.py index eed27ef4df6..bec8bf91ce5 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -33,47 +33,20 @@ import uuid import warnings import weakref -from argparse import ( - Action, - ArgumentDefaultsHelpFormatter, - ArgumentParser, - ArgumentTypeError, - RawDescriptionHelpFormatter, - _ArgumentGroup, -) +from argparse import (Action, ArgumentDefaultsHelpFormatter, ArgumentParser, + ArgumentTypeError, RawDescriptionHelpFormatter, + _ArgumentGroup) from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task from collections import UserDict, defaultdict -from collections.abc import ( - AsyncGenerator, - Awaitable, - Collection, - Generator, - Hashable, - Iterable, - Iterator, - KeysView, - Mapping, -) +from collections.abc import (AsyncGenerator, Awaitable, Collection, Generator, + Hashable, Iterable, Iterator, KeysView, Mapping) from concurrent.futures.process import ProcessPoolExecutor from dataclasses import dataclass, field from functools import cache, lru_cache, partial, wraps from types import MappingProxyType -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Generic, - Literal, - NamedTuple, - Optional, - Sequence, - Tuple, - Type, - TypeVar, - Union, - cast, - overload, -) +from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple, + Optional, Sequence, Tuple, Type, TypeVar, Union, cast, + overload) from urllib.parse import urlparse, urlunparse from uuid import uuid4 @@ -2952,15 +2925,13 @@ def is_torch_equal_or_newer(target: str) -> bool: return Version(importlib.metadata.version('torch')) >= Version(target) -def build_uri( - scheme: str, - host: str, - port: Optional[int] = None, - path: str = "", - params: str = "", - query: str = "", - fragment: str = "" -) -> str: +def build_uri(scheme: str, + host: str, + port: Optional[int] = None, + path: str = "", + params: str = "", + query: str = "", + fragment: str = "") -> str: """ Robustly build a URI that properly handles IPv6 addresses. @@ -2982,12 +2953,12 @@ def build_uri( # Check if it's an IPv6 address ip = ipaddress.ip_address(host) # Ensure IPv6 addresses are bracketed - if (isinstance(ip, ipaddress.IPv6Address) and - not (host.startswith('[') and host.endswith(']'))): + if (isinstance(ip, ipaddress.IPv6Address) + and not (host.startswith('[') and host.endswith(']'))): host = f'[{host}]' except ValueError: pass - + netloc = f"{host}:{port}" if port else host return urlunparse((scheme, netloc, path, params, query, fragment)) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 43a94c7b3e7..d41c783c4d1 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -750,7 +750,7 @@ def update_from_output( for request in self.running: req_id = request.request_id num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0) - + if num_tokens_scheduled == 0: # The request was not scheduled in this step. new_running.append(request) diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index 8df5dbec580..135c03532a0 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -80,7 +80,8 @@ def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]: output = self.collective_rpc("get_kv_cache_spec") return output - def get_kv_connector_handshake_metadata(self) -> list[Optional[dict[int, dict[int, dict]]]]: + def get_kv_connector_handshake_metadata( + self) -> list[Optional[dict[int, dict[int, dict]]]]: output = self.collective_rpc("get_kv_connector_handshake_metadata") return output diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index d0803209bf6..534f479bf89 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -542,7 +542,8 @@ def wait_for_engine_startup( "Received transfer handshake metadata from engine %s: %s", eng_index, txfer_metadata) if cache_config.transfer_handshake_metadata is None: - cache_config.transfer_handshake_metadata = defaultdict(dict) + cache_config.transfer_handshake_metadata = defaultdict( + dict) for dp_rank, tp_dict in txfer_metadata.items(): for tp_rank, metadata in tp_dict.items(): cache_config.transfer_handshake_metadata[dp_rank][ From b960355c5d23cb09a41dc3182141868a3a359f17 Mon Sep 17 00:00:00 2001 From: Will Eaton Date: Mon, 30 Jun 2025 15:08:16 -0400 Subject: [PATCH 21/33] satisfy mypy Signed-off-by: Will Eaton --- vllm/entrypoints/nixl_side_channel_server.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/vllm/entrypoints/nixl_side_channel_server.py b/vllm/entrypoints/nixl_side_channel_server.py index 9e9f6119008..ecf0405f1ce 100644 --- a/vllm/entrypoints/nixl_side_channel_server.py +++ b/vllm/entrypoints/nixl_side_channel_server.py @@ -2,13 +2,15 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio -from typing import Any, Optional +from typing import Optional import uvicorn from fastapi import FastAPI from vllm import envs from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorHandshakeMetadata) from vllm.logger import init_logger logger = init_logger(__name__) @@ -32,8 +34,9 @@ def _setup_routes(self): @self.app.get("/get_kv_connector_metadata/{dp_rank}/{tp_rank}") async def get_kv_connector_metadata(dp_rank: Optional[int] = None, tp_rank: Optional[int] = None): - kv_meta: Optional[dict[str, dict[str, dict[str, Any]]]] = ( - self.vllm_config.cache_config.transfer_handshake_metadata) + kv_meta: Optional[dict[int, dict[ + int, KVConnectorHandshakeMetadata]]] = ( + self.vllm_config.cache_config.transfer_handshake_metadata) if kv_meta is None: return None @@ -71,7 +74,8 @@ async def start_async(self): self.server = uvicorn.Server(config) # start the server in a background task - asyncio.create_task(self.server.serve()) + if self.server is not None: + asyncio.create_task(self.server.serve()) logger.info("NIXL side channel server started successfully") async def stop_async(self): From e0e88b201e0b3d02db2e38394c2989a957e00a2e Mon Sep 17 00:00:00 2001 From: Will Eaton Date: Mon, 30 Jun 2025 15:46:48 -0400 Subject: [PATCH 22/33] add comments; retrigger fastcheck Signed-off-by: Will Eaton --- .../kv_connector/v1/nixl_connector.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index fe2c256d993..0dc3df31ecf 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -81,6 +81,12 @@ class ReqMeta: class HandshakeStrategy(ABC): + """ + Abstract base class for handshake strategies. + + This class is used to abstract the handshake process for different + communication protocols. + """ def __init__(self, nixl_wrapper, tp_rank: int, tp_size: int, side_channel_port: int, engine_id: str): @@ -105,6 +111,12 @@ def cleanup(self) -> None: class ZmqHandshakeStrategy(HandshakeStrategy): + """ + Handshake strategy that uses a ZMQ socket at port defined by + VLLM_NIXL_SIDE_CHANNEL_PORT + tp_rank for communication. + + This is the default handshake strategy for NIXL, and is P2P. + """ def __init__(self, nixl_wrapper, tp_rank: int, tp_size: int, side_channel_port: int, engine_id: str, @@ -214,6 +226,11 @@ def _zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]: class HttpHandshakeStrategy(HandshakeStrategy): + """ + Handshake strategy that uses HTTP requests to fetch metadata from a + remote server. This is done through the front-end, and is + North-South, not P2P. + """ def __init__(self, nixl_wrapper, tp_rank: int, tp_size: int, side_channel_port: int, engine_id: str, From 34e3e555e66f95e8dccbba86362a3673f0bcc189 Mon Sep 17 00:00:00 2001 From: Will Eaton Date: Mon, 30 Jun 2025 16:00:02 -0400 Subject: [PATCH 23/33] add type hint in test Signed-off-by: Will Eaton --- .../unit/test_nixl_handshake_strategies.py | 47 +------------------ 1 file changed, 2 insertions(+), 45 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_handshake_strategies.py b/tests/v1/kv_connector/unit/test_nixl_handshake_strategies.py index 8ad39e4ad7e..c2a398625b6 100644 --- a/tests/v1/kv_connector/unit/test_nixl_handshake_strategies.py +++ b/tests/v1/kv_connector/unit/test_nixl_handshake_strategies.py @@ -3,6 +3,7 @@ import base64 import json +from typing import Any from unittest.mock import MagicMock, patch from urllib.error import URLError @@ -20,22 +21,6 @@ def test_abstract_base_class(self): with pytest.raises(TypeError): HandshakeStrategy(None, 0, 1, 8080, "test-engine") - def test_strategy_interface(self): - mock_nixl = MagicMock() - mock_add_agent = MagicMock() - - zmq_strategy = ZmqHandshakeStrategy(mock_nixl, 0, 1, 8080, - "test-engine", mock_add_agent) - assert hasattr(zmq_strategy, 'initiate_handshake') - assert hasattr(zmq_strategy, 'setup_listener') - assert hasattr(zmq_strategy, 'cleanup') - - http_strategy = HttpHandshakeStrategy(mock_nixl, 0, 1, 8080, - "test-engine", mock_add_agent) - assert hasattr(http_strategy, 'initiate_handshake') - assert hasattr(http_strategy, 'setup_listener') - assert hasattr(http_strategy, 'cleanup') - class TestZmqHandshakeStrategy: @@ -278,9 +263,8 @@ def test_http_handshake_missing_rank(self, mock_urlopen): strategy = HttpHandshakeStrategy(mock_nixl, 1, 2, 8080, "decode-engine", mock_add_agent) - mock_response = MagicMock() - empty_response = {"0": {}} + empty_response: dict[str, dict[str, dict[str, Any]]] = {"0": {}} mock_response.read.return_value = json.dumps(empty_response).encode() mock_urlopen.return_value.__enter__.return_value = mock_response @@ -326,30 +310,3 @@ def test_zmq_strategy_selection(self): @patch('vllm.envs.VLLM_NIXL_HANDSHAKE_METHOD', 'http') def test_http_strategy_selection(self): assert envs.VLLM_NIXL_HANDSHAKE_METHOD.lower() == 'http' - - def test_strategy_polymorphism(self): - mock_nixl = MagicMock() - mock_add_agent = MagicMock() - - strategies = [ - ZmqHandshakeStrategy(mock_nixl, 0, 1, 8080, "test", - mock_add_agent), - HttpHandshakeStrategy(mock_nixl, 0, 1, 8080, "test", - mock_add_agent) - ] - - test_metadata = NixlAgentMetadata( - engine_id="test-engine", - agent_metadata=b"test-data", - kv_caches_base_addr=[12345], - num_blocks=100, - block_len=16, - attn_backend_name="FLASH_ATTN_VLLM_V1") - - for strategy in strategies: - assert callable(strategy.initiate_handshake) - assert callable(strategy.setup_listener) - assert callable(strategy.cleanup) - - strategy.setup_listener(test_metadata) - strategy.cleanup() From 23254707c2f97af1ead892f8e3751cef680b6839 Mon Sep 17 00:00:00 2001 From: Will Eaton Date: Mon, 30 Jun 2025 16:12:34 -0400 Subject: [PATCH 24/33] revert api change Signed-off-by: Will Eaton --- vllm/distributed/kv_transfer/kv_connector/v1/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 9adbdc0de07..cff232f1b59 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -189,7 +189,7 @@ def clear_connector_metadata(self) -> None: """ self._connector_metadata = KVConnectorMetadata() - def get_connector_metadata(self) -> KVConnectorMetadata: + def _get_connector_metadata(self) -> KVConnectorMetadata: """Get the connector metadata. This function should only be called inside the connector. From 430ad03020113edc394b4dc3f381197785b70aa7 Mon Sep 17 00:00:00 2001 From: Will Eaton Date: Mon, 30 Jun 2025 16:36:32 -0400 Subject: [PATCH 25/33] fix usage Signed-off-by: Will Eaton --- .../kv_transfer/kv_connector/v1/base.py | 69 ++----------------- .../kv_connector/v1/nixl_connector.py | 59 +++++----------- 2 files changed, 23 insertions(+), 105 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index cff232f1b59..2cd54ba6a6d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -22,7 +22,6 @@ import enum from abc import ABC, abstractmethod -from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Callable, Optional import msgspec @@ -42,64 +41,6 @@ logger = init_logger(__name__) -@dataclass -class KVTransferFinishedResult: - """Result of KV transfer get_finished operation.""" - - finished_sending: set[str] - finished_recving: set[str] - pending_handshake: set[str] - - def has_any_finished(self) -> bool: - """Check if any requests finished or are pending.""" - return bool(self.finished_sending or self.finished_recving - or self.pending_handshake) - - def is_empty(self) -> bool: - """Check if all sets are empty.""" - return not self.has_any_finished() - - def get_all_finished_req_ids(self) -> set[str]: - """Get all request IDs that have finished (sending or receiving).""" - return self.finished_sending.union(self.finished_recving) - - def merge(self, - other: 'KVTransferFinishedResult') -> 'KVTransferFinishedResult': - """Merge with another result, combining all sets.""" - return KVTransferFinishedResult( - finished_sending=self.finished_sending.union( - other.finished_sending), - finished_recving=self.finished_recving.union( - other.finished_recving), - pending_handshake=self.pending_handshake.union( - other.pending_handshake)) - - @classmethod - def empty(cls) -> 'KVTransferFinishedResult': - """Create an empty result.""" - return cls(finished_sending=set(), - finished_recving=set(), - pending_handshake=set()) - - @classmethod - def from_tuple( - cls, result_tuple: tuple[set[str], set[str], set[str]] - ) -> 'KVTransferFinishedResult': - """Create from the old tuple format for backward compatibility.""" - finished_sending, finished_recving, pending_handshake = result_tuple - return cls(finished_sending=finished_sending, - finished_recving=finished_recving, - pending_handshake=pending_handshake) - - def to_tuple(self) -> tuple[set[str], set[str], set[str]]: - """Convert to the old tuple format for backward compatibility.""" - return ( - self.finished_sending, - self.finished_recving, - self.pending_handshake, - ) - - class KVConnectorRole(enum.Enum): # Connector running in the scheduler process SCHEDULER = 0 @@ -270,19 +211,19 @@ def wait_for_save(self): """ pass - def get_finished(self, - finished_req_ids: set[str]) -> KVTransferFinishedResult: + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[Optional[set[str]], Optional[set[str]]]: """ Notifies worker-side connector ids of requests that have finished generating tokens. Returns: - KVTransferFinishedResult containing sets of finished sending, - finished receiving, and pending handshake request IDs. + Tuple of (finished_sending, finished_recving) request ID sets. The finished saves/sends req ids must belong to a set provided in a call to this method (this call or a prior one). """ - return KVTransferFinishedResult.empty() + return None, None def get_pending_handshake_req_ids(self) -> Optional[set[str]]: """ diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 0dc3df31ecf..1b0b7bbaa9f 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -27,7 +27,7 @@ from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorBase_V1, KVConnectorHandshakeMetadata, KVConnectorMetadata, - KVConnectorRole, KVTransferFinishedResult) + KVConnectorRole) from vllm.distributed.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, get_tp_group) @@ -399,20 +399,13 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): if hasattr(self.connector_worker, 'xfer_metadata'): self.set_handshake_metadata(self.connector_worker.xfer_metadata) - def get_finished(self, - finished_req_ids: set[str]) -> KVTransferFinishedResult: + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[Optional[set[str]], Optional[set[str]]]: """Get the finished recving and sending requests.""" assert self.connector_worker is not None return self.connector_worker.get_finished() - def get_pending_handshake_req_ids(self) -> Optional[set[str]]: - """Get request IDs that are currently pending handshake completion.""" - if self.connector_worker is not None: - result = self.connector_worker.get_finished() - return (result.pending_handshake - if result.pending_handshake else None) - return None - def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: assert self.connector_worker is not None @@ -988,7 +981,7 @@ def add_remote_agent(self, return remote_agent_name - def get_finished(self) -> KVTransferFinishedResult: + def get_finished(self) -> tuple[Optional[set[str]], Optional[set[str]]]: """ Get requests that are done sending, done recving, and pending handshake. @@ -1003,61 +996,45 @@ def get_finished(self) -> KVTransferFinishedResult: done_sending = self._get_new_notifs() done_recving = self._pop_done_transfers(self._recving_transfers) - with self._handshake_lock: - pending_handshake = set() - for engine_id in self._handshake_futures: - pending_handshake.add(engine_id) - - local_result = KVTransferFinishedResult( - finished_sending=done_sending, - finished_recving=done_recving, - pending_handshake=pending_handshake) - if self.world_size == 1: - return local_result + return done_sending, done_recving - return self._coordinate_multi_rank_results(local_result) + return self._coordinate_multi_rank_results(done_sending, done_recving) def _coordinate_multi_rank_results( - self, local_result: KVTransferFinishedResult - ) -> KVTransferFinishedResult: + self, local_sending: set[str], local_recving: set[str] + ) -> tuple[Optional[set[str]], Optional[set[str]]]: """Coordinate results across multiple TP ranks.""" if self.tp_rank == 0: # Rank 0 collects results from all other ranks. - for req_id in local_result.finished_sending: + for req_id in local_sending: self._done_sending_count[req_id] += 1 - for req_id in local_result.finished_recving: + for req_id in local_recving: self._done_recving_count[req_id] += 1 - all_pending_handshake = local_result.pending_handshake.copy() for i in range(1, self.world_size): rank_data = self.tp_group.recv_object(src=i) - other_rank_result = KVTransferFinishedResult.from_tuple( - rank_data) + other_sending, other_recving = rank_data - for req_id in other_rank_result.get_all_finished_req_ids(): + sending_set = other_sending or set() + recving_set = other_recving or set() + for req_id in sending_set | recving_set: if (req_id in self._done_recving_count or req_id in self._recving_transfers): self._done_recving_count[req_id] += 1 else: self._done_sending_count[req_id] += 1 - all_pending_handshake.update( - other_rank_result.pending_handshake) - all_done_recving = self._get_globally_finished_requests( self._done_recving_count) all_done_sending = self._get_globally_finished_requests( self._done_sending_count) - return KVTransferFinishedResult( - finished_sending=all_done_sending, - finished_recving=all_done_recving, - pending_handshake=all_pending_handshake) + return all_done_sending, all_done_recving else: - self.tp_group.send_object(local_result.to_tuple(), dst=0) - return local_result + self.tp_group.send_object((local_sending, local_recving), dst=0) + return local_sending, local_recving def _get_globally_finished_requests( self, counter_dict: dict[str, int]) -> set[str]: From 4f3096677797601d33b18c47ec34491a217fef08 Mon Sep 17 00:00:00 2001 From: Will Eaton Date: Mon, 30 Jun 2025 16:44:08 -0400 Subject: [PATCH 26/33] remove dead code Signed-off-by: Will Eaton --- .../kv_connector/v1/multi_connector.py | 56 +++++++------------ 1 file changed, 21 insertions(+), 35 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index 5ce2a6e1523..2e507d16517 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -10,8 +10,7 @@ from vllm.distributed.kv_transfer.kv_connector.factory import ( KVConnectorFactory) from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole, - KVTransferFinishedResult) + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) from vllm.logger import init_logger from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.sched.output import SchedulerOutput @@ -103,49 +102,36 @@ def wait_for_save(self): for c in self._connectors: c.wait_for_save() - def get_finished(self, - finished_req_ids: set[str]) -> KVTransferFinishedResult: + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[Optional[set[str]], Optional[set[str]]]: finished_sending: set[str] = set() finished_recving: set[str] = set() - pending_handshake: set[str] = set() for c in self._connectors: - result = c.get_finished(finished_req_ids) - if result.is_empty(): - continue + sending, recving = c.get_finished(finished_req_ids) # Aggregate finished recving request ids. - finished_recving.update(result.finished_recving) - - # Aggregate pending handshake request ids. - pending_handshake.update(result.pending_handshake) + if recving: + finished_recving.update(recving) # Aggregate finished sending request ids - only include # once we've drained the "extra" count (for cases where # more than one connector is async-saving the same request). - for req_id in result.finished_sending: - extra_pending = self._extra_async_saves.get(req_id) - if extra_pending is None: - finished_sending.add(req_id) - continue - assert extra_pending > 0 - if extra_pending == 1: - del self._extra_async_saves[req_id] - else: - self._extra_async_saves[req_id] = extra_pending - 1 - - return KVTransferFinishedResult(finished_sending=finished_sending, - finished_recving=finished_recving, - pending_handshake=pending_handshake) - - def get_pending_handshake_req_ids(self) -> Optional[set[str]]: - """Get request IDs that are currently pending handshake completion.""" - pending_handshake: set[str] = set() - for c in self._connectors: - connector_pending = c.get_pending_handshake_req_ids() - if connector_pending: - pending_handshake.update(connector_pending) - return pending_handshake or None + if sending: + for req_id in sending: + extra_pending = self._extra_async_saves.get(req_id) + if extra_pending is None: + finished_sending.add(req_id) + continue + assert extra_pending > 0 + if extra_pending == 1: + del self._extra_async_saves[req_id] + else: + self._extra_async_saves[req_id] = extra_pending - 1 + + return (finished_sending if finished_sending else None, + finished_recving if finished_recving else None) # ============================== # Scheduler-side methods From 54e088ee3bfab8299f85bdd1b90197f56eff699e Mon Sep 17 00:00:00 2001 From: Will Eaton Date: Mon, 30 Jun 2025 16:53:14 -0400 Subject: [PATCH 27/33] fix type declaration Signed-off-by: Will Eaton --- .../kv_transfer/kv_connector/v1/nixl_connector.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 1b0b7bbaa9f..8087cd9b757 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -60,14 +60,14 @@ NixlWrapper = None -class NixlAgentMetadata(KVConnectorHandshakeMetadata, kw_only=True): - connector_type: str = "nixl" +class NixlAgentMetadata(KVConnectorHandshakeMetadata): engine_id: str agent_metadata: bytes kv_caches_base_addr: list[int] num_blocks: int block_len: int attn_backend_name: str + connector_type: str = "nixl" @dataclass @@ -696,7 +696,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # Initialize handshake strategy handshake_method = envs.VLLM_NIXL_HANDSHAKE_METHOD.lower() if handshake_method == "zmq": - self._handshake_strategy = ZmqHandshakeStrategy( + self._handshake_strategy: HandshakeStrategy = ZmqHandshakeStrategy( self.nixl_wrapper, self.tp_rank, self.world_size, self.side_channel_port, self.engine_id, self.add_remote_agent) elif handshake_method == "http": From ead79d647537cac3ae6eecb8e13b5cc6fbba2767 Mon Sep 17 00:00:00 2001 From: Will Eaton Date: Mon, 30 Jun 2025 16:56:17 -0400 Subject: [PATCH 28/33] revert spurious changes to multi-connector Signed-off-by: Will Eaton --- .../kv_connector/v1/multi_connector.py | 34 ++++++++----------- 1 file changed, 15 insertions(+), 19 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index 2e507d16517..be3c2339941 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -107,31 +107,27 @@ def get_finished( ) -> tuple[Optional[set[str]], Optional[set[str]]]: finished_sending: set[str] = set() finished_recving: set[str] = set() - for c in self._connectors: sending, recving = c.get_finished(finished_req_ids) - + if not recving and not sending: + continue # Aggregate finished recving request ids. - if recving: - finished_recving.update(recving) - + finished_recving.update(recving or ()) # Aggregate finished sending request ids - only include # once we've drained the "extra" count (for cases where # more than one connector is async-saving the same request). - if sending: - for req_id in sending: - extra_pending = self._extra_async_saves.get(req_id) - if extra_pending is None: - finished_sending.add(req_id) - continue - assert extra_pending > 0 - if extra_pending == 1: - del self._extra_async_saves[req_id] - else: - self._extra_async_saves[req_id] = extra_pending - 1 - - return (finished_sending if finished_sending else None, - finished_recving if finished_recving else None) + for req_id in sending or (): + extra_pending = self._extra_async_saves.get(req_id) + if extra_pending is None: + finished_sending.add(req_id) + continue + assert extra_pending > 0 + if extra_pending == 1: + del self._extra_async_saves[req_id] + else: + self._extra_async_saves[req_id] = extra_pending - 1 + + return finished_sending or None, finished_recving or None # ============================== # Scheduler-side methods From 1a6ac2c6a5d06f10f01c78e2c1489eb9b323a506 Mon Sep 17 00:00:00 2001 From: Will Eaton Date: Tue, 1 Jul 2025 09:56:51 -0400 Subject: [PATCH 29/33] remove dead code for threading Signed-off-by: Will Eaton --- .../kv_transfer/kv_connector/v1/nixl_connector.py | 4 ---- vllm/entrypoints/nixl_side_channel_server.py | 7 ------- 2 files changed, 11 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 8087cd9b757..dd5b7bd8121 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -654,8 +654,6 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self._done_sending_count: defaultdict[ReqId, int] = defaultdict(lambda: 0) - # Background thread for handling new handshake requests. - self._nixl_handshake_listener_t: Optional[threading.Thread] = None # Background thread for initializing new NIXL handshakes. self._handshake_initiation_executor = ThreadPoolExecutor( # NIXL is not guaranteed to be thread-safe, limit 1 worker. @@ -713,8 +711,6 @@ def __del__(self): self._handshake_initiation_executor.shutdown(wait=False) if hasattr(self, '_handshake_strategy'): self._handshake_strategy.cleanup() - if self._nixl_handshake_listener_t: - self._nixl_handshake_listener_t.join(timeout=0) def _nixl_handshake(self, host: str, port: int, remote_tp_size: int) -> dict[int, str]: diff --git a/vllm/entrypoints/nixl_side_channel_server.py b/vllm/entrypoints/nixl_side_channel_server.py index ecf0405f1ce..97935ce0c5d 100644 --- a/vllm/entrypoints/nixl_side_channel_server.py +++ b/vllm/entrypoints/nixl_side_channel_server.py @@ -105,13 +105,6 @@ def should_start_nixl_side_channel_server(vllm_config: VllmConfig) -> bool: async def start_nixl_side_channel_server_if_needed( vllm_config: VllmConfig) -> Optional[NixlSideChannelServer]: if not should_start_nixl_side_channel_server(vllm_config): - if (vllm_config.kv_transfer_config is not None - and vllm_config.kv_transfer_config.kv_connector - == "NixlConnector"): - handshake_method = envs.VLLM_NIXL_HANDSHAKE_METHOD.lower() - logger.info( - "Skipping NIXL HTTP side channel server (handshake method: %s)", - handshake_method) return None side_channel_host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST From f940b6d5c86ff4c77a8114471f96375b2408fdb3 Mon Sep 17 00:00:00 2001 From: Will Eaton Date: Tue, 1 Jul 2025 14:13:48 -0400 Subject: [PATCH 30/33] remove dead engine code Signed-off-by: Will Eaton --- vllm/v1/core/sched/scheduler.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index d41c783c4d1..00b0844a566 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -741,8 +741,6 @@ def update_from_output( new_running: list[Request] = [] outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) spec_decoding_stats: Optional[SpecDecodingStats] = None - num_requests_to_reschedule = 0 - num_tokens_to_reschedule = 0 # NOTE(woosuk): As len(self.running) can be up to 1K or more, the below # loop can be a performance bottleneck. We should do our best to avoid @@ -750,7 +748,6 @@ def update_from_output( for request in self.running: req_id = request.request_id num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0) - if num_tokens_scheduled == 0: # The request was not scheduled in this step. new_running.append(request) @@ -874,12 +871,6 @@ def update_from_output( if not stopped: new_running.append(request) - if num_requests_to_reschedule: - logger.info( - "Recovered from handshake failure: " - "%d request(s) rescheduled (%d tokens affected).", - num_requests_to_reschedule, num_tokens_to_reschedule) - # KV Connector: update state for finished KV Transfers. self._update_from_kv_xfer_finished(model_runner_output) From 7768f2184632358a689f2e1c3ff3db2be53859ec Mon Sep 17 00:00:00 2001 From: Will Eaton Date: Tue, 1 Jul 2025 14:18:12 -0400 Subject: [PATCH 31/33] fix optional ordering Signed-off-by: Will Eaton --- vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index dd5b7bd8121..ac1519d7c8b 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -61,13 +61,13 @@ class NixlAgentMetadata(KVConnectorHandshakeMetadata): + connector_type: str = "nixl" engine_id: str agent_metadata: bytes kv_caches_base_addr: list[int] num_blocks: int block_len: int attn_backend_name: str - connector_type: str = "nixl" @dataclass From 35684226ad5c3060cdd56e15fb6bd341ff73c58a Mon Sep 17 00:00:00 2001 From: Will Eaton Date: Wed, 9 Jul 2025 10:01:07 -0400 Subject: [PATCH 32/33] naming, fix engine default param issue; pr feedback Signed-off-by: Will Eaton --- .../kv_transfer/kv_connector/v1/base.py | 20 ++++++++++++++++++- .../kv_connector/v1/nixl_connector.py | 18 ++++++++--------- 2 files changed, 28 insertions(+), 10 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 2cd54ba6a6d..eda9a0eba40 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -7,9 +7,15 @@ Scheduler-side: runs in the scheduler, binds metadata, which is used by the worker-side to load/save KV cache. get_num_new_matched_tokens() - get number of new tokens - that exist in the remote KV cache + that exist in the remote KV cache. Might be called multiple + times for a given request and should be side-effect free. update_state_after_alloc() - update KVConnector state after temporary buffer alloc by the CacheManager. + request_finished() - called when a request is finished, with + the computed kv cache blocks for the request. + Returns whether KV cache should be freed now or will be + freed asynchronously and optionally returns KV transfer + params. Worker-side: runs in each worker, loads/saves KV cache to/from the Connector based on the metadata. @@ -18,6 +24,8 @@ save_kv_layer() - starts saving KV for layer i (maybe async) wait_for_save() - blocks until all saves are done + get_finished() - called with ids of finished requests, returns + ids of requests that have completed async sending/recving. """ import enum @@ -277,6 +285,16 @@ def update_state_after_alloc(self, request: "Request", num_external_tokens: int): """ Update KVConnector state after block allocation. + If get_num_new_matched_tokens previously returned True for a + request, this function may be called twice for that same request - + first when blocks are allocated for the connector tokens to be + asynchronously loaded into, and second when any additional blocks + are allocated, after the load/transfer is complete. + Args: + request (Request): the request object. + blocks (KVCacheBlocks): the blocks allocated for the request. + num_external_tokens (int): the number of tokens that will be + loaded from the external KV cache. """ pass diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index ac1519d7c8b..8d0619dfaa1 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -12,7 +12,7 @@ from collections import defaultdict from collections.abc import Iterator from concurrent.futures import Future, ThreadPoolExecutor -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Optional from urllib.error import HTTPError, URLError from urllib.request import Request as URLRequest @@ -61,13 +61,13 @@ class NixlAgentMetadata(KVConnectorHandshakeMetadata): + engine_id: str = field() + agent_metadata: bytes = field() + kv_caches_base_addr: list[int] = field() + num_blocks: int = field() + block_len: int = field() + attn_backend_name: str = field() connector_type: str = "nixl" - engine_id: str - agent_metadata: bytes - kv_caches_base_addr: list[int] - num_blocks: int - block_len: int - attn_backend_name: str @dataclass @@ -125,7 +125,7 @@ def __init__(self, nixl_wrapper, tp_rank: int, tp_size: int, engine_id) self.add_remote_agent_func = add_remote_agent_func self._listener_thread: Optional[threading.Thread] = None - self._tp_size_mapping: dict[str, int] = {engine_id: tp_size} + self._tp_size: dict[str, int] = {engine_id: tp_size} def initiate_handshake(self, host: str, port: int, remote_tp_size: int) -> dict[int, str]: @@ -159,7 +159,7 @@ def handshake(path: str, rank: int) -> tuple[NixlAgentMetadata, str]: # Handshake only with the other TP remote the current local rank will # pull from. With homogeneous TP it happens to be the same rank_i. - tp_ratio = self._tp_size_mapping[self.engine_id] // remote_tp_size + tp_ratio = self._tp_size[self.engine_id] // remote_tp_size p_remote_rank = self.tp_rank // tp_ratio if p_remote_rank > 0: path = make_zmq_path("tcp", host, port + p_remote_rank) From 332723af05207c676270b9dc3b6f0ff1769c64be Mon Sep 17 00:00:00 2001 From: Will Eaton Date: Wed, 9 Jul 2025 11:35:05 -0400 Subject: [PATCH 33/33] use built in util Signed-off-by: Will Eaton --- vllm/utils.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index bec8bf91ce5..307876107e9 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -2947,17 +2947,11 @@ def build_uri(scheme: str, Returns: Complete URI string """ - # Handle IPv6 addresses - if host: - try: - # Check if it's an IPv6 address - ip = ipaddress.ip_address(host) - # Ensure IPv6 addresses are bracketed - if (isinstance(ip, ipaddress.IPv6Address) - and not (host.startswith('[') and host.endswith(']'))): - host = f'[{host}]' - except ValueError: - pass + + # Ensure IPv6 addresses are bracketed + if (is_valid_ipv6_address(host) + and not (host.startswith('[') and host.endswith(']'))): + host = f'[{host}]' netloc = f"{host}:{port}" if port else host return urlunparse((scheme, netloc, path, params, query, fragment))