diff --git a/tests/v1/kv_connector/unit/test_nixl_handshake_strategies.py b/tests/v1/kv_connector/unit/test_nixl_handshake_strategies.py new file mode 100644 index 00000000000..c2a398625b6 --- /dev/null +++ b/tests/v1/kv_connector/unit/test_nixl_handshake_strategies.py @@ -0,0 +1,312 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import base64 +import json +from typing import Any +from unittest.mock import MagicMock, patch +from urllib.error import URLError + +import pytest + +from vllm import envs +from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( + HandshakeStrategy, HttpHandshakeStrategy, NixlAgentMetadata, + ZmqHandshakeStrategy) + + +class TestHandshakeStrategyAbstraction: + + def test_abstract_base_class(self): + with pytest.raises(TypeError): + HandshakeStrategy(None, 0, 1, 8080, "test-engine") + + +class TestZmqHandshakeStrategy: + + def create_test_metadata(self) -> NixlAgentMetadata: + return NixlAgentMetadata(engine_id="test-engine", + agent_metadata=b"test-agent-data", + kv_caches_base_addr=[12345], + num_blocks=100, + block_len=16, + attn_backend_name="FLASH_ATTN_VLLM_V1") + + @patch( + 'vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.ZmqHandshakeStrategy._zmq_ctx' + ) + @patch('vllm.utils.make_zmq_path') + def test_zmq_handshake_success(self, mock_make_path, mock_zmq_ctx): + mock_nixl = MagicMock() + mock_add_agent = MagicMock(return_value="agent-name-0") + + strategy = ZmqHandshakeStrategy(mock_nixl, 0, 1, 8080, "test-engine", + mock_add_agent) + + mock_socket = MagicMock() + mock_zmq_ctx.return_value.__enter__.return_value = mock_socket + mock_make_path.return_value = "tcp://localhost:8080" + + test_metadata = self.create_test_metadata() + with patch('msgspec.msgpack.Decoder') as mock_decoder_class: + mock_decoder = MagicMock() + mock_decoder_class.return_value = mock_decoder + mock_decoder.decode.return_value = test_metadata + + result = strategy.initiate_handshake("localhost", 8080, 1) + + assert result == {0: "agent-name-0"} + mock_add_agent.assert_called_once() + mock_socket.send.assert_called() + mock_socket.recv.assert_called() + + @patch( + 'vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.ZmqHandshakeStrategy._zmq_ctx' + ) + @patch('vllm.utils.make_zmq_path') + def test_zmq_handshake_multi_rank(self, mock_make_path, mock_zmq_ctx): + mock_nixl = MagicMock() + mock_add_agent = MagicMock(side_effect=["agent-0", "agent-1"]) + + strategy = ZmqHandshakeStrategy(mock_nixl, 1, 2, 8080, "test-engine", + mock_add_agent) + + mock_socket = MagicMock() + mock_zmq_ctx.return_value.__enter__.return_value = mock_socket + mock_make_path.side_effect = [ + "tcp://localhost:8080", "tcp://localhost:8081" + ] + + test_metadata = self.create_test_metadata() + with patch('msgspec.msgpack.Decoder') as mock_decoder_class: + mock_decoder = MagicMock() + mock_decoder_class.return_value = mock_decoder + mock_decoder.decode.return_value = test_metadata + + result = strategy.initiate_handshake("localhost", 8080, 2) + + assert result == {0: "agent-0", 1: "agent-1"} + assert mock_add_agent.call_count == 2 + + @patch('threading.Thread') + def test_setup_listener(self, mock_thread): + mock_nixl = MagicMock() + mock_add_agent = MagicMock() + + strategy = ZmqHandshakeStrategy(mock_nixl, 0, 1, 8080, "test-engine", + mock_add_agent) + + mock_thread_instance = MagicMock() + mock_thread.return_value = mock_thread_instance + + test_metadata = self.create_test_metadata() + + with patch('threading.Event') as mock_event_class: + mock_event = MagicMock() + mock_event_class.return_value = mock_event + + strategy.setup_listener(test_metadata) + + mock_thread.assert_called_once() + mock_thread_instance.start.assert_called_once() + mock_event.wait.assert_called_once() + + def test_cleanup(self): + mock_nixl = MagicMock() + mock_add_agent = MagicMock() + + strategy = ZmqHandshakeStrategy(mock_nixl, 0, 1, 8080, "test-engine", + mock_add_agent) + + mock_thread = MagicMock() + strategy._listener_thread = mock_thread + + strategy.cleanup() + + mock_thread.join.assert_called_once_with(timeout=0) + + +class TestHttpHandshakeStrategy: + + def create_test_metadata_response(self) -> dict: + return { + "0": { + "0": { + "engine_id": + "3871ab24-6b5a-4ea5-a614-5381594bcdde", + "agent_metadata": + base64.b64encode(b"nixl-prefill-agent-data").decode(), + "kv_caches_base_addr": [0x7f8b2c000000], + "num_blocks": + 1000, + "block_len": + 128, + "attn_backend_name": + "FLASH_ATTN_VLLM_V1" + } + } + } + + @patch( + 'vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.urlopen') + def test_http_handshake_success(self, mock_urlopen): + mock_nixl = MagicMock() + mock_add_agent = MagicMock(return_value="remote-agent-0") + + strategy = HttpHandshakeStrategy(mock_nixl, 0, 1, 8080, "test-engine", + mock_add_agent) + + mock_response = MagicMock() + mock_response.read.return_value = json.dumps( + self.create_test_metadata_response()).encode() + mock_urlopen.return_value.__enter__.return_value = mock_response + + result = strategy.initiate_handshake("localhost", 8080, 1) + + assert result == {0: "remote-agent-0"} + mock_add_agent.assert_called_once() + + call_args = mock_add_agent.call_args + metadata = call_args[0][0] + assert isinstance(metadata, NixlAgentMetadata) + assert metadata.engine_id == "3871ab24-6b5a-4ea5-a614-5381594bcdde" + assert metadata.agent_metadata == b"nixl-prefill-agent-data" + + @patch( + 'vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.urlopen') + def test_http_handshake_multi_rank(self, mock_urlopen): + mock_nixl = MagicMock() + mock_add_agent = MagicMock(return_value="remote-agent-1") + + strategy = HttpHandshakeStrategy(mock_nixl, 1, 2, 8080, "test-engine", + mock_add_agent) + + response_data = { + "0": { + "0": { + "engine_id": + "339a1bdd-e9ad-4c6e-a3e3-e0e7cca2238d", + "agent_metadata": + base64.b64encode(b"decode-agent-0-data").decode(), + "kv_caches_base_addr": [0x7f8b2c000000], + "num_blocks": + 800, + "block_len": + 128, + "attn_backend_name": + "FLASH_ATTN_VLLM_V1" + }, + "1": { + "engine_id": + "339a1bdd-e9ad-4c6e-a3e3-e0e7cca2238d", + "agent_metadata": + base64.b64encode(b"decode-agent-1-data").decode(), + "kv_caches_base_addr": [0x7f8b2d000000], + "num_blocks": + 800, + "block_len": + 128, + "attn_backend_name": + "FLASH_ATTN_VLLM_V1" + } + } + } + + mock_response = MagicMock() + mock_response.read.return_value = json.dumps(response_data).encode() + mock_urlopen.return_value.__enter__.return_value = mock_response + + result = strategy.initiate_handshake("localhost", 8080, 2) + + assert result == {1: "remote-agent-1"} + + call_args = mock_add_agent.call_args + metadata = call_args[0][0] + assert metadata.agent_metadata == b"decode-agent-1-data" + + @patch( + 'vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.urlopen') + def test_http_handshake_url_error(self, mock_urlopen): + mock_nixl = MagicMock() + mock_add_agent = MagicMock() + + strategy = HttpHandshakeStrategy(mock_nixl, 0, 1, 8080, "test-engine", + mock_add_agent) + + mock_urlopen.side_effect = URLError("Connection failed") + + with pytest.raises(URLError): + strategy.initiate_handshake("localhost", 8080, 1) + + @patch( + 'vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.urlopen') + def test_http_handshake_none_response(self, mock_urlopen): + mock_nixl = MagicMock() + mock_add_agent = MagicMock() + + strategy = HttpHandshakeStrategy(mock_nixl, 0, 1, 8080, "test-engine", + mock_add_agent) + + mock_response = MagicMock() + mock_response.read.return_value = json.dumps(None).encode() + mock_urlopen.return_value.__enter__.return_value = mock_response + + with pytest.raises(RuntimeError, + match="Remote server returned None metadata"): + strategy.initiate_handshake("localhost", 8080, 1) + + @patch( + 'vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.urlopen') + def test_http_handshake_missing_rank(self, mock_urlopen): + mock_nixl = MagicMock() + mock_add_agent = MagicMock() + + strategy = HttpHandshakeStrategy(mock_nixl, 1, 2, 8080, + "decode-engine", mock_add_agent) + mock_response = MagicMock() + empty_response: dict[str, dict[str, dict[str, Any]]] = {"0": {}} + mock_response.read.return_value = json.dumps(empty_response).encode() + mock_urlopen.return_value.__enter__.return_value = mock_response + + with pytest.raises(RuntimeError, + match="No metadata found for dp_rank 0"): + strategy.initiate_handshake("localhost", 8080, 1) + + def test_setup_listener_noop(self): + mock_nixl = MagicMock() + mock_add_agent = MagicMock() + + strategy = HttpHandshakeStrategy(mock_nixl, 0, 1, 8080, "test-engine", + mock_add_agent) + + test_metadata = NixlAgentMetadata( + engine_id="test-engine", + agent_metadata=b"test-data", + kv_caches_base_addr=[12345], + num_blocks=100, + block_len=16, + attn_backend_name="FLASH_ATTN_VLLM_V1") + + strategy.setup_listener(test_metadata) + + def test_cleanup_noop(self): + mock_nixl = MagicMock() + mock_add_agent = MagicMock() + + strategy = HttpHandshakeStrategy(mock_nixl, 0, 1, 8080, "test-engine", + mock_add_agent) + + strategy.cleanup() + + +class TestHandshakeStrategyIntegration: + + @patch.dict('os.environ', {'VLLM_NIXL_HANDSHAKE_METHOD': 'zmq'}) + @patch('vllm.envs.VLLM_NIXL_HANDSHAKE_METHOD', 'zmq') + def test_zmq_strategy_selection(self): + assert envs.VLLM_NIXL_HANDSHAKE_METHOD.lower() == 'zmq' + + @patch.dict('os.environ', {'VLLM_NIXL_HANDSHAKE_METHOD': 'http'}) + @patch('vllm.envs.VLLM_NIXL_HANDSHAKE_METHOD', 'http') + def test_http_strategy_selection(self): + assert envs.VLLM_NIXL_HANDSHAKE_METHOD.lower() == 'http' diff --git a/vllm/config.py b/vllm/config.py index 623ba3aaf10..cafc7d930f2 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -32,6 +32,8 @@ import vllm.envs as envs from vllm import version from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorHandshakeMetadata) from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.transformers_utils.config import ( @@ -1538,6 +1540,10 @@ class CacheConfig: num_cpu_blocks: Optional[int] = field(default=None, init=False) """The number of blocks to allocate for CPU memory.""" + transfer_handshake_metadata: Optional[dict[int, dict[ + int, KVConnectorHandshakeMetadata]]] = field(default=None, init=False) + """Metadata for KV connector handshake. Structure: dp_rank -> tp_rank""" + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, @@ -4628,6 +4634,11 @@ def __post_init__(self): # Hybrid KV cache manager is not compatible with KV events. self.scheduler_config.disable_hybrid_kv_cache_manager = True + if (self.kv_transfer_config is not None + and self.kv_transfer_config.is_kv_transfer_instance): + from collections import defaultdict + self.cache_config.transfer_handshake_metadata = defaultdict(dict) + def update_sizes_for_sequence_parallelism(self, possible_sizes: list) -> list: # remove the sizes that not multiple of tp_size when diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index f80b5eba235..eda9a0eba40 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -1,5 +1,4 @@ # SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ KVConnectorBase_V1 Class for Distributed KV Cache & Hidden State communication in vLLM v1 @@ -25,16 +24,17 @@ save_kv_layer() - starts saving KV for layer i (maybe async) wait_for_save() - blocks until all saves are done - get_finished() - called with ids of finished requests, returns ids of requests that have completed async sending/recving. """ import enum from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Callable, Optional +import msgspec import torch +from pydantic_core import core_schema from vllm.logger import init_logger from vllm.v1.core.sched.output import SchedulerOutput @@ -65,6 +65,39 @@ class KVConnectorMetadata: pass +class KVConnectorHandshakeMetadata( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + # required for @cached_property. + dict=True): + """ + Metadata optionally used for out of band connector handshake between + P/D workers. + """ + connector_type: str = "base" + + @classmethod + def __get_pydantic_core_schema__( + cls, _source_type: Any, _handler: Callable[[Any], + core_schema.CoreSchema] + ) -> core_schema.CoreSchema: + """bridge msgspec.Struct with pydantic for schema generation""" + return core_schema.no_info_after_validator_function( + cls, core_schema.dict_schema()) + + +class KVConnectorTransferMetadata( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + dict=True): + """ + Wrapper for transfer handshake metadata sent between engine and utils. + """ + tensor_parallel_rank: int + data_parallel_rank: int + content: Optional[dict] + + class KVConnectorBase_V1(ABC): def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): @@ -74,6 +107,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): self._connector_metadata = KVConnectorMetadata() self._vllm_config = vllm_config self._role = role + self._handshake_metadata: Optional[KVConnectorHandshakeMetadata] = None @property def role(self) -> KVConnectorRole: @@ -193,14 +227,30 @@ def get_finished( finished generating tokens. Returns: - ids of requests that have finished asynchronous transfer - (requests that previously returned True from request_finished()), - tuple of (sending/saving ids, recving/loading ids). + Tuple of (finished_sending, finished_recving) request ID sets. The finished saves/sends req ids must belong to a set provided in a call to this method (this call or a prior one). """ return None, None + def get_pending_handshake_req_ids(self) -> Optional[set[str]]: + """ + Get request IDs that are currently pending handshake completion. + + Returns: + Set of request IDs waiting for handshake, or None if not applicable. + """ + return None + + def get_handshake_metadata(self) -> Optional[KVConnectorHandshakeMetadata]: + """ + Get the handshake metadata for the connector. + + Returns: + KVConnectorHandshakeMetadata: the handshake metadata. + """ + return self._handshake_metadata + # ============================== # Scheduler-side methods # ============================== @@ -225,8 +275,7 @@ def get_num_new_matched_tokens( - The number of tokens that can be loaded from the external KV cache beyond what is already computed. - `True` if external KV cache tokens will be loaded - asynchronously (between scheduler steps). Must be - 'False' if the first element is 0. + asynchronously (between scheduler steps). """ pass @@ -236,13 +285,11 @@ def update_state_after_alloc(self, request: "Request", num_external_tokens: int): """ Update KVConnector state after block allocation. - If get_num_new_matched_tokens previously returned True for a request, this function may be called twice for that same request - first when blocks are allocated for the connector tokens to be asynchronously loaded into, and second when any additional blocks are allocated, after the load/transfer is complete. - Args: request (Request): the request object. blocks (KVCacheBlocks): the blocks allocated for the request. diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 7a077dce770..8d0619dfaa1 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -1,16 +1,22 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import base64 import contextlib +import json import math import queue import threading import time import uuid +from abc import ABC, abstractmethod from collections import defaultdict from collections.abc import Iterator from concurrent.futures import Future, ThreadPoolExecutor -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Optional +from urllib.error import HTTPError, URLError +from urllib.request import Request as URLRequest +from urllib.request import urlopen import msgspec import torch @@ -20,7 +26,8 @@ from vllm.attention.selector import backend_name_to_enum, get_attn_backend from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) + KVConnectorBase_V1, KVConnectorHandshakeMetadata, KVConnectorMetadata, + KVConnectorRole) from vllm.distributed.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, get_tp_group) @@ -28,7 +35,7 @@ from vllm.forward_context import ForwardContext from vllm.logger import init_logger from vllm.platforms import _Backend -from vllm.utils import make_zmq_path, make_zmq_socket, round_down +from vllm.utils import build_uri, make_zmq_path, make_zmq_socket, round_down from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.request import RequestStatus @@ -53,17 +60,14 @@ NixlWrapper = None -class NixlAgentMetadata( - msgspec.Struct, - omit_defaults=True, # type: ignore[call-arg] - # required for @cached_property. - dict=True): - engine_id: str - agent_metadata: bytes - kv_caches_base_addr: list[int] - num_blocks: int - block_len: int - attn_backend_name: str +class NixlAgentMetadata(KVConnectorHandshakeMetadata): + engine_id: str = field() + agent_metadata: bytes = field() + kv_caches_base_addr: list[int] = field() + num_blocks: int = field() + block_len: int = field() + attn_backend_name: str = field() + connector_type: str = "nixl" @dataclass @@ -76,6 +80,241 @@ class ReqMeta: tp_size: int +class HandshakeStrategy(ABC): + """ + Abstract base class for handshake strategies. + + This class is used to abstract the handshake process for different + communication protocols. + """ + + def __init__(self, nixl_wrapper, tp_rank: int, tp_size: int, + side_channel_port: int, engine_id: str): + self.nixl_wrapper = nixl_wrapper + self.tp_rank = tp_rank + self.tp_size = tp_size + self.side_channel_port = side_channel_port + self.engine_id = engine_id + + @abstractmethod + def initiate_handshake(self, host: str, port: int, + remote_tp_size: int) -> dict[int, str]: + pass + + @abstractmethod + def setup_listener(self, metadata: NixlAgentMetadata) -> None: + pass + + @abstractmethod + def cleanup(self) -> None: + pass + + +class ZmqHandshakeStrategy(HandshakeStrategy): + """ + Handshake strategy that uses a ZMQ socket at port defined by + VLLM_NIXL_SIDE_CHANNEL_PORT + tp_rank for communication. + + This is the default handshake strategy for NIXL, and is P2P. + """ + + def __init__(self, nixl_wrapper, tp_rank: int, tp_size: int, + side_channel_port: int, engine_id: str, + add_remote_agent_func): + super().__init__(nixl_wrapper, tp_rank, tp_size, side_channel_port, + engine_id) + self.add_remote_agent_func = add_remote_agent_func + self._listener_thread: Optional[threading.Thread] = None + self._tp_size: dict[str, int] = {engine_id: tp_size} + + def initiate_handshake(self, host: str, port: int, + remote_tp_size: int) -> dict[int, str]: + start_time = time.perf_counter() + + def handshake(path: str, rank: int) -> tuple[NixlAgentMetadata, str]: + with self._zmq_ctx(zmq.REQ, path) as sock: + sock.send(GET_META_MSG) + metadata_bytes = sock.recv() + decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) + metadata = decoder.decode(metadata_bytes) + got_metadata_time = time.perf_counter() + + # Register Remote agent + agent_name = self.add_remote_agent_func( + metadata, rank, remote_tp_size) + setup_agent_time = time.perf_counter() + + logger.debug("NIXL handshake: get metadata took: %s", + got_metadata_time - start_time) + logger.debug("NIXL handshake: add agent took: %s", + setup_agent_time - got_metadata_time) + return metadata, agent_name + + # Handshake with remote agent-rank0 first to get the tp_size of remote + path = make_zmq_path("tcp", host, port) + logger.debug("Querying master rank metadata on path: %s", path) + metadata, agent_name_0 = handshake(path, 0) + + agents = {0: agent_name_0} + + # Handshake only with the other TP remote the current local rank will + # pull from. With homogeneous TP it happens to be the same rank_i. + tp_ratio = self._tp_size[self.engine_id] // remote_tp_size + p_remote_rank = self.tp_rank // tp_ratio + if p_remote_rank > 0: + path = make_zmq_path("tcp", host, port + p_remote_rank) + logger.debug("Querying metadata on path: %s at remote rank %s", + path, p_remote_rank) + _, agent_name = handshake(path, p_remote_rank) + agents[p_remote_rank] = agent_name + + return agents + + def setup_listener(self, metadata: NixlAgentMetadata) -> None: + ready_event = threading.Event() + self._listener_thread = threading.Thread( + target=self._nixl_handshake_listener, + args=(metadata, ready_event, self.side_channel_port, self.tp_rank), + daemon=True, + name="nixl_handshake_listener") + self._listener_thread.start() + ready_event.wait() + + def cleanup(self) -> None: + if self._listener_thread: + self._listener_thread.join(timeout=0) + + @staticmethod + def _nixl_handshake_listener(metadata: NixlAgentMetadata, + ready_event: threading.Event, base_port: int, + tp_rank: int): + encoder = msgspec.msgpack.Encoder() + encoded_data = encoder.encode(metadata) + size_in_bytes = len(encoded_data) + logger.debug("Size of encoded NixlAgentMetadata: %s bytes", + size_in_bytes) + + # Listen for new requests for metadata + host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST + path = make_zmq_path("tcp", host, base_port + tp_rank) + logger.debug("Starting listening on path: %s", path) + with ZmqHandshakeStrategy._zmq_ctx(zmq.ROUTER, path) as sock: + ready_event.set() + while True: + identity, _, msg = sock.recv_multipart() + if msg != GET_META_MSG: + logger.warning( + "Connection listener got unexpected message %s", msg) + sock.send_multipart((identity, b"", encoded_data)) + + @staticmethod + @contextlib.contextmanager + def _zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]: + if socket_type not in (zmq.ROUTER, zmq.REQ): + raise ValueError(f"Unexpected socket type: {socket_type}") + + ctx: Optional[zmq.Context] = None + try: + ctx = zmq.Context() + yield make_zmq_socket(ctx=ctx, + path=addr, + socket_type=socket_type, + bind=socket_type == zmq.ROUTER) + finally: + if ctx is not None: + ctx.destroy(linger=0) + + +class HttpHandshakeStrategy(HandshakeStrategy): + """ + Handshake strategy that uses HTTP requests to fetch metadata from a + remote server. This is done through the front-end, and is + North-South, not P2P. + """ + + def __init__(self, nixl_wrapper, tp_rank: int, tp_size: int, + side_channel_port: int, engine_id: str, + add_remote_agent_func): + super().__init__(nixl_wrapper, tp_rank, tp_size, side_channel_port, + engine_id) + self.add_remote_agent_func = add_remote_agent_func + self._tp_size_mapping: dict[str, int] = {engine_id: tp_size} + + def initiate_handshake(self, host: str, port: int, + remote_tp_size: int) -> dict[int, str]: + start_time = time.perf_counter() + logger.debug("Starting NIXL handshake with %s:%s", host, port) + + url = build_uri("http", host, port, path="get_kv_connector_metadata") + + try: + req = URLRequest(url) + with urlopen(req, + timeout=envs.VLLM_NIXL_HANDSHAKE_TIMEOUT) as response: + response_data = response.read().decode('utf-8') + res = json.loads(response_data) + except (URLError, HTTPError) as e: + logger.error("Failed to fetch metadata from %s: %s", url, e) + raise + + if res is None: + logger.warning( + "Remote server returned None metadata, skipping handshake") + raise RuntimeError("Remote server returned None metadata") + + # Get dp_rank 0 data (standard for disaggregated prefill-decode) + dp_data = res.get("0", {}) + if not dp_data: + raise RuntimeError("No metadata found for dp_rank 0") + + remote_tp_size = len(dp_data.keys()) + + # Handshake only with the remote TP rank that current local rank will + # pull from. With homogeneous TP it happens to be the same rank_i. + tp_ratio = self._tp_size_mapping[self.engine_id] // remote_tp_size + p_remote_rank = self.tp_rank // tp_ratio + + # Get data for the specific rank we need to connect to + rank_data = dp_data.get(str(p_remote_rank), {}) + if not rank_data: + raise RuntimeError( + f"No metadata found for remote rank {p_remote_rank}") + + metadata_bytes = rank_data.get("agent_metadata", None) + if metadata_bytes is None: + raise RuntimeError( + f"No agent metadata found for remote rank {p_remote_rank}") + + rank_data_copy = rank_data.copy() + rank_data_copy.pop("agent_metadata", None) + metadata = NixlAgentMetadata( + agent_metadata=base64.b64decode(metadata_bytes), **rank_data_copy) + + pre_register = time.perf_counter() + # Register Remote agent + remote_agent_name = self.add_remote_agent_func(metadata, p_remote_rank, + remote_tp_size) + agent_time = time.perf_counter() + + logger.debug("Finished registering remote agent for engine %s", + metadata.engine_id) + logger.debug("NIXL handshake: get metadata took: %s", + pre_register - start_time) + logger.debug("NIXL handshake: add agent took: %s", + agent_time - pre_register) + + logger.debug("NIXL handshake method completed for %s:%s", host, port) + + # Return remote rank -> agent name mapping + return {p_remote_rank: remote_agent_name} + + def setup_listener(self, metadata: NixlAgentMetadata) -> None: + pass + + def cleanup(self) -> None: + pass + + class NixlConnectorMetadata(KVConnectorMetadata): def __init__(self): @@ -101,6 +340,8 @@ def add_new_req( class NixlConnector(KVConnectorBase_V1): def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): + super().__init__(vllm_config, role) + assert vllm_config.kv_transfer_config is not None assert vllm_config.kv_transfer_config.engine_id is not None self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id @@ -154,8 +395,13 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): assert self.connector_worker is not None self.connector_worker.register_kv_caches(kv_caches) - def get_finished(self, - finished_req_ids: set[str]) -> tuple[set[str], set[str]]: + # Set handshake metadata using the base class method + if hasattr(self.connector_worker, 'xfer_metadata'): + self.set_handshake_metadata(self.connector_worker.xfer_metadata) + + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[Optional[set[str]], Optional[set[str]]]: """Get the finished recving and sending requests.""" assert self.connector_worker is not None return self.connector_worker.get_finished() @@ -179,6 +425,12 @@ def wait_for_save(self): """NixlConnector does not save explicitly.""" pass + def set_handshake_metadata(self, handshake_metadata): + logger.debug("Setting handshake metadata for NIXL connector: %s", + handshake_metadata) + assert self.connector_worker is not None + self._handshake_metadata = handshake_metadata + class NixlConnectorScheduler: """Implementation of Scheduler side methods""" @@ -362,7 +614,6 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): envs.VLLM_NIXL_SIDE_CHANNEL_PORT + vllm_config.parallel_config.data_parallel_rank * vllm_config.parallel_config.tensor_parallel_size) - # Metadata. self.engine_id: EngineId = engine_id self.tp_rank = get_tensor_model_parallel_rank() @@ -403,8 +654,6 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self._done_sending_count: defaultdict[ReqId, int] = defaultdict(lambda: 0) - # Background thread for handling new handshake requests. - self._nixl_handshake_listener_t: Optional[threading.Thread] = None # Background thread for initializing new NIXL handshakes. self._handshake_initiation_executor = ThreadPoolExecutor( # NIXL is not guaranteed to be thread-safe, limit 1 worker. @@ -442,78 +691,31 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # finish reading before safely freeing the blocks. self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int) - def __del__(self): - """Cleanup background threads on destruction.""" - self._handshake_initiation_executor.shutdown(wait=False) - if self._nixl_handshake_listener_t: - self._nixl_handshake_listener_t.join(timeout=0) - - @staticmethod - def _nixl_handshake_listener(metadata: NixlAgentMetadata, - ready_event: threading.Event, base_port: int, - tp_rank: int): - """Background thread for getting new NIXL handshakes.""" - # NOTE(rob): this is a simple implementation. We will move - # to a better approach via HTTP endpoint soon. + # Initialize handshake strategy + handshake_method = envs.VLLM_NIXL_HANDSHAKE_METHOD.lower() + if handshake_method == "zmq": + self._handshake_strategy: HandshakeStrategy = ZmqHandshakeStrategy( + self.nixl_wrapper, self.tp_rank, self.world_size, + self.side_channel_port, self.engine_id, self.add_remote_agent) + elif handshake_method == "http": + self._handshake_strategy = HttpHandshakeStrategy( + self.nixl_wrapper, self.tp_rank, self.world_size, + self.side_channel_port, self.engine_id, self.add_remote_agent) + else: + raise ValueError(f"Unknown handshake method: {handshake_method}. " + "Supported methods: 'zmq', 'http'") - encoder = msgspec.msgpack.Encoder() - encoded_data = encoder.encode(metadata) - size_in_bytes = len(encoded_data) - logger.debug("Size of encoded NixlAgentMetadata: %s bytes", - str(size_in_bytes)) + logger.info("Using %s handshake strategy", handshake_method) - # Listen for new requests for metadata. - host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST - path = make_zmq_path("tcp", host, base_port + tp_rank) - logger.debug("Starting listening on path: %s", path) - with zmq_ctx(zmq.ROUTER, path) as sock: - ready_event.set() - while True: - identity, _, msg = sock.recv_multipart() - if msg != GET_META_MSG: - logger.warning( - "Connection listener got unexpected message %s", msg) - sock.send_multipart((identity, b"", encoded_data)) + def __del__(self): + self._handshake_initiation_executor.shutdown(wait=False) + if hasattr(self, '_handshake_strategy'): + self._handshake_strategy.cleanup() def _nixl_handshake(self, host: str, port: int, remote_tp_size: int) -> dict[int, str]: - """Do a NIXL handshake with a remote instance.""" - - start_time = time.perf_counter() - - # NOTE(rob): we need each rank to have a unique port. This is - # a hack to keep us moving. We will switch when moving to etcd - # or where we have a single ZMQ socket in the scheduler. - - def handshake(path: str, rank: int) -> str: - # Send query for the request. - with zmq_ctx(zmq.REQ, path) as sock: - sock.send(GET_META_MSG) - metadata_bytes = sock.recv() - decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) - metadata = decoder.decode(metadata_bytes) - got_metadata_time = time.perf_counter() - - # Register Remote agent. - remote_agent_name = self.add_remote_agent( - metadata, rank, remote_tp_size) - setup_agent_time = time.perf_counter() - - logger.debug("NIXL handshake: get metadata took: %s", - got_metadata_time - start_time) - logger.debug("NIXL handshake: add agent took: %s", - setup_agent_time - got_metadata_time) - return remote_agent_name - - # Handshake only with the remote TP rank that current local rank will - # pull from. With homogeneous TP it happens to be the same rank_i. - tp_ratio = self._tp_size[self.engine_id] // remote_tp_size - p_remote_rank = self.tp_rank // tp_ratio - path = make_zmq_path("tcp", host, port + p_remote_rank) - logger.debug("Querying metadata on path: %s at remote rank %s", path, - p_remote_rank) - # Remote rank -> agent name. - return {p_remote_rank: handshake(path, p_remote_rank)} + return self._handshake_strategy.initiate_handshake( + host, port, remote_tp_size) def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data in nixl.""" @@ -591,6 +793,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): # Optimization for models with local attention (Llama 4) if self.vllm_config.model_config.hf_config.model_type == "llama4": from transformers import Llama4TextConfig + assert isinstance(self.vllm_config.model_config.hf_text_config, Llama4TextConfig) llama4_config = self.vllm_config.model_config.hf_text_config @@ -634,22 +837,17 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist( "NIXL_INIT_AGENT", descs) - # After KV Caches registered, listen for new connections. - metadata = NixlAgentMetadata( + # Store metadata on worker instance for main connector access + self.xfer_metadata = NixlAgentMetadata( engine_id=self.engine_id, agent_metadata=self.nixl_wrapper.get_agent_metadata(), kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id], num_blocks=self.num_blocks, block_len=self.block_len, attn_backend_name=self.backend_name) - ready_event = threading.Event() - self._nixl_handshake_listener_t = threading.Thread( - target=self._nixl_handshake_listener, - args=(metadata, ready_event, self.side_channel_port, self.tp_rank), - daemon=True, - name="nixl_handshake_listener") - self._nixl_handshake_listener_t.start() - ready_event.wait() # Wait for listener ZMQ socket to be ready. + + # Setup handshake strategy listener + self._handshake_strategy.setup_listener(self.xfer_metadata) def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata, @@ -779,9 +977,9 @@ def add_remote_agent(self, return remote_agent_name - def get_finished(self) -> tuple[set[str], set[str]]: + def get_finished(self) -> tuple[Optional[set[str]], Optional[set[str]]]: """ - Get requests that are done sending or recving. + Get requests that are done sending, done recving, and pending handshake. In TP>1 setup, each rank exchanges KVs with its counterpart ranks independently. get_finished() runs in a worker creates @@ -793,56 +991,56 @@ def get_finished(self) -> tuple[set[str], set[str]]: """ done_sending = self._get_new_notifs() done_recving = self._pop_done_transfers(self._recving_transfers) - if len(done_sending) > 0 or len(done_recving) > 0: - logger.debug( - "Rank %s, get_finished: %s requests done sending " - "and %s requests done recving", self.tp_rank, - len(done_sending), len(done_recving)) if self.world_size == 1: return done_sending, done_recving - # Rank 0: get finished from all other ranks. + return self._coordinate_multi_rank_results(done_sending, done_recving) + + def _coordinate_multi_rank_results( + self, local_sending: set[str], local_recving: set[str] + ) -> tuple[Optional[set[str]], Optional[set[str]]]: + """Coordinate results across multiple TP ranks.""" + if self.tp_rank == 0: - for req_id in done_sending: + # Rank 0 collects results from all other ranks. + for req_id in local_sending: self._done_sending_count[req_id] += 1 - for req_id in done_recving: + for req_id in local_recving: self._done_recving_count[req_id] += 1 - # Keep track of how many other ranks have finished. - other_ranks_finished_ids: list[str] = [] for i in range(1, self.world_size): - other_ranks_finished_ids.extend( - self.tp_group.recv_object(src=i)) - for req_id in other_ranks_finished_ids: - if (req_id in self._done_recving_count - or req_id in self._recving_transfers): - self._done_recving_count[req_id] += 1 - else: - self._done_sending_count[req_id] += 1 - - # Return ids that finished on all ranks to the scheduler. - all_done_recving: set[str] = set() - for req_id in list(self._done_recving_count.keys()): - if self._done_recving_count[req_id] == self.world_size: - del self._done_recving_count[req_id] - all_done_recving.add(req_id) - - all_done_sending: set[str] = set() - for req_id in list(self._done_sending_count.keys()): - if self._done_sending_count[req_id] == self.world_size: - del self._done_sending_count[req_id] - all_done_sending.add(req_id) + rank_data = self.tp_group.recv_object(src=i) + other_sending, other_recving = rank_data + + sending_set = other_sending or set() + recving_set = other_recving or set() + for req_id in sending_set | recving_set: + if (req_id in self._done_recving_count + or req_id in self._recving_transfers): + self._done_recving_count[req_id] += 1 + else: + self._done_sending_count[req_id] += 1 + + all_done_recving = self._get_globally_finished_requests( + self._done_recving_count) + all_done_sending = self._get_globally_finished_requests( + self._done_sending_count) return all_done_sending, all_done_recving - - # Ranks 1 to N-1: send finished ids to Rank 0. else: - finished_req_ids = list(done_recving.union(done_sending)) - self.tp_group.send_object(finished_req_ids, dst=0) - - # Unused as only Rank 0 results are sent to scheduler. - return done_sending, done_recving + self.tp_group.send_object((local_sending, local_recving), dst=0) + return local_sending, local_recving + + def _get_globally_finished_requests( + self, counter_dict: dict[str, int]) -> set[str]: + """Get request IDs that have finished on all ranks.""" + finished_req_ids = set() + for req_id in list(counter_dict.keys()): + if counter_dict[req_id] == self.world_size: + del counter_dict[req_id] + finished_req_ids.add(req_id) + return finished_req_ids def _get_new_notifs(self) -> set[str]: """ @@ -894,6 +1092,7 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): Start loading by triggering non-blocking nixl_xfer. We check for these trnxs to complete in each step(). """ + for req_id, meta in metadata.requests.items(): remote_engine_id = meta.remote_engine_id logger.debug( @@ -1078,22 +1277,3 @@ def _get_block_descs_ids(self, for block_id in block_ids: descs_ids.append(reg_id * num_blocks + block_id) return descs_ids - - -@contextlib.contextmanager -def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]: - """Context manager for a ZMQ socket""" - - if socket_type not in (zmq.ROUTER, zmq.REQ): - raise ValueError(f"Unexpected socket type: {socket_type}") - - ctx: Optional[zmq.Context] = None - try: - ctx = zmq.Context() # type: ignore[attr-defined] - yield make_zmq_socket(ctx=ctx, - path=addr, - socket_type=socket_type, - bind=socket_type == zmq.ROUTER) - finally: - if ctx is not None: - ctx.destroy(linger=0) diff --git a/vllm/entrypoints/nixl_side_channel_server.py b/vllm/entrypoints/nixl_side_channel_server.py new file mode 100644 index 00000000000..97935ce0c5d --- /dev/null +++ b/vllm/entrypoints/nixl_side_channel_server.py @@ -0,0 +1,119 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import asyncio +from typing import Optional + +import uvicorn +from fastapi import FastAPI + +from vllm import envs +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorHandshakeMetadata) +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class NixlSideChannelServer: + + def __init__(self, vllm_config: VllmConfig, host: str, port: int): + self.vllm_config = vllm_config + self.host = host + self.port = port + self.app = FastAPI(title="vLLM NIXL Side Channel Server") + self.server = None + self.server_thread = None + self._setup_routes() + + def _setup_routes(self): + + @self.app.get("/get_kv_connector_metadata") + @self.app.get("/get_kv_connector_metadata/{dp_rank}") + @self.app.get("/get_kv_connector_metadata/{dp_rank}/{tp_rank}") + async def get_kv_connector_metadata(dp_rank: Optional[int] = None, + tp_rank: Optional[int] = None): + kv_meta: Optional[dict[int, dict[ + int, KVConnectorHandshakeMetadata]]] = ( + self.vllm_config.cache_config.transfer_handshake_metadata) + + if kv_meta is None: + return None + + if dp_rank is not None: + if dp_rank not in kv_meta: + return {} + dp_data = kv_meta[dp_rank] + + if tp_rank is not None: + if tp_rank not in dp_data: + return {} + return {dp_rank: {tp_rank: dp_data[tp_rank]}} + else: + return {dp_rank: dp_data} + + return kv_meta + + async def start_async(self): + if self.server is not None: + logger.warning("Side channel server is already running") + return + + logger.info("Starting NIXL side channel server on %s:%s", self.host, + self.port) + + # use uvicorn directly to avoid dependency on engine_client + config = uvicorn.Config( + app=self.app, + host=self.host, + port=self.port, + log_level="info", + access_log=True, + ) + self.server = uvicorn.Server(config) + + # start the server in a background task + if self.server is not None: + asyncio.create_task(self.server.serve()) + logger.info("NIXL side channel server started successfully") + + async def stop_async(self): + if self.server is not None: + logger.info("Stopping NIXL side channel server") + try: + self.server.should_exit = True + await asyncio.sleep(1) # give it time to shutdown + except Exception as e: + logger.warning("Error during side channel server shutdown: %s", + e) + self.server = None + logger.info("NIXL side channel server stopped") + + +def should_start_nixl_side_channel_server(vllm_config: VllmConfig) -> bool: + if vllm_config.kv_transfer_config is None: + return False + + if vllm_config.kv_transfer_config.kv_connector != "NixlConnector": + return False + + handshake_method = envs.VLLM_NIXL_HANDSHAKE_METHOD.lower() + return handshake_method == "http" + + +async def start_nixl_side_channel_server_if_needed( + vllm_config: VllmConfig) -> Optional[NixlSideChannelServer]: + if not should_start_nixl_side_channel_server(vllm_config): + return None + + side_channel_host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST + side_channel_port = envs.VLLM_NIXL_SIDE_CHANNEL_PORT + + logger.info("Starting NIXL side channel metadata server on %s:%d", + side_channel_host, side_channel_port) + + server = NixlSideChannelServer(vllm_config, side_channel_host, + side_channel_port) + await server.start_async() + return server diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index f3fd1548627..cc40e1f727d 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -47,6 +47,8 @@ resolve_mistral_chat_template) from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.nixl_side_channel_server import ( + start_nixl_side_channel_server_if_needed) from vllm.entrypoints.openai.cli_args import (log_non_default_args, make_arg_parser, validate_parsed_serve_args) @@ -1447,6 +1449,13 @@ async def run_server_worker(listen_address, vllm_config = await engine_client.get_vllm_config() await init_app_state(engine_client, vllm_config, app.state, args) + nixl_side_channel_server = None + try: + nixl_side_channel_server = await \ + start_nixl_side_channel_server_if_needed(vllm_config) + except Exception as e: + logger.warning("Failed to start NIXL side channel server: %s", e) + logger.info("Starting vLLM API server %d on %s", server_index, listen_address) shutdown_task = await serve_http( @@ -1471,6 +1480,12 @@ async def run_server_worker(listen_address, try: await shutdown_task finally: + if nixl_side_channel_server is not None: + try: + await nixl_side_channel_server.stop_async() + except Exception as e: + logger.warning("Error stopping NIXL side channel server: %s", + e) sock.close() diff --git a/vllm/envs.py b/vllm/envs.py index a3f19c7ee5c..d6e9b9bee96 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -126,6 +126,8 @@ VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost" VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557 + VLLM_NIXL_HANDSHAKE_TIMEOUT: float = 2.0 + VLLM_NIXL_HANDSHAKE_METHOD: str = "zmq" VLLM_ALL2ALL_BACKEND: str = "naive" VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840 VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1 @@ -899,6 +901,15 @@ def get_vllm_port() -> Optional[int]: "VLLM_NIXL_SIDE_CHANNEL_PORT": lambda: int(os.getenv("VLLM_NIXL_SIDE_CHANNEL_PORT", "5557")), + # Timeout in seconds for NIXL HTTP handshake requests. + # Default is 2 seconds + "VLLM_NIXL_HANDSHAKE_TIMEOUT": + lambda: float(os.getenv("VLLM_NIXL_HANDSHAKE_TIMEOUT", "2.0")), + + # NIXL handshake method ("zmq" or "http") + "VLLM_NIXL_HANDSHAKE_METHOD": + lambda: os.getenv("VLLM_NIXL_HANDSHAKE_METHOD", "zmq"), + # all2all backend for vllm's expert parallel communication # Available options: # - "naive": naive all2all implementation using all-reduce diff --git a/vllm/executor/uniproc_executor.py b/vllm/executor/uniproc_executor.py index 7ebeb4a2255..94d1f597bb2 100644 --- a/vllm/executor/uniproc_executor.py +++ b/vllm/executor/uniproc_executor.py @@ -57,6 +57,10 @@ def collective_rpc(self, answer = run_method(self.driver_worker, method, args, kwargs) return [answer] + def get_kv_connector_handshake_metadata(self) -> List[Optional[Dict]]: + """Get KV connector handshake metadata from all workers.""" + return self.collective_rpc("get_kv_connector_handshake_metadata") + def check_health(self) -> None: # UniProcExecutor will always be healthy as long as # it's running. diff --git a/vllm/utils.py b/vllm/utils.py index fdefda901c4..307876107e9 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -47,7 +47,7 @@ from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple, Optional, Sequence, Tuple, Type, TypeVar, Union, cast, overload) -from urllib.parse import urlparse +from urllib.parse import urlparse, urlunparse from uuid import uuid4 import cachetools @@ -2925,6 +2925,38 @@ def is_torch_equal_or_newer(target: str) -> bool: return Version(importlib.metadata.version('torch')) >= Version(target) +def build_uri(scheme: str, + host: str, + port: Optional[int] = None, + path: str = "", + params: str = "", + query: str = "", + fragment: str = "") -> str: + """ + Robustly build a URI that properly handles IPv6 addresses. + + Args: + scheme: URI scheme (e.g., 'http', 'https') + host: hostname or IP address + port: port number (optional) + path: path component + params: parameters component + query: query string + fragment: fragment identifier + + Returns: + Complete URI string + """ + + # Ensure IPv6 addresses are bracketed + if (is_valid_ipv6_address(host) + and not (host.startswith('[') and host.endswith(']'))): + host = f'[{host}]' + + netloc = f"{host}:{port}" if port else host + return urlunparse((scheme, netloc, path, params, query, fragment)) + + # Helper function used in testing. def _is_torch_equal_or_newer(torch_version: str, target: str) -> bool: torch_version = version.parse(torch_version) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 453ed364dc8..f5ad95cac22 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -78,14 +78,18 @@ def __init__(self, executor_fail_callback) # Setup KV Caches and update CacheConfig after profiling. - num_gpu_blocks, num_cpu_blocks, kv_cache_config = \ - self._initialize_kv_caches(vllm_config) + num_gpu_blocks, num_cpu_blocks, kv_cache_config, \ + transfer_handshake_metadata = self._initialize_kv_caches( + vllm_config) vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks self.collective_rpc("initialize_cache", args=(num_gpu_blocks, num_cpu_blocks)) + # Store KV connector metadata for handshake + self.transfer_handshake_metadata = transfer_handshake_metadata + self.structured_output_manager = StructuredOutputManager(vllm_config) # Setup scheduler. @@ -131,7 +135,8 @@ def __init__(self, self.batch_queue = queue.Queue(self.batch_queue_size) def _initialize_kv_caches( - self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]: + self, vllm_config: VllmConfig + ) -> tuple[int, int, KVCacheConfig, Optional[list[Optional[dict]]]]: start = time.time() # Get all kv cache needed by the model @@ -168,10 +173,16 @@ def _initialize_kv_caches( # Initialize kv cache and warmup the execution self.model_executor.initialize_from_config(kv_cache_configs) + # Collect KV connector xfer metadata from workers + # (after KV cache registration) + transfer_handshake_metadata = ( + self.model_executor.get_kv_connector_handshake_metadata()) + elapsed = time.time() - start logger.info(("init engine (profile, create kv cache, " "warmup model) took %.2f seconds"), elapsed) - return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config + return (num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config, + transfer_handshake_metadata) def add_request(self, request: EngineCoreRequest): """Add request to the scheduler.""" @@ -436,12 +447,29 @@ def _perform_handshake( # Send ready message. num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks - handshake_socket.send( - msgspec.msgpack.encode({ - "status": "READY", - "local": on_head_node, - "num_gpu_blocks": num_gpu_blocks, - })) + handshake_message = { + "status": "READY", + "local": on_head_node, + "num_gpu_blocks": num_gpu_blocks, + } + + # Include KV connector metadata if available + if hasattr(self, 'transfer_handshake_metadata' + ) and self.transfer_handshake_metadata: + # self.transfer_handshake_metadata is list of dicts from workers + # Each dict already has structure {dp_rank: {tp_rank: metadata}} + # Merge all worker dicts into a single dict + content: dict[str, dict[str, dict[str, Any]]] = {} + for worker_dict in self.transfer_handshake_metadata: + if worker_dict is not None: + # Deep merge nested dictionaries instead of overwrite + for dp_rank, tp_dict in worker_dict.items(): + if dp_rank not in content: + content[dp_rank] = {} + content[dp_rank].update(tp_dict) + handshake_message["transfer_handshake_metadata"] = content + + handshake_socket.send(msgspec.msgpack.encode(handshake_message)) @staticmethod def startup_handshake( diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index 50b9634a49e..135c03532a0 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from concurrent.futures import Future -from typing import Callable, Union +from typing import Callable, Optional, Union import torch import torch.distributed as dist @@ -80,6 +80,11 @@ def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]: output = self.collective_rpc("get_kv_cache_spec") return output + def get_kv_connector_handshake_metadata( + self) -> list[Optional[dict[int, dict[int, dict]]]]: + output = self.collective_rpc("get_kv_connector_handshake_metadata") + return output + def execute_model( self, scheduler_output, diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 192c9067740..534f479bf89 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -536,6 +536,18 @@ def wait_for_engine_startup( num_gpu_blocks = cache_config.num_gpu_blocks or 0 num_gpu_blocks += msg["num_gpu_blocks"] cache_config.num_gpu_blocks = num_gpu_blocks + # stash KV connector metadata in vllm_config if passed in. + if txfer_metadata := msg.get("transfer_handshake_metadata"): + logger.debug( + "Received transfer handshake metadata from engine %s: %s", + eng_index, txfer_metadata) + if cache_config.transfer_handshake_metadata is None: + cache_config.transfer_handshake_metadata = defaultdict( + dict) + for dp_rank, tp_dict in txfer_metadata.items(): + for tp_rank, metadata in tp_dict.items(): + cache_config.transfer_handshake_metadata[dp_rank][ + tp_rank] = metadata start_pending[0 if local else 1] -= 1 engine.state = CoreEngineState.READY diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 3c9de572040..3eae0022536 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1679,8 +1679,11 @@ def get_finished_kv_transfers( scheduler_output: "SchedulerOutput", ) -> tuple[Optional[set[str]], Optional[set[str]]]: if has_kv_transfer_group(): - return get_kv_transfer_group().get_finished( + result = get_kv_transfer_group().get_finished( scheduler_output.finished_req_ids) + return ( + result.finished_sending if result.finished_sending else None, + result.finished_recving if result.finished_recving else None) return None, None def generate_draft_token_ids( diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 9e7e44d0686..aa2e190ad39 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -5,6 +5,7 @@ import os from typing import TYPE_CHECKING, Optional +import msgspec import torch import torch.distributed import torch.nn as nn @@ -15,7 +16,10 @@ from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce) -from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized +from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized, + get_kv_transfer_group, + has_kv_transfer_group, + is_v1_kv_transfer_group) from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -235,6 +239,30 @@ def determine_available_memory(self) -> int: return int(available_kv_cache_memory) + def get_kv_connector_handshake_metadata(self) -> Optional[dict]: + """Get KV connector metadata from this worker if available.""" + + if not has_kv_transfer_group(): + return None + + connector = get_kv_transfer_group() + if not is_v1_kv_transfer_group(connector): + logger.warning("The KV connector is not a v1 connector. " + "This method is only supported for v1 connectors.") + return None + + metadata = connector.get_handshake_metadata() + if metadata is None: + logger.warning( + "KV connector metadata is not available. " + "This may happen if the KV connector is not initialized " + "or the worker is not part of a disaggregated KV cache setup.") + return None + + tp_rank = get_tp_group().rank_in_group + dp_rank = self.vllm_config.parallel_config.data_parallel_rank_local + return {dp_rank: {tp_rank: msgspec.to_builtins(metadata)}} + def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: return self.model_runner.get_kv_cache_spec()