Skip to content

Commit 662d0f6

Browse files
committed
working checkpoint
Signed-off-by: Will Eaton <weaton@redhat.com>
1 parent 1765197 commit 662d0f6

File tree

4 files changed

+32
-25
lines changed

4 files changed

+32
-25
lines changed

vllm/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4510,6 +4510,11 @@ def __post_init__(self):
45104510
if self.kv_events_config is not None:
45114511
# Hybrid KV cache manager is not compatible with KV events.
45124512
self.scheduler_config.disable_hybrid_kv_cache_manager = True
4513+
4514+
if (self.kv_transfer_config is not None
4515+
and self.kv_transfer_config.is_kv_transfer_instance):
4516+
from collections import defaultdict
4517+
self.cache_config.transfer_handshake_metadata = defaultdict(dict)
45134518

45144519
def update_sizes_for_sequence_parallelism(self,
45154520
possible_sizes: list) -> list:

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -457,16 +457,14 @@ def on_handshake_complete(fut: Future):
457457
# Remove from futures dict
458458
if engine_id in self._handshake_futures:
459459
del self._handshake_futures[engine_id]
460-
# The scheduler will retry them on the next cycle and
461-
# they'll be processed normally since the remote agent
462-
# is now registered.
463460
if engine_id in self._pending_requests:
464461
completed_reqs = self._pending_requests[engine_id]
465462
del self._pending_requests[engine_id]
463+
for req_id, meta in completed_reqs:
464+
self._ready_requests.put((req_id, meta))
466465
logger.debug(
467466
"Handshake completed for engine %s. "
468-
"Cleared %d requests from pending - " \
469-
"scheduler to retry",
467+
"Moved %d requests to ready queue for processing",
470468
engine_id, len(completed_reqs))
471469
except Exception as e:
472470
logger.warning("Handshake failed for engine %s: %s", engine_id,
@@ -507,6 +505,10 @@ def _nixl_handshake(self, host: str, port: int):
507505
logger.error("Failed to fetch metadata from %s: %s", url, e)
508506
raise
509507

508+
if res is None:
509+
logger.warning("Remote server returned None metadata, skipping handshake")
510+
raise RuntimeError("Remote server returned None metadata")
511+
510512
remote_tp_size = len(res.keys())
511513
# Default case is that the remote TP size is 1, so we can
512514
# directly access the metadata.
@@ -525,6 +527,7 @@ def _nixl_handshake(self, host: str, port: int):
525527
if metadata_bytes is not None:
526528
# Reconstruct NixlAgentMetadata from JSON response
527529
# agent_metadata is base64-encoded binary data, not msgpack
530+
tp_data.pop("agent_metadata", None)
528531
metadata = NixlAgentMetadata(
529532
agent_metadata=base64.b64decode(metadata_bytes), **tp_data)
530533

@@ -547,6 +550,7 @@ def _nixl_handshake(self, host: str, port: int):
547550
logger.warning(
548551
"Received None metadata from %s:%s, skipping NIXL handshake",
549552
host, port)
553+
raise RuntimeError("Remote server does not support NIXL")
550554

551555
logger.debug("NIXL handshake method completed for %s:%s", host, port)
552556

@@ -834,12 +838,6 @@ def get_finished(self) -> KVTransferFinishedResult:
834838
finished_recving=done_recving,
835839
pending_handshake=pending_handshake)
836840

837-
if not local_result.is_empty():
838-
logger.debug(
839-
"Rank %s, get_finished: %s requests done sending, "
840-
"%s requests done recving, %s pending handshake", self.tp_rank,
841-
len(done_sending), len(done_recving), len(pending_handshake))
842-
843841
if self.world_size == 1:
844842
return local_result
845843

@@ -939,26 +937,31 @@ def _pop_done_transfers(
939937
return done_req_ids
940938

941939
def _process_ready_requests(self):
942-
"""Process requests that are ready after handshake completion.
943-
944-
Note: With scheduler-based retry logic, this method is simplified
945-
as automatic retries are handled by the scheduler.
946-
"""
947-
# Clear any remaining items in the ready queue to prevent memory leaks
940+
"""Process requests that are ready after handshake completion."""
941+
processed_count = 0
948942
while True:
949943
try:
950-
self._ready_requests.get_nowait()
944+
req_id, meta = self._ready_requests.get_nowait()
945+
logger.debug("Processing ready request %s for engine %s",
946+
req_id, meta.remote_engine_id)
947+
self._read_blocks(
948+
request_id=req_id,
949+
dst_engine_id=meta.remote_engine_id,
950+
local_block_ids=meta.local_block_ids,
951+
remote_block_ids=meta.remote_block_ids,
952+
)
953+
processed_count += 1
951954
except queue.Empty:
952955
break
956+
957+
if processed_count > 0:
958+
logger.debug("Processed %d ready requests", processed_count)
953959

954960
def start_load_kv(self, metadata: NixlConnectorMetadata):
955961
"""
956962
Start loading by triggering non-blocking nixl_xfer.
957963
We check for these trnxs to complete in each step().
958964
"""
959-
logger.debug("start_load_kv called with %d requests",
960-
len(metadata.requests))
961-
962965
for req_id, meta in metadata.requests.items():
963966
logger.debug(
964967
"start_load_kv for request %s from remote engine %s. "

vllm/v1/engine/core.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,9 +175,7 @@ def _initialize_kv_caches(
175175
# Collect KV connector xfer metadata from workers
176176
# (after KV cache registration)
177177
transfer_handshake_metadata = (
178-
self.model_executor.get_kv_connector_handshake_metadata()
179-
if self.vllm_config.cache_config.transfer_handshake_metadata else
180-
None)
178+
self.model_executor.get_kv_connector_handshake_metadata())
181179

182180
elapsed = time.time() - start
183181
logger.info(("init engine (profile, create kv cache, "

vllm/v1/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -541,7 +541,8 @@ def wait_for_engine_startup(
541541
logger.debug(
542542
"Received transfer handshake metadata from engine %s: %s",
543543
eng_index, txfer_metadata)
544-
# Merge the received metadata with existing cache config
544+
if cache_config.transfer_handshake_metadata is None:
545+
cache_config.transfer_handshake_metadata = defaultdict(dict)
545546
for tp_rank, dp_dict in txfer_metadata.items():
546547
for dp_rank, metadata in dp_dict.items():
547548
cache_config.transfer_handshake_metadata[tp_rank][

0 commit comments

Comments
 (0)