Skip to content

Commit 11fd02a

Browse files
committed
actually use handshake timeout; simplify route
Signed-off-by: Will Eaton <weaton@redhat.com>
1 parent e9a5fd8 commit 11fd02a

File tree

4 files changed

+61
-40
lines changed

4 files changed

+61
-40
lines changed

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
from dataclasses import dataclass
1313
from typing import TYPE_CHECKING, Any, Optional
1414
from urllib.error import HTTPError, URLError
15-
from urllib.request import Request, urlopen
15+
from urllib.request import Request as URLRequest
16+
from urllib.request import urlopen
1617

1718
import torch
1819

@@ -398,7 +399,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
398399

399400
# Background handshake threads for remote engines
400401
self._executor = ThreadPoolExecutor(
401-
max_workers=4, thread_name_prefix="nixl-handshake")
402+
max_workers=1, thread_name_prefix="nixl-handshake")
402403
# Thread results for handshake completion tracking
403404
self._handshake_futures: dict[str, Future] = {}
404405
self._pending_requests: dict[str, list[tuple[str, ReqMeta]]] = {}
@@ -491,13 +492,17 @@ def _nixl_handshake(self, host: str, port: int):
491492

492493
# Use the new endpoint scheme to filter by dp_rank and tp_rank
493494
# Default to dp_rank 0 and use current tp_rank for optimal filtering
494-
url = build_uri("http", host, port, path=f"get_kv_connector_metadata/0/{self.tp_rank}")
495+
url = build_uri("http",
496+
host,
497+
port,
498+
path=f"get_kv_connector_metadata/0/{self.tp_rank}")
495499
logger.debug("Querying metadata on path: %s", url)
496500

497501
try:
498-
req = Request(url)
502+
req = URLRequest(url)
499503
logger.debug("About to send HTTP request to %s", url)
500-
with urlopen(req, timeout=5.0) as response:
504+
with urlopen(req,
505+
timeout=envs.VLLM_NIXL_HANDSHAKE_TIMEOUT) as response:
501506
logger.debug("Received HTTP response from %s", url)
502507
response_data = response.read().decode('utf-8')
503508
res = json.loads(response_data)
@@ -507,27 +512,36 @@ def _nixl_handshake(self, host: str, port: int):
507512
raise
508513

509514
if res is None:
510-
logger.warning("Remote server returned None metadata, skipping handshake")
515+
logger.warning(
516+
"Remote server returned None metadata, skipping handshake")
511517
raise RuntimeError("Remote server returned None metadata")
512518

