Skip to content

Commit 9dd2c1c

Browse files
committed
attempt to background agent registration
Signed-off-by: Will Eaton <weaton@redhat.com>
1 parent efbd791 commit 9dd2c1c

File tree

2 files changed

+161
-30
lines changed

2 files changed

+161
-30
lines changed

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 161 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
11
# SPDX-License-Identifier: Apache-2.0
2+
<<<<<<< HEAD
23
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
34
import asyncio
5+
=======
6+
import base64
7+
import json
8+
>>>>>>> a1eaf5a5e (attempt to background agent registration)
49
import math
510
import threading
611
import time
712
import uuid
813
from collections import defaultdict
914
from dataclasses import dataclass
1015
from typing import TYPE_CHECKING, Any, Optional
11-
import json
12-
import base64
13-
import aiohttp
14-
import msgspec
16+
from urllib.request import Request, urlopen
17+
from urllib.error import URLError, HTTPError
18+
from urllib.parse import urljoin
1519
import torch
1620

1721
from vllm import envs
@@ -379,6 +383,12 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
379383
# [req_id -> list[handle]]
380384
self._recving_transfers = defaultdict[str, list[Transfer]](list)
381385

386+
387+
# Pending requests waiting for handshake completion
388+
# [engine_id -> list[(req_id, meta)]]
389+
self._pending_requests: dict[str, list[tuple[str, ReqMeta]]] = {}
390+
391+
382392
# Complete transfer tracker. Used by the rank 0 to track finished
383393
# transactions on ranks 1 to N-1.
384394
# [req_id -> count]
@@ -387,8 +397,11 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
387397
self._done_sending_count: defaultdict[str,
388398
int] = defaultdict(lambda: 0)
389399

390-
# Background thread for establishing new connections.
391-
self._nixl_handshake_listener_t: Optional[threading.Thread] = None
400+
# Background handshake threads for remote engines
401+
self._handshake_threads: dict[str, threading.Thread] = {}
402+
403+
# Thread results for handshake completion tracking
404+
self._handshake_results: dict[str, bool] = {}
392405

393406
self.vllm_config = vllm_config
394407
self.block_size = vllm_config.cache_config.block_size
@@ -417,19 +430,45 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
417430
# finish reading before safely freeing the blocks.
418431
self.consumer_notification_counts_by_req = defaultdict[str, int](int)
419432

420-
async def _nixl_handshake(self, host: str, port: int):
433+
def _run_handshake_in_thread(self, engine_id: str, host: str, port: int):
434+
"""Run handshake in background thread."""
435+
436+
def handshake_worker():
437+
logger.debug("Starting handshake worker for engine %s", engine_id)
438+
try:
439+
self._nixl_handshake(host, port)
440+
self._handshake_results[engine_id] = True
441+
logger.debug("Handshake succeeded for engine %s", engine_id)
442+
except Exception as e:
443+
self._handshake_results[engine_id] = False
444+
logger.warning("Handshake failed for engine %s: %s", engine_id, e)
445+
finally:
446+
logger.debug("Handshake worker finished for engine %s", engine_id)
447+
448+
thread = threading.Thread(target=handshake_worker, daemon=True)
449+
thread._start_time = time.time() # track when thread started
450+
self._handshake_threads[engine_id] = thread
451+
thread.start()
452+
return thread
453+
454+
def _nixl_handshake(self, host: str, port: int):
421455
"""Do a NIXL handshake with a remote instance."""
422456

423457
start_time = time.perf_counter()
424458

425-
url = build_uri(host, port, path="get_kv_connector_metadata")
459+
# TODO: make the scheme dynamic, and/or implement https on both sides.
460+
url = build_uri("http", host, port, path="get_kv_connector_metadata")
426461
logger.debug("Querying metadata on path: %s", url)
427462

