12
12
from dataclasses import dataclass
13
13
from typing import TYPE_CHECKING , Any , Optional
14
14
from urllib .error import HTTPError , URLError
15
- from urllib .request import Request , urlopen
15
+ from urllib .request import Request as URLRequest
16
+ from urllib .request import urlopen
16
17
17
18
import torch
18
19
@@ -398,7 +399,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
398
399
399
400
# Background handshake threads for remote engines
400
401
self ._executor = ThreadPoolExecutor (
401
- max_workers = 4 , thread_name_prefix = "nixl-handshake" )
402
+ max_workers = 1 , thread_name_prefix = "nixl-handshake" )
402
403
# Thread results for handshake completion tracking
403
404
self ._handshake_futures : dict [str , Future ] = {}
404
405
self ._pending_requests : dict [str , list [tuple [str , ReqMeta ]]] = {}
@@ -491,13 +492,17 @@ def _nixl_handshake(self, host: str, port: int):
491
492
492
493
# Use the new endpoint scheme to filter by dp_rank and tp_rank
493
494
# Default to dp_rank 0 and use current tp_rank for optimal filtering
494
- url = build_uri ("http" , host , port , path = f"get_kv_connector_metadata/0/{ self .tp_rank } " )
495
+ url = build_uri ("http" ,
496
+ host ,
497
+ port ,
498
+ path = f"get_kv_connector_metadata/0/{ self .tp_rank } " )
495
499
logger .debug ("Querying metadata on path: %s" , url )
496
500
497
501
try :
498
- req = Request (url )
502
+ req = URLRequest (url )
499
503
logger .debug ("About to send HTTP request to %s" , url )
500
- with urlopen (req , timeout = 5.0 ) as response :
504
+ with urlopen (req ,
505
+ timeout = envs .VLLM_NIXL_HANDSHAKE_TIMEOUT ) as response :
501
506
logger .debug ("Received HTTP response from %s" , url )
502
507
response_data = response .read ().decode ('utf-8' )
503
508
res = json .loads (response_data )
@@ -507,27 +512,36 @@ def _nixl_handshake(self, host: str, port: int):
507
512
raise
508
513
509
514
if res is None :
510
- logger .warning ("Remote server returned None metadata, skipping handshake" )
515
+ logger .warning (
516
+ "Remote server returned None metadata, skipping handshake" )
511
517
raise RuntimeError ("Remote server returned None metadata" )
512
518
513
- # With filtered response from new endpoint, we get: {dp_rank: {tp_rank: metadata}}
514
- # Since we filtered by dp_rank=0 and tp_rank=self.tp_rank, extract directly
519
+ # With filtered response from new endpoint, we get:
520
+ # {dp_rank: {tp_rank: metadata}}
521
+ # Since we filtered by dp_rank=0 and tp_rank=self.tp_rank,
522
+ # extract directly.
515
523
if "0" in res and str (self .tp_rank ) in res ["0" ]:
516
524
tp_data = res ["0" ][str (self .tp_rank )]
517
525
metadata_bytes = tp_data .get ("agent_metadata" , None )
518
- p_remote_rank = self .tp_rank # Use current tp_rank for filtered response
526
+ # use current tp_rank for filtered response
527
+ p_remote_rank = self .tp_rank
519
528
else :
520
529
# Fallback to unfiltered endpoint for heterogeneous TP cases
521
- url_fallback = build_uri ("http" , host , port , path = "get_kv_connector_metadata" )
522
- logger .debug ("Using fallback unfiltered endpoint: %s" , url_fallback )
523
- req = Request (url_fallback )
524
- with urlopen (req , timeout = 5.0 ) as response :
530
+ url_fallback = build_uri ("http" ,
531
+ host ,
532
+ port ,
533
+ path = "get_kv_connector_metadata" )
534
+ logger .debug ("Using fallback unfiltered endpoint: %s" ,
535
+ url_fallback )
536
+ req = URLRequest (url_fallback )
537
+ with urlopen (req ,
538
+ timeout = envs .VLLM_NIXL_HANDSHAKE_TIMEOUT ) as response :
525
539
response_data = response .read ().decode ('utf-8' )
526
540
res = json .loads (response_data )
527
-
541
+
528
542
dp_data = res .get ("0" , {})
529
543
remote_tp_size = len (dp_data .keys ()) if dp_data else 1
530
-
544
+
531
545
# Handle heterogeneous TP mapping
532
546
tp_ratio = self ._tp_size [self .engine_id ] // remote_tp_size
533
547
p_remote_rank = self .tp_rank // tp_ratio
@@ -952,8 +966,8 @@ def _process_ready_requests(self):
952
966
while True :
953
967
try :
954
968
req_id , meta = self ._ready_requests .get_nowait ()
955
- logger .debug ("Processing ready request %s for engine %s" ,
956
- req_id , meta .remote_engine_id )
969
+ logger .debug ("Processing ready request %s for engine %s" ,
970
+ req_id , meta .remote_engine_id )
957
971
self ._read_blocks (
958
972
request_id = req_id ,
959
973
dst_engine_id = meta .remote_engine_id ,
@@ -963,7 +977,7 @@ def _process_ready_requests(self):
963
977
processed_count += 1
964
978
except queue .Empty :
965
979
break
966
-
980
+
967
981
if processed_count > 0 :
968
982
logger .debug ("Processed %d ready requests" , processed_count )
969
983
@@ -972,7 +986,7 @@ def start_load_kv(self, metadata: NixlConnectorMetadata):
972
986
Start loading by triggering non-blocking nixl_xfer.
973
987
We check for these trnxs to complete in each step().
974
988
"""
975
-
989
+
976
990
for req_id , meta in metadata .requests .items ():
977
991
logger .debug (
978
992
"start_load_kv for request %s from remote engine %s. "
0 commit comments