513-
# With filtered response from new endpoint, we get: {dp_rank: {tp_rank: metadata}}
514-
# Since we filtered by dp_rank=0 and tp_rank=self.tp_rank, extract directly
519+
# With filtered response from new endpoint, we get:
520+
# {dp_rank: {tp_rank: metadata}}
521+
# Since we filtered by dp_rank=0 and tp_rank=self.tp_rank,
522+
# extract directly.
515523
if "0" in res and str(self.tp_rank) in res["0"]:
516524
tp_data = res["0"][str(self.tp_rank)]
517525
metadata_bytes = tp_data.get("agent_metadata", None)
518-
p_remote_rank = self.tp_rank # Use current tp_rank for filtered response
526+
# use current tp_rank for filtered response
527+
p_remote_rank = self.tp_rank
519528
else:
520529
# Fallback to unfiltered endpoint for heterogeneous TP cases
521-
url_fallback = build_uri("http", host, port, path="get_kv_connector_metadata")
522-
logger.debug("Using fallback unfiltered endpoint: %s", url_fallback)
523-
req = Request(url_fallback)
524-
with urlopen(req, timeout=5.0) as response:
530+
url_fallback = build_uri("http",
531+
host,
532+
port,
533+
path="get_kv_connector_metadata")
534+
logger.debug("Using fallback unfiltered endpoint: %s",
535+
url_fallback)
536+
req = URLRequest(url_fallback)
537+
with urlopen(req,
538+
timeout=envs.VLLM_NIXL_HANDSHAKE_TIMEOUT) as response:
525539
response_data = response.read().decode('utf-8')
526540
res = json.loads(response_data)
527-
541+
528542
dp_data = res.get("0", {})
529543
remote_tp_size = len(dp_data.keys()) if dp_data else 1
530-
544+
531545
# Handle heterogeneous TP mapping
532546
tp_ratio = self._tp_size[self.engine_id] // remote_tp_size
533547
p_remote_rank = self.tp_rank // tp_ratio
@@ -952,8 +966,8 @@ def _process_ready_requests(self):
952966
while True:
953967
try:
954968
req_id, meta = self._ready_requests.get_nowait()
955-
logger.debug("Processing ready request %s for engine %s",
956-
req_id, meta.remote_engine_id)
969+
logger.debug("Processing ready request %s for engine %s",
970+
req_id, meta.remote_engine_id)
957971
self._read_blocks(
958972
request_id=req_id,
959973
dst_engine_id=meta.remote_engine_id,
@@ -963,7 +977,7 @@ def _process_ready_requests(self):
963977
processed_count += 1
964978
except queue.Empty:
965979
break
966-
980+
967981
if processed_count > 0:
968982
logger.debug("Processed %d ready requests", processed_count)
969983

@@ -972,7 +986,7 @@ def start_load_kv(self, metadata: NixlConnectorMetadata):
972986
Start loading by triggering non-blocking nixl_xfer.
973987
We check for these trnxs to complete in each step().
974988
"""
975-
989+
976990
for req_id, meta in metadata.requests.items():
977991
logger.debug(
978992
"start_load_kv for request %s from remote engine %s. "

vllm/entrypoints/openai/api_server.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -871,28 +871,29 @@ async def show_server_info(raw_request: Request):
871871
@router.get("/get_kv_connector_metadata")
872872
@router.get("/get_kv_connector_metadata/{dp_rank}")
873873
@router.get("/get_kv_connector_metadata/{dp_rank}/{tp_rank}")
874-
async def get_kv_connector_metadata(raw_request: Request, dp_rank: int = None, tp_rank: int = None):
875-
kv_connector_metadata = raw_request.app.state.vllm_config.cache_config.transfer_handshake_metadata
876-
877-
if kv_connector_metadata is None:
878-
return JSONResponse(content=None)
879-
880-
# Filter by dp_rank if specified
874+
async def get_kv_connector_metadata(raw_request: Request,
875+
dp_rank: Optional[int] = None,
876+
tp_rank: Optional[int] = None):
877+
kv_meta: Optional[dict[str, dict[str, dict[str, Any]]]] = (
878+
raw_request.app.state.vllm_config.cache_config.
879+
transfer_handshake_metadata)
880+
881+
if kv_meta is None:
882+
return None
883+
881884
if dp_rank is not None:
882-
if dp_rank not in kv_connector_metadata:
883-
return JSONResponse(content={})
884-
dp_data = kv_connector_metadata[dp_rank]
885-
886-
# Filter by tp_rank if also specified
885+
if dp_rank not in kv_meta:
886+
return {}
887+
dp_data = kv_meta[dp_rank]
888+
887889
if tp_rank is not None:
888890
if tp_rank not in dp_data:
889-
return JSONResponse(content={})
890-
return JSONResponse(content={dp_rank: {tp_rank: dp_data[tp_rank]}})
891+
return {}
892+
return {dp_rank: {tp_rank: dp_data[tp_rank]}}
891893
else:
892-
return JSONResponse(content={dp_rank: dp_data})
893-
894-
# Return all metadata if no filtering
895-
return JSONResponse(content=kv_connector_metadata)
894+
return {dp_rank: dp_data}
895+
896+
return kv_meta
896897

897898
@router.post("/reset_prefix_cache")
898899
async def reset_prefix_cache(raw_request: Request):

vllm/envs.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@
122122
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
123123
VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost"
124124
VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557
125+
VLLM_NIXL_HANDSHAKE_TIMEOUT: float = 2.0
125126
VLLM_ALL2ALL_BACKEND: str = "naive"
126127
VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840
127128
VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1
@@ -840,6 +841,11 @@ def get_vllm_port() -> Optional[int]:
840841
"VLLM_NIXL_SIDE_CHANNEL_PORT":
841842
lambda: int(os.getenv("VLLM_NIXL_SIDE_CHANNEL_PORT", "5557")),
842843

844+
# Timeout in seconds for NIXL HTTP handshake requests.
845+
# Default is 2 seconds
846+
"VLLM_NIXL_HANDSHAKE_TIMEOUT":
847+
lambda: float(os.getenv("VLLM_NIXL_HANDSHAKE_TIMEOUT", "2.0")),
848+
843849
# all2all backend for vllm's expert parallel communication
844850
# Available options:
845851
# - "naive": naive all2all implementation using all-reduce

vllm/v1/engine/core.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -453,12 +453,12 @@ def _perform_handshake(
453453
if hasattr(self, 'transfer_handshake_metadata'
454454
) and self.transfer_handshake_metadata:
455455
# self.transfer_handshake_metadata is list of dicts from workers
456-
# Each dict already has structure {tp_rank: {dp_rank: metadata}}
456+
# Each dict already has structure {dp_rank: {tp_rank: metadata}}
457457
# Merge all worker dicts into a single dict
458-
content = {}
458+
content: dict[str, dict[str, dict[str, Any]]] = {}
459459
for worker_dict in self.transfer_handshake_metadata:
460460
if worker_dict is not None:
461-
# Deep merge the nested dictionaries instead of overwriting
461+
# Deep merge nested dictionaries instead of overwrite
462462
for dp_rank, tp_dict in worker_dict.items():
463463
if dp_rank not in content:
464464
content[dp_rank] = {}

0 commit comments

Comments
 (0)