Skip to content

Commit b9f61e1

Browse files
[Bugfix][Nixl] Fix DP Metadata Handshake (#19008)
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
1 parent d6fd3a3 commit b9f61e1

File tree

1 file changed

+36
-32
lines changed

1 file changed

+36
-32
lines changed

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 36 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
2020
from vllm.distributed.parallel_state import (
2121
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
22-
get_tp_group, get_world_group)
22+
get_tp_group)
2323
from vllm.logger import init_logger
2424
from vllm.utils import make_zmq_path, make_zmq_socket, round_down
2525
from vllm.v1.core.sched.output import SchedulerOutput
@@ -172,6 +172,11 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
172172
self.vllm_config = vllm_config
173173
self.block_size = vllm_config.cache_config.block_size
174174
self.engine_id = engine_id
175+
self.side_channel_host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
176+
self.side_channel_port = (
177+
envs.VLLM_NIXL_SIDE_CHANNEL_PORT +
178+
vllm_config.parallel_config.data_parallel_rank_local *
179+
vllm_config.parallel_config.tensor_parallel_size)
175180
logger.info("Initializing NIXL Scheduler %s", engine_id)
176181

177182
# Requests that need to start recv.
@@ -310,8 +315,8 @@ def request_finished(
310315
do_remote_decode=False,
311316
remote_block_ids=computed_block_ids,
312317
remote_engine_id=self.engine_id,
313-
remote_host=envs.VLLM_NIXL_SIDE_CHANNEL_HOST,
314-
remote_port=envs.VLLM_NIXL_SIDE_CHANNEL_PORT,
318+
remote_host=self.side_channel_host,
319+
remote_port=self.side_channel_port,
315320
)
316321

317322

@@ -330,11 +335,19 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
330335
# Map of engine_id -> agent_name.
331336
self._remote_agents: dict[str, str] = {}
332337

338+
# NIXL handshake port.
339+
# NOTE(rob): Within a DP group, each DP rank gets its own
340+
# base port (which is sent in the KVTransferParams).
341+
# Each TP rank listens/queries on the base_port + tp_rank.
342+
self.side_channel_port = (
343+
envs.VLLM_NIXL_SIDE_CHANNEL_PORT +
344+
vllm_config.parallel_config.data_parallel_rank_local *
345+
vllm_config.parallel_config.tensor_parallel_size)
346+
333347
# Metadata.
334348
self.engine_id = engine_id
335-
self.rank = get_tensor_model_parallel_rank()
349+
self.tp_rank = get_tensor_model_parallel_rank()
336350
self.world_size = get_tensor_model_parallel_world_size()
337-
self.world_rank = get_world_group().rank_in_group
338351
self.tp_group = get_tp_group()
339352

340353
# KV Caches and nixl tracking data.
@@ -383,16 +396,11 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
383396

384397
@staticmethod
385398
def _nixl_handshake_listener(metadata: NixlAgentMetadata,
386-
ready_event: threading.Event,
387-
world_rank: int):
399+
ready_event: threading.Event, base_port: int,
400+
tp_rank: int):
388401
"""Background thread for getting new NIXL handshakes."""
389402
# NOTE(rob): this is a simple implementation. We will move
390-
# to a better approach like an ETCD server in the future.
391-
392-
# NOTE(rob): to support heterogeneous TP, we will have to
393-
# move this into the scheduler rather than worker, since
394-
# each rank needs the metadata of all other ranks (whereas
395-
# in this setup, each rank only gets one other rank's meta.
403+
# to a better approach via HTTP endpoint soon.
396404