428-
timeout = aiohttp.ClientTimeout(total=30.0)
429-
async with aiohttp.ClientSession(timeout=timeout) as session:
430-
async with session.get(url) as response:
431-
res = await response.json()
463+
try:
464+
req = Request(url)
465+
with urlopen(req, timeout=5.0) as response:
466+
response_data = response.read().decode('utf-8')
467+
res = json.loads(response_data)
432468
logger.debug("NIXL handshake response: %s", res)
469+
except (URLError, HTTPError) as e:
470+
logger.error("Failed to fetch metadata from %s: %s", url, e)
471+
raise
433472

434473

435474
remote_tp_size = len(res.keys())
@@ -460,13 +499,18 @@ async def _nixl_handshake(self, host: str, port: int):
460499
)
461500

462501
# Register Remote agent.
463-
self.add_remote_agent(metadata, p_remote_rank)
464-
setup_agent_time = time.perf_counter()
465-
466-
logger.debug("NIXL handshake: get metadata took: %s",
467-
time.perf_counter() - start_time)
468-
logger.debug("NIXL handshake: add agent took: %s",
469-
setup_agent_time - (time.perf_counter() - start_time))
502+
logger.debug("About to register remote agent for engine %s",
503+
metadata.engine_id)
504+
pre_register = time.perf_counter()
505+
self.add_remote_agent(metadata, remote_tp_rank=p_remote_rank)
506+
agent_time = time.perf_counter()
507+
logger.debug("Finished registering remote agent for engine %s",
508+
metadata.engine_id)
509+
510+
logger.debug("NIXL handshake: get metadata took: %s",
511+
pre_register - start_time)
512+
logger.debug("NIXL handshake: add agent took: %s",
513+
agent_time - pre_register)
470514
else:
471515
# If metadata_bytes is None, it means the remote agent
472516
# is not using NIXL, so we can skip the handshake.
@@ -755,6 +799,7 @@ def get_finished(self) -> tuple[set[str], set[str]]:
755799
"Rank %s, get_finished: %s requests done sending "
756800
"and %s requests done recving", self.tp_rank,
757801
len(done_sending), len(done_recving))
802+
758803

759804
if self.world_size == 1:
760805
return done_sending, done_recving
@@ -843,39 +888,128 @@ def _pop_done_transfers(
843888
xfer_state)
844889
return done_req_ids
845890

