From ea2f5f60c6e987efbed4f874b2a5ed7fa37de26e Mon Sep 17 00:00:00 2001 From: ganyi Date: Thu, 3 Jul 2025 15:53:27 +0800 Subject: [PATCH 1/4] bring asynchronous kv cache pulling phylosophy into vllm-ascend Signed-off-by: ganyi --- .../llmdatadist_c_mgr_connector.py | 88 +++++++++++++++---- 1 file changed, 69 insertions(+), 19 deletions(-) diff --git a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py index 34543cc05c..477866089f 100644 --- a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py +++ b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py @@ -6,11 +6,13 @@ from collections.abc import Iterator from dataclasses import dataclass from typing import Any, Optional, Tuple +from enum import Enum import llm_datadist # type: ignore import msgspec import torch import zmq +from concurrent.futures import ThreadPoolExecutor from llm_datadist import (BlocksCacheKey, CacheDesc, LLMConfig, LLMDataDist, LLMException, LLMRole) from vllm.config import KVTransferConfig, VllmConfig @@ -37,6 +39,10 @@ torch.int32: llm_datadist.DataType.DT_INT32 } +class LLMDataDistCMgrEvent(Enum): + ReqForMetadata = 0 + ReqForFinished = 1 + class LLMDataDistCMgrAgentMetadata(msgspec.Struct): super_pod_id: str @@ -298,6 +304,9 @@ def __init__(self, vllm_config: VllmConfig): self.local_agent_metadata: Optional[ LLMDataDistCMgrAgentMetadata] = None self.vllm_config = vllm_config + self.executor = ThreadPoolExecutor(1) + self.thread_lock = threading.Lock() + self.llm_datadist_role = None self.llm_datadist_remote_role = None @@ -343,18 +352,27 @@ def listen_for_agent_metadata_req(self, event: threading.Event): event.set() while True: identity, _, msg = sock.recv_multipart() - decode_msg = msg_decoder.decode(msg) - if "cluster_id" in decode_msg: - decode_msg = LLMDataDistCMgrAgentMetadata(**decode_msg) - logger.info( - f"LLMDataDistCMgrConnectorWorker: Receive message from cluster {decode_msg.cluster_id}" - ) - sock.send_multipart((identity, b"", msg_to_send)) - self.add_remote_agent(decode_msg) + event_msg, decode_msg = msg_decoder.decode(msg) + if event_msg == LLMDataDistCMgrEvent.ReqForMetadata: + if "cluster_id" in decode_msg: + decode_msg = LLMDataDistCMgrAgentMetadata(**decode_msg) + logger.info( + f"LLMDataDistCMgrConnectorWorker: Receive message from cluster {decode_msg.cluster_id}" + ) + sock.send_multipart((identity, b"", msg_to_send)) + self.add_remote_agent(decode_msg) + else: + logger.warning( + f"LLMDataDistCMgrConnectorWorker: receiving unrecognized data {decode_msg}" + ) + elif event_msg == LLMDataDistCMgrEvent.ReqForFinished: + finished_req_id = decode_msg[0] + with self.thread_lock: + logger.debug(f"LLMDataDistCMgrConnectorWorker: Receiving request {finished_req_id} finished") + self.finished_reqs.add(finished_req_id) else: - logger.warning( - f"LLMDataDistCMgrConnectorWorker: receiving unrecognized data {decode_msg}" - ) + raise RuntimeError(f"LLMDataDistCMgrConnectorWorker: Receiving unexpected request event {event_msg} from remote !") + def init_llm_datadist(self): assert self.local_agent_metadata is not None @@ -517,13 +535,27 @@ def register_kv_caches(self, kv_caches: dict[str, Tuple[torch.Tensor]]): self.ready_event.wait() def start_load_kv(self, metadata: LLMDataDistCMgrConnectorMetadata): + futures = [] for req_id, meta in metadata.requests.items(): logger.debug(f"Start to transmit {req_id}") - self._read_blocks(meta.local_block_ids, - meta.remote_block_ids, meta.remote_host, - int(meta.remote_port), meta.engine_id, req_id, - meta.remote_tp_size) - self.finished_reqs.add(req_id) + future = self.executor.submit( + self._read_blocks, + local_block_ids=meta.local_block_ids, + remote_block_ids=meta.remote_block_ids, + remote_ip=meta.remote_host, + remote_port=int(meta.remote_port), + remote_engine_id=meta.engine_id, + request_id=req_id, + remote_tp_size=meta.remote_tp_size, + ) + futures.append(future) + + def handle_exception(future): + if future.exception(): + logger.error(f"KV transfer task failed: {future.exception()}") + + for future in futures: + future.add_done_callback(handle_exception) def add_remote_agent(self, metadata: LLMDataDistCMgrAgentMetadata) -> int: assert self.local_agent_metadata is not None @@ -673,7 +705,7 @@ def connect_to_remote_agent(self, host: str, port: int) -> int: url = f"tcp://{host}:{port}" logger.debug(f"Querying metadata from url: {url}") msg_encoder = msgspec.msgpack.Encoder() - msg_send = msg_encoder.encode(self.local_agent_metadata) + msg_send = msg_encoder.encode([LLMDataDistCMgrEvent.ReqForMetadata, self.local_agent_metadata]) with zmq_ctx(zmq.REQ, url) as sock: # type: ignore[attr-defined] logger.info("Try request remote metadata from socket......") sock.send(msg_send) @@ -685,6 +717,20 @@ def connect_to_remote_agent(self, host: str, port: int) -> int: cluster_id = self.add_remote_agent(metadata) return cluster_id + + def send_finsih_to_remote(self, host: str, port: int, request_id): + url = f"tcp://{host}:{port}" + logger.debug(f"Sending finished to remote: {url}") + msg_encoder = msgspec.msgpack.Encoder() + msg_send = msg_encoder.encode([LLMDataDistCMgrEvent.ReqForFinished, [request_id]]) + with zmq_ctx(zmq.REQ, url) as sock: + try: + sock.send(msg_send) + logger.debug(f"Request id {request_id} finished message send to remote {url}") + except Exception as e: + logger.error(f"Failed to send reqest_id {request_id} to prefill: {e}") + + def _read_blocks( self, local_block_ids: list[int], @@ -750,14 +796,18 @@ def _read_blocks( raise RuntimeError( "LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status" ) + self.send_finsih_to_remote(remote_ip, remote_port + tp_offset, request_id) + with self.thread_lock: + self.finished_reqs.add(request_id) def get_finished( self, finished_req_ids: set[str] ) -> tuple[Optional[set[str]], Optional[set[str]]]: """Get the finished recving and sending requuests.""" import copy - req_ids_to_ret = copy.deepcopy(self.finished_reqs) - self.finished_reqs.clear() + with self.thread_lock: + req_ids_to_ret = copy.deepcopy(self.finished_reqs) + self.finished_reqs.clear() if self.llm_datadist_role == LLMRole.PROMPT: return req_ids_to_ret, None else: From 82230a8eb731b86354d2815f52f61c41708a7b03 Mon Sep 17 00:00:00 2001 From: ganyi Date: Thu, 3 Jul 2025 15:56:31 +0800 Subject: [PATCH 2/4] fix lint Signed-off-by: ganyi --- .../llmdatadist_c_mgr_connector.py | 29 ++++++++++++------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py index 477866089f..711cab7bef 100644 --- a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py +++ b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py @@ -39,6 +39,7 @@ torch.int32: llm_datadist.DataType.DT_INT32 } + class LLMDataDistCMgrEvent(Enum): ReqForMetadata = 0 ReqForFinished = 1 @@ -307,7 +308,6 @@ def __init__(self, vllm_config: VllmConfig): self.executor = ThreadPoolExecutor(1) self.thread_lock = threading.Lock() - self.llm_datadist_role = None self.llm_datadist_remote_role = None if self.kv_transfer_config.kv_role == "kv_producer": @@ -368,11 +368,14 @@ def listen_for_agent_metadata_req(self, event: threading.Event): elif event_msg == LLMDataDistCMgrEvent.ReqForFinished: finished_req_id = decode_msg[0] with self.thread_lock: - logger.debug(f"LLMDataDistCMgrConnectorWorker: Receiving request {finished_req_id} finished") + logger.debug( + f"LLMDataDistCMgrConnectorWorker: Receiving request {finished_req_id} finished" + ) self.finished_reqs.add(finished_req_id) else: - raise RuntimeError(f"LLMDataDistCMgrConnectorWorker: Receiving unexpected request event {event_msg} from remote !") - + raise RuntimeError( + f"LLMDataDistCMgrConnectorWorker: Receiving unexpected request event {event_msg} from remote !" + ) def init_llm_datadist(self): assert self.local_agent_metadata is not None @@ -705,7 +708,8 @@ def connect_to_remote_agent(self, host: str, port: int) -> int: url = f"tcp://{host}:{port}" logger.debug(f"Querying metadata from url: {url}") msg_encoder = msgspec.msgpack.Encoder() - msg_send = msg_encoder.encode([LLMDataDistCMgrEvent.ReqForMetadata, self.local_agent_metadata]) + msg_send = msg_encoder.encode( + [LLMDataDistCMgrEvent.ReqForMetadata, self.local_agent_metadata]) with zmq_ctx(zmq.REQ, url) as sock: # type: ignore[attr-defined] logger.info("Try request remote metadata from socket......") sock.send(msg_send) @@ -717,19 +721,21 @@ def connect_to_remote_agent(self, host: str, port: int) -> int: cluster_id = self.add_remote_agent(metadata) return cluster_id - def send_finsih_to_remote(self, host: str, port: int, request_id): url = f"tcp://{host}:{port}" logger.debug(f"Sending finished to remote: {url}") msg_encoder = msgspec.msgpack.Encoder() - msg_send = msg_encoder.encode([LLMDataDistCMgrEvent.ReqForFinished, [request_id]]) + msg_send = msg_encoder.encode( + [LLMDataDistCMgrEvent.ReqForFinished, [request_id]]) with zmq_ctx(zmq.REQ, url) as sock: try: sock.send(msg_send) - logger.debug(f"Request id {request_id} finished message send to remote {url}") + logger.debug( + f"Request id {request_id} finished message send to remote {url}" + ) except Exception as e: - logger.error(f"Failed to send reqest_id {request_id} to prefill: {e}") - + logger.error( + f"Failed to send reqest_id {request_id} to prefill: {e}") def _read_blocks( self, @@ -796,7 +802,8 @@ def _read_blocks( raise RuntimeError( "LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status" ) - self.send_finsih_to_remote(remote_ip, remote_port + tp_offset, request_id) + self.send_finsih_to_remote(remote_ip, remote_port + tp_offset, + request_id) with self.thread_lock: self.finished_reqs.add(request_id) From 8c2864cdaaa2d737027f3098783a680d619b682a Mon Sep 17 00:00:00 2001 From: ganyi Date: Sat, 5 Jul 2025 20:02:54 +0800 Subject: [PATCH 3/4] async pullkv Signed-off-by: ganyi --- vllm_ascend/distributed/llmdatadist_c_mgr_connector.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py index 711cab7bef..0a77200c54 100644 --- a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py +++ b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py @@ -4,15 +4,15 @@ import threading import time from collections.abc import Iterator +from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass -from typing import Any, Optional, Tuple from enum import Enum +from typing import Any, Optional, Tuple import llm_datadist # type: ignore import msgspec import torch import zmq -from concurrent.futures import ThreadPoolExecutor from llm_datadist import (BlocksCacheKey, CacheDesc, LLMConfig, LLMDataDist, LLMException, LLMRole) from vllm.config import KVTransferConfig, VllmConfig @@ -353,6 +353,7 @@ def listen_for_agent_metadata_req(self, event: threading.Event): while True: identity, _, msg = sock.recv_multipart() event_msg, decode_msg = msg_decoder.decode(msg) + event_msg = LLMDataDistCMgrEvent(event_msg) if event_msg == LLMDataDistCMgrEvent.ReqForMetadata: if "cluster_id" in decode_msg: decode_msg = LLMDataDistCMgrAgentMetadata(**decode_msg) From 77aa4ddbfb533461f503cd9e6ee1480c3f9b1641 Mon Sep 17 00:00:00 2001 From: ganyi Date: Mon, 7 Jul 2025 16:13:27 +0800 Subject: [PATCH 4/4] fix mypy Signed-off-by: ganyi --- vllm_ascend/distributed/llmdatadist_c_mgr_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py index 0a77200c54..6f54cb53fe 100644 --- a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py +++ b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py @@ -728,7 +728,7 @@ def send_finsih_to_remote(self, host: str, port: int, request_id): msg_encoder = msgspec.msgpack.Encoder() msg_send = msg_encoder.encode( [LLMDataDistCMgrEvent.ReqForFinished, [request_id]]) - with zmq_ctx(zmq.REQ, url) as sock: + with zmq_ctx(zmq.REQ, url) as sock: # type: ignore[attr-defined] try: sock.send(msg_send) logger.debug(