Skip to content

Commit cd587c9

Browse files
Missmiaomleiyimingnjhill
authored
[BugFix]: Properly set engine_id when using multi connector (#19487)
Signed-off-by: Nick Hill <nhill@redhat.com> Co-authored-by: leiyiming <leiyiming@kingsoft.com> Co-authored-by: Nick Hill <nhill@redhat.com>
1 parent 332d4cb commit cd587c9

File tree

4 files changed

+48
-31
lines changed

4 files changed

+48
-31
lines changed

tests/v1/kv_connector/unit/test_multi_connector.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ def wrapper(*args, **kwargs):
7676
return attr
7777

7878

79+
# This relies on "fork" multiprocessing method being used.
80+
# It's the default but vLLM may fall back to spawn if for example CUDA
81+
# is already initialized.
7982
KVConnectorFactory.register_connector("TestSharedStorageConnector",
8083
TestSharedStorageConnector.__module__,
8184
TestSharedStorageConnector.__name__)

tests/v1/kv_connector/unit/test_nixl_connector.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,8 @@ def __init__(self, *args, hand_shake_latency: float = 1.8, **kwargs):
166166
super().__init__(*args, **kwargs)
167167
self._hand_shake_latency = hand_shake_latency
168168

169-
def _nixl_handshake(self, host: str, port: int,
170-
remote_tp_size: int) -> dict[int, str]:
169+
def _nixl_handshake(self, host: str, port: int, remote_tp_size: int,
170+
expected_engine_id: str) -> dict[int, str]:
171171
# Mimic slow _nixl_handshake, as well as bypass zmq communication.
172172
time.sleep(self._hand_shake_latency)
173173
# These should've been done in register_kv_caches(), called by
@@ -177,6 +177,8 @@ def _nixl_handshake(self, host: str, port: int,
177177
self.num_blocks = 1
178178
self.dst_num_blocks[self.engine_id] = self.num_blocks
179179

180+
assert expected_engine_id == self.REMOTE_ENGINE_ID
181+
180182
remote_agent_name = self.add_remote_agent(
181183
NixlAgentMetadata(
182184
engine_id=self.REMOTE_ENGINE_ID,

vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,10 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
4747
assert ktcs is not None
4848
for ktc in ktcs:
4949
temp_config = copy.copy(vllm_config)
50-
temp_config.kv_transfer_config = KVTransferConfig(**ktc)
50+
engine_id = ktc.get("engine_id",
51+
vllm_config.kv_transfer_config.engine_id)
52+
temp_config.kv_transfer_config = KVTransferConfig(
53+
**ktc, engine_id=engine_id)
5154
self._connectors.append(
5255
KVConnectorFactory.create_connector_v1(temp_config, role))
5356

@@ -187,7 +190,7 @@ def request_finished(
187190
async_saves += 1
188191
if txfer_params is not None:
189192
if kv_txfer_params is not None:
190-
#TODO we can probably change this to merge the dicts here,
193+
# TODO we can probably change this to merge the dicts here,
191194
# checking for key clashes.
192195
raise RuntimeError(
193196
"Only one connector can produce KV transfer params")

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 36 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -488,8 +488,13 @@ def _nixl_handshake_listener(metadata: NixlAgentMetadata,
488488
"Connection listener got unexpected message %s", msg)
489489
sock.send_multipart((identity, b"", encoded_data))
490490

491-
def _nixl_handshake(self, host: str, port: int,
492-
remote_tp_size: int) -> dict[int, str]:
491+
def _nixl_handshake(
492+
self,
493+
host: str,
494+
port: int,
495+
remote_tp_size: int,
496+
expected_engine_id: str,
497+
) -> dict[int, str]:
493498
"""Do a NIXL handshake with a remote instance."""
494499

495500
start_time = time.perf_counter()
@@ -498,35 +503,39 @@ def _nixl_handshake(self, host: str, port: int,
498503
# a hack to keep us moving. We will switch when moving to etcd
499504
# or where we have a single ZMQ socket in the scheduler.
500505

501-
def handshake(path: str, rank: int) -> str:
502-
# Send query for the request.
503-
with zmq_ctx(zmq.REQ, path) as sock:
504-
sock.send(GET_META_MSG)
505-
metadata_bytes = sock.recv()
506-
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
507-
metadata = decoder.decode(metadata_bytes)
508-
got_metadata_time = time.perf_counter()
509-
510-
# Register Remote agent.
511-
remote_agent_name = self.add_remote_agent(
512-
metadata, rank, remote_tp_size)
513-
setup_agent_time = time.perf_counter()
514-
515-
logger.debug("NIXL handshake: get metadata took: %s",
516-
got_metadata_time - start_time)
517-
logger.debug("NIXL handshake: add agent took: %s",
518-
setup_agent_time - got_metadata_time)
519-
return remote_agent_name
520-
521506
# Handshake only with the remote TP rank that current local rank will
522507
# pull from. With homogeneous TP it happens to be the same rank_i.
523508
tp_ratio = self._tp_size[self.engine_id] // remote_tp_size
524509
p_remote_rank = self.tp_rank // tp_ratio
525510
path = make_zmq_path("tcp", host, port + p_remote_rank)
526511
logger.debug("Querying metadata on path: %s at remote rank %s", path,
527512
p_remote_rank)
513+
514+
# Send query for the request.
515+
with zmq_ctx(zmq.REQ, path) as sock:
516+
sock.send(GET_META_MSG)
517+
metadata_bytes = sock.recv()
518+
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
519+
metadata = decoder.decode(metadata_bytes)
520+
got_metadata_time = time.perf_counter()
521+
logger.debug("NIXL handshake: get metadata took: %s",
522+
got_metadata_time - start_time)
523+
524+
# Ensure engine id matches.
525+
if metadata.engine_id != expected_engine_id:
526+
raise RuntimeError(f"Remote NIXL agent engine ID mismatch. "
527+
f"Expected {expected_engine_id},"
528+
f"received {metadata.engine_id}.")
529+
530+
# Register Remote agent.
531+
remote_agent_name = self.add_remote_agent(metadata, p_remote_rank,
532+
remote_tp_size)
533+
setup_agent_time = time.perf_counter()
534+
logger.debug("NIXL handshake: add agent took: %s",
535+
setup_agent_time - got_metadata_time)
536+
528537
# Remote rank -> agent name.
529-
return {p_remote_rank: handshake(path, p_remote_rank)}
538+
return {p_remote_rank: remote_agent_name}
530539

531540
def _background_nixl_handshake(self, req_id: str,
532541
remote_engine_id: EngineId, meta: ReqMeta):
@@ -535,7 +544,7 @@ def _background_nixl_handshake(self, req_id: str,
535544
if fut is None:
536545
fut = self._handshake_initiation_executor.submit(
537546
self._nixl_handshake, meta.remote_host, meta.remote_port,
538-
meta.tp_size)
547+
meta.tp_size, remote_engine_id)
539548
self._handshake_futures[remote_engine_id] = fut
540549

541550
def done_callback(f: Future[dict[int, str]], eid=remote_engine_id):
@@ -738,10 +747,10 @@ def add_remote_agent(self,
738747
if remote_tp_rank in self._remote_agents.get(engine_id, {}):
739748
return self._remote_agents[engine_id][remote_tp_rank]
740749

741-
if engine_id in self._tp_size:
742-
assert self._tp_size[engine_id] == remote_tp_size
743-
else:
750+
if engine_id not in self._tp_size:
744751
self._tp_size[engine_id] = remote_tp_size
752+
else:
753+
assert self._tp_size[engine_id] == remote_tp_size
745754
# We may eventually enable this after asserting equality in cache
746755
# layout and close outputs.
747756
assert nixl_agent_meta.attn_backend_name == self.backend_name

0 commit comments

Comments
 (0)