diff --git a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py index 34543cc05c..6f54cb53fe 100644 --- a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py +++ b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py @@ -4,7 +4,9 @@ import threading import time from collections.abc import Iterator +from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass +from enum import Enum from typing import Any, Optional, Tuple import llm_datadist # type: ignore @@ -38,6 +40,11 @@ } +class LLMDataDistCMgrEvent(Enum): + ReqForMetadata = 0 + ReqForFinished = 1 + + class LLMDataDistCMgrAgentMetadata(msgspec.Struct): super_pod_id: str server_id: str @@ -298,6 +305,8 @@ def __init__(self, vllm_config: VllmConfig): self.local_agent_metadata: Optional[ LLMDataDistCMgrAgentMetadata] = None self.vllm_config = vllm_config + self.executor = ThreadPoolExecutor(1) + self.thread_lock = threading.Lock() self.llm_datadist_role = None self.llm_datadist_remote_role = None @@ -343,17 +352,30 @@ def listen_for_agent_metadata_req(self, event: threading.Event): event.set() while True: identity, _, msg = sock.recv_multipart() - decode_msg = msg_decoder.decode(msg) - if "cluster_id" in decode_msg: - decode_msg = LLMDataDistCMgrAgentMetadata(**decode_msg) - logger.info( - f"LLMDataDistCMgrConnectorWorker: Receive message from cluster {decode_msg.cluster_id}" - ) - sock.send_multipart((identity, b"", msg_to_send)) - self.add_remote_agent(decode_msg) + event_msg, decode_msg = msg_decoder.decode(msg) + event_msg = LLMDataDistCMgrEvent(event_msg) + if event_msg == LLMDataDistCMgrEvent.ReqForMetadata: + if "cluster_id" in decode_msg: + decode_msg = LLMDataDistCMgrAgentMetadata(**decode_msg) + logger.info( + f"LLMDataDistCMgrConnectorWorker: Receive message from cluster {decode_msg.cluster_id}" + ) + sock.send_multipart((identity, b"", msg_to_send)) + self.add_remote_agent(decode_msg) + else: + logger.warning( + f"LLMDataDistCMgrConnectorWorker: receiving unrecognized data {decode_msg}" + ) + elif event_msg == LLMDataDistCMgrEvent.ReqForFinished: + finished_req_id = decode_msg[0] + with self.thread_lock: + logger.debug( + f"LLMDataDistCMgrConnectorWorker: Receiving request {finished_req_id} finished" + ) + self.finished_reqs.add(finished_req_id) else: - logger.warning( - f"LLMDataDistCMgrConnectorWorker: receiving unrecognized data {decode_msg}" + raise RuntimeError( + f"LLMDataDistCMgrConnectorWorker: Receiving unexpected request event {event_msg} from remote !" ) def init_llm_datadist(self): @@ -517,13 +539,27 @@ def register_kv_caches(self, kv_caches: dict[str, Tuple[torch.Tensor]]): self.ready_event.wait() def start_load_kv(self, metadata: LLMDataDistCMgrConnectorMetadata): + futures = [] for req_id, meta in metadata.requests.items(): logger.debug(f"Start to transmit {req_id}") - self._read_blocks(meta.local_block_ids, - meta.remote_block_ids, meta.remote_host, - int(meta.remote_port), meta.engine_id, req_id, - meta.remote_tp_size) - self.finished_reqs.add(req_id) + future = self.executor.submit( + self._read_blocks, + local_block_ids=meta.local_block_ids, + remote_block_ids=meta.remote_block_ids, + remote_ip=meta.remote_host, + remote_port=int(meta.remote_port), + remote_engine_id=meta.engine_id, + request_id=req_id, + remote_tp_size=meta.remote_tp_size, + ) + futures.append(future) + + def handle_exception(future): + if future.exception(): + logger.error(f"KV transfer task failed: {future.exception()}") + + for future in futures: + future.add_done_callback(handle_exception) def add_remote_agent(self, metadata: LLMDataDistCMgrAgentMetadata) -> int: assert self.local_agent_metadata is not None @@ -673,7 +709,8 @@ def connect_to_remote_agent(self, host: str, port: int) -> int: url = f"tcp://{host}:{port}" logger.debug(f"Querying metadata from url: {url}") msg_encoder = msgspec.msgpack.Encoder() - msg_send = msg_encoder.encode(self.local_agent_metadata) + msg_send = msg_encoder.encode( + [LLMDataDistCMgrEvent.ReqForMetadata, self.local_agent_metadata]) with zmq_ctx(zmq.REQ, url) as sock: # type: ignore[attr-defined] logger.info("Try request remote metadata from socket......") sock.send(msg_send) @@ -685,6 +722,22 @@ def connect_to_remote_agent(self, host: str, port: int) -> int: cluster_id = self.add_remote_agent(metadata) return cluster_id + def send_finsih_to_remote(self, host: str, port: int, request_id): + url = f"tcp://{host}:{port}" + logger.debug(f"Sending finished to remote: {url}") + msg_encoder = msgspec.msgpack.Encoder() + msg_send = msg_encoder.encode( + [LLMDataDistCMgrEvent.ReqForFinished, [request_id]]) + with zmq_ctx(zmq.REQ, url) as sock: # type: ignore[attr-defined] + try: + sock.send(msg_send) + logger.debug( + f"Request id {request_id} finished message send to remote {url}" + ) + except Exception as e: + logger.error( + f"Failed to send reqest_id {request_id} to prefill: {e}") + def _read_blocks( self, local_block_ids: list[int], @@ -750,14 +803,19 @@ def _read_blocks( raise RuntimeError( "LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status" ) + self.send_finsih_to_remote(remote_ip, remote_port + tp_offset, + request_id) + with self.thread_lock: + self.finished_reqs.add(request_id) def get_finished( self, finished_req_ids: set[str] ) -> tuple[Optional[set[str]], Optional[set[str]]]: """Get the finished recving and sending requuests.""" import copy - req_ids_to_ret = copy.deepcopy(self.finished_reqs) - self.finished_reqs.clear() + with self.thread_lock: + req_ids_to_ret = copy.deepcopy(self.finished_reqs) + self.finished_reqs.clear() if self.llm_datadist_role == LLMRole.PROMPT: return req_ids_to_ret, None else: