Skip to content

Commit 4fcef49

Browse files
sdavidbdDavid Ben-David
andauthored
[V1] [KVConnector] Fix MultiprocExecutor worker output aggregation (#21048)
Signed-off-by: David Ben-David <davidb@pliops.com> Co-authored-by: David Ben-David <davidb@pliops.com>
1 parent 8a4e5c5 commit 4fcef49

File tree

2 files changed

+129
-4
lines changed

2 files changed

+129
-4
lines changed
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import threading
4+
from collections import defaultdict
5+
from concurrent.futures import Future
6+
from typing import Optional
7+
8+
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
9+
from vllm.v1.outputs import ModelRunnerOutput
10+
11+
12+
class DummyMultiprocExecutor(MultiprocExecutor):
13+
14+
def __init__(self, output_rank, world_size):
15+
# Manually initialize minimal required fields
16+
self.output_rank = output_rank
17+
self.world_size = world_size
18+
self._send_remaining_count = defaultdict[str,
19+
int](lambda: self.world_size)
20+
self._recv_remaining_count = defaultdict[str,
21+
int](lambda: self.world_size)
22+
self.io_thread_pool = None
23+
self.shutdown_event = threading.Event()
24+
25+
26+
class DummyModelRunnerOutput(ModelRunnerOutput):
27+
28+
def __init__(self,
29+
finished_sending: Optional[set[str]] = None,
30+
finished_recving: Optional[set[str]] = None):
31+
self.finished_sending = finished_sending
32+
self.finished_recving = finished_recving
33+
34+
35+
def test_aggregate_workers_output():
36+
executor = DummyMultiprocExecutor(output_rank=0, world_size=2)
37+
38+
output1 = DummyModelRunnerOutput(finished_sending={'req1'},
39+
finished_recving={'req2'})
40+
output2 = DummyModelRunnerOutput(finished_sending=None,
41+
finished_recving=None)
42+
43+
aggregated = executor._aggregate_workers_output([output1, output2])
44+
45+
assert aggregated is output1
46+
assert aggregated.finished_sending is None
47+
assert aggregated.finished_recving is None
48+
49+
output1 = DummyModelRunnerOutput(finished_sending=None,
50+
finished_recving=None)
51+
output2 = DummyModelRunnerOutput(finished_sending={'req1'},
52+
finished_recving=None)
53+
54+
aggregated = executor._aggregate_workers_output([output1, output2])
55+
56+
assert aggregated is output1
57+
assert aggregated.finished_sending == {'req1'}
58+
assert aggregated.finished_recving is None
59+
60+
output1 = DummyModelRunnerOutput(finished_sending=None,
61+
finished_recving=None)
62+
output2 = DummyModelRunnerOutput(finished_sending={'req1'},
63+
finished_recving={'req2'})
64+
65+
aggregated = executor._aggregate_workers_output([output1, output2])
66+
67+
assert aggregated is output1
68+
assert aggregated.finished_sending is None
69+
assert aggregated.finished_recving == {'req2'}
70+
71+
72+
def test_async_aggregate_workers_output():
73+
executor = DummyMultiprocExecutor(output_rank=0, world_size=2)
74+
75+
future1: Future[DummyModelRunnerOutput] = Future()
76+
future2: Future[DummyModelRunnerOutput] = Future()
77+
result_future = executor._async_aggregate_workers_output(
78+
[future1, future2])
79+
80+
output1 = DummyModelRunnerOutput(finished_sending={'req1'},
81+
finished_recving={'req2'})
82+
output2 = DummyModelRunnerOutput(finished_sending=None,
83+
finished_recving=None)
84+
future1.set_result(output1)
85+
future2.set_result(output2)
86+
87+
assert result_future.done()
88+
aggregated = result_future.result()
89+
assert aggregated is output1
90+
assert aggregated.finished_sending is None
91+
assert aggregated.finished_recving is None
92+
93+
future1 = Future()
94+
future2 = Future()
95+
result_future = executor._async_aggregate_workers_output(
96+
[future1, future2])
97+
98+
output1 = DummyModelRunnerOutput(finished_sending=None,
99+
finished_recving=None)
100+
output2 = DummyModelRunnerOutput(finished_sending={'req1'},
101+
finished_recving=None)
102+
future1.set_result(output1)
103+
future2.set_result(output2)
104+
105+
assert result_future.done()
106+
aggregated = result_future.result()
107+
assert aggregated is output1
108+
assert aggregated.finished_sending == {'req1'}
109+
assert aggregated.finished_recving is None
110+
111+
future1 = Future()
112+
future2 = Future()
113+
result_future = executor._async_aggregate_workers_output(
114+
[future1, future2])
115+
116+
output1 = DummyModelRunnerOutput(finished_sending=None,
117+
finished_recving=None)
118+
output2 = DummyModelRunnerOutput(finished_sending={'req1'},
119+
finished_recving={'req2'})
120+
future1.set_result(output1)
121+
future2.set_result(output2)
122+
123+
assert result_future.done()
124+
aggregated = result_future.result()
125+
assert aggregated is output1
126+
assert aggregated.finished_sending is None
127+
assert aggregated.finished_recving == {'req2'}

vllm/v1/executor/multiproc_executor.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -273,10 +273,8 @@ def update_finished_set(req_ids: Optional[set[str]],
273273
output = outputs[self.output_rank]
274274

275275
# set the aggregated finished_sending / finished_recving
276-
if finished_sending:
277-
output.finished_sending = finished_sending
278-
if finished_recving:
279-
output.finished_recving = finished_recving
276+
output.finished_sending = finished_sending if finished_sending else None
277+
output.finished_recving = finished_recving if finished_recving else None
280278

281279
return output
282280

0 commit comments

Comments
 (0)