Skip to content

Commit 6b188ed

Browse files
authored
[0.9.1][PD] Added support for delay-free blocks in prefill nodes (#1691)
### What this PR does / why we need it? PD Logic Analysis: In the current implementation, the P-node immediately releases memory blocks after completing inference. Under high concurrency scenarios, if the P-node's inference speed significantly outpaces the D-node's block pulling operations, this leads to memory block contention/corruption issues. Current Solution: The D-node sends acknowledgment messages to the worker connector in the P-node's driver worker after completing data reception. The P-node maintains a counter to track these acknowledgments - memory blocks are only released after receiving confirmations from all D-node worker connectors involved in the KV cache transfer. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? --------- Signed-off-by: underfituu <hzhucong@163.com>
1 parent 9a5e650 commit 6b188ed

File tree

1 file changed

+44
-12
lines changed

1 file changed

+44
-12
lines changed

vllm_ascend/distributed/llmdatadist_c_mgr_connector.py

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import math
44
import threading
55
import time
6+
from collections import defaultdict
67
from collections.abc import Iterator
78
from concurrent.futures import ThreadPoolExecutor
89
from dataclasses import dataclass
@@ -268,9 +269,11 @@ def request_finished(
268269
# we just transfer any data that computed from prefill node
269270
# note: there might be some issue on this, check it if there is any unexpected result
270271
computed_block_ids = block_ids
271-
# If prompt < block_size, no xfer so free blocks immediately.
272-
273-
return False, dict(
272+
delay_free_blocks = len(computed_block_ids) > 0
273+
if delay_free_blocks:
274+
logger.info("Delaying free of %d blocks for request %s",
275+
len(computed_block_ids), request.request_id)
276+
return delay_free_blocks, dict(
274277
do_remote_prefill=True,
275278
do_remote_decode=False,
276279
remote_block_ids=computed_block_ids,
@@ -334,6 +337,8 @@ def __init__(self, vllm_config: VllmConfig):
334337
self.init_llm_datadist()
335338
self.finished_reqs: set[str] = set()
336339
self.soc_info = NPUSocInfo()
340+
self.done_receiving_counts: defaultdict[str,
341+
set[int]] = defaultdict(set)
337342

338343
def listen_for_agent_metadata_req(self, event: threading.Event):
339344
assert self.local_agent_metadata is not None
@@ -368,16 +373,41 @@ def listen_for_agent_metadata_req(self, event: threading.Event):
368373
)
369374
elif event_msg == LLMDataDistCMgrEvent.ReqForFinished:
370375
finished_req_id = decode_msg[0]
376+
decode_tp_rank = decode_msg[1]
377+
decode_tp_size = decode_msg[2]
371378
with self.thread_lock:
372-
logger.debug(
373-
f"LLMDataDistCMgrConnectorWorker: Receiving request {finished_req_id} finished"
374-
)
375-
self.finished_reqs.add(finished_req_id)
379+
if self._increment_task_count(finished_req_id,
380+
decode_tp_rank,
381+
decode_tp_size):
382+
logger.debug(
383+
f"LLMDataDistCMgrConnectorWorker: Receiving request {finished_req_id} finished"
384+
)
385+
self.finished_reqs.add(finished_req_id)
386+
sock.send_multipart(
387+
(identity, b"", b"receiving decode finished"))
376388
else:
377389
raise RuntimeError(
378390
f"LLMDataDistCMgrConnectorWorker: Receiving unexpected request event {event_msg} from remote !"
379391
)
380392

393+
def _increment_task_count(self, request_id: str, tp_rank: int,
394+
decode_tp_size: int):
395+
if request_id not in self.done_receiving_counts:
396+
self.done_receiving_counts[request_id] = set()
397+
if tp_rank in self.done_receiving_counts[request_id]:
398+
logger.warning(
399+
f"Received duplicate done signal for request {request_id} "
400+
f"from tp rank {tp_rank}. Ignoring.")
401+
return False
402+
self.done_receiving_counts[request_id].add(tp_rank)
403+
if len(self.done_receiving_counts[request_id]) == decode_tp_size:
404+
self.done_receiving_counts.pop(request_id)
405+
logger.info("All transfers completed for request: "
406+
f"{request_id}. Total ranks: "
407+
f"{decode_tp_size}.")
408+
return True
409+
return False
410+
381411
def init_llm_datadist(self):
382412
assert self.local_agent_metadata is not None
383413
llm_config = LLMConfig()
@@ -722,18 +752,21 @@ def connect_to_remote_agent(self, host: str, port: int) -> int:
722752
cluster_id = self.add_remote_agent(metadata)
723753
return cluster_id
724754

725-
def send_finsih_to_remote(self, host: str, port: int, request_id):
755+
def send_finish_to_remote(self, host: str, port: int, request_id):
726756
url = f"tcp://{host}:{port}"
727757
logger.debug(f"Sending finished to remote: {url}")
728758
msg_encoder = msgspec.msgpack.Encoder()
729-
msg_send = msg_encoder.encode(
730-
[LLMDataDistCMgrEvent.ReqForFinished, [request_id]])
759+
msg_send = msg_encoder.encode([
760+
LLMDataDistCMgrEvent.ReqForFinished,
761+
[request_id, self.tp_rank, self.tp_size]
762+
])
731763
with zmq_ctx(zmq.REQ, url) as sock: # type: ignore[attr-defined]
732764
try:
733765
sock.send(msg_send)
734766
logger.debug(
735767
f"Request id {request_id} finished message send to remote {url}"
736768
)
769+
_ = sock.recv()
737770
except Exception as e:
738771
logger.error(
739772
f"Failed to send reqest_id {request_id} to prefill: {e}")
@@ -803,8 +836,7 @@ def _read_blocks(
803836
raise RuntimeError(
804837
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
805838
)
806-
self.send_finsih_to_remote(remote_ip, remote_port + tp_offset,
807-
request_id)
839+
self.send_finish_to_remote(remote_ip, remote_port, request_id)
808840
with self.thread_lock:
809841
self.finished_reqs.add(request_id)
810842

0 commit comments

Comments
 (0)