@@ -79,15 +79,16 @@ class ReqMeta:
79
79
class NixlConnectorMetadata (KVConnectorMetadata ):
80
80
81
81
def __init__ (self ):
82
- self .requests : dict [ReqId , ReqMeta ] = {}
82
+ self .reqs_to_recv : dict [ReqId , ReqMeta ] = {}
83
+ self .reqs_to_send : dict [ReqId , float ] = {}
83
84
84
85
def add_new_req (
85
86
self ,
86
87
request_id : ReqId ,
87
88
local_block_ids : list [int ],
88
89
kv_transfer_params : dict [str , Any ],
89
90
):
90
- self .requests [request_id ] = ReqMeta (
91
+ self .reqs_to_recv [request_id ] = ReqMeta (
91
92
local_block_ids = local_block_ids ,
92
93
remote_block_ids = kv_transfer_params ["remote_block_ids" ],
93
94
remote_engine_id = kv_transfer_params ["remote_engine_id" ],
@@ -194,10 +195,12 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
194
195
vllm_config .parallel_config .tensor_parallel_size )
195
196
logger .info ("Initializing NIXL Scheduler %s" , engine_id )
196
197
197
- # Requests that need to start recv.
198
+ # Requests that need to start recv/send .
198
199
# New requests are added by update_state_after_alloc in
199
200
# the scheduler. Used to make metadata passed to Worker.
200
201
self ._reqs_need_recv : dict [ReqId , tuple [Request , list [int ]]] = {}
202
+ # Reqs to send and their expiration time
203
+ self ._reqs_need_send : dict [ReqId , float ] = {}
201
204
202
205
def get_num_new_matched_tokens (
203
206
self , request : "Request" ,
@@ -284,6 +287,9 @@ def build_connector_meta(
284
287
# Clear the list once workers start the transfers
285
288
self ._reqs_need_recv .clear ()
286
289
290
+ meta .reqs_to_send = self ._reqs_need_send
291
+ self ._reqs_need_send = {}
292
+
287
293
return meta
288
294
289
295
def request_finished (
@@ -325,6 +331,11 @@ def request_finished(
325
331
# If prompt < block_size, no xfer so free blocks immediately.
326
332
delay_free_blocks = len (computed_block_ids ) > 0
327
333
334
+ if delay_free_blocks :
335
+ # Prefill request on remote. It will be read from D upon completion
336
+ self ._reqs_need_send [request .request_id ] = time .perf_counter (
337
+ ) + envs .VLLM_NIXL_ABORT_REQUEST_TIMEOUT
338
+
328
339
return delay_free_blocks , dict (
329
340
do_remote_prefill = True ,
330
341
do_remote_decode = False ,
@@ -394,6 +405,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
394
405
# In progress transfers.
395
406
# [req_id -> list[handle]]
396
407
self ._recving_transfers = defaultdict [ReqId , list [Transfer ]](list )
408
+ # Track the expiration time of requests that are waiting to be sent.
409
+ self ._reqs_to_send : dict [ReqId , float ] = {}
397
410
398
411
# Complete transfer tracker. Used by the rank 0 to track finished
399
412
# transactions on ranks 1 to N-1.
@@ -826,6 +839,16 @@ def get_finished(self) -> tuple[set[str], set[str]]:
826
839
"and %s requests done recving" , self .tp_rank ,
827
840
len (done_sending ), len (done_recving ))
828
841
842
+ # Handle timeout to avoid stranding blocks on remote.
843
+ now = time .perf_counter ()
844
+ while self ._reqs_to_send :
845
+ req_id , expires = next (iter (self ._reqs_to_send .items ()))
846
+ # Sorted dict, oldest requests are put first so we can exit early.
847
+ if now < expires :
848
+ break
849
+ del self ._reqs_to_send [req_id ]
850
+ done_sending .add (req_id )
851
+
829
852
if self .world_size == 1 :
830
853
return done_sending , done_recving
831
854
@@ -857,7 +880,7 @@ def get_finished(self) -> tuple[set[str], set[str]]:
857
880
858
881
all_done_sending : set [str ] = set ()
859
882
for req_id in list (self ._done_sending_count .keys ()):
860
- if self ._done_sending_count [req_id ] = = self .world_size :
883
+ if self ._done_sending_count [req_id ] > = self .world_size :
861
884
del self ._done_sending_count [req_id ]
862
885
all_done_sending .add (req_id )
863
886
@@ -887,6 +910,7 @@ def _get_new_notifs(self) -> set[str]:
887
910
tp_ratio ):
888
911
notified_req_ids .add (req_id )
889
912
del self .consumer_notification_counts_by_req [req_id ]
913
+ del self ._reqs_to_send [req_id ]
890
914
return notified_req_ids
891
915
892
916
def _pop_done_transfers (
@@ -921,7 +945,7 @@ def start_load_kv(self, metadata: NixlConnectorMetadata):
921
945
Start loading by triggering non-blocking nixl_xfer.
922
946
We check for these trnxs to complete in each step().
923
947
"""
924
- for req_id , meta in metadata .requests .items ():
948
+ for req_id , meta in metadata .reqs_to_recv .items ():
925
949
remote_engine_id = meta .remote_engine_id
926
950
logger .debug (
927
951
"start_load_kv for request %s from remote engine %s. "
@@ -943,6 +967,9 @@ def start_load_kv(self, metadata: NixlConnectorMetadata):
943
967
while not self ._ready_requests .empty ():
944
968
self ._read_blocks_for_req (* self ._ready_requests .get_nowait ())
945
969
970
+ # Add to requests that are waiting to be read and track expiration.
971
+ self ._reqs_to_send .update (metadata .reqs_to_send )
972
+
946
973
def _read_blocks_for_req (self , req_id : str , meta : ReqMeta ):
947
974
logger .debug (
948
975
"Remote agent %s available, calling _read_blocks for req %s" ,
0 commit comments