Skip to content

Commit 9d72daf

Browse files
[V1][Perf] Simpler request output queues (vllm-project#15156)
Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com> Co-authored-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
1 parent 6dd55af commit 9d72daf

File tree

3 files changed

+146
-25
lines changed

3 files changed

+146
-25
lines changed

tests/v1/engine/test_output_processor.py

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@
1111
STOP_STRINGS,
1212
DummyOutputProcessorTestVectors,
1313
MockEngineCore)
14+
from vllm.outputs import CompletionOutput, RequestOutput
1415
from vllm.sampling_params import RequestOutputKind, SamplingParams
1516
from vllm.sequence import PromptLogprobs, SampleLogprobs
1617
from vllm.transformers_utils.tokenizer import AnyTokenizer
1718
from vllm.v1.engine import EngineCoreRequest
18-
from vllm.v1.engine.output_processor import OutputProcessor
19+
from vllm.v1.engine.output_processor import (OutputProcessor,
20+
RequestOutputCollector)
1921
from vllm.v1.metrics.stats import IterationStats
2022

2123

@@ -834,3 +836,88 @@ def test_iteration_stats(dummy_test_vectors):
834836

835837
assert iteration_stats.num_prompt_tokens == 0
836838
assert iteration_stats.num_generation_tokens == num_active
839+
840+
841+
@pytest.mark.asyncio
842+
async def test_request_output_collector():
843+
NUM_REQS = 3
844+
TEXT = "a"
845+
846+
def make_outputs() -> list[RequestOutput]:
847+
return [
848+
RequestOutput(
849+
request_id="my-request-id",
850+
prompt=None,
851+
prompt_token_ids=[1, 2, 3],
852+
prompt_logprobs=None,
853+
outputs=[
854+
CompletionOutput(
855+
index=0,
856+
text=TEXT,
857+
token_ids=[idx],
858+
cumulative_logprob=(idx + 1 * 1.0),
859+
logprobs=[{
860+
"a": idx,
861+
"b": idx
862+
}],
863+
finish_reason="length" if
864+
(idx == NUM_REQS - 1) else None,
865+
)
866+
],
867+
finished=(idx == NUM_REQS - 1),
868+
) for idx in range(NUM_REQS)
869+
]
870+
871+
collector = RequestOutputCollector(RequestOutputKind.DELTA)
872+
873+
# CASE 1: Put then get.
874+
outputs = make_outputs()
875+
collector.put(outputs[0])
876+
output = await collector.get()
877+
assert not collector.ready.is_set()
878+
assert collector.output is None
879+
assert output.outputs[0].text == "a"
880+
assert output.outputs[0].token_ids == [0]
881+
882+
# CASE 2: 2 puts then get.
883+
num_to_put = 2
884+
outputs = make_outputs()
885+
for i in range(num_to_put):
886+
collector.put(outputs[i])
887+
output = await collector.get()
888+
assert not collector.ready.is_set()
889+
assert collector.output is None
890+
891+
assert not output.finished
892+
# Text, token_ids, and logprobs should get merged.
893+
assert output.outputs[0].text == TEXT * num_to_put
894+
for tok_0, tok_1 in zip(output.outputs[0].token_ids,
895+
list(range(num_to_put))):
896+
assert tok_0 == tok_1
897+
assert len(output.outputs[0].logprobs) == num_to_put
898+
899+
# Cumulative logprobs should be the last one.
900+
cumulative_logprob_expected = 1.0 * num_to_put
901+
assert output.outputs[0].cumulative_logprob == cumulative_logprob_expected
902+
903+
# CASE 3: Put all 3 (including a finished).
904+
num_to_put = 3
905+
outputs = make_outputs()
906+
for i in range(num_to_put):
907+
collector.put(outputs[i])
908+
output = await collector.get()
909+
assert not collector.ready.is_set()
910+
assert collector.output is None
911+
912+
assert output.finished
913+
assert output.outputs[0].finish_reason == "length"
914+
# Text, token_ids, and logprobs should get merged.
915+
assert output.outputs[0].text == TEXT * num_to_put
916+
for tok_0, tok_1 in zip(output.outputs[0].token_ids,
917+
list(range(num_to_put))):
918+
assert tok_0 == tok_1
919+
assert len(output.outputs[0].logprobs) == num_to_put
920+
921+
# Cumulative logprobs should be the last one.
922+
cumulative_logprob_expected = 1.0 * num_to_put
923+
assert output.outputs[0].cumulative_logprob == cumulative_logprob_expected

vllm/v1/engine/async_llm.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,15 @@
2121
from vllm.outputs import RequestOutput
2222
from vllm.pooling_params import PoolingParams
2323
from vllm.prompt_adapter.request import PromptAdapterRequest
24-
from vllm.sampling_params import RequestOutputKind, SamplingParams
24+
from vllm.sampling_params import SamplingParams
2525
from vllm.transformers_utils.tokenizer import AnyTokenizer
2626
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
2727
from vllm.usage.usage_lib import UsageContext
2828
from vllm.utils import Device, cdiv, kill_process_tree
2929
from vllm.v1.engine import EngineCoreRequest
3030
from vllm.v1.engine.core_client import EngineCoreClient
31-
from vllm.v1.engine.output_processor import OutputProcessor
31+
from vllm.v1.engine.output_processor import (OutputProcessor,
32+
RequestOutputCollector)
3233
from vllm.v1.engine.parallel_sampling import ParentRequest
3334
from vllm.v1.engine.processor import Processor
3435
from vllm.v1.executor.abstract import Executor
@@ -176,11 +177,14 @@ async def add_request(
176177
trace_headers: Optional[Mapping[str, str]] = None,
177178
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
178179
priority: int = 0,
179-
) -> asyncio.Queue[RequestOutput]:
180+
) -> RequestOutputCollector:
180181
"""Add new request to the AsyncLLM."""
181182

