13
13
from collections .abc import Iterator
14
14
from concurrent .futures import Future , ThreadPoolExecutor
15
15
from dataclasses import dataclass
16
- from typing import TYPE_CHECKING , Any , Dict , Optional
16
+ from typing import TYPE_CHECKING , Any , Optional
17
17
from urllib .error import HTTPError , URLError
18
18
from urllib .request import Request as URLRequest
19
19
from urllib .request import urlopen
@@ -81,81 +81,83 @@ class ReqMeta:
81
81
82
82
83
83
class HandshakeStrategy (ABC ):
84
-
85
- def __init__ (self , nixl_wrapper , tp_rank : int , tp_size : int ,
84
+
85
+ def __init__ (self , nixl_wrapper , tp_rank : int , tp_size : int ,
86
86
side_channel_port : int , engine_id : str ):
87
87
self .nixl_wrapper = nixl_wrapper
88
88
self .tp_rank = tp_rank
89
89
self .tp_size = tp_size
90
90
self .side_channel_port = side_channel_port
91
91
self .engine_id = engine_id
92
-
92
+
93
93
@abstractmethod
94
- def initiate_handshake (self , host : str , port : int ,
95
- remote_tp_size : int ) -> Dict [int , str ]:
94
+ def initiate_handshake (self , host : str , port : int ,
95
+ remote_tp_size : int ) -> dict [int , str ]:
96
96
pass
97
-
97
+
98
98
@abstractmethod
99
99
def setup_listener (self , metadata : NixlAgentMetadata ) -> None :
100
100
pass
101
-
101
+
102
102
@abstractmethod
103
103
def cleanup (self ) -> None :
104
104
pass
105
105
106
106
107
107
class ZmqHandshakeStrategy (HandshakeStrategy ):
108
-
108
+
109
109
def __init__ (self , nixl_wrapper , tp_rank : int , tp_size : int ,
110
- side_channel_port : int , engine_id : str ,
110
+ side_channel_port : int , engine_id : str ,
111
111
add_remote_agent_func ):
112
- super ().__init__ (nixl_wrapper , tp_rank , tp_size , side_channel_port , engine_id )
112
+ super ().__init__ (nixl_wrapper , tp_rank , tp_size , side_channel_port ,
113
+ engine_id )
113
114
self .add_remote_agent_func = add_remote_agent_func
114
115
self ._listener_thread : Optional [threading .Thread ] = None
115
- self ._tp_size_mapping : Dict [str , int ] = {engine_id : tp_size }
116
-
117
- def initiate_handshake (self , host : str , port : int ,
118
- remote_tp_size : int ) -> Dict [int , str ]:
116
+ self ._tp_size_mapping : dict [str , int ] = {engine_id : tp_size }
117
+
118
+ def initiate_handshake (self , host : str , port : int ,
119
+ remote_tp_size : int ) -> dict [int , str ]:
119
120
start_time = time .perf_counter ()
120
-
121
+
121
122
def handshake (path : str , rank : int ) -> tuple [NixlAgentMetadata , str ]:
122
123
with self ._zmq_ctx (zmq .REQ , path ) as sock :
123
124
sock .send (GET_META_MSG )
124
125
metadata_bytes = sock .recv ()
125
126
decoder = msgspec .msgpack .Decoder (NixlAgentMetadata )
126
127
metadata = decoder .decode (metadata_bytes )
127
128
got_metadata_time = time .perf_counter ()
128
-
129
+
129
130
# Register Remote agent
130
- agent_name = self .add_remote_agent_func (metadata , rank , remote_tp_size )
131
+ agent_name = self .add_remote_agent_func (
132
+ metadata , rank , remote_tp_size )
131
133
setup_agent_time = time .perf_counter ()
132
-
134
+
133
135
logger .debug ("NIXL handshake: get metadata took: %s" ,
134
- got_metadata_time - start_time )
135
- logger .debug ("NIXL handshake: add agent took: %s" ,
136
- setup_agent_time - got_metadata_time )
136
+ got_metadata_time - start_time )
137
+ logger .debug ("NIXL handshake: add agent took: %s" ,
138
+ setup_agent_time - got_metadata_time )
137
139
return metadata , agent_name
138
-
140
+
139
141
# Handshake with remote agent-rank0 first to get the tp_size of remote
140
142
path = make_zmq_path ("tcp" , host , port )
141
143
logger .debug ("Querying master rank metadata on path: %s" , path )
142
144
metadata , agent_name_0 = handshake (path , 0 )
143
-
145
+
144
146
agents = {0 : agent_name_0 }
145
-
147
+
146
148
# Handshake only with the other TP remote the current local rank will
147
149
# pull from. With homogeneous TP it happens to be the same rank_i.
148
150
tp_ratio = self ._tp_size_mapping [self .engine_id ] // remote_tp_size
149
151
p_remote_rank = self .tp_rank // tp_ratio
150
152
if p_remote_rank > 0 :
151
153
path = make_zmq_path ("tcp" , host , port + p_remote_rank )
152
154
logger .debug ("Querying metadata on path: %s at remote rank %s" ,
153
- path , p_remote_rank )
155
+ path , p_remote_rank )
154
156
_ , agent_name = handshake (path , p_remote_rank )
155
157
agents [p_remote_rank ] = agent_name
156
-
158
+
157
159
return agents
158
-
160
+
159
161
def setup_listener (self , metadata : NixlAgentMetadata ) -> None :
160
162
ready_event = threading .Event ()
161
163
self ._listener_thread = threading .Thread (
@@ -165,20 +167,21 @@ def setup_listener(self, metadata: NixlAgentMetadata) -> None:
165
167
name = "nixl_handshake_listener" )
166
168
self ._listener_thread .start ()
167
169
ready_event .wait ()
168
-
170
+
169
171
def cleanup (self ) -> None :
170
172
if self ._listener_thread :
171
173
self ._listener_thread .join (timeout = 0 )
172
-
174
+
173
175
@staticmethod
174
176
def _nixl_handshake_listener (metadata : NixlAgentMetadata ,
175
- ready_event : threading .Event , base_port : int ,
176
- tp_rank : int ):
177
+ ready_event : threading .Event , base_port : int ,
178
+ tp_rank : int ):
177
179
encoder = msgspec .msgpack .Encoder ()
178
180
encoded_data = encoder .encode (metadata )
179
181
size_in_bytes = len (encoded_data )
180
- logger .debug ("Size of encoded NixlAgentMetadata: %s bytes" , size_in_bytes )
181
-
182
+ logger .debug ("Size of encoded NixlAgentMetadata: %s bytes" ,
183
+ size_in_bytes )
184
+
182
185
# Listen for new requests for metadata
183
186
host = envs .VLLM_NIXL_SIDE_CHANNEL_HOST
184
187
path = make_zmq_path ("tcp" , host , base_port + tp_rank )
@@ -188,97 +191,109 @@ def _nixl_handshake_listener(metadata: NixlAgentMetadata,
188
191
while True :
189
192
identity , _ , msg = sock .recv_multipart ()
190
193
if msg != GET_META_MSG :
191
- logger .warning ("Connection listener got unexpected message %s" , msg )
194
+ logger .warning (
195
+ "Connection listener got unexpected message %s" , msg )
192
196
sock .send_multipart ((identity , b"" , encoded_data ))
193
-
197
+
194
198
@staticmethod
195
199
@contextlib .contextmanager
196
200
def _zmq_ctx (socket_type : Any , addr : str ) -> Iterator [zmq .Socket ]:
197
201
if socket_type not in (zmq .ROUTER , zmq .REQ ):
198
202
raise ValueError (f"Unexpected socket type: { socket_type } " )
199
-
203
+
200
204
ctx : Optional [zmq .Context ] = None
201
205
try :
202
206
ctx = zmq .Context ()
203
- yield make_zmq_socket (ctx = ctx , path = addr , socket_type = socket_type ,
204
- bind = socket_type == zmq .ROUTER )
207
+ yield make_zmq_socket (ctx = ctx ,
208
+ path = addr ,
209
+ socket_type = socket_type ,
210
+ bind = socket_type == zmq .ROUTER )
205
211
finally :
206
212
if ctx is not None :
207
213
ctx .destroy (linger = 0 )
208
214
209
215
210
216
class HttpHandshakeStrategy (HandshakeStrategy ):
211
-
217
+
212
218
def __init__ (self , nixl_wrapper , tp_rank : int , tp_size : int ,
213
219
side_channel_port : int , engine_id : str ,
214
220
add_remote_agent_func ):
215
- super ().__init__ (nixl_wrapper , tp_rank , tp_size , side_channel_port , engine_id )
221
+ super ().__init__ (nixl_wrapper , tp_rank , tp_size , side_channel_port ,
222
+ engine_id )
216
223
self .add_remote_agent_func = add_remote_agent_func
217
- self ._tp_size_mapping : Dict [str , int ] = {engine_id : tp_size }
218
-
219
- def initiate_handshake (self , host : str , port : int ,
220
- remote_tp_size : int ) -> Dict [int , str ]:
224
+ self ._tp_size_mapping : dict [str , int ] = {engine_id : tp_size }
225
+
226
+ def initiate_handshake (self , host : str , port : int ,
227
+ remote_tp_size : int ) -> dict [int , str ]:
221
228
start_time = time .perf_counter ()
222
229
logger .debug ("Starting NIXL handshake with %s:%s" , host , port )
223
-
230
+
224
231
url = build_uri ("http" , host , port , path = "get_kv_connector_metadata" )
225
-
232
+
226
233
try :
227
234
req = URLRequest (url )
228
- with urlopen (req , timeout = envs .VLLM_NIXL_HANDSHAKE_TIMEOUT ) as response :
235
+ with urlopen (req ,
236
+ timeout = envs .VLLM_NIXL_HANDSHAKE_TIMEOUT ) as response :
229
237
response_data = response .read ().decode ('utf-8' )
230
238
res = json .loads (response_data )
231
239
except (URLError , HTTPError ) as e :
232
240
logger .error ("Failed to fetch metadata from %s: %s" , url , e )
233
241
raise
234
-
242
+
235
243
if res is None :
236
- logger .warning ("Remote server returned None metadata, skipping handshake" )
244
+ logger .warning (
245
+ "Remote server returned None metadata, skipping handshake" )
237
246
raise RuntimeError ("Remote server returned None metadata" )
238
-
247
+
239
248
# Get dp_rank 0 data (standard for disaggregated prefill-decode)
240
249
dp_data = res .get ("0" , {})
241
250
if not dp_data :
242
251
raise RuntimeError ("No metadata found for dp_rank 0" )
243
-
252
+
244
253
remote_tp_size = len (dp_data .keys ())
245
-
254
+
246
255
# Handshake only with the remote TP rank that current local rank will
247
256
# pull from. With homogeneous TP it happens to be the same rank_i.
248
257
tp_ratio = self ._tp_size_mapping [self .engine_id ] // remote_tp_size
249
258
p_remote_rank = self .tp_rank // tp_ratio
250
-
259
+
251
260
# Get data for the specific rank we need to connect to
252
261
rank_data = dp_data .get (str (p_remote_rank ), {})
253
262
if not rank_data :
254
- raise RuntimeError (f"No metadata found for remote rank { p_remote_rank } " )
255
-
263
+ raise RuntimeError (
264
+ f"No metadata found for remote rank { p_remote_rank } " )
265
+
256
266
metadata_bytes = rank_data .get ("agent_metadata" , None )
257
267
if metadata_bytes is None :
258
- raise RuntimeError (f"No agent metadata found for remote rank { p_remote_rank } " )
259
-
268
+ raise RuntimeError (
269
+ f"No agent metadata found for remote rank { p_remote_rank } " )
270
+
260
271
rank_data_copy = rank_data .copy ()
261
272
rank_data_copy .pop ("agent_metadata" , None )
262
273
metadata = NixlAgentMetadata (
263
274
agent_metadata = base64 .b64decode (metadata_bytes ), ** rank_data_copy )
264
-
275
+
265
276
pre_register = time .perf_counter ()
266
277
# Register Remote agent
267
- remote_agent_name = self .add_remote_agent_func (metadata , p_remote_rank , remote_tp_size )
278
+ remote_agent_name = self .add_remote_agent_func (metadata , p_remote_rank ,
279
+ remote_tp_size )
268
280
agent_time = time .perf_counter ()
269
-
270
- logger .debug ("Finished registering remote agent for engine %s" , metadata .engine_id )
271
- logger .debug ("NIXL handshake: get metadata took: %s" , pre_register - start_time )
272
- logger .debug ("NIXL handshake: add agent took: %s" , agent_time - pre_register )
273
-
281
+
282
+ logger .debug ("Finished registering remote agent for engine %s" ,
283
+ metadata .engine_id )
284
+ logger .debug ("NIXL handshake: get metadata took: %s" ,
285
+ pre_register - start_time )
286
+ logger .debug ("NIXL handshake: add agent took: %s" ,
287
+ agent_time - pre_register )
288
+
274
289
logger .debug ("NIXL handshake method completed for %s:%s" , host , port )
275
-
290
+
276
291
# Return remote rank -> agent name mapping
277
292
return {p_remote_rank : remote_agent_name }
278
-
293
+
279
294
def setup_listener (self , metadata : NixlAgentMetadata ) -> None :
280
295
pass
281
-
296
+
282
297
def cleanup (self ) -> None :
283
298
pass
284
299
@@ -680,8 +695,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
680
695
self .side_channel_port , self .engine_id , self .add_remote_agent )
681
696
else :
682
697
raise ValueError (f"Unknown handshake method: { handshake_method } . "
683
- "Supported methods: 'zmq', 'http'" )
684
-
698
+ "Supported methods: 'zmq', 'http'" )
699
+
685
700
logger .info ("Using %s handshake strategy" , handshake_method )
686
701
687
702
def __del__ (self ):
@@ -693,7 +708,8 @@ def __del__(self):
693
708
694
709
def _nixl_handshake (self , host : str , port : int ,
695
710
remote_tp_size : int ) -> dict [int , str ]:
696
- return self ._handshake_strategy .initiate_handshake (host , port , remote_tp_size )
711
+ return self ._handshake_strategy .initiate_handshake (
712
+ host , port , remote_tp_size )
697
713
698
714
def register_kv_caches (self , kv_caches : dict [str , torch .Tensor ]):
699
715
"""Register the KV Cache data in nixl."""
0 commit comments