Skip to content

Commit 039f7e3

Browse files
authored
[https://nvbugspro.nvidia.com/bug/5243740][fix] deduce default max_tokens for trtllm-serve (#4265)
* Deduce default max_tokens for trtllm-serve Signed-off-by: Pengyun Lin <81065165+LinPoly@users.noreply.github.com> * Improve executor_config.max_seq_len assignment in TRT workflow Signed-off-by: Pengyun Lin <81065165+LinPoly@users.noreply.github.com> * Enhance error message Signed-off-by: Pengyun Lin <81065165+LinPoly@users.noreply.github.com> * Add deduced max_tokens test Signed-off-by: Pengyun Lin <81065165+LinPoly@users.noreply.github.com> --------- Signed-off-by: Pengyun Lin <81065165+LinPoly@users.noreply.github.com>
1 parent 0d7269e commit 039f7e3

File tree

4 files changed

+49
-4
lines changed

4 files changed

+49
-4
lines changed

tensorrt_llm/executor/worker.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,11 +394,34 @@ def _enqueue_request(self, request: GenerationRequest) -> int:
394394
)
395395

396396
assert request.id is not None
397+
398+
def _deduce_max_tokens(request: GenerationRequest,
399+
executor_config: tllm.ExecutorConfig) -> int:
400+
if request.sampling_params.max_tokens:
401+
return request.sampling_params.max_tokens
402+
# deduce max_tokens when it's not set by user
403+
query_token_len = len(
404+
request.query_token_ids) if request.query_token_ids else 0
405+
cp_size = 1 if (not hasattr(executor_config, "mapping")
406+
or executor_config.mapping.cp_size
407+
is None) else executor_config.mapping.cp_size
408+
if not hasattr(executor_config, "max_seq_len"):
409+
raise RuntimeError(
410+
"max_tokens for sampling is not set and cannot be deduced")
411+
splited_prompt_len = int(len(prompt_token_ids) / cp_size)
412+
default_max_tokens = executor_config.max_seq_len - splited_prompt_len - query_token_len
413+
if default_max_tokens < 0:
414+
raise ValueError(
415+
f"Deduced max_tokens {default_max_tokens} is less than 0, because"
416+
f"prompt length {splited_prompt_len} plus query length {query_token_len} "
417+
f"is larger than max_seq_len {executor_config.max_seq_len}")
418+
return default_max_tokens
419+
397420
try:
398421
executor_request = tllm.Request(
399422
client_id=request.id,
400423
input_token_ids=prompt_token_ids,
401-
max_tokens=request.sampling_params.max_tokens,
424+
max_tokens=_deduce_max_tokens(request, self._executor_config),
402425
streaming=request.streaming,
403426
sampling_config=request.sampling_params._get_sampling_config(),
404427
end_id=-1 if request.sampling_params.ignore_eos else

tensorrt_llm/llmapi/llm.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,7 @@ def _check_arguments(self, prompt_len: int, query_len: int,
486486

487487
if (not self.args.enable_chunked_prefill) and (
488488
prompt_len / self.args.parallel_config.cp_size + query_len +
489-
sampling_params.max_tokens > max_seq_len):
489+
(sampling_params.max_tokens or 0) > max_seq_len):
490490
raise ValueError(
491491
f"The sum of prompt length ({prompt_len/self.args.parallel_config.cp_size}) and query length ({query_len}) max_tokens ({sampling_params.max_tokens}) should not exceed "
492492
f"max_seq_len ({build_config.max_seq_len})")
@@ -542,6 +542,14 @@ def _build_model(self):
542542
max_batch_size=max_batch_size,
543543
max_num_tokens=max_num_tokens,
544544
gather_generation_logits=self.args.gather_generation_logits)
545+
if self.args.backend is None:
546+
# also set executor_config.max_seq_len in TRT workflow, to deduce default max_tokens
547+
if max_seq_len is not None:
548+
executor_config.max_seq_len = max_seq_len
549+
else:
550+
engine_config = EngineConfig.from_json_file(self._engine_dir /
551+
"config.json")
552+
executor_config.max_seq_len = engine_config.build_config.max_seq_len
545553
if self.args.kv_cache_config is not None:
546554
executor_config.kv_cache_config = PybindMirror.maybe_to_pybind(
547555
self.args.kv_cache_config)

tensorrt_llm/serve/openai_protocol.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ class CompletionRequest(OpenAIBaseModel):
156156
frequency_penalty: Optional[float] = 0.0
157157
logit_bias: Optional[Dict[str, float]] = None
158158
logprobs: Optional[int] = None
159-
max_tokens: Optional[int] = 16
159+
max_tokens: Optional[int] = None
160160
n: int = 1
161161
presence_penalty: Optional[float] = 0.0
162162
seed: Optional[int] = Field(default=None)
@@ -426,7 +426,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
426426
logit_bias: Optional[Dict[str, float]] = None
427427
logprobs: Optional[int] = None
428428
top_logprobs: Optional[int] = 0
429-
max_completion_tokens: int = Field(default=16,
429+
max_completion_tokens: int = Field(default=None,
430430
validation_alias='max_tokens')
431431
n: Optional[int] = 1
432432
presence_penalty: Optional[float] = 0.0

tests/unittest/llmapi/apps/_test_openai_chat.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,18 @@ def test_single_chat_session(client: openai.OpenAI, model_name: str):
130130
)
131131
assert legacy.choices[0].message.content \
132132
== chat_completion.choices[0].message.content
133+
# test deduced max_tokens
134+
chat_completion = client.chat.completions.create(
135+
model=model_name,
136+
messages=messages,
137+
temperature=0.0,
138+
logprobs=False,
139+
)
140+
assert chat_completion.id is not None
141+
assert len(chat_completion.choices) == 1
142+
message = chat_completion.choices[0].message
143+
assert message.content is not None
144+
assert message.role == "assistant"
133145

134146

135147
def test_single_chat_session_with_logprobs(client: openai.OpenAI,
@@ -458,6 +470,7 @@ def test_custom_role(client: openai.OpenAI, model_name: str):
458470
"content": "what is 1+1?",
459471
}], # type: ignore
460472
temperature=0.0,
473+
max_completion_tokens=16,
461474
seed=0)
462475

463476
resp2 = client.chat.completions.create(
@@ -470,6 +483,7 @@ def test_custom_role(client: openai.OpenAI, model_name: str):
470483
}]
471484
}], # type: ignore
472485
temperature=0.0,
486+
max_completion_tokens=16,
473487
seed=0)
474488

475489
content1 = resp1.choices[0].message.content

0 commit comments

Comments
 (0)