1
1
# SPDX-License-Identifier: Apache-2.0
2
+ < << << << HEAD
2
3
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
import asyncio
5
+ == == == =
6
+ import base64
7
+ import json
8
+ > >> >> >> a1eaf5a5e (attempt to background agent registration )
4
9
import math
5
10
import threading
6
11
import time
7
12
import uuid
8
13
from collections import defaultdict
9
14
from dataclasses import dataclass
10
15
from typing import TYPE_CHECKING , Any , Optional
11
- import json
12
- import base64
13
- import aiohttp
14
- import msgspec
16
+ from urllib .request import Request , urlopen
17
+ from urllib .error import URLError , HTTPError
18
+ from urllib .parse import urljoin
15
19
import torch
16
20
17
21
from vllm import envs
@@ -379,6 +383,12 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
379
383
# [req_id -> list[handle]]
380
384
self ._recving_transfers = defaultdict [str , list [Transfer ]](list )
381
385
386
+
387
+ # Pending requests waiting for handshake completion
388
+ # [engine_id -> list[(req_id, meta)]]
389
+ self ._pending_requests : dict [str , list [tuple [str , ReqMeta ]]] = {}
390
+
391
+
382
392
# Complete transfer tracker. Used by the rank 0 to track finished
383
393
# transactions on ranks 1 to N-1.
384
394
# [req_id -> count]
@@ -387,8 +397,11 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
387
397
self ._done_sending_count : defaultdict [str ,
388
398
int ] = defaultdict (lambda : 0 )
389
399
390
- # Background thread for establishing new connections.
391
- self ._nixl_handshake_listener_t : Optional [threading .Thread ] = None
400
+ # Background handshake threads for remote engines
401
+ self ._handshake_threads : dict [str , threading .Thread ] = {}
402
+
403
+ # Thread results for handshake completion tracking
404
+ self ._handshake_results : dict [str , bool ] = {}
392
405
393
406
self .vllm_config = vllm_config
394
407
self .block_size = vllm_config .cache_config .block_size
@@ -417,19 +430,45 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
417
430
# finish reading before safely freeing the blocks.
418
431
self .consumer_notification_counts_by_req = defaultdict [str , int ](int )
419
432
420
- async def _nixl_handshake (self , host : str , port : int ):
433
+ def _run_handshake_in_thread (self , engine_id : str , host : str , port : int ):
434
+ """Run handshake in background thread."""
435
+
436
+ def handshake_worker ():
437
+ logger .debug ("Starting handshake worker for engine %s" , engine_id )
438
+ try :
439
+ self ._nixl_handshake (host , port )
440
+ self ._handshake_results [engine_id ] = True
441
+ logger .debug ("Handshake succeeded for engine %s" , engine_id )
442
+ except Exception as e :
443
+ self ._handshake_results [engine_id ] = False
444
+ logger .warning ("Handshake failed for engine %s: %s" , engine_id , e )
445
+ finally :
446
+ logger .debug ("Handshake worker finished for engine %s" , engine_id )
447
+
448
+ thread = threading .Thread (target = handshake_worker , daemon = True )
449
+ thread ._start_time = time .time () # track when thread started
450
+ self ._handshake_threads [engine_id ] = thread
451
+ thread .start ()
452
+ return thread
453
+
454
+ def _nixl_handshake (self , host : str , port : int ):
421
455
"""Do a NIXL handshake with a remote instance."""
422
456
423
457
start_time = time .perf_counter ()
424
458
425
- url = build_uri (host , port , path = "get_kv_connector_metadata" )
459
+ # TODO: make the scheme dynamic, and/or implement https on both sides.
460
+ url = build_uri ("http" , host , port , path = "get_kv_connector_metadata" )
426
461
logger .debug ("Querying metadata on path: %s" , url )
427
462
428
- timeout = aiohttp .ClientTimeout (total = 30.0 )
429
- async with aiohttp .ClientSession (timeout = timeout ) as session :
430
- async with session .get (url ) as response :
431
- res = await response .json ()
463
+ try :
464
+ req = Request (url )
465
+ with urlopen (req , timeout = 5.0 ) as response :
466
+ response_data = response .read ().decode ('utf-8' )
467
+ res = json .loads (response_data )
432
468
logger .debug ("NIXL handshake response: %s" , res )
469
+ except (URLError , HTTPError ) as e :
470
+ logger .error ("Failed to fetch metadata from %s: %s" , url , e )
471
+ raise
433
472
434
473
435
474
remote_tp_size = len (res .keys ())
@@ -460,13 +499,18 @@ async def _nixl_handshake(self, host: str, port: int):
460
499
)
461
500
462
501
# Register Remote agent.
463
- self .add_remote_agent (metadata , p_remote_rank )
464
- setup_agent_time = time .perf_counter ()
465
-
466
- logger .debug ("NIXL handshake: get metadata took: %s" ,
467
- time .perf_counter () - start_time )
468
- logger .debug ("NIXL handshake: add agent took: %s" ,
469
- setup_agent_time - (time .perf_counter () - start_time ))
502
+ logger .debug ("About to register remote agent for engine %s" ,
503
+ metadata .engine_id )
504
+ pre_register = time .perf_counter ()
505
+ self .add_remote_agent (metadata , remote_tp_rank = p_remote_rank )
506
+ agent_time = time .perf_counter ()
507
+ logger .debug ("Finished registering remote agent for engine %s" ,
508
+ metadata .engine_id )
509
+
510
+ logger .debug ("NIXL handshake: get metadata took: %s" ,
511
+ pre_register - start_time )
512
+ logger .debug ("NIXL handshake: add agent took: %s" ,
513
+ agent_time - pre_register )
470
514
else :
471
515
# If metadata_bytes is None, it means the remote agent
472
516
# is not using NIXL, so we can skip the handshake.
@@ -755,6 +799,7 @@ def get_finished(self) -> tuple[set[str], set[str]]:
755
799
"Rank %s, get_finished: %s requests done sending "
756
800
"and %s requests done recving" , self .tp_rank ,
757
801
len (done_sending ), len (done_recving ))
802
+
758
803
759
804
if self .world_size == 1 :
760
805
return done_sending , done_recving
@@ -843,39 +888,128 @@ def _pop_done_transfers(
843
888
xfer_state )
844
889
return done_req_ids
845
890
891
+ def _process_completed_handshakes (self ):
892
+ """Process completed handshakes and mark remote agents as ready."""
893
+
894
+ # debug: log current state
895
+ if self ._handshake_threads :
896
+ logger .debug ("Processing handshakes: %d active threads, %d pending" ,
897
+ len (self ._handshake_threads ),
898
+ sum (len (reqs ) for reqs in self ._pending_requests .values ()))
899
+
900
+ completed_engines = []
901
+ for engine_id , thread in list (self ._handshake_threads .items ()):
902
+ logger .debug ("Checking handshake thread for engine %s: alive=%s" ,
903
+ engine_id , thread .is_alive ())
904
+
905
+ # check for timeout (threads running > 30 seconds)
906
+ thread_age = time .time () - getattr (thread , '_start_time' , time .time ())
907
+ if thread .is_alive () and thread_age > 30.0 :
908
+ logger .warning ("Handshake thread for %s running %.1fs (hung?)" ,
909
+ engine_id , thread_age )
910
+
911
+ if not thread .is_alive ():
912
+ logger .debug ("Handshake completed for engine %s" , engine_id )
913
+ completed_engines .append (engine_id )
914
+
915
+ success = self ._handshake_results .get (engine_id , False )
916
+ logger .debug ("Handshake result for engine %s: success=%s" ,
917
+ engine_id , success )
918
+ if not success :
919
+ logger .warning ("Handshake failed for engine %s" , engine_id )
920
+ continue
921
+
922
+ logger .debug ("Handshake succeeded for engine %s" , engine_id )
923
+ if engine_id in self ._pending_requests :
924
+ pending_reqs = self ._pending_requests [engine_id ]
925
+ logger .debug (
926
+ "Handshake completed for %s, clearing %d pending requests "
927
+ "(will retry naturally on next start_load_kv)" ,
928
+ engine_id , len (pending_reqs ))
929
+
930
+ # clear pending requests - they'll be retried naturally
931
+ # by the event loop on the next start_load_kv() call
932
+ del self ._pending_requests [engine_id ]
933
+
934
+ for engine_id in completed_engines :
935
+ logger .debug ("Cleaning up handshake thread for engine %s" ,
936
+ engine_id )
937
+ del self ._handshake_threads [engine_id ]
938
+ if engine_id in self ._handshake_results :
939
+ del self ._handshake_results [engine_id ]
940
+
941
+ def _is_request_pending_handshake (self , req_id : str ) -> bool :
942
+ """Check if request is still pending handshake completion."""
943
+ for engine_requests in self ._pending_requests .values ():
944
+ for pending_req_id , _ in engine_requests :
945
+ if pending_req_id == req_id :
946
+ return True
947
+ return False
948
+
846
949
def start_load_kv (self , metadata : NixlConnectorMetadata ):
847
950
"""
848
951
Start loading by triggering non-blocking nixl_xfer.
849
952
We check for these trnxs to complete in each step().
850
953
"""
954
+ logger .debug ("start_load_kv called with %d requests" , len (metadata .requests ))
851
955
for req_id , meta in metadata .requests .items ():
956
+ if (req_id in self ._recving_transfers or
957
+ self ._is_request_pending_handshake (req_id )):
958
+ logger .debug (
959
+ "Request %s already being processed, skipping" , req_id )
960
+ continue
961
+
852
962
logger .debug (
853
963
"start_load_kv for request %s from remote engine %s. "
854
964
"Num local_block_ids: %s. Num remote_block_ids: %s. " , req_id ,
855
965
meta .remote_engine_id , len (meta .local_block_ids ),
856
966
len (meta .remote_block_ids ))
967
+
968
+
969
+ if meta .remote_engine_id not in self ._remote_agents :
970
+ logger .debug (
971
+ "Remote engine %s not registered for request %s, "
972
+ "starting handshake and deferring transfer" ,
973
+ meta .remote_engine_id , req_id )
974
+
975
+ if meta .remote_engine_id not in self ._handshake_threads :
976
+ logger .debug (
977
+ "Starting handshake thread for remote engine %s" ,
978
+ meta .remote_engine_id )
979
+ self ._run_handshake_in_thread (
980
+ meta .remote_engine_id , meta .remote_host ,
981
+ meta .remote_port )
982
+ else :
983
+ logger .debug (
984
+ "Handshake thread already exists for remote engine %s" ,
985
+ meta .remote_engine_id )
986
+
987
+ if meta .remote_engine_id not in self ._pending_requests :
988
+ self ._pending_requests [meta .remote_engine_id ] = []
989
+ self ._pending_requests [meta .remote_engine_id ].append (
990
+ (req_id , meta ))
991
+
992
+ logger .debug (
993
+ "Request %s marked as pending handshake for engine %s" ,
994
+ req_id , meta .remote_engine_id )
995
+ continue
996
+
997
+ logger .debug ("Remote agent available for %s, calling _read_blocks" ,
998
+ meta .remote_engine_id )
857
999
self ._read_blocks (
858
1000
request_id = req_id ,
859
1001
dst_engine_id = meta .remote_engine_id ,
860
1002
local_block_ids = meta .local_block_ids ,
861
1003
remote_block_ids = meta .remote_block_ids ,
862
- remote_host = meta .remote_host ,
863
- remote_port = meta .remote_port ,
864
1004
)
865
1005
866
1006
def _read_blocks (
867
1007
self ,
868
1008
local_block_ids : list [int ],
869
1009
remote_block_ids : list [int ],
870
- remote_host : str ,
871
- remote_port : int ,
872
1010
dst_engine_id : str ,
873
1011
request_id : str ,
874
1012
):
875
- # NOTE(rob): this takes ~2s. We need to get this off the hotpath.
876
- if dst_engine_id not in self ._remote_agents :
877
- asyncio .run (self ._nixl_handshake (remote_host , remote_port ))
878
-
879
1013
# NOTE(rob): having the staging blocks be on the READER side is
880
1014
# not going to work well (since we will have to call rearrange tensors).
881
1015
# after we detect the txn is complete (which means we cannot make the
0 commit comments