Skip to content

Draft: WIP NixlConnector allow configurable handshake backend +HTTP #19447

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 36 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
efbd791
initial noodling
wseaton Jun 2, 2025
9dd2c1c
attempt to background agent registration
wseaton Jun 10, 2025
a0c03cd
more simple immediate retry
wseaton Jun 11, 2025
cf616b5
fix bad merge
wseaton Jun 11, 2025
505f586
implement nicks suggestions
wseaton Jun 16, 2025
a1a0918
change retry logic and move to scheduler; simply background handshake…
wseaton Jun 16, 2025
518a59f
fixup by changing types at callsite
wseaton Jun 16, 2025
6426e3f
fixup popping requests
wseaton Jun 16, 2025
29885af
debug logging
wseaton Jun 16, 2025
504bba7
working checkpoint
wseaton Jun 17, 2025
dd49c96
fix unreachable
wseaton Jun 17, 2025
f65c1b3
remove uv.lock
wseaton Jun 17, 2025
efd655b
remove unused
wseaton Jun 17, 2025
dba3835
flip protocol; fix scheduling order bug
wseaton Jun 18, 2025
85855d1
fix bug in case of no kvconnectorgroup
wseaton Jun 18, 2025
1fc1af4
actually use handshake timeout; simplify route
wseaton Jun 18, 2025
9d8c15c
revert back to working het TP logic
wseaton Jun 19, 2025
fbf0630
Merge branch 'main' into kv-xfer-updates
wseaton Jun 26, 2025
69af91c
Resolve merge conflicts and remove deprecated pending handshake logic
wseaton Jun 26, 2025
91af1cd
Merge remote-tracking branch 'origin/main' into kv-xfer-updates
wseaton Jun 27, 2025
83ec83a
move nixl sidechannel to own entrypoint
wseaton Jun 27, 2025
9a86b37
allow configurable handshake strategy
wseaton Jun 30, 2025
cc887f5
pre-commit fixes
wseaton Jun 30, 2025
b960355
satisfy mypy
wseaton Jun 30, 2025
e0e88b2
add comments; retrigger fastcheck
wseaton Jun 30, 2025
34e3e55
add type hint in test
wseaton Jun 30, 2025
2325470
revert api change
wseaton Jun 30, 2025
430ad03
fix usage
wseaton Jun 30, 2025
4f30966
remove dead code
wseaton Jun 30, 2025
54e088e
fix type declaration
wseaton Jun 30, 2025
ead79d6
revert spurious changes to multi-connector
wseaton Jun 30, 2025
1a6ac2c
remove dead code for threading
wseaton Jul 1, 2025
f940b6d
remove dead engine code
wseaton Jul 1, 2025
7768f21
fix optional ordering
wseaton Jul 1, 2025
3568422
naming, fix engine default param issue; pr feedback
wseaton Jul 9, 2025
332723a
use built in util
wseaton Jul 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
312 changes: 312 additions & 0 deletions tests/v1/kv_connector/unit/test_nixl_handshake_strategies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,312 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import base64
import json
from typing import Any
from unittest.mock import MagicMock, patch
from urllib.error import URLError

import pytest

from vllm import envs
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
HandshakeStrategy, HttpHandshakeStrategy, NixlAgentMetadata,
ZmqHandshakeStrategy)


class TestHandshakeStrategyAbstraction:

def test_abstract_base_class(self):
with pytest.raises(TypeError):
HandshakeStrategy(None, 0, 1, 8080, "test-engine")


class TestZmqHandshakeStrategy:

def create_test_metadata(self) -> NixlAgentMetadata:
return NixlAgentMetadata(engine_id="test-engine",
agent_metadata=b"test-agent-data",
kv_caches_base_addr=[12345],
num_blocks=100,
block_len=16,
attn_backend_name="FLASH_ATTN_VLLM_V1")

@patch(
'vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.ZmqHandshakeStrategy._zmq_ctx'
)
@patch('vllm.utils.make_zmq_path')
def test_zmq_handshake_success(self, mock_make_path, mock_zmq_ctx):
mock_nixl = MagicMock()
mock_add_agent = MagicMock(return_value="agent-name-0")

strategy = ZmqHandshakeStrategy(mock_nixl, 0, 1, 8080, "test-engine",
mock_add_agent)

mock_socket = MagicMock()
mock_zmq_ctx.return_value.__enter__.return_value = mock_socket
mock_make_path.return_value = "tcp://localhost:8080"

test_metadata = self.create_test_metadata()
with patch('msgspec.msgpack.Decoder') as mock_decoder_class:
mock_decoder = MagicMock()
mock_decoder_class.return_value = mock_decoder
mock_decoder.decode.return_value = test_metadata

result = strategy.initiate_handshake("localhost", 8080, 1)

assert result == {0: "agent-name-0"}
mock_add_agent.assert_called_once()
mock_socket.send.assert_called()
mock_socket.recv.assert_called()

