Skip to content

Commit c1633da

Browse files
authored
added stream_n to v1/async_llm, created streaming_params (#1)
* added stream_n to v1/async_llm, created streaming_params * Updated OpenAI compatible API to work with StreamingParams Signed-off-by: Rohin Garg <rohin@character.ai>
1 parent 2a05d6e commit c1633da

File tree

11 files changed

+171
-9
lines changed

11 files changed

+171
-9
lines changed

requirements/test.txt

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ argcomplete==3.5.1
2727
# via datamodel-code-generator
2828
arrow==1.3.0
2929
# via isoduration
30+
async-timeout==5.0.1
31+
# via
32+
# aiohttp
33+
# redis
3034
attrs==24.2.0
3135
# via
3236
# aiohttp
@@ -126,6 +130,11 @@ encodec==0.1.1
126130
# via vocos
127131
evaluate==0.4.3
128132
# via lm-eval
133+
exceptiongroup==1.3.0
134+
# via
135+
# anyio
136+
# hypothesis
137+
# pytest
129138
fastparquet==2024.11.0
130139
# via genai-perf
131140
fastrlock==0.8.2
@@ -683,8 +692,13 @@ tokenizers==0.21.1
683692
# via
684693
# -r requirements/test.in
685694
# transformers
695+
toml==0.10.2
696+
# via datamodel-code-generator
686697
tomli==2.2.1
687-
# via schemathesis
698+
# via
699+
# black
700+
# pytest
701+
# schemathesis
688702
tomli-w==1.2.0
689703
# via schemathesis
690704
torch==2.7.0+cu128
@@ -756,12 +770,17 @@ types-python-dateutil==2.9.0.20241206
756770
# via arrow
757771
typing-extensions==4.12.2
758772
# via
773+
# anyio
774+
# black
775+
# exceptiongroup
759776
# huggingface-hub
760777
# librosa
761778
# mistral-common
779+
# multidict
762780
# pqdm
763781
# pydantic
764782
# pydantic-core
783+
# rich
765784
# torch
766785
# typer
767786
tzdata==2024.2

tests/v1/engine/test_async_llm.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from vllm.inputs import PromptType
1515
from vllm.platforms import current_platform
1616
from vllm.sampling_params import RequestOutputKind
17+
from vllm.streaming_params import StreamingParams
1718
from vllm.v1.engine.async_llm import AsyncLLM
1819
from vllm.v1.metrics.loggers import LoggingStatLogger
1920

@@ -62,9 +63,13 @@ async def generate(engine: AsyncLLM,
6263
seed=33,
6364
n=n,
6465
prompt_logprobs=prompt_logprobs)
66+
67+
streaming_params = StreamingParams(stream_n=3)
68+
6569
async for out in engine.generate(request_id=request_id,
6670
prompt=prompt,
67-
sampling_params=sampling_params):
71+
sampling_params=sampling_params,
72+
streaming_params=streaming_params):
6873

6974
num_tokens = sum(len(output.token_ids) for output in out.outputs)
7075
if output_kind == RequestOutputKind.DELTA:
@@ -209,11 +214,15 @@ async def test_finished_flag(monkeypatch: pytest.MonkeyPatch, n: int,
209214
temperature=1.0,
210215
seed=33,
211216
n=n)
217+
218+
streaming_params = StreamingParams(stream_n=3)
219+
212220
outputs = [
213221
out
214222
async for out in engine.generate(request_id="request-33",
215223
prompt=prompt,
216-
sampling_params=sampling_params)
224+
sampling_params=sampling_params,
225+
streaming_params=streaming_params)
217226
]
218227

219228
# Assert only the last output has the finished flag set

