1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+
4
+ import base64
5
+ import json
6
+ from unittest .mock import MagicMock , patch
7
+ from urllib .error import URLError
8
+
9
+ import pytest
10
+
11
+ from vllm import envs
12
+ from vllm .distributed .kv_transfer .kv_connector .v1 .nixl_connector import (
13
+ HandshakeStrategy , HttpHandshakeStrategy , NixlAgentMetadata ,
14
+ ZmqHandshakeStrategy )
15
+
16
+
17
+ class TestHandshakeStrategyAbstraction :
18
+
19
+ def test_abstract_base_class (self ):
20
+ with pytest .raises (TypeError ):
21
+ HandshakeStrategy (None , 0 , 1 , 8080 , "test-engine" )
22
+
23
+ def test_strategy_interface (self ):
24
+ mock_nixl = MagicMock ()
25
+ mock_add_agent = MagicMock ()
26
+
27
+ zmq_strategy = ZmqHandshakeStrategy (
28
+ mock_nixl , 0 , 1 , 8080 , "test-engine" , mock_add_agent )
29
+ assert hasattr (zmq_strategy , 'initiate_handshake' )
30
+ assert hasattr (zmq_strategy , 'setup_listener' )
31
+ assert hasattr (zmq_strategy , 'cleanup' )
32
+
33
+ http_strategy = HttpHandshakeStrategy (
34
+ mock_nixl , 0 , 1 , 8080 , "test-engine" , mock_add_agent )
35
+ assert hasattr (http_strategy , 'initiate_handshake' )
36
+ assert hasattr (http_strategy , 'setup_listener' )
37
+ assert hasattr (http_strategy , 'cleanup' )
38
+
39
+
40
+ class TestZmqHandshakeStrategy :
41
+
42
+ def create_test_metadata (self ) -> NixlAgentMetadata :
43
+ return NixlAgentMetadata (
44
+ engine_id = "test-engine" ,
45
+ agent_metadata = b"test-agent-data" ,
46
+ kv_caches_base_addr = [12345 ],
47
+ num_blocks = 100 ,
48
+ block_len = 16 ,
49
+ attn_backend_name = "FLASH_ATTN_VLLM_V1"
50
+ )
51
+
52
+ @patch ('vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.ZmqHandshakeStrategy._zmq_ctx' )
53
+ @patch ('vllm.utils.make_zmq_path' )
54
+ def test_zmq_handshake_success (self , mock_make_path , mock_zmq_ctx ):
55
+ mock_nixl = MagicMock ()
56
+ mock_add_agent = MagicMock (return_value = "agent-name-0" )
57
+
58
+ strategy = ZmqHandshakeStrategy (
59
+ mock_nixl , 0 , 1 , 8080 , "test-engine" , mock_add_agent )
60
+
61
+ mock_socket = MagicMock ()
62
+ mock_zmq_ctx .return_value .__enter__ .return_value = mock_socket
63
+ mock_make_path .return_value = "tcp://localhost:8080"
64
+
65
+ test_metadata = self .create_test_metadata ()
66
+ with patch ('msgspec.msgpack.Decoder' ) as mock_decoder_class :
67
+ mock_decoder = MagicMock ()
68
+ mock_decoder_class .return_value = mock_decoder
69
+ mock_decoder .decode .return_value = test_metadata
70
+
71
+ result = strategy .initiate_handshake ("localhost" , 8080 , 1 )
72
+
73
+ assert result == {0 : "agent-name-0" }
74
+ mock_add_agent .assert_called_once ()
75
+ mock_socket .send .assert_called ()
76
+ mock_socket .recv .assert_called ()
77
+
78
+ @patch ('vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.ZmqHandshakeStrategy._zmq_ctx' )
79
+ @patch ('vllm.utils.make_zmq_path' )
80
+ def test_zmq_handshake_multi_rank (self , mock_make_path , mock_zmq_ctx ):
81
+ mock_nixl = MagicMock ()
82
+ mock_add_agent = MagicMock (side_effect = ["agent-0" , "agent-1" ])
83
+
84
+ strategy = ZmqHandshakeStrategy (
85
+ mock_nixl , 1 , 2 , 8080 , "test-engine" , mock_add_agent )
86
+
87
+ mock_socket = MagicMock ()
88
+ mock_zmq_ctx .return_value .__enter__ .return_value = mock_socket
89
+ mock_make_path .side_effect = ["tcp://localhost:8080" , "tcp://localhost:8081" ]
90
+
91
+ test_metadata = self .create_test_metadata ()
92
+ with patch ('msgspec.msgpack.Decoder' ) as mock_decoder_class :
93
+ mock_decoder = MagicMock ()
94
+ mock_decoder_class .return_value = mock_decoder
95
+ mock_decoder .decode .return_value = test_metadata
96
+
97
+ result = strategy .initiate_handshake ("localhost" , 8080 , 2 )
98
+
99
+ assert result == {0 : "agent-0" , 1 : "agent-1" }
100
+ assert mock_add_agent .call_count == 2
101
+
102
+ @patch ('threading.Thread' )
103
+ def test_setup_listener (self , mock_thread ):
104
+ mock_nixl = MagicMock ()
105
+ mock_add_agent = MagicMock ()
106
+
107
+ strategy = ZmqHandshakeStrategy (
108
+ mock_nixl , 0 , 1 , 8080 , "test-engine" , mock_add_agent )
109
+
110
+ mock_thread_instance = MagicMock ()
111
+ mock_thread .return_value = mock_thread_instance
112
+
113
+ test_metadata = self .create_test_metadata ()
114
+
115
+ with patch ('threading.Event' ) as mock_event_class :
116
+ mock_event = MagicMock ()
117
+ mock_event_class .return_value = mock_event
118
+
119
+ strategy .setup_listener (test_metadata )
120
+
121
+ mock_thread .assert_called_once ()
122
+ mock_thread_instance .start .assert_called_once ()
123
+ mock_event .wait .assert_called_once ()
124
+
125
+ def test_cleanup (self ):
126
+ mock_nixl = MagicMock ()
127
+ mock_add_agent = MagicMock ()
128
+
129
+ strategy = ZmqHandshakeStrategy (
130
+ mock_nixl , 0 , 1 , 8080 , "test-engine" , mock_add_agent )
131
+
132
+ mock_thread = MagicMock ()
133
+ strategy ._listener_thread = mock_thread
134
+
135
+ strategy .cleanup ()
136
+
137
+ mock_thread .join .assert_called_once_with (timeout = 0 )
138
+
139
+
140
+ class TestHttpHandshakeStrategy :
141
+
142
+ def create_test_metadata_response (self ) -> dict :
143
+ return {
144
+ "0" : {
145
+ "0" : {
146
+ "engine_id" : "3871ab24-6b5a-4ea5-a614-5381594bcdde" ,
147
+ "agent_metadata" : base64 .b64encode (b"nixl-prefill-agent-data" ).decode (),
148
+ "kv_caches_base_addr" : [0x7f8b2c000000 ],
149
+ "num_blocks" : 1000 ,
150
+ "block_len" : 128 ,
151
+ "attn_backend_name" : "FLASH_ATTN_VLLM_V1"
152
+ }
153
+ }
154
+ }
155
+
156
+ @patch ('vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.urlopen' )
157
+ def test_http_handshake_success (self , mock_urlopen ):
158
+ mock_nixl = MagicMock ()
159
+ mock_add_agent = MagicMock (return_value = "remote-agent-0" )
160
+
161
+ strategy = HttpHandshakeStrategy (
162
+ mock_nixl , 0 , 1 , 8080 , "test-engine" , mock_add_agent )
163
+
164
+ mock_response = MagicMock ()
165
+ mock_response .read .return_value = json .dumps (
166
+ self .create_test_metadata_response ()).encode ()
167
+ mock_urlopen .return_value .__enter__ .return_value = mock_response
168
+
169
+ result = strategy .initiate_handshake ("localhost" , 8080 , 1 )
170
+
171
+ assert result == {0 : "remote-agent-0" }
172
+ mock_add_agent .assert_called_once ()
173
+
174
+ call_args = mock_add_agent .call_args
175
+ metadata = call_args [0 ][0 ]
176
+ assert isinstance (metadata , NixlAgentMetadata )
177
+ assert metadata .engine_id == "3871ab24-6b5a-4ea5-a614-5381594bcdde"
178
+ assert metadata .agent_metadata == b"nixl-prefill-agent-data"
179
+
180
+ @patch ('vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.urlopen' )
181
+ def test_http_handshake_multi_rank (self , mock_urlopen ):
182
+ mock_nixl = MagicMock ()
183
+ mock_add_agent = MagicMock (return_value = "remote-agent-1" )
184
+
185
+ strategy = HttpHandshakeStrategy (
186
+ mock_nixl , 1 , 2 , 8080 , "test-engine" , mock_add_agent )
187
+
188
+ response_data = {
189
+ "0" : {
190
+ "0" : {
191
+ "engine_id" : "339a1bdd-e9ad-4c6e-a3e3-e0e7cca2238d" ,
192
+ "agent_metadata" : base64 .b64encode (b"decode-agent-0-data" ).decode (),
193
+ "kv_caches_base_addr" : [0x7f8b2c000000 ],
194
+ "num_blocks" : 800 ,
195
+ "block_len" : 128 ,
196
+ "attn_backend_name" : "FLASH_ATTN_VLLM_V1"
197
+ },
198
+ "1" : {
199
+ "engine_id" : "339a1bdd-e9ad-4c6e-a3e3-e0e7cca2238d" ,
200
+ "agent_metadata" : base64 .b64encode (b"decode-agent-1-data" ).decode (),
201
+ "kv_caches_base_addr" : [0x7f8b2d000000 ],
202
+ "num_blocks" : 800 ,
203
+ "block_len" : 128 ,
204
+ "attn_backend_name" : "FLASH_ATTN_VLLM_V1"
205
+ }
206
+ }
207
+ }
208
+
209
+ mock_response = MagicMock ()
210
+ mock_response .read .return_value = json .dumps (response_data ).encode ()
211
+ mock_urlopen .return_value .__enter__ .return_value = mock_response
212
+
213
+ result = strategy .initiate_handshake ("localhost" , 8080 , 2 )
214
+
215
+ assert result == {1 : "remote-agent-1" }
216
+
217
+ call_args = mock_add_agent .call_args
218
+ metadata = call_args [0 ][0 ]
219
+ assert metadata .agent_metadata == b"decode-agent-1-data"
220
+
221
+ @patch ('vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.urlopen' )
222
+ def test_http_handshake_url_error (self , mock_urlopen ):
223
+ mock_nixl = MagicMock ()
224
+ mock_add_agent = MagicMock ()
225
+
226
+ strategy = HttpHandshakeStrategy (
227
+ mock_nixl , 0 , 1 , 8080 , "test-engine" , mock_add_agent )
228
+
229
+ mock_urlopen .side_effect = URLError ("Connection failed" )
230
+
231
+ with pytest .raises (URLError ):
232
+ strategy .initiate_handshake ("localhost" , 8080 , 1 )
233
+
234
+ @patch ('vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.urlopen' )
235
+ def test_http_handshake_none_response (self , mock_urlopen ):
236
+ mock_nixl = MagicMock ()
237
+ mock_add_agent = MagicMock ()
238
+
239
+ strategy = HttpHandshakeStrategy (
240
+ mock_nixl , 0 , 1 , 8080 , "test-engine" , mock_add_agent )
241
+
242
+ mock_response = MagicMock ()
243
+ mock_response .read .return_value = json .dumps (None ).encode ()
244
+ mock_urlopen .return_value .__enter__ .return_value = mock_response
245
+
246
+ with pytest .raises (RuntimeError , match = "Remote server returned None metadata" ):
247
+ strategy .initiate_handshake ("localhost" , 8080 , 1 )
248
+
249
+ @patch ('vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.urlopen' )
250
+ def test_http_handshake_missing_rank (self , mock_urlopen ):
251
+ mock_nixl = MagicMock ()
252
+ mock_add_agent = MagicMock ()
253
+
254
+ strategy = HttpHandshakeStrategy (
255
+ mock_nixl , 1 , 2 , 8080 , "decode-engine" , mock_add_agent )
256
+
257
+ mock_response = MagicMock ()
258
+ empty_response = {"0" : {}}
259
+ mock_response .read .return_value = json .dumps (empty_response ).encode ()
260
+ mock_urlopen .return_value .__enter__ .return_value = mock_response
261
+
262
+ with pytest .raises (RuntimeError , match = "No metadata found for dp_rank 0" ):
263
+ strategy .initiate_handshake ("localhost" , 8080 , 1 )
264
+
265
+ def test_setup_listener_noop (self ):
266
+ mock_nixl = MagicMock ()
267
+ mock_add_agent = MagicMock ()
268
+
269
+ strategy = HttpHandshakeStrategy (
270
+ mock_nixl , 0 , 1 , 8080 , "test-engine" , mock_add_agent )
271
+
272
+ test_metadata = NixlAgentMetadata (
273
+ engine_id = "test-engine" ,
274
+ agent_metadata = b"test-data" ,
275
+ kv_caches_base_addr = [12345 ],
276
+ num_blocks = 100 ,
277
+ block_len = 16 ,
278
+ attn_backend_name = "FLASH_ATTN_VLLM_V1"
279
+ )
280
+
281
+ strategy .setup_listener (test_metadata )
282
+
283
+ def test_cleanup_noop (self ):
284
+ mock_nixl = MagicMock ()
285
+ mock_add_agent = MagicMock ()
286
+
287
+ strategy = HttpHandshakeStrategy (
288
+ mock_nixl , 0 , 1 , 8080 , "test-engine" , mock_add_agent )
289
+
290
+ strategy .cleanup ()
291
+
292
+
293
+ class TestHandshakeStrategyIntegration :
294
+
295
+ @patch .dict ('os.environ' , {'VLLM_NIXL_HANDSHAKE_METHOD' : 'zmq' })
296
+ @patch ('vllm.envs.VLLM_NIXL_HANDSHAKE_METHOD' , 'zmq' )
297
+ def test_zmq_strategy_selection (self ):
298
+ assert envs .VLLM_NIXL_HANDSHAKE_METHOD .lower () == 'zmq'
299
+
300
+ @patch .dict ('os.environ' , {'VLLM_NIXL_HANDSHAKE_METHOD' : 'http' })
301
+ @patch ('vllm.envs.VLLM_NIXL_HANDSHAKE_METHOD' , 'http' )
302
+ def test_http_strategy_selection (self ):
303
+ assert envs .VLLM_NIXL_HANDSHAKE_METHOD .lower () == 'http'
304
+
305
+ def test_strategy_polymorphism (self ):
306
+ mock_nixl = MagicMock ()
307
+ mock_add_agent = MagicMock ()
308
+
309
+ strategies = [
310
+ ZmqHandshakeStrategy (mock_nixl , 0 , 1 , 8080 , "test" , mock_add_agent ),
311
+ HttpHandshakeStrategy (mock_nixl , 0 , 1 , 8080 , "test" , mock_add_agent )
312
+ ]
313
+
314
+ test_metadata = NixlAgentMetadata (
315
+ engine_id = "test-engine" ,
316
+ agent_metadata = b"test-data" ,
317
+ kv_caches_base_addr = [12345 ],
318
+ num_blocks = 100 ,
319
+ block_len = 16 ,
320
+ attn_backend_name = "FLASH_ATTN_VLLM_V1"
321
+ )
322
+
323
+ for strategy in strategies :
324
+ assert callable (strategy .initiate_handshake )
325
+ assert callable (strategy .setup_listener )
326
+ assert callable (strategy .cleanup )
327
+
328
+ strategy .setup_listener (test_metadata )
329
+ strategy .cleanup ()
0 commit comments