Skip to content

Commit dba3835

Browse files
committed
flip protocol; fix scheduling order bug
Signed-off-by: Will Eaton <weaton@redhat.com>
1 parent efd655b commit dba3835

File tree

7 files changed

+66
-29
lines changed

7 files changed

+66
-29
lines changed

vllm/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1515,7 +1515,7 @@ class CacheConfig:
15151515

15161516
transfer_handshake_metadata: Optional[dict[int, dict[
15171517
int, KVConnectorHandshakeMetadata]]] = field(default=None, init=False)
1518-
"""Metadata for the KV connector handshake."""
1518+
"""Metadata for the KV connector handshake. Structure: dp_rank -> tp_rank -> metadata"""
15191519

15201520
def compute_hash(self) -> str:
15211521
"""

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -489,8 +489,9 @@ def _nixl_handshake(self, host: str, port: int):
489489
start_time = time.perf_counter()
490490
logger.debug("Starting NIXL handshake with %s:%s", host, port)
491491

492-
# TODO: make the scheme dynamic, and/or implement https on both sides.
493-
url = build_uri("http", host, port, path="get_kv_connector_metadata")
492+
# Use the new endpoint scheme to filter by dp_rank and tp_rank
493+
# Default to dp_rank 0 and use current tp_rank for optimal filtering
494+
url = build_uri("http", host, port, path=f"get_kv_connector_metadata/0/{self.tp_rank}")
494495
logger.debug("Querying metadata on path: %s", url)
495496

496497
try:
@@ -509,20 +510,29 @@ def _nixl_handshake(self, host: str, port: int):
509510
logger.warning("Remote server returned None metadata, skipping handshake")
510511
raise RuntimeError("Remote server returned None metadata")
511512

512-
remote_tp_size = len(res.keys())
513-
# Default case is that the remote TP size is 1, so we can
514-
# directly access the metadata.
515-
tp_data = res.get(str(self.tp_rank), {}).get("0", {})
516-
metadata_bytes = tp_data.get("agent_metadata", None)
517-
518-
# Handshake only with the other TP remote the current local rank will
519-
# pull from. With homogeneous TP it happens to be the same rank_i.
520-
tp_ratio = self._tp_size[self.engine_id] // remote_tp_size
521-
p_remote_rank = self.tp_rank // tp_ratio
522-
if p_remote_rank > 0:
523-
metadata_bytes = res.get(str(p_remote_rank),
524-
{}).get("0",
525-
{}).get("agent_metadata", None)
513+
# With filtered response from new endpoint, we get: {dp_rank: {tp_rank: metadata}}
514+
# Since we filtered by dp_rank=0 and tp_rank=self.tp_rank, extract directly
515+
if "0" in res and str(self.tp_rank) in res["0"]:
516+
tp_data = res["0"][str(self.tp_rank)]
517+
metadata_bytes = tp_data.get("agent_metadata", None)
518+
p_remote_rank = self.tp_rank # Use current tp_rank for filtered response
519+
else:
520+
# Fallback to unfiltered endpoint for heterogeneous TP cases
521+
url_fallback = build_uri("http", host, port, path="get_kv_connector_metadata")
522+
logger.debug("Using fallback unfiltered endpoint: %s", url_fallback)
523+
req = Request(url_fallback)
524+
with urlopen(req, timeout=5.0) as response:
525+
response_data = response.read().decode('utf-8')
526+
res = json.loads(response_data)
527+
528+
dp_data = res.get("0", {})
529+
remote_tp_size = len(dp_data.keys()) if dp_data else 1
530+
531+
# Handle heterogeneous TP mapping
532+
tp_ratio = self._tp_size[self.engine_id] // remote_tp_size
533+
p_remote_rank = self.tp_rank // tp_ratio
534+
tp_data = dp_data.get(str(p_remote_rank), {})
535+
metadata_bytes = tp_data.get("agent_metadata", None)
526536

527537
if metadata_bytes is not None:
528538
# Reconstruct NixlAgentMetadata from JSON response
@@ -962,6 +972,7 @@ def start_load_kv(self, metadata: NixlConnectorMetadata):
962972
Start loading by triggering non-blocking nixl_xfer.
963973
We check for these trnxs to complete in each step().
964974
"""
975+
965976
for req_id, meta in metadata.requests.items():
966977
logger.debug(
967978
"start_load_kv for request %s from remote engine %s. "

vllm/entrypoints/openai/api_server.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -869,8 +869,29 @@ async def show_server_info(raw_request: Request):
869869
return JSONResponse(content=server_info)
870870

871871
@router.get("/get_kv_connector_metadata")
872-
async def get_kv_connector_metadata(raw_request: Request):
872+
@router.get("/get_kv_connector_metadata/{dp_rank}")
873+
@router.get("/get_kv_connector_metadata/{dp_rank}/{tp_rank}")
874+
async def get_kv_connector_metadata(raw_request: Request, dp_rank: int = None, tp_rank: int = None):
873875
kv_connector_metadata = raw_request.app.state.vllm_config.cache_config.transfer_handshake_metadata
876+
877+
if kv_connector_metadata is None:
878+
return JSONResponse(content=None)
879+
880+
# Filter by dp_rank if specified
881+
if dp_rank is not None:
882+
if dp_rank not in kv_connector_metadata:
883+
return JSONResponse(content={})
884+
dp_data = kv_connector_metadata[dp_rank]
885+
886+
# Filter by tp_rank if also specified
887+
if tp_rank is not None:
888+
if tp_rank not in dp_data:
889+
return JSONResponse(content={})
890+
return JSONResponse(content={dp_rank: {tp_rank: dp_data[tp_rank]}})
891+
else:
892+
return JSONResponse(content={dp_rank: dp_data})
893+
894+
# Return all metadata if no filtering
874895
return JSONResponse(content=kv_connector_metadata)
875896

876897
@router.post("/reset_prefix_cache")

vllm/v1/core/sched/scheduler.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -719,11 +719,7 @@ def update_from_output(
719719
for request in self.running:
720720
req_id = request.request_id
721721
num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0)
722-
if num_tokens_scheduled == 0:
723-
# The request was not scheduled in this step.
724-
new_running.append(request)
725-
continue
726-
722+
727723
# Check if this request is pending handshake and needs to reschedule
728724
if (pending_handshake_req_ids
729725
and req_id in pending_handshake_req_ids):
@@ -734,6 +730,11 @@ def update_from_output(
734730
num_tokens_to_reschedule -= request.num_computed_tokens
735731
new_running.append(request)
736732
continue
733+
734+
if num_tokens_scheduled == 0:
735+
# The request was not scheduled in this step.
736+
new_running.append(request)
737+
continue
737738

738739
req_index = model_runner_output.req_id_to_index[req_id]
739740
generated_token_ids = sampled_token_ids[req_index]

vllm/v1/engine/core.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,11 @@ def _perform_handshake(
458458
content = {}
459459
for worker_dict in self.transfer_handshake_metadata:
460460
if worker_dict is not None:
461-
content.update(worker_dict)
461+
# Deep merge the nested dictionaries instead of overwriting
462+
for dp_rank, tp_dict in worker_dict.items():
463+
if dp_rank not in content:
464+
content[dp_rank] = {}
465+
content[dp_rank].update(tp_dict)
462466
handshake_message["transfer_handshake_metadata"] = content
463467

464468
handshake_socket.send(msgspec.msgpack.encode(handshake_message))

vllm/v1/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -543,10 +543,10 @@ def wait_for_engine_startup(
543543
eng_index, txfer_metadata)
544544
if cache_config.transfer_handshake_metadata is None:
545545
cache_config.transfer_handshake_metadata = defaultdict(dict)
546-
for tp_rank, dp_dict in txfer_metadata.items():
547-
for dp_rank, metadata in dp_dict.items():
548-
cache_config.transfer_handshake_metadata[tp_rank][
549-
dp_rank] = metadata
546+
for dp_rank, tp_dict in txfer_metadata.items():
547+
for tp_rank, metadata in tp_dict.items():
548+
cache_config.transfer_handshake_metadata[dp_rank][
549+
tp_rank] = metadata
550550

551551
start_pending[0 if local else 1] -= 1
552552
engine.state = CoreEngineState.READY

vllm/v1/worker/gpu_worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ def get_kv_connector_handshake_metadata(self) -> Optional[dict]:
252252

253253
tp_rank = get_tp_group().rank_in_group
254254
dp_rank = self.vllm_config.parallel_config.data_parallel_rank_local
255-
return {tp_rank: {dp_rank: msgspec.to_builtins(metadata)}}
255+
return {dp_rank: {tp_rank: msgspec.to_builtins(metadata)}}
256256

257257
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
258258
return self.model_runner.get_kv_cache_spec()

0 commit comments

Comments
 (0)