891+
def _process_completed_handshakes(self):
892+
"""Process completed handshakes and mark remote agents as ready."""
893+
894+
# debug: log current state
895+
if self._handshake_threads:
896+
logger.debug("Processing handshakes: %d active threads, %d pending",
897+
len(self._handshake_threads),
898+
sum(len(reqs) for reqs in self._pending_requests.values()))
899+
900+
completed_engines = []
901+
for engine_id, thread in list(self._handshake_threads.items()):
902+
logger.debug("Checking handshake thread for engine %s: alive=%s",
903+
engine_id, thread.is_alive())
904+
905+
# check for timeout (threads running > 30 seconds)
906+
thread_age = time.time() - getattr(thread, '_start_time', time.time())
907+
if thread.is_alive() and thread_age > 30.0:
908+
logger.warning("Handshake thread for %s running %.1fs (hung?)",
909+
engine_id, thread_age)
910+
911+
if not thread.is_alive():
912+
logger.debug("Handshake completed for engine %s", engine_id)
913+
completed_engines.append(engine_id)
914+
915+
success = self._handshake_results.get(engine_id, False)
916+
logger.debug("Handshake result for engine %s: success=%s",
917+
engine_id, success)
918+
if not success:
919+
logger.warning("Handshake failed for engine %s", engine_id)
920+
continue
921+
922+
logger.debug("Handshake succeeded for engine %s", engine_id)
923+
if engine_id in self._pending_requests:
924+
pending_reqs = self._pending_requests[engine_id]
925+
logger.debug(
926+
"Handshake completed for %s, clearing %d pending requests "
927+
"(will retry naturally on next start_load_kv)",
928+
engine_id, len(pending_reqs))
929+
930+
# clear pending requests - they'll be retried naturally
931+
# by the event loop on the next start_load_kv() call
932+
del self._pending_requests[engine_id]
933+
934+
for engine_id in completed_engines:
935+
logger.debug("Cleaning up handshake thread for engine %s",
936+
engine_id)
937+
del self._handshake_threads[engine_id]
938+
if engine_id in self._handshake_results:
939+
del self._handshake_results[engine_id]
940+
941+
def _is_request_pending_handshake(self, req_id: str) -> bool:
942+
"""Check if request is still pending handshake completion."""
943+
for engine_requests in self._pending_requests.values():
944+
for pending_req_id, _ in engine_requests:
945+
if pending_req_id == req_id:
946+
return True
947+
return False
948+
846949
def start_load_kv(self, metadata: NixlConnectorMetadata):
847950
"""
848951
Start loading by triggering non-blocking nixl_xfer.
849952
We check for these trnxs to complete in each step().
850953
"""
954+
logger.debug("start_load_kv called with %d requests", len(metadata.requests))
851955
for req_id, meta in metadata.requests.items():
956+
if (req_id in self._recving_transfers or
957+
self._is_request_pending_handshake(req_id)):
958+
logger.debug(
959+
"Request %s already being processed, skipping", req_id)
960+
continue
961+
852962
logger.debug(
853963
"start_load_kv for request %s from remote engine %s. "
854964
"Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id,
855965
meta.remote_engine_id, len(meta.local_block_ids),
856966
len(meta.remote_block_ids))
967+
968+
969+
if meta.remote_engine_id not in self._remote_agents:
970+
logger.debug(
971+
"Remote engine %s not registered for request %s, "
972+
"starting handshake and deferring transfer",
973+
meta.remote_engine_id, req_id)
974+
975+
if meta.remote_engine_id not in self._handshake_threads:
976+
logger.debug(
977+
"Starting handshake thread for remote engine %s",
978+
meta.remote_engine_id)
979+
self._run_handshake_in_thread(
980+
meta.remote_engine_id, meta.remote_host,
981+
meta.remote_port)
982+
else:
983+
logger.debug(
984+
"Handshake thread already exists for remote engine %s",
985+
meta.remote_engine_id)
986+
987+
if meta.remote_engine_id not in self._pending_requests:
988+
self._pending_requests[meta.remote_engine_id] = []
989+
self._pending_requests[meta.remote_engine_id].append(
990+
(req_id, meta))
991+
992+
logger.debug(
993+
"Request %s marked as pending handshake for engine %s",
994+
req_id, meta.remote_engine_id)
995+
continue
996+
997+
logger.debug("Remote agent available for %s, calling _read_blocks",
998+
meta.remote_engine_id)
857999
self._read_blocks(
8581000
request_id=req_id,
8591001
dst_engine_id=meta.remote_engine_id,
8601002
local_block_ids=meta.local_block_ids,
8611003
remote_block_ids=meta.remote_block_ids,
862-
remote_host=meta.remote_host,
863-
remote_port=meta.remote_port,
8641004
)
8651005

8661006
def _read_blocks(
8671007
self,
8681008
local_block_ids: list[int],
8691009
remote_block_ids: list[int],
870-
remote_host: str,
871-
remote_port: int,
8721010
dst_engine_id: str,
8731011
request_id: str,
8741012
):
875-
# NOTE(rob): this takes ~2s. We need to get this off the hotpath.
876-
if dst_engine_id not in self._remote_agents:
877-
asyncio.run(self._nixl_handshake(remote_host, remote_port))
878-
8791013
# NOTE(rob): having the staging blocks be on the READER side is
8801014
# not going to work well (since we will have to call rearrange tensors).
8811015
# after we detect the txn is complete (which means we cannot make the

vllm/v1/core/sched/scheduler.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -331,9 +331,6 @@ def schedule(self) -> SchedulerOutput:
331331
if is_ready:
332332
request.status = RequestStatus.WAITING
333333
else:
334-
logger.debug(
335-
"%s is still in WAITING_FOR_REMOTE_KVS state.",
336-
request.request_id)
337334
self.waiting.popleft()
338335
skipped_waiting_requests.appendleft(request)
339336
continue

0 commit comments

Comments
 (0)