Skip to content

[0.9.1][Perf] Launch load kv task asynchronously with thread pool. #1612

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 76 additions & 18 deletions vllm_ascend/distributed/llmdatadist_c_mgr_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import threading
import time
from collections.abc import Iterator
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from enum import Enum
from typing import Any, Optional, Tuple

import llm_datadist # type: ignore
Expand Down Expand Up @@ -38,6 +40,11 @@
}


class LLMDataDistCMgrEvent(Enum):
ReqForMetadata = 0
ReqForFinished = 1


class LLMDataDistCMgrAgentMetadata(msgspec.Struct):
super_pod_id: str
server_id: str
Expand Down Expand Up @@ -298,6 +305,8 @@ def __init__(self, vllm_config: VllmConfig):
self.local_agent_metadata: Optional[
LLMDataDistCMgrAgentMetadata] = None
self.vllm_config = vllm_config
self.executor = ThreadPoolExecutor(1)
self.thread_lock = threading.Lock()

self.llm_datadist_role = None
self.llm_datadist_remote_role = None
Expand Down Expand Up @@ -343,17 +352,30 @@ def listen_for_agent_metadata_req(self, event: threading.Event):
event.set()
while True:
identity, _, msg = sock.recv_multipart()
decode_msg = msg_decoder.decode(msg)
if "cluster_id" in decode_msg:
decode_msg = LLMDataDistCMgrAgentMetadata(**decode_msg)
logger.info(
f"LLMDataDistCMgrConnectorWorker: Receive message from cluster {decode_msg.cluster_id}"
)
sock.send_multipart((identity, b"", msg_to_send))
self.add_remote_agent(decode_msg)
event_msg, decode_msg = msg_decoder.decode(msg)
event_msg = LLMDataDistCMgrEvent(event_msg)
if event_msg == LLMDataDistCMgrEvent.ReqForMetadata:
if "cluster_id" in decode_msg:
decode_msg = LLMDataDistCMgrAgentMetadata(**decode_msg)
logger.info(
f"LLMDataDistCMgrConnectorWorker: Receive message from cluster {decode_msg.cluster_id}"
)
sock.send_multipart((identity, b"", msg_to_send))
self.add_remote_agent(decode_msg)
else:
logger.warning(
f"LLMDataDistCMgrConnectorWorker: receiving unrecognized data {decode_msg}"
)
elif event_msg == LLMDataDistCMgrEvent.ReqForFinished:
finished_req_id = decode_msg[0]
with self.thread_lock:
logger.debug(
f"LLMDataDistCMgrConnectorWorker: Receiving request {finished_req_id} finished"
)
self.finished_reqs.add(finished_req_id)
else:
logger.warning(
f"LLMDataDistCMgrConnectorWorker: receiving unrecognized data {decode_msg}"
raise RuntimeError(
f"LLMDataDistCMgrConnectorWorker: Receiving unexpected request event {event_msg} from remote !"
)

def init_llm_datadist(self):
Expand Down Expand Up @@ -517,13 +539,27 @@ def register_kv_caches(self, kv_caches: dict[str, Tuple[torch.Tensor]]):
self.ready_event.wait()

def start_load_kv(self, metadata: LLMDataDistCMgrConnectorMetadata):
futures = []
for req_id, meta in metadata.requests.items():
logger.debug(f"Start to transmit {req_id}")
self._read_blocks(meta.local_block_ids,
meta.remote_block_ids, meta.remote_host,
int(meta.remote_port), meta.engine_id, req_id,
meta.remote_tp_size)
self.finished_reqs.add(req_id)
future = self.executor.submit(
self._read_blocks,
local_block_ids=meta.local_block_ids,
remote_block_ids=meta.remote_block_ids,
remote_ip=meta.remote_host,
remote_port=int(meta.remote_port),
remote_engine_id=meta.engine_id,
request_id=req_id,
remote_tp_size=meta.remote_tp_size,
)
futures.append(future)

def handle_exception(future):
if future.exception():
logger.error(f"KV transfer task failed: {future.exception()}")

for future in futures:
future.add_done_callback(handle_exception)

def add_remote_agent(self, metadata: LLMDataDistCMgrAgentMetadata) -> int:
assert self.local_agent_metadata is not None
Expand Down Expand Up @@ -673,7 +709,8 @@ def connect_to_remote_agent(self, host: str, port: int) -> int:
url = f"tcp://{host}:{port}"
logger.debug(f"Querying metadata from url: {url}")
msg_encoder = msgspec.msgpack.Encoder()
msg_send = msg_encoder.encode(self.local_agent_metadata)
msg_send = msg_encoder.encode(
[LLMDataDistCMgrEvent.ReqForMetadata, self.local_agent_metadata])
with zmq_ctx(zmq.REQ, url) as sock: # type: ignore[attr-defined]
logger.info("Try request remote metadata from socket......")
sock.send(msg_send)
Expand All @@ -685,6 +722,22 @@ def connect_to_remote_agent(self, host: str, port: int) -> int:
cluster_id = self.add_remote_agent(metadata)
return cluster_id

def send_finsih_to_remote(self, host: str, port: int, request_id):
url = f"tcp://{host}:{port}"
logger.debug(f"Sending finished to remote: {url}")
msg_encoder = msgspec.msgpack.Encoder()
msg_send = msg_encoder.encode(
[LLMDataDistCMgrEvent.ReqForFinished, [request_id]])
with zmq_ctx(zmq.REQ, url) as sock: # type: ignore[attr-defined]
try:
sock.send(msg_send)
logger.debug(
f"Request id {request_id} finished message send to remote {url}"
)
except Exception as e:
logger.error(
f"Failed to send reqest_id {request_id} to prefill: {e}")

def _read_blocks(
self,
local_block_ids: list[int],
Expand Down Expand Up @@ -750,14 +803,19 @@ def _read_blocks(
raise RuntimeError(
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
)
self.send_finsih_to_remote(remote_ip, remote_port + tp_offset,
request_id)
with self.thread_lock:
self.finished_reqs.add(request_id)

def get_finished(
self, finished_req_ids: set[str]
) -> tuple[Optional[set[str]], Optional[set[str]]]:
"""Get the finished recving and sending requuests."""
import copy
req_ids_to_ret = copy.deepcopy(self.finished_reqs)
self.finished_reqs.clear()
with self.thread_lock:
req_ids_to_ret = copy.deepcopy(self.finished_reqs)
self.finished_reqs.clear()
if self.llm_datadist_role == LLMRole.PROMPT:
return req_ids_to_ret, None
else:
Expand Down