Skip to content

Commit cc876d0

Browse files
authored
[KVConnector] Aggregate finished requests on the scheduler (#19555)
Signed-off-by: Or Ozeri <oro@il.ibm.com>
1 parent fdfd409 commit cc876d0

File tree

5 files changed

+139
-110
lines changed

5 files changed

+139
-110
lines changed

vllm/distributed/kv_transfer/kv_connector/v1/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,9 @@ def get_finished(
190190
) -> tuple[Optional[set[str]], Optional[set[str]]]:
191191
"""
192192
Notifies worker-side connector ids of requests that have
193-
finished generating tokens.
193+
finished generating tokens on the worker.
194+
The scheduler process (via the MultiprocExecutor) will use this output
195+
to track which workers are done.
194196
195197
Returns:
196198
ids of requests that have finished asynchronous transfer

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 4 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -408,14 +408,6 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
408408
# Track the expiration time of requests that are waiting to be sent.
409409
self._reqs_to_send: dict[ReqId, float] = {}
410410

411-
# Complete transfer tracker. Used by the rank 0 to track finished
412-
# transactions on ranks 1 to N-1.
413-
# [req_id -> count]
414-
self._done_recving_count: defaultdict[ReqId,
415-
int] = defaultdict(lambda: 0)
416-
self._done_sending_count: defaultdict[ReqId,
417-
int] = defaultdict(lambda: 0)
418-
419411
# Background thread for handling new handshake requests.
420412
self._nixl_handshake_listener_t: Optional[threading.Thread] = None
421413
# Background thread for initializing new NIXL handshakes.
@@ -830,15 +822,9 @@ def add_remote_agent(self,
830822

831823
def get_finished(self) -> tuple[set[str], set[str]]:
832824
"""
833-
Get requests that are done sending or recving.
834-
835-
In TP>1 setup, each rank exchanges KVs with its counterpart
836-
ranks independently. get_finished() runs in a worker creates
837-
the done_sending and done_recving sets that are sent to the
838-
scheduler via ModelRunnerOutput by Rank 0. To ensure trnxs
839-
are done before adding to finished, Ranks 1 to N-1 communicate
840-
to Rank 0 once their transaction is done + Rank 0 returns
841-
finished sets to Scheduler only once all ranks are done.
825+
Get requests that are done sending or recving on this specific worker.
826+
The scheduler process (via the MultiprocExecutor) will use this output
827+
to track which workers are done.
842828
"""
843829
done_sending = self._get_new_notifs()
844830
done_recving = self._pop_done_transfers(self._recving_transfers)
@@ -858,50 +844,7 @@ def get_finished(self) -> tuple[set[str], set[str]]:
858844
del self._reqs_to_send[req_id]
859845
done_sending.add(req_id)
860846

861-
if self.world_size == 1:
862-
return done_sending, done_recving
863-
864-
# Rank 0: get finished from all other ranks.
865-
if self.tp_rank == 0:
866-
for req_id in done_sending:
867-
self._done_sending_count[req_id] += 1
868-
for req_id in done_recving:
869-
self._done_recving_count[req_id] += 1
870-
871-
# Keep track of how many other ranks have finished.
872-
other_ranks_finished_ids: list[str] = []
873-
for i in range(1, self.world_size):
874-
other_ranks_finished_ids.extend(
875-
self.tp_group.recv_object(src=i))
876-
for req_id in other_ranks_finished_ids:
877-
if (req_id in self._done_recving_count
878-
or req_id in self._recving_transfers):
879-
self._done_recving_count[req_id] += 1
880-
else:
881-
self._done_sending_count[req_id] += 1
882-
883-
# Return ids that finished on all ranks to the scheduler.
884-
all_done_recving: set[str] = set()
885-
for req_id in list(self._done_recving_count.keys()):
886-
if self._done_recving_count[req_id] == self.world_size:
887-
del self._done_recving_count[req_id]
888-
all_done_recving.add(req_id)
889-
890-
all_done_sending: set[str] = set()
891-
for req_id in list(self._done_sending_count.keys()):
892-
if self._done_sending_count[req_id] >= self.world_size:
893-
del self._done_sending_count[req_id]
894-
all_done_sending.add(req_id)
895-
896-
return all_done_sending, all_done_recving
897-
898-
# Ranks 1 to N-1: send finished ids to Rank 0.
899-
else:
900-
finished_req_ids = list(done_recving.union(done_sending))
901-
self.tp_group.send_object(finished_req_ids, dst=0)
902-
903-
# Unused as only Rank 0 results are sent to scheduler.
904-
return done_sending, done_recving
847+
return done_sending, done_recving
905848

906849
def _get_new_notifs(self) -> set[str]:
907850
"""

vllm/v1/executor/multiproc_executor.py

Lines changed: 105 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
import time
1010
import traceback
1111
import weakref
12-
from concurrent.futures import Future, ThreadPoolExecutor
12+
from collections import defaultdict
13+
from concurrent.futures import CancelledError, Future, ThreadPoolExecutor
1314
from dataclasses import dataclass
1415
from enum import Enum, auto
1516
from functools import partial
@@ -111,10 +112,19 @@ def _init_executor(self) -> None:
111112
if self.max_concurrent_batches > 1:
112113
# Note: must use only 1 IO thread to keep dequeue sequence
113114
# from the response queue
115+
# _async_aggregate_workers_output also assumes a single IO thread
114116
self.io_thread_pool = ThreadPoolExecutor(
115117
max_workers=1, thread_name_prefix="mp_exec_io")
116118

117119
self.output_rank = self._get_output_rank()
120+
self.has_connector = self.vllm_config.kv_transfer_config is not None
121+
122+
# Complete transfer tracker. Used by to track finished requests
123+
# [req_id -> n_finished_workers]
124+
self._recv_remaining_count = defaultdict[str,
125+
int](lambda: self.world_size)
126+
self._send_remaining_count = defaultdict[str,
127+
int](lambda: self.world_size)
118128

119129
def start_worker_monitor(self):
120130
workers = self.workers
@@ -155,13 +165,29 @@ def execute_model(
155165
self,
156166
scheduler_output,
157167
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
158-
(output, ) = self.collective_rpc(
168+
non_block = self.max_concurrent_batches > 1
169+
170+
if not self.has_connector:
171+
# get output only from a single worker (output_rank)
172+
(output, ) = self.collective_rpc(
173+
"execute_model",
174+
args=(scheduler_output, ),
175+
unique_reply_rank=self.output_rank,
176+
non_block=non_block,
177+
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS)
178+
return output
179+
180+
# get output from all workers
181+
outputs = self.collective_rpc(
159182
"execute_model",
160183
args=(scheduler_output, ),
161-
unique_reply_rank=self.output_rank,
162-
non_block=self.max_concurrent_batches > 1,
184+
non_block=non_block,
163185
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS)
164-
return output
186+
187+
# aggregate all workers output to a single output
188+
if non_block:
189+
return self._async_aggregate_workers_output(outputs)
190+
return self._aggregate_workers_output(outputs)
165191

166192
def collective_rpc(self,
167193
method: Union[str, Callable],
@@ -220,6 +246,80 @@ def get_response(w: WorkerProcHandle,
220246
except TimeoutError as e:
221247
raise TimeoutError(f"RPC call to {method} timed out.") from e
222248

249+
def _aggregate_workers_output(
250+
self, outputs: list[ModelRunnerOutput]) -> ModelRunnerOutput:
251+
# aggregate finished_sending, finished_recving from all workers
252+
253+
finished_sending = set[str]()
254+
finished_recving = set[str]()
255+
for output in outputs:
256+
# update finished_sending
257+
for req_id in output.finished_sending or []:
258+
new_count = self._send_remaining_count[req_id] - 1
259+
if new_count == 0:
260+
# got response from all workers, report back to scheduler
261+
finished_sending.add(req_id)
262+
del self._send_remaining_count[req_id]
263+
else:
264+
self._send_remaining_count[req_id] = new_count
265+
266+
# update finished_recving
267+
for req_id in output.finished_recving or []:
268+
new_count = self._recv_remaining_count[req_id] - 1
269+
if new_count == 0:
270+
# got response from all workers, report back to scheduler
271+
finished_recving.add(req_id)
272+
del self._recv_remaining_count[req_id]
273+
else:
274+
self._recv_remaining_count[req_id] = new_count
275+
276+
# select output of the worker specified by output_rank
277+
output = outputs[self.output_rank]
278+
279+
# set the aggregated finished_sending / finished_recving
280+
if finished_sending:
281+
output.finished_sending = finished_sending
282+
if finished_recving:
283+
output.finished_recving = finished_recving
284+
285+
return output
286+
287+
def _async_aggregate_workers_output(
288+
self, output_futures: list[Future[ModelRunnerOutput]]
289+
) -> (Future[ModelRunnerOutput]):
290+
"""Takes a list of futures and returns a single future which resolves
291+
to the respective list of outputs."""
292+
result_future: Future[ModelRunnerOutput] = Future()
293+
294+
outputs: list[Optional[ModelRunnerOutput]] = [None
295+
] * len(output_futures)
296+
297+
def make_callback(idx):
298+
299+
def callback(fut):
300+
if result_future.done():
301+
return
302+
303+
try:
304+
outputs[idx] = fut.result()
305+
except CancelledError:
306+
result_future.cancel()
307+
except Exception as e:
308+
result_future.set_exception(e)
309+
310+
# this check assumes io_thread_pool uses a single thread
311+
if all(outputs):
312+
result_future.set_result(
313+
self._aggregate_workers_output(
314+
cast(list[ModelRunnerOutput], outputs)))
315+
316+
return callback
317+
318+
for i, output_future in enumerate(output_futures):
319+
output_future.add_done_callback(make_callback(i))
320+
321+
return result_future
322+
223323
@staticmethod
224324
def _ensure_worker_termination(worker_procs: list[BaseProcess]):
225325
"""Ensure that all worker processes are terminated. Assumes workers have

vllm/v1/worker/gpu_model_runner.py

Lines changed: 6 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4-
import copy
54
import gc
65
import time
76
import weakref
@@ -1234,8 +1233,6 @@ def _pool(
12341233
hidden_states: torch.Tensor,
12351234
num_scheduled_tokens: int,
12361235
num_scheduled_tokens_np: np.ndarray,
1237-
finished_sending: Optional[set[str]],
1238-
finished_recving: Optional[set[str]],
12391236
) -> ModelRunnerOutput:
12401237
assert self.input_batch.num_reqs ==\
12411238
len(self.input_batch.pooling_params), \
@@ -1270,8 +1267,6 @@ def _pool(
12701267
logprobs=None,
12711268
prompt_logprobs_dict={},
12721269
pooler_output=pooler_output,
1273-
finished_sending=finished_sending,
1274-
finished_recving=finished_recving,
12751270
)
12761271

12771272
@torch.inference_mode()
@@ -1282,11 +1277,12 @@ def execute_model(
12821277
) -> Union[ModelRunnerOutput, IntermediateTensors]:
12831278
self._update_states(scheduler_output)
12841279
if not scheduler_output.total_num_scheduled_tokens:
1285-
if not has_kv_transfer_group():
1286-
# Return empty ModelRunnerOutput if there's no work to do.
1287-
return EMPTY_MODEL_RUNNER_OUTPUT
1280+
if has_kv_transfer_group():
1281+
with set_forward_context(None, self.vllm_config):
1282+
self.maybe_setup_kv_connector(scheduler_output)
12881283

1289-
return self.kv_connector_no_forward(scheduler_output)
1284+
# Return empty ModelRunnerOutput if there's no work to do.
1285+
return EMPTY_MODEL_RUNNER_OUTPUT
12901286

12911287
# Prepare the decoder inputs.
12921288
(attn_metadata, attention_cuda_graphs, logits_indices,
@@ -1379,8 +1375,6 @@ def execute_model(
13791375
)
13801376

13811377
self.maybe_wait_for_kv_save()
1382-
finished_sending, finished_recving = (
1383-
self.get_finished_kv_transfers(scheduler_output))
13841378

13851379
if self.use_aux_hidden_state_outputs:
13861380
hidden_states, aux_hidden_states = model_output
@@ -1406,8 +1400,7 @@ def execute_model(
14061400
else:
14071401
if self.input_batch.pooling_params:
14081402
return self._pool(hidden_states, num_scheduled_tokens,
1409-
num_scheduled_tokens_np, finished_sending,
1410-
finished_recving)
1403+
num_scheduled_tokens_np)
14111404

14121405
sample_hidden_states = hidden_states[logits_indices]
14131406
logits = self.model.compute_logits(sample_hidden_states, None)
@@ -1560,8 +1553,6 @@ def execute_model(
15601553
logprobs=logprobs_lists,
15611554
prompt_logprobs_dict=prompt_logprobs_dict,
15621555
pooler_output=[],
1563-
finished_sending=finished_sending,
1564-
finished_recving=finished_recving,
15651556
num_nans_in_logits=num_nans_in_logits,
15661557
)
15671558

@@ -1686,22 +1677,6 @@ def propose_draft_token_ids(
16861677
spec_token_ids = draft_token_ids.tolist()
16871678
return spec_token_ids
16881679

1689-
def kv_connector_no_forward(
1690-
self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
1691-
# KV send/recv even if no work to do.
1692-
with set_forward_context(None, self.vllm_config):
1693-
self.maybe_setup_kv_connector(scheduler_output)
1694-
finished_sending, finished_recving = (
1695-
self.get_finished_kv_transfers(scheduler_output))
1696-
1697-
if not finished_sending and not finished_recving:
1698-
return EMPTY_MODEL_RUNNER_OUTPUT
1699-
1700-
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
1701-
output.finished_sending = finished_sending
1702-
output.finished_recving = finished_recving
1703-
return output
1704-
17051680
@staticmethod
17061681
def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"):
17071682
# Update KVConnector with the KVConnector metadata forward().
@@ -1723,15 +1698,6 @@ def maybe_wait_for_kv_save() -> None:
17231698
if has_kv_transfer_group():
17241699
get_kv_transfer_group().wait_for_save()
17251700

1726-
@staticmethod
1727-
def get_finished_kv_transfers(
1728-
scheduler_output: "SchedulerOutput",
1729-
) -> tuple[Optional[set[str]], Optional[set[str]]]:
1730-
if has_kv_transfer_group():
1731-
return get_kv_transfer_group().get_finished(
1732-
scheduler_output.finished_req_ids)
1733-
return None, None
1734-
17351701
def propose_ngram_draft_token_ids(
17361702
self,
17371703
sampled_token_ids: list[list[int]],

0 commit comments

Comments
 (0)