Skip to content

Commit ffd1d9a

Browse files
authored
[0.9.1][Perf] Launch load kv task asynchronously with thread pool. (#1612)
### What this PR does / why we need it? The Implementation of current `LLMDataDistCMgrConnector` connect and pull kv in synchronous manager which may brought drop of latency to the decode task if the consumer node consistently receiving tasks pushed from the remote producer node. The omni_infer launches llmdatadist's `pull_kv` method in another thread which brings the better overlap between pull kv cache and model run. This implementation gains better performance against the synchronous path. In this PR, we bring this asynchronous philosophy into the vllm-ascend, and launch the `link`, `pull_kv` and `request_finished` tasks also in async managers. ### Does this PR introduce _any_ user-facing change? No any user interface change. ### How was this patch tested? --------- Signed-off-by: ganyi <pleaplusone.gy@gmail.com>
1 parent 04e6169 commit ffd1d9a

File tree

1 file changed

+76
-18
lines changed

1 file changed

+76
-18
lines changed

vllm_ascend/distributed/llmdatadist_c_mgr_connector.py

Lines changed: 76 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
import threading
55
import time
66
from collections.abc import Iterator
7+
from concurrent.futures import ThreadPoolExecutor
78
from dataclasses import dataclass
9+
from enum import Enum
810
from typing import Any, Optional, Tuple
911

1012
import llm_datadist # type: ignore
@@ -38,6 +40,11 @@
3840
}
3941

4042

43+
class LLMDataDistCMgrEvent(Enum):
44+
ReqForMetadata = 0
45+
ReqForFinished = 1
46+
47+
4148
class LLMDataDistCMgrAgentMetadata(msgspec.Struct):
4249
super_pod_id: str
4350
server_id: str
@@ -298,6 +305,8 @@ def __init__(self, vllm_config: VllmConfig):
298305
self.local_agent_metadata: Optional[
299306
LLMDataDistCMgrAgentMetadata] = None
300307
self.vllm_config = vllm_config
308+
self.executor = ThreadPoolExecutor(1)
309+
self.thread_lock = threading.Lock()
301310