@patch(
'vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.ZmqHandshakeStrategy._zmq_ctx'
)
@patch('vllm.utils.make_zmq_path')
def test_zmq_handshake_multi_rank(self, mock_make_path, mock_zmq_ctx):
mock_nixl = MagicMock()
mock_add_agent = MagicMock(side_effect=["agent-0", "agent-1"])

strategy = ZmqHandshakeStrategy(mock_nixl, 1, 2, 8080, "test-engine",
mock_add_agent)

mock_socket = MagicMock()
mock_zmq_ctx.return_value.__enter__.return_value = mock_socket
mock_make_path.side_effect = [
"tcp://localhost:8080", "tcp://localhost:8081"
]

test_metadata = self.create_test_metadata()
with patch('msgspec.msgpack.Decoder') as mock_decoder_class:
mock_decoder = MagicMock()
mock_decoder_class.return_value = mock_decoder
mock_decoder.decode.return_value = test_metadata

result = strategy.initiate_handshake("localhost", 8080, 2)

assert result == {0: "agent-0", 1: "agent-1"}
assert mock_add_agent.call_count == 2

@patch('threading.Thread')
def test_setup_listener(self, mock_thread):
mock_nixl = MagicMock()
mock_add_agent = MagicMock()

strategy = ZmqHandshakeStrategy(mock_nixl, 0, 1, 8080, "test-engine",
mock_add_agent)

mock_thread_instance = MagicMock()
mock_thread.return_value = mock_thread_instance

test_metadata = self.create_test_metadata()

with patch('threading.Event') as mock_event_class:
mock_event = MagicMock()
mock_event_class.return_value = mock_event

strategy.setup_listener(test_metadata)

mock_thread.assert_called_once()
mock_thread_instance.start.assert_called_once()
mock_event.wait.assert_called_once()

def test_cleanup(self):
mock_nixl = MagicMock()
mock_add_agent = MagicMock()

strategy = ZmqHandshakeStrategy(mock_nixl, 0, 1, 8080, "test-engine",
mock_add_agent)

mock_thread = MagicMock()
strategy._listener_thread = mock_thread

strategy.cleanup()

mock_thread.join.assert_called_once_with(timeout=0)


class TestHttpHandshakeStrategy:

def create_test_metadata_response(self) -> dict:
return {
"0": {
"0": {
"engine_id":
"3871ab24-6b5a-4ea5-a614-5381594bcdde",
"agent_metadata":
base64.b64encode(b"nixl-prefill-agent-data").decode(),
"kv_caches_base_addr": [0x7f8b2c000000],
"num_blocks":
1000,
"block_len":
128,
"attn_backend_name":
"FLASH_ATTN_VLLM_V1"
}
}
}

@patch(
'vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.urlopen')
def test_http_handshake_success(self, mock_urlopen):
mock_nixl = MagicMock()
mock_add_agent = MagicMock(return_value="remote-agent-0")

strategy = HttpHandshakeStrategy(mock_nixl, 0, 1, 8080, "test-engine",
mock_add_agent)

mock_response = MagicMock()
mock_response.read.return_value = json.dumps(
self.create_test_metadata_response()).encode()
mock_urlopen.return_value.__enter__.return_value = mock_response

result = strategy.initiate_handshake("localhost", 8080, 1)

assert result == {0: "remote-agent-0"}
mock_add_agent.assert_called_once()

call_args = mock_add_agent.call_args
metadata = call_args[0][0]
assert isinstance(metadata, NixlAgentMetadata)
assert metadata.engine_id == "3871ab24-6b5a-4ea5-a614-5381594bcdde"
assert metadata.agent_metadata == b"nixl-prefill-agent-data"

@patch(
'vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.urlopen')
def test_http_handshake_multi_rank(self, mock_urlopen):
mock_nixl = MagicMock()
mock_add_agent = MagicMock(return_value="remote-agent-1")

strategy = HttpHandshakeStrategy(mock_nixl, 1, 2, 8080, "test-engine",
mock_add_agent)

response_data = {
"0": {
"0": {
"engine_id":
"339a1bdd-e9ad-4c6e-a3e3-e0e7cca2238d",
"agent_metadata":
base64.b64encode(b"decode-agent-0-data").decode(),
"kv_caches_base_addr": [0x7f8b2c000000],
"num_blocks":
800,
"block_len":
128,
"attn_backend_name":
"FLASH_ATTN_VLLM_V1"
},
"1": {
"engine_id":
"339a1bdd-e9ad-4c6e-a3e3-e0e7cca2238d",
"agent_metadata":
base64.b64encode(b"decode-agent-1-data").decode(),
"kv_caches_base_addr": [0x7f8b2d000000],
"num_blocks":
800,
"block_len":
128,
"attn_backend_name":
"FLASH_ATTN_VLLM_V1"
}
}
}

