|
3 | 3 | import math
|
4 | 4 | import threading
|
5 | 5 | import time
|
| 6 | +from collections import defaultdict |
6 | 7 | from collections.abc import Iterator
|
7 | 8 | from concurrent.futures import ThreadPoolExecutor
|
8 | 9 | from dataclasses import dataclass
|
@@ -268,9 +269,11 @@ def request_finished(
|
268 | 269 | # we just transfer any data that computed from prefill node
|
269 | 270 | # note: there might be some issue on this, check it if there is any unexpected result
|
270 | 271 | computed_block_ids = block_ids
|
271 |
| - # If prompt < block_size, no xfer so free blocks immediately. |
272 |
| - |
273 |
| - return False, dict( |
| 272 | + delay_free_blocks = len(computed_block_ids) > 0 |
| 273 | + if delay_free_blocks: |
| 274 | + logger.info("Delaying free of %d blocks for request %s", |
| 275 | + len(computed_block_ids), request.request_id) |
| 276 | + return delay_free_blocks, dict( |
274 | 277 | do_remote_prefill=True,
|
275 | 278 | do_remote_decode=False,
|
276 | 279 | remote_block_ids=computed_block_ids,
|
@@ -334,6 +337,8 @@ def __init__(self, vllm_config: VllmConfig):
|
334 | 337 | self.init_llm_datadist()
|
335 | 338 | self.finished_reqs: set[str] = set()
|
336 | 339 | self.soc_info = NPUSocInfo()
|
| 340 | + self.done_receiving_counts: defaultdict[str, |
| 341 | + set[int]] = defaultdict(set) |
337 | 342 |
|
338 | 343 | def listen_for_agent_metadata_req(self, event: threading.Event):
|
339 | 344 | assert self.local_agent_metadata is not None
|
@@ -368,16 +373,41 @@ def listen_for_agent_metadata_req(self, event: threading.Event):
|
368 | 373 | )
|
369 | 374 | elif event_msg == LLMDataDistCMgrEvent.ReqForFinished:
|
370 | 375 | finished_req_id = decode_msg[0]
|
| 376 | + decode_tp_rank = decode_msg[1] |
| 377 | + decode_tp_size = decode_msg[2] |
371 | 378 | with self.thread_lock:
|
372 |
| - logger.debug( |
373 |
| - f"LLMDataDistCMgrConnectorWorker: Receiving request {finished_req_id} finished" |
374 |
| - ) |
375 |
| - self.finished_reqs.add(finished_req_id) |
| 379 | + if self._increment_task_count(finished_req_id, |
| 380 | + decode_tp_rank, |
| 381 | + decode_tp_size): |
| 382 | + logger.debug( |
| 383 | + f"LLMDataDistCMgrConnectorWorker: Receiving request {finished_req_id} finished" |
| 384 | + ) |
| 385 | + self.finished_reqs.add(finished_req_id) |
| 386 | + sock.send_multipart( |
| 387 | + (identity, b"", b"receiving decode finished")) |
376 | 388 | else:
|
377 | 389 | raise RuntimeError(
|
378 | 390 | f"LLMDataDistCMgrConnectorWorker: Receiving unexpected request event {event_msg} from remote !"
|
379 | 391 | )
|
380 | 392 |
|
| 393 | + def _increment_task_count(self, request_id: str, tp_rank: int, |
| 394 | + decode_tp_size: int): |
| 395 | + if request_id not in self.done_receiving_counts: |
| 396 | + self.done_receiving_counts[request_id] = set() |
| 397 | + if tp_rank in self.done_receiving_counts[request_id]: |
| 398 | + logger.warning( |
| 399 | + f"Received duplicate done signal for request {request_id} " |
| 400 | + f"from tp rank {tp_rank}. Ignoring.") |
| 401 | + return False |
| 402 | + self.done_receiving_counts[request_id].add(tp_rank) |
| 403 | + if len(self.done_receiving_counts[request_id]) == decode_tp_size: |
| 404 | + self.done_receiving_counts.pop(request_id) |
| 405 | + logger.info("All transfers completed for request: " |
| 406 | + f"{request_id}. Total ranks: " |
| 407 | + f"{decode_tp_size}.") |
| 408 | + return True |
| 409 | + return False |
| 410 | + |
381 | 411 | def init_llm_datadist(self):
|
382 | 412 | assert self.local_agent_metadata is not None
|
383 | 413 | llm_config = LLMConfig()
|
@@ -722,18 +752,21 @@ def connect_to_remote_agent(self, host: str, port: int) -> int:
|
722 | 752 | cluster_id = self.add_remote_agent(metadata)
|
723 | 753 | return cluster_id
|
724 | 754 |
|
725 |
| - def send_finsih_to_remote(self, host: str, port: int, request_id): |
| 755 | + def send_finish_to_remote(self, host: str, port: int, request_id): |
726 | 756 | url = f"tcp://{host}:{port}"
|
727 | 757 | logger.debug(f"Sending finished to remote: {url}")
|
728 | 758 | msg_encoder = msgspec.msgpack.Encoder()
|
729 |
| - msg_send = msg_encoder.encode( |
730 |
| - [LLMDataDistCMgrEvent.ReqForFinished, [request_id]]) |
| 759 | + msg_send = msg_encoder.encode([ |
| 760 | + LLMDataDistCMgrEvent.ReqForFinished, |
| 761 | + [request_id, self.tp_rank, self.tp_size] |
| 762 | + ]) |
731 | 763 | with zmq_ctx(zmq.REQ, url) as sock: # type: ignore[attr-defined]
|
732 | 764 | try:
|
733 | 765 | sock.send(msg_send)
|
734 | 766 | logger.debug(
|
735 | 767 | f"Request id {request_id} finished message send to remote {url}"
|
736 | 768 | )
|
| 769 | + _ = sock.recv() |
737 | 770 | except Exception as e:
|
738 | 771 | logger.error(
|
739 | 772 | f"Failed to send reqest_id {request_id} to prefill: {e}")
|
@@ -803,8 +836,7 @@ def _read_blocks(
|
803 | 836 | raise RuntimeError(
|
804 | 837 | "LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
|
805 | 838 | )
|
806 |
| - self.send_finsih_to_remote(remote_ip, remote_port + tp_offset, |
807 |
| - request_id) |
| 839 | + self.send_finish_to_remote(remote_ip, remote_port, request_id) |
808 | 840 | with self.thread_lock:
|
809 | 841 | self.finished_reqs.add(request_id)
|
810 | 842 |
|
|
0 commit comments