vllm/engine/async_llm_engine.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from vllm.prompt_adapter.request import PromptAdapterRequest
3434
from vllm.sampling_params import SamplingParams
3535
from vllm.sequence import ExecuteModelRequest
36+
from vllm.streaming_params import StreamingParams
3637
from vllm.transformers_utils.tokenizer import AnyTokenizer
3738
from vllm.usage.usage_lib import UsageContext
3839
from vllm.utils import Device, deprecate_kwargs, weak_bind
@@ -972,6 +973,7 @@ async def generate(
972973
self,
973974
prompt: PromptType,
974975
sampling_params: SamplingParams,
976+
streaming_params: StreamingParams,
975977
request_id: str,
976978
lora_request: Optional[LoRARequest] = None,
977979
trace_headers: Optional[Mapping[str, str]] = None,
@@ -1045,6 +1047,8 @@ async def generate(
10451047
>>> ...
10461048
"""
10471049
try:
1050+
buffer: Optional[RequestOutput] = None # buffer of output tokens
1051+
buffer_token_count = 0
10481052
async for output in await self.add_request(
10491053
request_id,
10501054
prompt,
@@ -1054,7 +1058,19 @@ async def generate(
10541058
prompt_adapter_request=prompt_adapter_request,
10551059
priority=priority,
10561060
):
1057-
yield LLMEngine.validate_output(output, RequestOutput)
1061+
output = LLMEngine.validate_output(output, RequestOutput)
1062+
if buffer is None:
1063+
buffer = output
1064+
else:
1065+
buffer.add(output, aggregate=True)
1066+
1067+
buffer_token_count += sum(
1068+
len(o.token_ids) for o in output.outputs)
1069+
if buffer_token_count >= streaming_params.stream_n:
1070+
yield buffer
1071+
buffer = None
1072+
buffer_token_count = 0
1073+
10581074
except asyncio.CancelledError:
10591075
await self.abort(request_id)
10601076
raise

vllm/engine/multiprocessing/client.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from vllm.outputs import PoolingRequestOutput, RequestOutput
4747
from vllm.prompt_adapter.request import PromptAdapterRequest
4848
from vllm.sampling_params import SamplingParams
49+
from vllm.streaming_params import StreamingParams
4950
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
5051
from vllm.utils import Device, deprecate_kwargs
5152

@@ -445,6 +446,7 @@ def generate(
445446
self,
446447
prompt: PromptType,
447448
sampling_params: SamplingParams,
449+
streaming_params: StreamingParams,
448450
request_id: str,
449451
lora_request: Optional[LoRARequest] = None,
450452
trace_headers: Optional[Mapping[str, str]] = None,
@@ -460,6 +462,7 @@ def generate(
460462
*,
461463
inputs: PromptType,
462464
sampling_params: SamplingParams,
465+
streaming_params: StreamingParams,
463466
request_id: str,
464467
lora_request: Optional[LoRARequest] = None,
465468
trace_headers: Optional[Mapping[str, str]] = None,
@@ -476,6 +479,7 @@ def generate(
476479
self,
477480
prompt: Optional[PromptType] = None,
478481
sampling_params: Optional[SamplingParams] = None,
482+
streaming_params: Optional[StreamingParams] = None,
479483
request_id: Optional[str] = None,
480484
lora_request: Optional[LoRARequest] = None,
481485
trace_headers: Optional[Mapping[str, str]] = None,
@@ -509,8 +513,9 @@ def generate(
509513
and request_id is not None)
510514

511515
return self._process_request(prompt, sampling_params, request_id,
512-
lora_request, trace_headers,
513-
prompt_adapter_request, priority)
516+
streaming_params, lora_request,
517+
trace_headers, prompt_adapter_request,
518+
priority)
514519

515520
@overload
516521
def encode(
@@ -590,6 +595,7 @@ async def _process_request(
590595
prompt: PromptType,
591596
params: Union[SamplingParams, PoolingParams],
592597
request_id: str,
598+
streaming_params: Optional[StreamingParams] = None,
593599
lora_request: Optional[LoRARequest] = None,
594600
trace_headers: Optional[Mapping[str, str]] = None,
595601
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
@@ -660,14 +666,33 @@ async def _process_request(
660666
# queue after pulling them from the zmq socket.
661667
finished = False
662668
try:
669+
buffer = None # buffer of output tokens
670+
buffer_token_count = 0
663671
while not finished:
664672
request_output = await queue.get()
665673

666674
if isinstance(request_output, BaseException):
667675
raise request_output
668676

669677
finished = request_output.finished
670-
yield request_output
678+
if buffer is None:
679+
buffer = request_output
680+
else:
681+
buffer.add(request_output, aggregate=True)
682+
683+
if isinstance(request_output, RequestOutput):
684+
buffer_token_count += sum(
685+
len(o.token_ids) for o in request_output.outputs)
686+
else:
687+
buffer_token_count += 1
688+
if streaming_params is None or \
689+
buffer_token_count >= streaming_params.stream_n or \
690+
finished:
691+
692+
yield buffer
693+
buffer = None
694+
buffer_token_count = 0
695+
671696
finally:
672697
# Request was canceled by the client.
673698
if not finished and not self.errored:

vllm/engine/protocol.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from vllm.pooling_params import PoolingParams
1818
from vllm.prompt_adapter.request import PromptAdapterRequest
1919
from vllm.sampling_params import BeamSearchParams, SamplingParams
20+
from vllm.streaming_params import StreamingParams
2021
from vllm.transformers_utils.tokenizer import AnyTokenizer
2122
from vllm.utils import Device, collect_from_async_generator, random_uuid
2223

@@ -51,6 +52,7 @@ def generate(
5152
self,
5253
prompt: PromptType,
5354
sampling_params: SamplingParams,
55+
streaming_params: StreamingParams,
5456
request_id: str,
5557
lora_request: Optional[LoRARequest] = None,
5658
trace_headers: Optional[Mapping[str, str]] = None,
@@ -126,7 +128,7 @@ async def beam_search(
126128
task = asyncio.create_task(
127129
collect_from_async_generator(
128130
self.generate(individual_prompt, beam_search_params,
129-
request_id_item)))
131+
StreamingParams(), request_id_item)))
130132
tasks.append(task)
131133

132134
output = await asyncio.gather(*tasks)

vllm/entrypoints/openai/protocol.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
2121
RequestOutputKind, SamplingParams)
2222
from vllm.sequence import Logprob
23+
from vllm.streaming_params import StreamingParams
2324
from vllm.utils import random_uuid, resolve_obj_by_qualname
2425

2526
logger = init_logger(__name__)
@@ -151,6 +152,7 @@ class ResponseFormat(OpenAIBaseModel):
151152
class StreamOptions(OpenAIBaseModel):
152153
include_usage: Optional[bool] = True
153154
continuous_usage_stats: Optional[bool] = False
155+
stream_n: Optional[int] = 1
154156

155157

156158
class FunctionDefinition(OpenAIBaseModel):
@@ -540,6 +542,13 @@ def to_sampling_params(
540542
guided_decoding=guided_decoding,
541543
logit_bias=self.logit_bias)
542544

545+
def to_streaming_params(self, ) -> StreamingParams:
546+
stream_n = None
547+
if self.stream_options is not None and \
548+
self.stream_options.stream_n is not None:
549+
stream_n = self.stream_options.stream_n
550+
return StreamingParams(stream_n=stream_n)
551+
543552
def _get_guided_json_from_tool(
544553
self) -> Optional[Union[str, dict, BaseModel]]:
545554
# user has chosen to not use any tool
@@ -973,6 +982,13 @@ def to_sampling_params(
973982
logit_bias=self.logit_bias,
974983
allowed_token_ids=self.allowed_token_ids)
975984

985+
def to_streaming_params(self, ) -> StreamingParams:
986+
stream_n = None
987+
if self.stream_options is not None and \
988+
self.stream_options.stream_n is not None:
989+
stream_n = self.stream_options.stream_n
990+
return StreamingParams(stream_n=stream_n)
991+
976992
@model_validator(mode="before")
977993
@classmethod
978994
def check_guided_decoding_count(cls, data):
@@ -1725,6 +1741,11 @@ def to_sampling_params(
17251741
if self.stream \
17261742
else RequestOutputKind.FINAL_ONLY)
17271743

1744+
def to_streaming_params(
1745+
self,
1746+
) -> StreamingParams: # stream_options not defined in transcription request
1747+
return StreamingParams(stream_n=None)
1748+
17281749
@model_validator(mode="before")
17291750
@classmethod
17301751
def validate_stream_options(cls, data):

vllm/entrypoints/openai/serving_chat.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,8 @@ async def create_chat_completion(
221221
self.model_config.logits_processor_pattern,
222222
self.default_sampling_params)
223223

224+
streaming_params = request.to_streaming_params()
225+
224226
self._log_inputs(request_id,
225227
request_prompts[i],
226228
params=sampling_params,
@@ -240,6 +242,7 @@ async def create_chat_completion(
240242
generator = self.engine_client.generate(
241243
engine_prompt,
242244
sampling_params,
245+
streaming_params,
243246
request_id,
244247
lora_request=lora_request,
245248
trace_headers=trace_headers,

vllm/entrypoints/openai/serving_completion.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ async def create_completion(
142142
self.default_sampling_params)
143143

144144
request_id_item = f"{request_id}-{i}"
145+
streaming_params = request.to_streaming_params()
145146

146147
self._log_inputs(request_id_item,
147148
request_prompts[i],
@@ -162,6 +163,7 @@ async def create_completion(
162163
generator = self.engine_client.generate(
163164
engine_prompt,
164165
sampling_params,
166+
streaming_params,
165167
request_id_item,
166168
lora_request=lora_request,
167169
prompt_adapter_request=prompt_adapter_request,

vllm/entrypoints/openai/serving_transcription.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@ async def create_transcription(
282282
default_max_tokens = self.model_config.max_model_len
283283
sampling_params = request.to_sampling_params(
284284
default_max_tokens, self.default_sampling_params)
285+
streaming_params = request.to_streaming_params()
285286

286287
self._log_inputs(
287288
request_id,
@@ -293,6 +294,7 @@ async def create_transcription(
293294
result_generator = self.engine_client.generate(
294295
prompt,
295296
sampling_params,
297+
streaming_params,
296298
request_id,
297299
)
298300
except ValueError as e:

0 commit comments

Comments
 (0)