mock_response = MagicMock()
mock_response.read.return_value = json.dumps(response_data).encode()
mock_urlopen.return_value.__enter__.return_value = mock_response

result = strategy.initiate_handshake("localhost", 8080, 2)

assert result == {1: "remote-agent-1"}

call_args = mock_add_agent.call_args
metadata = call_args[0][0]
assert metadata.agent_metadata == b"decode-agent-1-data"

@patch(
'vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.urlopen')
def test_http_handshake_url_error(self, mock_urlopen):
mock_nixl = MagicMock()
mock_add_agent = MagicMock()

strategy = HttpHandshakeStrategy(mock_nixl, 0, 1, 8080, "test-engine",
mock_add_agent)

mock_urlopen.side_effect = URLError("Connection failed")

with pytest.raises(URLError):
strategy.initiate_handshake("localhost", 8080, 1)

@patch(
'vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.urlopen')
def test_http_handshake_none_response(self, mock_urlopen):
mock_nixl = MagicMock()
mock_add_agent = MagicMock()

strategy = HttpHandshakeStrategy(mock_nixl, 0, 1, 8080, "test-engine",
mock_add_agent)

mock_response = MagicMock()
mock_response.read.return_value = json.dumps(None).encode()
mock_urlopen.return_value.__enter__.return_value = mock_response

with pytest.raises(RuntimeError,
match="Remote server returned None metadata"):
strategy.initiate_handshake("localhost", 8080, 1)

@patch(
'vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.urlopen')
def test_http_handshake_missing_rank(self, mock_urlopen):
mock_nixl = MagicMock()
mock_add_agent = MagicMock()

strategy = HttpHandshakeStrategy(mock_nixl, 1, 2, 8080,
"decode-engine", mock_add_agent)
mock_response = MagicMock()
empty_response: dict[str, dict[str, dict[str, Any]]] = {"0": {}}
mock_response.read.return_value = json.dumps(empty_response).encode()
mock_urlopen.return_value.__enter__.return_value = mock_response

with pytest.raises(RuntimeError,
match="No metadata found for dp_rank 0"):
strategy.initiate_handshake("localhost", 8080, 1)

def test_setup_listener_noop(self):
mock_nixl = MagicMock()
mock_add_agent = MagicMock()

strategy = HttpHandshakeStrategy(mock_nixl, 0, 1, 8080, "test-engine",
mock_add_agent)

test_metadata = NixlAgentMetadata(
engine_id="test-engine",
agent_metadata=b"test-data",
kv_caches_base_addr=[12345],
num_blocks=100,
block_len=16,
attn_backend_name="FLASH_ATTN_VLLM_V1")

strategy.setup_listener(test_metadata)

def test_cleanup_noop(self):
mock_nixl = MagicMock()
mock_add_agent = MagicMock()

strategy = HttpHandshakeStrategy(mock_nixl, 0, 1, 8080, "test-engine",
mock_add_agent)

strategy.cleanup()


class TestHandshakeStrategyIntegration:

@patch.dict('os.environ', {'VLLM_NIXL_HANDSHAKE_METHOD': 'zmq'})
@patch('vllm.envs.VLLM_NIXL_HANDSHAKE_METHOD', 'zmq')
def test_zmq_strategy_selection(self):
assert envs.VLLM_NIXL_HANDSHAKE_METHOD.lower() == 'zmq'

@patch.dict('os.environ', {'VLLM_NIXL_HANDSHAKE_METHOD': 'http'})
@patch('vllm.envs.VLLM_NIXL_HANDSHAKE_METHOD', 'http')
def test_http_strategy_selection(self):
assert envs.VLLM_NIXL_HANDSHAKE_METHOD.lower() == 'http'
11 changes: 11 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
import vllm.envs as envs
from vllm import version
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorHandshakeMetadata)
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.transformers_utils.config import (
Expand Down Expand Up @@ -1538,6 +1540,10 @@ class CacheConfig:
num_cpu_blocks: Optional[int] = field(default=None, init=False)
"""The number of blocks to allocate for CPU memory."""

transfer_handshake_metadata: Optional[dict[int, dict[
int, KVConnectorHandshakeMetadata]]] = field(default=None, init=False)
"""Metadata for KV connector handshake. Structure: dp_rank -> tp_rank"""

def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
Expand Down Expand Up @@ -4628,6 +4634,11 @@ def __post_init__(self):
# Hybrid KV cache manager is not compatible with KV events.
self.scheduler_config.disable_hybrid_kv_cache_manager = True

if (self.kv_transfer_config is not None
and self.kv_transfer_config.is_kv_transfer_instance):
from collections import defaultdict
self.cache_config.transfer_handshake_metadata = defaultdict(dict)

def update_sizes_for_sequence_parallelism(self,
possible_sizes: list) -> list:
# remove the sizes that not multiple of tp_size when
Expand Down
Loading