Skip to content

Commit a4113b0

Browse files
authored
[Platform] Add custom default max tokens (#18557)
Signed-off-by: Gabriel Marinho <gmarinho@ibm.com>
1 parent 7e1665b commit a4113b0

File tree

5 files changed

+59
-60
lines changed

5 files changed

+59
-60
lines changed

vllm/entrypoints/openai/protocol.py

Lines changed: 10 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
229229
logit_bias: Optional[dict[str, float]] = None
230230
logprobs: Optional[bool] = False
231231
top_logprobs: Optional[int] = 0
232-
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
233232
max_tokens: Optional[int] = Field(
234233
default=None,
235234
deprecated=
@@ -433,23 +432,10 @@ class ChatCompletionRequest(OpenAIBaseModel):
433432
}
434433

435434
def to_beam_search_params(
436-
self,
437-
default_max_tokens: int,
438-
default_sampling_params: Optional[dict] = None
439-
) -> BeamSearchParams:
440-
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
441-
max_tokens = self.max_completion_tokens or self.max_tokens
435+
self, max_tokens: int,
436+
default_sampling_params: dict) -> BeamSearchParams:
442437

443-
if default_sampling_params is None:
444-
default_sampling_params = {}
445438
n = self.n if self.n is not None else 1
446-
447-
# Use minimum of context window, user request & server limit.
448-
max_tokens = min(
449-
val for val in (default_max_tokens, max_tokens,
450-
default_sampling_params.get("max_tokens", None))
451-
if val is not None)
452-
453439
if (temperature := self.temperature) is None:
454440
temperature = default_sampling_params.get(
455441
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
@@ -465,21 +451,10 @@ def to_beam_search_params(
465451

466452
def to_sampling_params(
467453
self,
468-
default_max_tokens: int,
454+
max_tokens: int,
469455
logits_processor_pattern: Optional[str],
470-
default_sampling_params: Optional[dict] = None,
456+
default_sampling_params: dict,
471457
) -> SamplingParams:
472-
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
473-
max_tokens = self.max_completion_tokens or self.max_tokens
474-
475-
if default_sampling_params is None:
476-
default_sampling_params = {}
477-
478-
# Use minimum of context window, user request & server limit.
479-
max_tokens = min(
480-
val for val in (default_max_tokens, max_tokens,
481-
default_sampling_params.get("max_tokens", None))
482-
if val is not None)
483458

484459
# Default parameters
485460
if (repetition_penalty := self.repetition_penalty) is None:
@@ -898,22 +873,15 @@ class CompletionRequest(OpenAIBaseModel):
898873
}
899874

900875
def to_beam_search_params(
901-
self,
902-
default_max_tokens: int,
903-
default_sampling_params: Optional[dict] = None
876+
self,
877+
max_tokens: int,
878+
default_sampling_params: Optional[dict] = None,
904879
) -> BeamSearchParams:
905-
max_tokens = self.max_tokens
906880

907881
if default_sampling_params is None:
908882
default_sampling_params = {}
909883
n = self.n if self.n is not None else 1
910884

911-
# Use minimum of context window, user request & server limit.
912-
max_tokens = min(
913-
val for val in (default_max_tokens, max_tokens,
914-
default_sampling_params.get("max_tokens", None))
915-
if val is not None)
916-
917885
if (temperature := self.temperature) is None:
918886
temperature = default_sampling_params.get("temperature", 1.0)
919887

@@ -928,21 +896,14 @@ def to_beam_search_params(
928896

929897
def to_sampling_params(
930898
self,
931-
default_max_tokens: int,
899+
max_tokens: int,
932900
logits_processor_pattern: Optional[str],
933901
default_sampling_params: Optional[dict] = None,
934902
) -> SamplingParams:
935-
max_tokens = self.max_tokens
936903

937904
if default_sampling_params is None:
938905
default_sampling_params = {}
939906

940-
# Use minimum of context window, user request & server limit.
941-
max_tokens = min(
942-
val for val in (default_max_tokens, max_tokens,
943-
default_sampling_params.get("max_tokens", None))
944-
if val is not None)
945-
946907
# Default parameters
947908
if (repetition_penalty := self.repetition_penalty) is None:
948909
repetition_penalty = default_sampling_params.get(
@@ -1813,7 +1774,7 @@ def to_sampling_params(
18131774
self,
18141775
default_max_tokens: int,
18151776
default_sampling_params: Optional[dict] = None) -> SamplingParams:
1816-
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
1777+
18171778
max_tokens = default_max_tokens
18181779

18191780
if default_sampling_params is None:
@@ -2029,7 +1990,7 @@ def to_sampling_params(
20291990
self,
20301991
default_max_tokens: int,
20311992
default_sampling_params: Optional[dict] = None) -> SamplingParams:
2032-
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
1993+
20331994
max_tokens = default_max_tokens
20341995

20351996
if default_sampling_params is None:

vllm/entrypoints/openai/serving_chat.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
3535
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
3636
MistralToolCall)
37+
from vllm.entrypoints.utils import get_max_tokens
3738
from vllm.logger import init_logger
3839
from vllm.outputs import CompletionOutput, RequestOutput
3940
from vllm.reasoning import ReasoningParser, ReasoningParserManager
@@ -233,15 +234,22 @@ async def create_chat_completion(
233234
try:
234235
for i, engine_prompt in enumerate(engine_prompts):
235236
sampling_params: Union[SamplingParams, BeamSearchParams]
236-
default_max_tokens = self.max_model_len - len(
237-
engine_prompt["prompt_token_ids"])
237+
238+
if self.default_sampling_params is None:
239+
self.default_sampling_params = {}
240+
241+
max_tokens = get_max_tokens(
242+
max_model_len=self.max_model_len,
243+
request=request,
244+
input_length=len(engine_prompt["prompt_token_ids"]),
245+
default_sampling_params=self.default_sampling_params)
246+
238247
if request.use_beam_search:
239248
sampling_params = request.to_beam_search_params(
240-
default_max_tokens, self.default_sampling_params)
249+
max_tokens, self.default_sampling_params)
241250
else:
242251
sampling_params = request.to_sampling_params(
243-
default_max_tokens,
244-
self.model_config.logits_processor_pattern,
252+
max_tokens, self.model_config.logits_processor_pattern,
245253
self.default_sampling_params)
246254

247255
self._log_inputs(request_id,

vllm/entrypoints/openai/serving_completion.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
is_text_tokens_prompt)
3434
# yapf: enable
3535
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
36+
from vllm.entrypoints.utils import get_max_tokens
3637
from vllm.inputs.data import (EmbedsPrompt, TokensPrompt, is_embeds_prompt,
3738
is_tokens_prompt)
3839
from vllm.logger import init_logger
@@ -160,15 +161,22 @@ async def create_completion(
160161
input_length = len(engine_prompt["prompt_token_ids"])
161162
else:
162163
assert_never(engine_prompt)
163-
default_max_tokens = self.max_model_len - input_length
164+
165+
if self.default_sampling_params is None:
166+
self.default_sampling_params = {}
167+
168+
max_tokens = get_max_tokens(
169+
max_model_len=self.max_model_len,
170+
request=request,
171+
input_length=input_length,
172+
default_sampling_params=self.default_sampling_params)
164173

165174
if request.use_beam_search:
166175
sampling_params = request.to_beam_search_params(
167-
default_max_tokens, self.default_sampling_params)
176+
max_tokens, self.default_sampling_params)
168177
else:
169178
sampling_params = request.to_sampling_params(
170-
default_max_tokens,
171-
self.model_config.logits_processor_pattern,
179+
max_tokens, self.model_config.logits_processor_pattern,
172180
self.default_sampling_params)
173181

174182
request_id_item = f"{request_id}-{i}"

vllm/entrypoints/utils.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,17 @@
55
import asyncio
66
import functools
77
import os
8-
from typing import Any, Optional
8+
import sys
9+
from typing import Any, Optional, Union
910

1011
from fastapi import Request
1112
from fastapi.responses import JSONResponse, StreamingResponse
1213
from starlette.background import BackgroundTask, BackgroundTasks
1314

15+
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
16+
CompletionRequest)
1417
from vllm.logger import init_logger
18+
from vllm.platforms import current_platform
1519

1620
logger = init_logger(__name__)
1721

@@ -181,7 +185,6 @@ def _validate_truncation_size(
181185

182186
def show_filtered_argument_or_group_from_help(parser: argparse.ArgumentParser,
183187
subcommand_name: list[str]):
184-
import sys
185188

186189
# Only handle --help=<keyword> for the current subcommand.
187190
# Since subparser_init() runs for all subcommands during CLI setup,
@@ -242,3 +245,18 @@ def show_filtered_argument_or_group_from_help(parser: argparse.ArgumentParser,
242245
print(f"\nNo group or parameter matching '{search_keyword}'")
243246
print("Tip: use `--help=listgroup` to view all groups.")
244247
sys.exit(1)
248+
249+
250+
def get_max_tokens(max_model_len: int, request: Union[ChatCompletionRequest,
251+
CompletionRequest],
252+
input_length: int, default_sampling_params: dict) -> int:
253+
254+
max_tokens = getattr(request, "max_completion_tokens",
255+
None) or request.max_tokens
256+
default_max_tokens = max_model_len - input_length
257+
max_output_tokens = current_platform.get_max_output_tokens(input_length)
258+
259+
return min(val
260+
for val in (default_max_tokens, max_tokens, max_output_tokens,
261+
default_sampling_params.get("max_tokens"))
262+
if val is not None)

vllm/platforms/interface.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os
55
import platform
66
import random
7+
import sys
78
from datetime import timedelta
89
from platform import uname
910
from typing import TYPE_CHECKING, NamedTuple, Optional, Union
@@ -164,6 +165,9 @@ def is_neuron(self) -> bool:
164165
def is_out_of_tree(self) -> bool:
165166
return self._enum == PlatformEnum.OOT
166167

168+
def get_max_output_tokens(self, prompt_len: int) -> int:
169+
return sys.maxsize
170+
167171
def is_cuda_alike(self) -> bool:
168172
"""Stateless version of [torch.cuda.is_available][]."""
169173
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)

0 commit comments

Comments
 (0)