@@ -490,23 +490,14 @@ def _nixl_handshake(self, host: str, port: int):
490
490
start_time = time .perf_counter ()
491
491
logger .debug ("Starting NIXL handshake with %s:%s" , host , port )
492
492
493
- # Use the new endpoint scheme to filter by dp_rank and tp_rank
494
- # Default to dp_rank 0 and use current tp_rank for optimal filtering
495
- url = build_uri ("http" ,
496
- host ,
497
- port ,
498
- path = f"get_kv_connector_metadata/0/{ self .tp_rank } " )
499
- logger .debug ("Querying metadata on path: %s" , url )
493
+ url = build_uri ("http" , host , port , path = "get_kv_connector_metadata" )
500
494
501
495
try :
502
496
req = URLRequest (url )
503
- logger .debug ("About to send HTTP request to %s" , url )
504
497
with urlopen (req ,
505
498
timeout = envs .VLLM_NIXL_HANDSHAKE_TIMEOUT ) as response :
506
- logger .debug ("Received HTTP response from %s" , url )
507
499
response_data = response .read ().decode ('utf-8' )
508
500
res = json .loads (response_data )
509
- logger .debug ("NIXL handshake response: %s" , res )
510
501
except (URLError , HTTPError ) as e :
511
502
logger .error ("Failed to fetch metadata from %s: %s" , url , e )
512
503
raise
@@ -516,65 +507,50 @@ def _nixl_handshake(self, host: str, port: int):
516
507
"Remote server returned None metadata, skipping handshake" )
517
508
raise RuntimeError ("Remote server returned None metadata" )
518
509
519
- # With filtered response from new endpoint, we get:
520
- # {dp_rank: {tp_rank: metadata}}
521
- # Since we filtered by dp_rank=0 and tp_rank=self.tp_rank,
522
- # extract directly.
523
- if "0" in res and str (self .tp_rank ) in res ["0" ]:
524
- tp_data = res ["0" ][str (self .tp_rank )]
525
- metadata_bytes = tp_data .get ("agent_metadata" , None )
526
- # use current tp_rank for filtered response
527
- p_remote_rank = self .tp_rank
528
- else :
529
- # Fallback to unfiltered endpoint for heterogeneous TP cases
530
- url_fallback = build_uri ("http" ,
531
- host ,
532
- port ,
533
- path = "get_kv_connector_metadata" )
534
- logger .debug ("Using fallback unfiltered endpoint: %s" ,
535
- url_fallback )
536
- req = URLRequest (url_fallback )
537
- with urlopen (req ,
538
- timeout = envs .VLLM_NIXL_HANDSHAKE_TIMEOUT ) as response :
539
- response_data = response .read ().decode ('utf-8' )
540
- res = json .loads (response_data )
541
-
542
- dp_data = res .get ("0" , {})
543
- remote_tp_size = len (dp_data .keys ()) if dp_data else 1
544
-
545
- # Handle heterogeneous TP mapping
546
- tp_ratio = self ._tp_size [self .engine_id ] // remote_tp_size
547
- p_remote_rank = self .tp_rank // tp_ratio
548
- tp_data = dp_data .get (str (p_remote_rank ), {})
549
- metadata_bytes = tp_data .get ("agent_metadata" , None )
550
-
551
- if metadata_bytes is not None :
552
- # Reconstruct NixlAgentMetadata from JSON response
553
- # agent_metadata is base64-encoded binary data, not msgpack
554
- tp_data .pop ("agent_metadata" , None )
555
- metadata = NixlAgentMetadata (
556
- agent_metadata = base64 .b64decode (metadata_bytes ), ** tp_data )
557
-
558
- # Register Remote agent.
559
- logger .debug ("About to register remote agent for engine %s" ,
560
- metadata .engine_id )
561
- pre_register = time .perf_counter ()
562
- self .add_remote_agent (metadata , remote_tp_rank = p_remote_rank )
563
- agent_time = time .perf_counter ()
564
- logger .debug ("Finished registering remote agent for engine %s" ,
565
- metadata .engine_id )
566
-
567
- logger .debug ("NIXL handshake: get metadata took: %s" ,
568
- pre_register - start_time )
569
- logger .debug ("NIXL handshake: add agent took: %s" ,
570
- agent_time - pre_register )
571
- else :
572
- # If metadata_bytes is None, it means the remote agent
573
- # is not using NIXL, so we can skip the handshake.
574
- logger .warning (
575
- "Received None metadata from %s:%s, skipping NIXL handshake" ,
576
- host , port )
577
- raise RuntimeError ("Remote server does not support NIXL" )
510
+ # Get dp_rank 0 data (standard for disaggregated prefill-decode)
511
+ dp_data = res .get ("0" , {})
512
+ if not dp_data :
513
+ raise RuntimeError ("No metadata found for dp_rank 0" )
514
+
515
+ remote_tp_size = len (dp_data .keys ())
516
+ rank0_data = dp_data .get ("0" , {})
517
+ if not rank0_data :
518
+ raise RuntimeError ("No metadata found for remote rank 0" )
519
+
520
+ metadata_bytes = rank0_data .get ("agent_metadata" , None )
521
+ if metadata_bytes is None :
522
+ raise RuntimeError ("No agent metadata found for remote rank 0" )
523
+
524
+ rank0_data_copy = rank0_data .copy ()
525
+ rank0_data_copy .pop ("agent_metadata" , None )
526
+ rank0_metadata = NixlAgentMetadata (
527
+ agent_metadata = base64 .b64decode (metadata_bytes ), ** rank0_data_copy )
528
+
529
+ pre_register = time .perf_counter ()
530
+ self .add_remote_agent (rank0_metadata , remote_tp_rank = 0 )
531
+ tp_ratio = self ._tp_size [self .engine_id ] // remote_tp_size
532
+ p_remote_rank = self .tp_rank // tp_ratio
533
+
534
+ if p_remote_rank > 0 :
535
+ p_rank_data = dp_data .get (str (p_remote_rank ), {})
536
+ if p_rank_data :
537
+ p_metadata_bytes = p_rank_data .get ("agent_metadata" , None )
538
+ if p_metadata_bytes :
539
+ p_rank_data_copy = p_rank_data .copy ()
540
+ p_rank_data_copy .pop ("agent_metadata" , None )
541
+ p_metadata = NixlAgentMetadata (
542
+ agent_metadata = base64 .b64decode (p_metadata_bytes ),
543
+ ** p_rank_data_copy )
544
+ self .add_remote_agent (p_metadata , remote_tp_rank = p_remote_rank )
545
+
546
+ agent_time = time .perf_counter ()
547
+
548
+ logger .debug ("Finished registering remote agent for engine %s" ,
549
+ rank0_metadata .engine_id )
550
+ logger .debug ("NIXL handshake: get metadata took: %s" ,
551
+ pre_register - start_time )
552
+ logger .debug ("NIXL handshake: add agent took: %s" ,
553
+ agent_time - pre_register )
578
554
579
555
logger .debug ("NIXL handshake method completed for %s:%s" , host , port )
580
556
0 commit comments