19
19
KVConnectorBase_V1 , KVConnectorMetadata , KVConnectorRole )
20
20
from vllm .distributed .parallel_state import (
21
21
get_tensor_model_parallel_rank , get_tensor_model_parallel_world_size ,
22
- get_tp_group , get_world_group )
22
+ get_tp_group )
23
23
from vllm .logger import init_logger
24
24
from vllm .utils import make_zmq_path , make_zmq_socket , round_down
25
25
from vllm .v1 .core .sched .output import SchedulerOutput
@@ -172,6 +172,11 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
172
172
self .vllm_config = vllm_config
173
173
self .block_size = vllm_config .cache_config .block_size
174
174
self .engine_id = engine_id
175
+ self .side_channel_host = envs .VLLM_NIXL_SIDE_CHANNEL_HOST
176
+ self .side_channel_port = (
177
+ envs .VLLM_NIXL_SIDE_CHANNEL_PORT +
178
+ vllm_config .parallel_config .data_parallel_rank_local *
179
+ vllm_config .parallel_config .tensor_parallel_size )
175
180
logger .info ("Initializing NIXL Scheduler %s" , engine_id )
176
181
177
182
# Requests that need to start recv.
@@ -310,8 +315,8 @@ def request_finished(
310
315
do_remote_decode = False ,
311
316
remote_block_ids = computed_block_ids ,
312
317
remote_engine_id = self .engine_id ,
313
- remote_host = envs . VLLM_NIXL_SIDE_CHANNEL_HOST ,
314
- remote_port = envs . VLLM_NIXL_SIDE_CHANNEL_PORT ,
318
+ remote_host = self . side_channel_host ,
319
+ remote_port = self . side_channel_port ,
315
320
)
316
321
317
322
@@ -330,11 +335,19 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
330
335
# Map of engine_id -> agent_name.
331
336
self ._remote_agents : dict [str , str ] = {}
332
337
338
+ # NIXL handshake port.
339
+ # NOTE(rob): Within a DP group, each DP rank gets its own
340
+ # base port (which is sent in the KVTransferParams).
341
+ # Each TP rank listens/queries on the base_port + tp_rank.
342
+ self .side_channel_port = (
343
+ envs .VLLM_NIXL_SIDE_CHANNEL_PORT +
344
+ vllm_config .parallel_config .data_parallel_rank_local *
345
+ vllm_config .parallel_config .tensor_parallel_size )
346
+
333
347
# Metadata.
334
348
self .engine_id = engine_id
335
- self .rank = get_tensor_model_parallel_rank ()
349
+ self .tp_rank = get_tensor_model_parallel_rank ()
336
350
self .world_size = get_tensor_model_parallel_world_size ()
337
- self .world_rank = get_world_group ().rank_in_group
338
351
self .tp_group = get_tp_group ()
339
352
340
353
# KV Caches and nixl tracking data.
@@ -383,16 +396,11 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
383
396
384
397
@staticmethod
385
398
def _nixl_handshake_listener (metadata : NixlAgentMetadata ,
386
- ready_event : threading .Event ,
387
- world_rank : int ):
399
+ ready_event : threading .Event , base_port : int ,
400
+ tp_rank : int ):
388
401
"""Background thread for getting new NIXL handshakes."""
389
402
# NOTE(rob): this is a simple implementation. We will move
390
- # to a better approach like an ETCD server in the future.
391
-
392
- # NOTE(rob): to support heterogeneous TP, we will have to
393
- # move this into the scheduler rather than worker, since
394
- # each rank needs the metadata of all other ranks (whereas
395
- # in this setup, each rank only gets one other rank's meta.
403
+ # to a better approach via HTTP endpoint soon.
396
404
397
405
encoder = msgspec .msgpack .Encoder ()
398
406
encoded_data = encoder .encode (metadata )
@@ -402,11 +410,7 @@ def _nixl_handshake_listener(metadata: NixlAgentMetadata,
402
410
403
411
# Listen for new requests for metadata.
404
412
host = envs .VLLM_NIXL_SIDE_CHANNEL_HOST
405
- # NOTE(rob): we need each rank to have a unique port. This
406
- # hack to keeps us moving. We will switch when moving to etcd
407
- # or where we have a single ZMQ socket in the scheduler.
408
- port = envs .VLLM_NIXL_SIDE_CHANNEL_PORT + world_rank
409
- path = make_zmq_path ("tcp" , host , port )
413
+ path = make_zmq_path ("tcp" , host , base_port + tp_rank )
410
414
logger .debug ("Starting listening on path: %s" , path )
411
415
with zmq_ctx (zmq .ROUTER , path ) as sock :
412
416
ready_event .set ()
@@ -421,10 +425,10 @@ def _nixl_handshake(self, host: str, port: int):
421
425
"""Do a NIXL handshake with a remote instance."""
422
426
423
427
start_time = time .perf_counter ()
424
- # NOTE(rob): we need each rank to have a unique port. This is
425
- # a hack to keep us moving. We will switch when moving to etcd
426
- # or where we have a single ZMQ socket in the scheduler .
427
- path = make_zmq_path ("tcp" , host , port + self .world_rank )
428
+ # NOTE(rob): we need each tp_rank to have a unique port.
429
+ # This is a hack to keep us moving. We will switch when
430
+ # we switch to HTTP-based NIXL metadata exchange .
431
+ path = make_zmq_path ("tcp" , host , port + self .tp_rank )
428
432
logger .debug ("Querying metadata on path: %s" , path )
429
433
with zmq_ctx (zmq .REQ , path ) as sock :
430
434
# Send query for the request.
@@ -532,7 +536,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
532
536
ready_event = threading .Event ()
533
537
self ._nixl_handshake_listener_t = threading .Thread (
534
538
target = self ._nixl_handshake_listener ,
535
- args = (metadata , ready_event , self .world_rank ),
539
+ args = (metadata , ready_event , self .side_channel_port , self . tp_rank ),
536
540
daemon = True ,
537
541
name = "nixl_handshake_listener" )
538
542
self ._nixl_handshake_listener_t .start ()
@@ -556,9 +560,9 @@ def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata):
556
560
block_offset = block_id * self .block_len
557
561
# (addr, len, device id)
558
562
blocks_data .append (
559
- (base_addr + block_offset , self .block_len , self .rank ))
560
- logger .debug ("Created %s blocks for src engine %s and rank %s" ,
561
- len (blocks_data ), self .engine_id , self .rank )
563
+ (base_addr + block_offset , self .block_len , self .tp_rank ))
564
+ logger .debug ("Created %s blocks for src engine %s and tp_rank %s" ,
565
+ len (blocks_data ), self .engine_id , self .tp_rank )
562
566
563
567
# Register with NIXL.
564
568
descs = self .nixl_wrapper .get_xfer_descs (blocks_data , "VRAM" )
@@ -573,9 +577,9 @@ def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata):
573
577
block_offset = block_id * self .block_len
574
578
# (addr, len, device id)
575
579
blocks_data .append (
576
- (base_addr + block_offset , self .block_len , self .rank ))
577
- logger .debug ("Created %s blocks for dst engine %s and rank %s" ,
578
- len (blocks_data ), engine_id , self .rank )
580
+ (base_addr + block_offset , self .block_len , self .tp_rank ))
581
+ logger .debug ("Created %s blocks for dst engine %s and tp_rank %s" ,
582
+ len (blocks_data ), engine_id , self .tp_rank )
579
583
580
584
# Register with NIXL.
581
585
descs = self .nixl_wrapper .get_xfer_descs (blocks_data , "VRAM" )
@@ -600,14 +604,14 @@ def get_finished(self) -> tuple[set[str], set[str]]:
600
604
if len (done_sending ) > 0 or len (done_recving ) > 0 :
601
605
logger .debug (
602
606
"Rank %s, get_finished: %s requests done sending "
603
- "and %s requests done recving" , self .rank , len ( done_sending ) ,
604
- len (done_recving ))
607
+ "and %s requests done recving" , self .tp_rank ,
608
+ len (done_sending ), len ( done_recving ))
605
609
606
610
if self .world_size == 1 :
607
611
return done_sending , done_recving
608
612
609
613
# Rank 0: get finished from all other ranks.
610
- if self .rank == 0 :
614
+ if self .tp_rank == 0 :
611
615
for req_id in done_sending :
612
616
self ._done_sending_count [req_id ] += 1
613
617
for req_id in done_recving :
0 commit comments