397405
encoder = msgspec.msgpack.Encoder()
398406
encoded_data = encoder.encode(metadata)
@@ -402,11 +410,7 @@ def _nixl_handshake_listener(metadata: NixlAgentMetadata,
402410

403411
# Listen for new requests for metadata.
404412
host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
405-
# NOTE(rob): we need each rank to have a unique port. This
406-
# hack to keeps us moving. We will switch when moving to etcd
407-
# or where we have a single ZMQ socket in the scheduler.
408-
port = envs.VLLM_NIXL_SIDE_CHANNEL_PORT + world_rank
409-
path = make_zmq_path("tcp", host, port)
413+
path = make_zmq_path("tcp", host, base_port + tp_rank)
410414
logger.debug("Starting listening on path: %s", path)
411415
with zmq_ctx(zmq.ROUTER, path) as sock:
412416
ready_event.set()
@@ -421,10 +425,10 @@ def _nixl_handshake(self, host: str, port: int):
421425
"""Do a NIXL handshake with a remote instance."""
422426

423427
start_time = time.perf_counter()
424-
# NOTE(rob): we need each rank to have a unique port. This is
425-
# a hack to keep us moving. We will switch when moving to etcd
426-
# or where we have a single ZMQ socket in the scheduler.
427-
path = make_zmq_path("tcp", host, port + self.world_rank)
428+
# NOTE(rob): we need each tp_rank to have a unique port.
429+
# This is a hack to keep us moving. We will switch when
430+
# we switch to HTTP-based NIXL metadata exchange.
431+
path = make_zmq_path("tcp", host, port + self.tp_rank)
428432
logger.debug("Querying metadata on path: %s", path)
429433
with zmq_ctx(zmq.REQ, path) as sock:
430434
# Send query for the request.
@@ -532,7 +536,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
532536
ready_event = threading.Event()
533537
self._nixl_handshake_listener_t = threading.Thread(
534538
target=self._nixl_handshake_listener,
535-
args=(metadata, ready_event, self.world_rank),
539+
args=(metadata, ready_event, self.side_channel_port, self.tp_rank),
536540
daemon=True,
537541
name="nixl_handshake_listener")
538542
self._nixl_handshake_listener_t.start()
@@ -556,9 +560,9 @@ def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata):
556560
block_offset = block_id * self.block_len
557561
# (addr, len, device id)
558562
blocks_data.append(
559-
(base_addr + block_offset, self.block_len, self.rank))
560-
logger.debug("Created %s blocks for src engine %s and rank %s",
561-
len(blocks_data), self.engine_id, self.rank)
563+
(base_addr + block_offset, self.block_len, self.tp_rank))
564+
logger.debug("Created %s blocks for src engine %s and tp_rank %s",
565+
len(blocks_data), self.engine_id, self.tp_rank)
562566

563567
# Register with NIXL.
564568
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
@@ -573,9 +577,9 @@ def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata):
573577
block_offset = block_id * self.block_len
574578
# (addr, len, device id)
575579
blocks_data.append(
576-
(base_addr + block_offset, self.block_len, self.rank))
577-
logger.debug("Created %s blocks for dst engine %s and rank %s",
578-
len(blocks_data), engine_id, self.rank)
580+
(base_addr + block_offset, self.block_len, self.tp_rank))
581+
logger.debug("Created %s blocks for dst engine %s and tp_rank %s",
582+
len(blocks_data), engine_id, self.tp_rank)
579583

580584
# Register with NIXL.
581585
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
@@ -600,14 +604,14 @@ def get_finished(self) -> tuple[set[str], set[str]]:
600604
if len(done_sending) > 0 or len(done_recving) > 0:
601605
logger.debug(
602606
"Rank %s, get_finished: %s requests done sending "
603-
"and %s requests done recving", self.rank, len(done_sending),
604-
len(done_recving))
607+
"and %s requests done recving", self.tp_rank,
608+
len(done_sending), len(done_recving))
605609

606610
if self.world_size == 1:
607611
return done_sending, done_recving
608612

609613
# Rank 0: get finished from all other ranks.
610-
if self.rank == 0:
614+
if self.tp_rank == 0:
611615
for req_id in done_sending:
612616
self._done_sending_count[req_id] += 1
613617
for req_id in done_recving:

0 commit comments

Comments
 (0)