302311
self.llm_datadist_role = None
303312
self.llm_datadist_remote_role = None
@@ -343,17 +352,30 @@ def listen_for_agent_metadata_req(self, event: threading.Event):
343352
event.set()
344353
while True:
345354
identity, _, msg = sock.recv_multipart()
346-
decode_msg = msg_decoder.decode(msg)
347-
if "cluster_id" in decode_msg:
348-
decode_msg = LLMDataDistCMgrAgentMetadata(**decode_msg)
349-
logger.info(
350-
f"LLMDataDistCMgrConnectorWorker: Receive message from cluster {decode_msg.cluster_id}"
351-
)
352-
sock.send_multipart((identity, b"", msg_to_send))
353-
self.add_remote_agent(decode_msg)
355+
event_msg, decode_msg = msg_decoder.decode(msg)
356+
event_msg = LLMDataDistCMgrEvent(event_msg)
357+
if event_msg == LLMDataDistCMgrEvent.ReqForMetadata:
358+
if "cluster_id" in decode_msg:
359+
decode_msg = LLMDataDistCMgrAgentMetadata(**decode_msg)
360+
logger.info(
361+
f"LLMDataDistCMgrConnectorWorker: Receive message from cluster {decode_msg.cluster_id}"
362+
)
363+
sock.send_multipart((identity, b"", msg_to_send))
364+
self.add_remote_agent(decode_msg)
365+
else:
366+
logger.warning(
367+
f"LLMDataDistCMgrConnectorWorker: receiving unrecognized data {decode_msg}"
368+
)
369+
elif event_msg == LLMDataDistCMgrEvent.ReqForFinished:
370+
finished_req_id = decode_msg[0]
371+
with self.thread_lock:
372+
logger.debug(
373+
f"LLMDataDistCMgrConnectorWorker: Receiving request {finished_req_id} finished"
374+
)
375+
self.finished_reqs.add(finished_req_id)
354376
else:
355-
logger.warning(
356-
f"LLMDataDistCMgrConnectorWorker: receiving unrecognized data {decode_msg}"
377+
raise RuntimeError(
378+
f"LLMDataDistCMgrConnectorWorker: Receiving unexpected request event {event_msg} from remote !"
357379
)
358380

359381
def init_llm_datadist(self):
@@ -517,13 +539,27 @@ def register_kv_caches(self, kv_caches: dict[str, Tuple[torch.Tensor]]):
517539
self.ready_event.wait()
518540

519541
def start_load_kv(self, metadata: LLMDataDistCMgrConnectorMetadata):
542+
futures = []
520543
for req_id, meta in metadata.requests.items():
521544
logger.debug(f"Start to transmit {req_id}")
522-
self._read_blocks(meta.local_block_ids,
523-
meta.remote_block_ids, meta.remote_host,
524-
int(meta.remote_port), meta.engine_id, req_id,
525-
meta.remote_tp_size)
526-
self.finished_reqs.add(req_id)
545+
future = self.executor.submit(
546+
self._read_blocks,
547+
local_block_ids=meta.local_block_ids,
548+
remote_block_ids=meta.remote_block_ids,
549+
remote_ip=meta.remote_host,
550+
remote_port=int(meta.remote_port),
551+
remote_engine_id=meta.engine_id,
552+
request_id=req_id,
553+
remote_tp_size=meta.remote_tp_size,
554+
)
555+
futures.append(future)
556+
557+
def handle_exception(future):
558+
if future.exception():
559+
logger.error(f"KV transfer task failed: {future.exception()}")
560+
561+
for future in futures:
562+
future.add_done_callback(handle_exception)
527563

528564
def add_remote_agent(self, metadata: LLMDataDistCMgrAgentMetadata) -> int:
529565
assert self.local_agent_metadata is not None
@@ -673,7 +709,8 @@ def connect_to_remote_agent(self, host: str, port: int) -> int:
673709
url = f"tcp://{host}:{port}"
674710
logger.debug(f"Querying metadata from url: {url}")
675711
msg_encoder = msgspec.msgpack.Encoder()
676-
msg_send = msg_encoder.encode(self.local_agent_metadata)
712+
msg_send = msg_encoder.encode(
713+
[LLMDataDistCMgrEvent.ReqForMetadata, self.local_agent_metadata])
677714
with zmq_ctx(zmq.REQ, url) as sock: # type: ignore[attr-defined]
678715
logger.info("Try request remote metadata from socket......")
679716
sock.send(msg_send)
@@ -685,6 +722,22 @@ def connect_to_remote_agent(self, host: str, port: int) -> int:
685722
cluster_id = self.add_remote_agent(metadata)
686723
return cluster_id
687724

725+
def send_finsih_to_remote(self, host: str, port: int, request_id):
726+
url = f"tcp://{host}:{port}"
727+
logger.debug(f"Sending finished to remote: {url}")
728+
msg_encoder = msgspec.msgpack.Encoder()
729+
msg_send = msg_encoder.encode(
730+
[LLMDataDistCMgrEvent.ReqForFinished, [request_id]])
731+
with zmq_ctx(zmq.REQ, url) as sock: # type: ignore[attr-defined]
732+
try:
733+
sock.send(msg_send)
734+
logger.debug(
735+
f"Request id {request_id} finished message send to remote {url}"
736+
)
737+
except Exception as e:
738+
logger.error(
739+
f"Failed to send reqest_id {request_id} to prefill: {e}")
740+
688741
def _read_blocks(
689742
self,
690743
local_block_ids: list[int],
@@ -750,14 +803,19 @@ def _read_blocks(
750803
raise RuntimeError(
751804
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
752805
)
806+
self.send_finsih_to_remote(remote_ip, remote_port + tp_offset,
807+
request_id)
808+
with self.thread_lock:
809+
self.finished_reqs.add(request_id)
753810

754811
def get_finished(
755812
self, finished_req_ids: set[str]
756813
) -> tuple[Optional[set[str]], Optional[set[str]]]:
757814
"""Get the finished recving and sending requuests."""
758815
import copy
759-
req_ids_to_ret = copy.deepcopy(self.finished_reqs)
760-
self.finished_reqs.clear()
816+
with self.thread_lock:
817+
req_ids_to_ret = copy.deepcopy(self.finished_reqs)
818+
self.finished_reqs.clear()
761819
if self.llm_datadist_role == LLMRole.PROMPT:
762820
return req_ids_to_ret, None
763821
else:

0 commit comments

Comments
 (0)