4
4
import threading
5
5
import time
6
6
from collections .abc import Iterator
7
+ from concurrent .futures import ThreadPoolExecutor
7
8
from dataclasses import dataclass
9
+ from enum import Enum
8
10
from typing import Any , Optional , Tuple
9
11
10
12
import llm_datadist # type: ignore
38
40
}
39
41
40
42
43
+ class LLMDataDistCMgrEvent (Enum ):
44
+ ReqForMetadata = 0
45
+ ReqForFinished = 1
46
+
47
+
41
48
class LLMDataDistCMgrAgentMetadata (msgspec .Struct ):
42
49
super_pod_id : str
43
50
server_id : str
@@ -298,6 +305,8 @@ def __init__(self, vllm_config: VllmConfig):
298
305
self .local_agent_metadata : Optional [
299
306
LLMDataDistCMgrAgentMetadata ] = None
300
307
self .vllm_config = vllm_config
308
+ self .executor = ThreadPoolExecutor (1 )
309
+ self .thread_lock = threading .Lock ()
301
310
302
311
self .llm_datadist_role = None
303
312
self .llm_datadist_remote_role = None
@@ -343,17 +352,30 @@ def listen_for_agent_metadata_req(self, event: threading.Event):
343
352
event .set ()
344
353
while True :
345
354
identity , _ , msg = sock .recv_multipart ()
346
- decode_msg = msg_decoder .decode (msg )
347
- if "cluster_id" in decode_msg :
348
- decode_msg = LLMDataDistCMgrAgentMetadata (** decode_msg )
349
- logger .info (
350
- f"LLMDataDistCMgrConnectorWorker: Receive message from cluster { decode_msg .cluster_id } "
351
- )
352
- sock .send_multipart ((identity , b"" , msg_to_send ))
353
- self .add_remote_agent (decode_msg )
355
+ event_msg , decode_msg = msg_decoder .decode (msg )
356
+ event_msg = LLMDataDistCMgrEvent (event_msg )
357
+ if event_msg == LLMDataDistCMgrEvent .ReqForMetadata :
358
+ if "cluster_id" in decode_msg :
359
+ decode_msg = LLMDataDistCMgrAgentMetadata (** decode_msg )
360
+ logger .info (
361
+ f"LLMDataDistCMgrConnectorWorker: Receive message from cluster { decode_msg .cluster_id } "
362
+ )
363
+ sock .send_multipart ((identity , b"" , msg_to_send ))
364
+ self .add_remote_agent (decode_msg )
365
+ else :
366
+ logger .warning (
367
+ f"LLMDataDistCMgrConnectorWorker: receiving unrecognized data { decode_msg } "
368
+ )
369
+ elif event_msg == LLMDataDistCMgrEvent .ReqForFinished :
370
+ finished_req_id = decode_msg [0 ]
371
+ with self .thread_lock :
372
+ logger .debug (
373
+ f"LLMDataDistCMgrConnectorWorker: Receiving request { finished_req_id } finished"
374
+ )
375
+ self .finished_reqs .add (finished_req_id )
354
376
else :
355
- logger . warning (
356
- f"LLMDataDistCMgrConnectorWorker: receiving unrecognized data { decode_msg } "
377
+ raise RuntimeError (
378
+ f"LLMDataDistCMgrConnectorWorker: Receiving unexpected request event { event_msg } from remote ! "
357
379
)
358
380
359
381
def init_llm_datadist (self ):
@@ -517,13 +539,27 @@ def register_kv_caches(self, kv_caches: dict[str, Tuple[torch.Tensor]]):
517
539
self .ready_event .wait ()
518
540
519
541
def start_load_kv (self , metadata : LLMDataDistCMgrConnectorMetadata ):
542
+ futures = []
520
543
for req_id , meta in metadata .requests .items ():
521
544
logger .debug (f"Start to transmit { req_id } " )
522
- self ._read_blocks (meta .local_block_ids ,
523
- meta .remote_block_ids , meta .remote_host ,
524
- int (meta .remote_port ), meta .engine_id , req_id ,
525
- meta .remote_tp_size )
526
- self .finished_reqs .add (req_id )
545
+ future = self .executor .submit (
546
+ self ._read_blocks ,
547
+ local_block_ids = meta .local_block_ids ,
548
+ remote_block_ids = meta .remote_block_ids ,
549
+ remote_ip = meta .remote_host ,
550
+ remote_port = int (meta .remote_port ),
551
+ remote_engine_id = meta .engine_id ,
552
+ request_id = req_id ,
553
+ remote_tp_size = meta .remote_tp_size ,
554
+ )
555
+ futures .append (future )
556
+
557
+ def handle_exception (future ):
558
+ if future .exception ():
559
+ logger .error (f"KV transfer task failed: { future .exception ()} " )
560
+
561
+ for future in futures :
562
+ future .add_done_callback (handle_exception )
527
563
528
564
def add_remote_agent (self , metadata : LLMDataDistCMgrAgentMetadata ) -> int :
529
565
assert self .local_agent_metadata is not None
@@ -673,7 +709,8 @@ def connect_to_remote_agent(self, host: str, port: int) -> int:
673
709
url = f"tcp://{ host } :{ port } "
674
710
logger .debug (f"Querying metadata from url: { url } " )
675
711
msg_encoder = msgspec .msgpack .Encoder ()
676
- msg_send = msg_encoder .encode (self .local_agent_metadata )
712
+ msg_send = msg_encoder .encode (
713
+ [LLMDataDistCMgrEvent .ReqForMetadata , self .local_agent_metadata ])
677
714
with zmq_ctx (zmq .REQ , url ) as sock : # type: ignore[attr-defined]
678
715
logger .info ("Try request remote metadata from socket......" )
679
716
sock .send (msg_send )
@@ -685,6 +722,22 @@ def connect_to_remote_agent(self, host: str, port: int) -> int:
685
722
cluster_id = self .add_remote_agent (metadata )
686
723
return cluster_id
687
724
725
+ def send_finsih_to_remote (self , host : str , port : int , request_id ):
726
+ url = f"tcp://{ host } :{ port } "
727
+ logger .debug (f"Sending finished to remote: { url } " )
728
+ msg_encoder = msgspec .msgpack .Encoder ()
729
+ msg_send = msg_encoder .encode (
730
+ [LLMDataDistCMgrEvent .ReqForFinished , [request_id ]])
731
+ with zmq_ctx (zmq .REQ , url ) as sock : # type: ignore[attr-defined]
732
+ try :
733
+ sock .send (msg_send )
734
+ logger .debug (
735
+ f"Request id { request_id } finished message send to remote { url } "
736
+ )
737
+ except Exception as e :
738
+ logger .error (
739
+ f"Failed to send reqest_id { request_id } to prefill: { e } " )
740
+
688
741
def _read_blocks (
689
742
self ,
690
743
local_block_ids : list [int ],
@@ -750,14 +803,19 @@ def _read_blocks(
750
803
raise RuntimeError (
751
804
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
752
805
)
806
+ self .send_finsih_to_remote (remote_ip , remote_port + tp_offset ,
807
+ request_id )
808
+ with self .thread_lock :
809
+ self .finished_reqs .add (request_id )
753
810
754
811
def get_finished (
755
812
self , finished_req_ids : set [str ]
756
813
) -> tuple [Optional [set [str ]], Optional [set [str ]]]:
757
814
"""Get the finished recving and sending requuests."""
758
815
import copy
759
- req_ids_to_ret = copy .deepcopy (self .finished_reqs )
760
- self .finished_reqs .clear ()
816
+ with self .thread_lock :
817
+ req_ids_to_ret = copy .deepcopy (self .finished_reqs )
818
+ self .finished_reqs .clear ()
761
819
if self .llm_datadist_role == LLMRole .PROMPT :
762
820
return req_ids_to_ret , None
763
821
else :
0 commit comments