182-
# Create a new output queue for the request.
183-
queue: asyncio.Queue[RequestOutput] = asyncio.Queue()
183+
assert isinstance(params, SamplingParams), \
184+
"Pooling is not supported in V1"
185+
186+
# Create a new output collector for the request.
187+
queue = RequestOutputCollector(output_kind=params.output_kind)
184188

185189
# Convert Input --> Request.
186190
request = self.processor.process_inputs(request_id, prompt, params,
@@ -189,25 +193,23 @@ async def add_request(
189193
prompt_adapter_request,
190194
priority)
191195

192-
n = params.n if isinstance(params, SamplingParams) else 1
193-
194-
if n == 1:
196+
if params.n == 1:
195197
await self._add_request(request, None, 0, queue)
196198
return queue
197199

198200
# Fan out child requests (for n>1).
199201
parent_request = ParentRequest(request_id, params)
200-
for idx in range(n):
202+
for idx in range(params.n):
201203
request_id, params = parent_request.get_child_info(idx)
202-
child_request = request if idx == n - 1 else copy(request)
204+
child_request = request if idx == params.n - 1 else copy(request)
203205
child_request.request_id = request_id
204206
child_request.sampling_params = params
205207
await self._add_request(child_request, parent_request, idx, queue)
206208
return queue
207209

208210
async def _add_request(self, request: EngineCoreRequest,
209211
parent_req: Optional[ParentRequest], index: int,
210-
queue: asyncio.Queue[RequestOutput]):
212+
queue: RequestOutputCollector):
211213

212214
# Add the request to OutputProcessor (this process).
213215
self.output_processor.add_request(request, parent_req, index, queue)
@@ -272,15 +274,7 @@ async def generate(
272274
while not finished:
273275
# Note: drain queue without await if possible (avoids
274276
# task switching under load which helps performance).
275-
out = q.get_nowait() if not q.empty() else await q.get()
276-
277-
# Coalesce any additional queued outputs
278-
while not q.empty():
279-
next_out = q.get_nowait()
280-
if sampling_params.output_kind == RequestOutputKind.DELTA:
281-
out.add(next_out)
282-
else:
283-
out = next_out
277+
out = q.get_nowait() or await q.get()
284278

285279
# Note: both OutputProcessor and EngineCore handle their
286280
# own request cleanup based on finished.

vllm/v1/engine/output_processor.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,46 @@
1717
RequestStateStats)
1818

1919

20+
class RequestOutputCollector:
21+
"""
22+
Collects streamed RequestOutputs per individual request,
23+
for hand-off to the consuming asyncio generate task.
24+
25+
When streaming deltas, RequestOutputs are merged if the
26+
producer gets ahead of the consumer.
27+
"""
28+
29+
def __init__(self, output_kind: RequestOutputKind):
30+
self.aggregate = output_kind == RequestOutputKind.DELTA
31+
self.output: Optional[RequestOutput] = None
32+
self.ready = asyncio.Event()
33+
34+
def put(self, output: RequestOutput) -> None:
35+
if self.output is None:
36+
self.output = output
37+
self.ready.set()
38+
elif self.aggregate:
39+
# Coalesce the outputs in delta case.
40+
self.output.add(output)
41+
else:
42+
# Just replace latest in non-delta case.
43+
self.output = output
44+
45+
async def get(self) -> RequestOutput:
46+
while (output := self.output) is None:
47+
await self.ready.wait()
48+
self.output = None
49+
self.ready.clear()
50+
return output
51+
52+
def get_nowait(self) -> Optional[RequestOutput]:
53+
output = self.output
54+
if output is not None:
55+
self.output = None
56+
self.ready.clear()
57+
return output
58+
59+
2060
@dataclass
2161
class OutputProcessorOutput:
2262

@@ -39,7 +79,7 @@ def __init__(
3979
detokenizer: IncrementalDetokenizer,
4080
max_tokens_param: Optional[int],
4181
arrival_time: float,
42-
queue: Optional[asyncio.Queue[RequestOutput]],
82+
queue: Optional[RequestOutputCollector],
4383
log_stats: bool,
4484
):
4585
self.request_id = request_id
@@ -66,7 +106,7 @@ def from_new_request(
66106
request: EngineCoreRequest,
67107
parent_req: Optional[ParentRequest],
68108
request_index: int,
69-
queue: Optional[asyncio.Queue[RequestOutput]],
109+
queue: Optional[RequestOutputCollector],
70110
log_stats: bool,
71111
) -> "RequestState":
72112
if not request.sampling_params.detokenize:
@@ -217,7 +257,7 @@ def add_request(
217257
request: EngineCoreRequest,
218258
parent_req: Optional[ParentRequest] = None,
219259
request_index: int = 0,
220-
queue: Optional[asyncio.Queue[RequestOutput]] = None,
260+
queue: Optional[RequestOutputCollector] = None,
221261
) -> None:
222262
request_id = request.request_id
223263
if request_id in self.request_states:
@@ -300,7 +340,7 @@ def process_outputs(
300340
new_token_ids, finish_reason, stop_reason):
301341
if req_state.queue is not None:
302342
# AsyncLLM: put into queue for handling by generate().
303-
req_state.queue.put_nowait(request_output)
343+
req_state.queue.put(request_output)
304344
else:
305345
# LLMEngine: return list of RequestOutputs.
306346
request_outputs.append(request_output)

0 commit comments

Comments
 (0)