Skip to content

Commit 6eb01a5

Browse files
committed
revert back to working het TP logic
Signed-off-by: Will Eaton <weaton@redhat.com>
1 parent 11fd02a commit 6eb01a5

File tree

1 file changed

+45
-69
lines changed

1 file changed

+45
-69
lines changed

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 45 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -490,23 +490,14 @@ def _nixl_handshake(self, host: str, port: int):
490490
start_time = time.perf_counter()
491491
logger.debug("Starting NIXL handshake with %s:%s", host, port)
492492

493-
# Use the new endpoint scheme to filter by dp_rank and tp_rank
494-
# Default to dp_rank 0 and use current tp_rank for optimal filtering
495-
url = build_uri("http",
496-
host,
497-
port,
498-
path=f"get_kv_connector_metadata/0/{self.tp_rank}")
499-
logger.debug("Querying metadata on path: %s", url)
493+
url = build_uri("http", host, port, path="get_kv_connector_metadata")
500494

501495
try:
502496
req = URLRequest(url)
503-
logger.debug("About to send HTTP request to %s", url)
504497
with urlopen(req,
505498
timeout=envs.VLLM_NIXL_HANDSHAKE_TIMEOUT) as response:
506-
logger.debug("Received HTTP response from %s", url)
507499
response_data = response.read().decode('utf-8')
508500
res = json.loads(response_data)
509-
logger.debug("NIXL handshake response: %s", res)
510501
except (URLError, HTTPError) as e:
511502
logger.error("Failed to fetch metadata from %s: %s", url, e)
512503
raise
@@ -516,65 +507,50 @@ def _nixl_handshake(self, host: str, port: int):
516507
"Remote server returned None metadata, skipping handshake")
517508
raise RuntimeError("Remote server returned None metadata")
518509

519-
# With filtered response from new endpoint, we get:
520-
# {dp_rank: {tp_rank: metadata}}
521-
# Since we filtered by dp_rank=0 and tp_rank=self.tp_rank,
522-
# extract directly.
523-
if "0" in res and str(self.tp_rank) in res["0"]:
524-
tp_data = res["0"][str(self.tp_rank)]
525-
metadata_bytes = tp_data.get("agent_metadata", None)
526-
# use current tp_rank for filtered response
527-
p_remote_rank = self.tp_rank
528-
else:
529-
# Fallback to unfiltered endpoint for heterogeneous TP cases
530-
url_fallback = build_uri("http",
531-
host,
532-
port,
533-
path="get_kv_connector_metadata")
534-
logger.debug("Using fallback unfiltered endpoint: %s",
535-
url_fallback)
536-
req = URLRequest(url_fallback)
537-
with urlopen(req,
538-
timeout=envs.VLLM_NIXL_HANDSHAKE_TIMEOUT) as response:
539-
response_data = response.read().decode('utf-8')
540-
res = json.loads(response_data)
541-
542-
dp_data = res.get("0", {})
543-
remote_tp_size = len(dp_data.keys()) if dp_data else 1
544-
545-
# Handle heterogeneous TP mapping
546-
tp_ratio = self._tp_size[self.engine_id] // remote_tp_size
547-
p_remote_rank = self.tp_rank // tp_ratio
548-
tp_data = dp_data.get(str(p_remote_rank), {})
549-
metadata_bytes = tp_data.get("agent_metadata", None)
550-
551-
if metadata_bytes is not None:
552-
# Reconstruct NixlAgentMetadata from JSON response
553-
# agent_metadata is base64-encoded binary data, not msgpack
554-
tp_data.pop("agent_metadata", None)
555-
metadata = NixlAgentMetadata(
556-
agent_metadata=base64.b64decode(metadata_bytes), **tp_data)
557-
558-
# Register Remote agent.
559-
logger.debug("About to register remote agent for engine %s",
560-
metadata.engine_id)
561-
pre_register = time.perf_counter()
562-
self.add_remote_agent(metadata, remote_tp_rank=p_remote_rank)
563-
agent_time = time.perf_counter()
564-
logger.debug("Finished registering remote agent for engine %s",
565-
metadata.engine_id)
566-
567-
logger.debug("NIXL handshake: get metadata took: %s",
568-
pre_register - start_time)
569-
logger.debug("NIXL handshake: add agent took: %s",
570-
agent_time - pre_register)
571-
else:
572-
# If metadata_bytes is None, it means the remote agent
573-
# is not using NIXL, so we can skip the handshake.
574-
logger.warning(
575-
"Received None metadata from %s:%s, skipping NIXL handshake",
576-
host, port)
577-
raise RuntimeError("Remote server does not support NIXL")
510+
# Get dp_rank 0 data (standard for disaggregated prefill-decode)
511+
dp_data = res.get("0", {})
512+
if not dp_data:
513+
raise RuntimeError("No metadata found for dp_rank 0")
514+
515+
remote_tp_size = len(dp_data.keys())
516+
rank0_data = dp_data.get("0", {})
517+
if not rank0_data:
518+
raise RuntimeError("No metadata found for remote rank 0")
519+
520+
metadata_bytes = rank0_data.get("agent_metadata", None)
521+
if metadata_bytes is None:
522+
raise RuntimeError("No agent metadata found for remote rank 0")
523+
524+
rank0_data_copy = rank0_data.copy()
525+
rank0_data_copy.pop("agent_metadata", None)
526+
rank0_metadata = NixlAgentMetadata(
527+
agent_metadata=base64.b64decode(metadata_bytes), **rank0_data_copy)
528+
529+
pre_register = time.perf_counter()
530+
self.add_remote_agent(rank0_metadata, remote_tp_rank=0)
531+
tp_ratio = self._tp_size[self.engine_id] // remote_tp_size
532+
p_remote_rank = self.tp_rank // tp_ratio
533+
534+
if p_remote_rank > 0:
535+
p_rank_data = dp_data.get(str(p_remote_rank), {})
536+
if p_rank_data:
537+
p_metadata_bytes = p_rank_data.get("agent_metadata", None)
538+
if p_metadata_bytes:
539+
p_rank_data_copy = p_rank_data.copy()
540+
p_rank_data_copy.pop("agent_metadata", None)
541+
p_metadata = NixlAgentMetadata(
542+
agent_metadata=base64.b64decode(p_metadata_bytes),
543+
**p_rank_data_copy)
544+
self.add_remote_agent(p_metadata, remote_tp_rank=p_remote_rank)
545+
546+
agent_time = time.perf_counter()
547+
548+
logger.debug("Finished registering remote agent for engine %s",
549+
rank0_metadata.engine_id)
550+
logger.debug("NIXL handshake: get metadata took: %s",
551+
pre_register - start_time)
552+
logger.debug("NIXL handshake: add agent took: %s",
553+
agent_time - pre_register)
578554

579555
logger.debug("NIXL handshake method completed for %s:%s", host, port)
580556

0 commit comments

Comments
 (0)