Skip to content

Commit 9a86b37

Browse files
committed
allow configurable handshake strategy
1 parent 83ec83a commit 9a86b37

File tree

4 files changed

+575
-68
lines changed

4 files changed

+575
-68
lines changed
Lines changed: 329 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,329 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import base64
5+
import json
6+
from unittest.mock import MagicMock, patch
7+
from urllib.error import URLError
8+
9+
import pytest
10+
11+
from vllm import envs
12+
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
13+
HandshakeStrategy, HttpHandshakeStrategy, NixlAgentMetadata,
14+
ZmqHandshakeStrategy)
15+
16+
17+
class TestHandshakeStrategyAbstraction:
18+
19+
def test_abstract_base_class(self):
20+
with pytest.raises(TypeError):
21+
HandshakeStrategy(None, 0, 1, 8080, "test-engine")
22+
23+
def test_strategy_interface(self):
24+
mock_nixl = MagicMock()
25+
mock_add_agent = MagicMock()
26+
27+
zmq_strategy = ZmqHandshakeStrategy(
28+
mock_nixl, 0, 1, 8080, "test-engine", mock_add_agent)
29+
assert hasattr(zmq_strategy, 'initiate_handshake')
30+
assert hasattr(zmq_strategy, 'setup_listener')
31+
assert hasattr(zmq_strategy, 'cleanup')
32+
33+
http_strategy = HttpHandshakeStrategy(
34+
mock_nixl, 0, 1, 8080, "test-engine", mock_add_agent)
35+
assert hasattr(http_strategy, 'initiate_handshake')
36+
assert hasattr(http_strategy, 'setup_listener')
37+
assert hasattr(http_strategy, 'cleanup')
38+
39+
40+
class TestZmqHandshakeStrategy:
41+
42+
def create_test_metadata(self) -> NixlAgentMetadata:
43+
return NixlAgentMetadata(
44+
engine_id="test-engine",
45+
agent_metadata=b"test-agent-data",
46+
kv_caches_base_addr=[12345],
47+
num_blocks=100,
48+
block_len=16,
49+
attn_backend_name="FLASH_ATTN_VLLM_V1"
50+
)
51+
52+
@patch('vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.ZmqHandshakeStrategy._zmq_ctx')
53+
@patch('vllm.utils.make_zmq_path')
54+
def test_zmq_handshake_success(self, mock_make_path, mock_zmq_ctx):
55+
mock_nixl = MagicMock()
56+
mock_add_agent = MagicMock(return_value="agent-name-0")
57+
58+
strategy = ZmqHandshakeStrategy(
59+
mock_nixl, 0, 1, 8080, "test-engine", mock_add_agent)
60+
61+
mock_socket = MagicMock()
62+
mock_zmq_ctx.return_value.__enter__.return_value = mock_socket
63+
mock_make_path.return_value = "tcp://localhost:8080"
64+
65+
test_metadata = self.create_test_metadata()
66+
with patch('msgspec.msgpack.Decoder') as mock_decoder_class:
67+
mock_decoder = MagicMock()
68+
mock_decoder_class.return_value = mock_decoder
69+
mock_decoder.decode.return_value = test_metadata
70+
71+
result = strategy.initiate_handshake("localhost", 8080, 1)
72+
73+
assert result == {0: "agent-name-0"}
74+
mock_add_agent.assert_called_once()
75+
mock_socket.send.assert_called()
76+
mock_socket.recv.assert_called()
77+
78+
@patch('vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.ZmqHandshakeStrategy._zmq_ctx')
79+
@patch('vllm.utils.make_zmq_path')
80+
def test_zmq_handshake_multi_rank(self, mock_make_path, mock_zmq_ctx):
81+
mock_nixl = MagicMock()
82+
mock_add_agent = MagicMock(side_effect=["agent-0", "agent-1"])
83+
84+
strategy = ZmqHandshakeStrategy(
85+
mock_nixl, 1, 2, 8080, "test-engine", mock_add_agent)
86+
87+
mock_socket = MagicMock()
88+
mock_zmq_ctx.return_value.__enter__.return_value = mock_socket
89+
mock_make_path.side_effect = ["tcp://localhost:8080", "tcp://localhost:8081"]
90+
91+
test_metadata = self.create_test_metadata()
92+
with patch('msgspec.msgpack.Decoder') as mock_decoder_class:
93+
mock_decoder = MagicMock()
94+
mock_decoder_class.return_value = mock_decoder
95+
mock_decoder.decode.return_value = test_metadata
96+
97+
result = strategy.initiate_handshake("localhost", 8080, 2)
98+
99+
assert result == {0: "agent-0", 1: "agent-1"}
100+
assert mock_add_agent.call_count == 2
101+
102+
@patch('threading.Thread')
103+
def test_setup_listener(self, mock_thread):
104+
mock_nixl = MagicMock()
105+
mock_add_agent = MagicMock()
106+
107+
strategy = ZmqHandshakeStrategy(
108+
mock_nixl, 0, 1, 8080, "test-engine", mock_add_agent)
109+
110+
mock_thread_instance = MagicMock()
111+
mock_thread.return_value = mock_thread_instance
112+
113+
test_metadata = self.create_test_metadata()
114+
115+
with patch('threading.Event') as mock_event_class:
116+
mock_event = MagicMock()
117+
mock_event_class.return_value = mock_event
118+
119+
strategy.setup_listener(test_metadata)
120+
121+
mock_thread.assert_called_once()
122+
mock_thread_instance.start.assert_called_once()
123+
mock_event.wait.assert_called_once()
124+
125+
def test_cleanup(self):
126+
mock_nixl = MagicMock()
127+
mock_add_agent = MagicMock()
128+
129+
strategy = ZmqHandshakeStrategy(
130+
mock_nixl, 0, 1, 8080, "test-engine", mock_add_agent)
131+
132+
mock_thread = MagicMock()
133+
strategy._listener_thread = mock_thread
134+
135+
strategy.cleanup()
136+
137+
mock_thread.join.assert_called_once_with(timeout=0)
138+
139+
140+
class TestHttpHandshakeStrategy:
141+
142+
def create_test_metadata_response(self) -> dict:
143+
return {
144+
"0": {
145+
"0": {
146+
"engine_id": "3871ab24-6b5a-4ea5-a614-5381594bcdde",
147+
"agent_metadata": base64.b64encode(b"nixl-prefill-agent-data").decode(),
148+
"kv_caches_base_addr": [0x7f8b2c000000],
149+
"num_blocks": 1000,
150+
"block_len": 128,
151+
"attn_backend_name": "FLASH_ATTN_VLLM_V1"
152+
}
153+
}
154+
}
155+
156+
@patch('vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.urlopen')
157+
def test_http_handshake_success(self, mock_urlopen):
158+
mock_nixl = MagicMock()
159+
mock_add_agent = MagicMock(return_value="remote-agent-0")
160+
161+
strategy = HttpHandshakeStrategy(
162+
mock_nixl, 0, 1, 8080, "test-engine", mock_add_agent)
163+
164+
mock_response = MagicMock()
165+
mock_response.read.return_value = json.dumps(
166+
self.create_test_metadata_response()).encode()
167+
mock_urlopen.return_value.__enter__.return_value = mock_response
168+
169+
result = strategy.initiate_handshake("localhost", 8080, 1)
170+
171+
assert result == {0: "remote-agent-0"}
172+
mock_add_agent.assert_called_once()
173+
174+
call_args = mock_add_agent.call_args
175+
metadata = call_args[0][0]
176+
assert isinstance(metadata, NixlAgentMetadata)
177+
assert metadata.engine_id == "3871ab24-6b5a-4ea5-a614-5381594bcdde"
178+
assert metadata.agent_metadata == b"nixl-prefill-agent-data"
179+
180+
@patch('vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.urlopen')
181+
def test_http_handshake_multi_rank(self, mock_urlopen):
182+
mock_nixl = MagicMock()
183+
mock_add_agent = MagicMock(return_value="remote-agent-1")
184+
185+
strategy = HttpHandshakeStrategy(
186+
mock_nixl, 1, 2, 8080, "test-engine", mock_add_agent)
187+
188+
response_data = {
189+
"0": {
190+
"0": {
191+
"engine_id": "339a1bdd-e9ad-4c6e-a3e3-e0e7cca2238d",
192+
"agent_metadata": base64.b64encode(b"decode-agent-0-data").decode(),
193+
"kv_caches_base_addr": [0x7f8b2c000000],
194+
"num_blocks": 800,
195+
"block_len": 128,
196+
"attn_backend_name": "FLASH_ATTN_VLLM_V1"
197+
},
198+
"1": {
199+
"engine_id": "339a1bdd-e9ad-4c6e-a3e3-e0e7cca2238d",
200+
"agent_metadata": base64.b64encode(b"decode-agent-1-data").decode(),
201+
"kv_caches_base_addr": [0x7f8b2d000000],
202+
"num_blocks": 800,
203+
"block_len": 128,
204+
"attn_backend_name": "FLASH_ATTN_VLLM_V1"
205+
}
206+
}
207+
}
208+
209+
mock_response = MagicMock()
210+
mock_response.read.return_value = json.dumps(response_data).encode()
211+
mock_urlopen.return_value.__enter__.return_value = mock_response
212+
213+
result = strategy.initiate_handshake("localhost", 8080, 2)
214+
215+
assert result == {1: "remote-agent-1"}
216+
217+
call_args = mock_add_agent.call_args
218+
metadata = call_args[0][0]
219+
assert metadata.agent_metadata == b"decode-agent-1-data"
220+
221+
@patch('vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.urlopen')
222+
def test_http_handshake_url_error(self, mock_urlopen):
223+
mock_nixl = MagicMock()
224+
mock_add_agent = MagicMock()
225+
226+
strategy = HttpHandshakeStrategy(
227+
mock_nixl, 0, 1, 8080, "test-engine", mock_add_agent)
228+
229+
mock_urlopen.side_effect = URLError("Connection failed")
230+
231+
with pytest.raises(URLError):
232+
strategy.initiate_handshake("localhost", 8080, 1)
233+
234+
@patch('vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.urlopen')
235+
def test_http_handshake_none_response(self, mock_urlopen):
236+
mock_nixl = MagicMock()
237+
mock_add_agent = MagicMock()
238+
239+
strategy = HttpHandshakeStrategy(
240+
mock_nixl, 0, 1, 8080, "test-engine", mock_add_agent)
241+
242+
mock_response = MagicMock()
243+
mock_response.read.return_value = json.dumps(None).encode()
244+
mock_urlopen.return_value.__enter__.return_value = mock_response
245+
246+
with pytest.raises(RuntimeError, match="Remote server returned None metadata"):
247+
strategy.initiate_handshake("localhost", 8080, 1)
248+
249+
@patch('vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.urlopen')
250+
def test_http_handshake_missing_rank(self, mock_urlopen):
251+
mock_nixl = MagicMock()
252+
mock_add_agent = MagicMock()
253+
254+
strategy = HttpHandshakeStrategy(
255+
mock_nixl, 1, 2, 8080, "decode-engine", mock_add_agent)
256+
257+
mock_response = MagicMock()
258+
empty_response = {"0": {}}
259+
mock_response.read.return_value = json.dumps(empty_response).encode()
260+
mock_urlopen.return_value.__enter__.return_value = mock_response
261+
262+
with pytest.raises(RuntimeError, match="No metadata found for dp_rank 0"):
263+
strategy.initiate_handshake("localhost", 8080, 1)
264+
265+
def test_setup_listener_noop(self):
266+
mock_nixl = MagicMock()
267+
mock_add_agent = MagicMock()
268+
269+
strategy = HttpHandshakeStrategy(
270+
mock_nixl, 0, 1, 8080, "test-engine", mock_add_agent)
271+
272+
test_metadata = NixlAgentMetadata(
273+
engine_id="test-engine",
274+
agent_metadata=b"test-data",
275+
kv_caches_base_addr=[12345],
276+
num_blocks=100,
277+
block_len=16,
278+
attn_backend_name="FLASH_ATTN_VLLM_V1"
279+
)
280+
281+
strategy.setup_listener(test_metadata)
282+
283+
def test_cleanup_noop(self):
284+
mock_nixl = MagicMock()
285+
mock_add_agent = MagicMock()
286+
287+
strategy = HttpHandshakeStrategy(
288+
mock_nixl, 0, 1, 8080, "test-engine", mock_add_agent)
289+
290+
strategy.cleanup()
291+
292+
293+
class TestHandshakeStrategyIntegration:
294+
295+
@patch.dict('os.environ', {'VLLM_NIXL_HANDSHAKE_METHOD': 'zmq'})
296+
@patch('vllm.envs.VLLM_NIXL_HANDSHAKE_METHOD', 'zmq')
297+
def test_zmq_strategy_selection(self):
298+
assert envs.VLLM_NIXL_HANDSHAKE_METHOD.lower() == 'zmq'
299+
300+
@patch.dict('os.environ', {'VLLM_NIXL_HANDSHAKE_METHOD': 'http'})
301+
@patch('vllm.envs.VLLM_NIXL_HANDSHAKE_METHOD', 'http')
302+
def test_http_strategy_selection(self):
303+
assert envs.VLLM_NIXL_HANDSHAKE_METHOD.lower() == 'http'
304+
305+
def test_strategy_polymorphism(self):
306+
mock_nixl = MagicMock()
307+
mock_add_agent = MagicMock()
308+
309+
strategies = [
310+
ZmqHandshakeStrategy(mock_nixl, 0, 1, 8080, "test", mock_add_agent),
311+
HttpHandshakeStrategy(mock_nixl, 0, 1, 8080, "test", mock_add_agent)
312+
]
313+
314+
test_metadata = NixlAgentMetadata(
315+
engine_id="test-engine",
316+
agent_metadata=b"test-data",
317+
kv_caches_base_addr=[12345],
318+
num_blocks=100,
319+
block_len=16,
320+
attn_backend_name="FLASH_ATTN_VLLM_V1"
321+
)
322+
323+
for strategy in strategies:
324+
assert callable(strategy.initiate_handshake)
325+
assert callable(strategy.setup_listener)
326+
assert callable(strategy.cleanup)
327+
328+
strategy.setup_listener(test_metadata)
329+
strategy.cleanup()

0 commit comments

Comments
 (0)