Skip to content

[Feature] limit thinking tokens #20859

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4405,6 +4405,29 @@ def set_splitting_ops_for_v1(self):
]


class ReasoningConfig:
"""Configuration for reasoning models."""

think_start_str: Optional[str] = None
"""String that indicates the start of reasoning."""
think_end_str: Optional[str] = None
"""String that indicates the end of reasoning."""
think_start_token_ids: Optional[int] = None
"""Token ID that indicates the start of reasoning."""
think_end_token_ids: Optional[int] = None
"""Token ID that indicates the end of reasoning."""

def __init__(self,
think_start_str: Optional[str] = None,
think_end_str: Optional[str] = None,
think_start_token_ids: Optional[int] = None,
think_end_token_ids: Optional[int] = None):
self.think_start_str = think_start_str
self.think_end_str = think_end_str
self.think_start_token_ids = think_start_token_ids
self.think_end_token_ids = think_end_token_ids


@config
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class VllmConfig:
Expand Down Expand Up @@ -4461,6 +4484,8 @@ class VllmConfig:
# some opaque config, only used to provide additional information
# for the hash computation, mainly used for testing, debugging or out of
# tree config registration.
reasoning_config: Optional[ReasoningConfig] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

"""The configurations for reasoning model."""
additional_config: Union[dict, SupportsHash] = field(default_factory=dict)
"""Additional config for specified platform. Different platforms may
support different configs. Make sure the configs are valid for the platform
Expand Down
14 changes: 11 additions & 3 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@
ModelConfig, ModelDType, ModelImpl, MultiModalConfig,
ObservabilityConfig, ParallelConfig, PoolerConfig,
PrefixCachingHashAlgo, PromptAdapterConfig,
SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
TaskOption, TokenizerMode, TokenizerPoolConfig,
VllmConfig, get_attr_docs, get_field)
ReasoningConfig, SchedulerConfig, SchedulerPolicy,
SpeculativeConfig, TaskOption, TokenizerMode,
TokenizerPoolConfig, VllmConfig, get_attr_docs, get_field)
from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform
from vllm.plugins import load_general_plugins
Expand Down Expand Up @@ -464,6 +464,8 @@ class EngineArgs:
kv_transfer_config: Optional[KVTransferConfig] = None
kv_events_config: Optional[KVEventsConfig] = None

reasoning_config: Optional[ReasoningConfig] = None

generation_config: str = ModelConfig.generation_config
enable_sleep_mode: bool = ModelConfig.enable_sleep_mode
override_generation_config: dict[str, Any] = \
Expand Down Expand Up @@ -934,6 +936,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
**vllm_kwargs["kv_events_config"])
vllm_group.add_argument("--compilation-config", "-O",
**vllm_kwargs["compilation_config"])
vllm_group.add_argument("--reasoning-config",
**vllm_kwargs["reasoning_config"])
vllm_group.add_argument("--additional-config",
**vllm_kwargs["additional_config"])

Expand Down Expand Up @@ -1332,6 +1336,9 @@ def create_engine_config(
collect_detailed_traces=self.collect_detailed_traces,
)

reasoning_config_dict = json.loads(self.reasoning_config)
reasoning_config = ReasoningConfig(**reasoning_config_dict)

config = VllmConfig(
model_config=model_config,
cache_config=cache_config,
Expand All @@ -1347,6 +1354,7 @@ def create_engine_config(
compilation_config=self.compilation_config,
kv_transfer_config=self.kv_transfer_config,
kv_events_config=self.kv_events_config,
reasoning_config=reasoning_config,
additional_config=self.additional_config,
)

Expand Down
2 changes: 2 additions & 0 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
prompt_logprobs: Optional[int] = None
allowed_token_ids: Optional[list[int]] = None
bad_words: list[str] = Field(default_factory=list)
max_think_tokens: Optional[int] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we introduce some heuristic with reasoning_effort. I'm thinking:

  • low -> 1024
  • medium -> 2048
  • high -> 8192

Then we can also have this as additional extra_body for users to override if they have custom context length set to vllm server here.

Copy link
Contributor Author

@llsj14 llsj14 Jul 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds reasonable. So the user should only provide "reasoning_effort": [low, medium, high] as the sampling parameter? What I’m a bit concerned about is that it’s hard to control at the token level, and it’s only configurable when the server loads.

Copy link
Collaborator

@aarnphm aarnphm Jul 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reasoning_effort are mostly for openai compatible endpoint. If users want more control, we then respect thinking_token_budget or some naming in the body instead of reasoning_effort.

Two scenarios:

  • Users who already uses reasoning_effort from openai frontend: nothing changes for them
  • If they want to increase the thinking budget, knowing that the model context length supports it:
    client.chat.completions.create(..., 
                                   reasoning_effort="medium", # we ignore reasoning_effort here for thinking_tokens_budget
                                   extra_body={"thinking_tokens_budget": 16384}
                                  )

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also this should be included in the max_tokens calculation as well

# --8<-- [end:chat-completion-sampling-params]

# --8<-- [start:chat-completion-extra-params]
Expand Down Expand Up @@ -670,6 +671,7 @@ def to_sampling_params(
guided_decoding=guided_decoding,
logit_bias=self.logit_bias,
bad_words= self.bad_words,
max_think_tokens=self.max_think_tokens,
allowed_token_ids=self.allowed_token_ids,
extra_args=extra_args or None,
)
Expand Down
7 changes: 7 additions & 0 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ class SamplingParams(
extra_args: Arbitrary additional args, that can be used by custom
sampling implementations, plugins, etc. Not used by any in-tree
sampling implementations.
max_think_tokens: Maximum number of tokens allowed for thinking
"""

n: int = 1
Expand Down Expand Up @@ -248,6 +249,9 @@ class SamplingParams(
bad_words: Optional[list[str]] = None
_bad_words_token_ids: Optional[list[list[int]]] = None

# Maximum number of tokens allowed for thinking operations.
max_think_tokens: Optional[int] = None

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we rename this to thinking_budget, would help provide consistency in naming since the max thinking here would refer to the thinking budget provided by the user.


@staticmethod
def from_optional(
n: Optional[int] = 1,
Expand All @@ -263,6 +267,7 @@ def from_optional(
stop: Optional[Union[str, list[str]]] = None,
stop_token_ids: Optional[list[int]] = None,
bad_words: Optional[list[str]] = None,
max_think_tokens: Optional[int] = None,
include_stop_str_in_output: bool = False,
ignore_eos: bool = False,
max_tokens: Optional[int] = 16,
Expand Down Expand Up @@ -306,6 +311,7 @@ def from_optional(
stop=stop,
stop_token_ids=stop_token_ids,
bad_words=bad_words,
max_think_tokens=max_think_tokens,
include_stop_str_in_output=include_stop_str_in_output,
ignore_eos=ignore_eos,
max_tokens=max_tokens,
Expand Down Expand Up @@ -574,6 +580,7 @@ def __repr__(self) -> str:
f"stop={self.stop}, "
f"stop_token_ids={self.stop_token_ids}, "
f"bad_words={self.bad_words}, "
f"max_think_tokens={self.max_think_tokens}, "
f"include_stop_str_in_output={self.include_stop_str_in_output}, "
f"ignore_eos={self.ignore_eos}, "
f"max_tokens={self.max_tokens}, "
Expand Down
191 changes: 176 additions & 15 deletions vllm/v1/sample/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
from dataclasses import dataclass, field
from enum import Enum
from itertools import chain
from typing import Optional, Union
from typing import Any, Optional, Union

import torch
from torch._prims_common import DeviceLikeType

from vllm import PoolingParams, SamplingParams
from vllm.config import ReasoningConfig
from vllm.logger import init_logger

logger = init_logger(__name__)
Expand All @@ -24,9 +25,10 @@ class MoveDirectionality(Enum):
SWAP = 1


# (index, params, output_tok_ids) tuples for new
# (index, params, prompt_tok_ids, output_tok_ids) tuples for new
# requests added to the batch.
AddedRequest = tuple[int, Union[SamplingParams, PoolingParams], list[int]]
AddedRequest = tuple[int, Union[SamplingParams, PoolingParams], list[int],
list[int]]
# (index 1, index 2, directionality) tuples representing
# one-way moves or two-way swaps of requests in batch
MovedRequest = tuple[int, int, MoveDirectionality]
Expand All @@ -43,9 +45,9 @@ class BatchUpdate:
# within the persistent batch.
#
# Note: each added request is represented as
# (index, params, output_tok_ids)
# Key assumption: output_tok_ids is a reference to the
# request's running output tokens list; in this way
# (index, params, prompt_tok_ids, output_tok_ids)
# Key assumption: prompt_tok_ids, output_tok_ids is a reference to the
# request's prompt and running output tokens list; in this way
# the logits processors always see the latest list of
# generated tokens
removed: Sequence[RemovedRequest]
Expand Down Expand Up @@ -260,7 +262,7 @@ def update_state(self, batch_update: Optional[BatchUpdate]):

needs_update = False
# Process added requests.
for index, params, _ in batch_update.added:
for index, params, _, _ in batch_update.added:
min_p = params.min_p if isinstance(params, SamplingParams) else 0.0
if self.min_p_cpu[index] != min_p:
needs_update = True
Expand Down Expand Up @@ -337,7 +339,7 @@ def update_state(self, batch_update: Optional[BatchUpdate]):

# Process added requests.
needs_update = bool(batch_update.added)
for index, params, _ in batch_update.added:
for index, params, _, _ in batch_update.added:
if isinstance(params, SamplingParams) and (lb :=
params.logit_bias):
self.biases[index] = lb
Expand Down Expand Up @@ -420,7 +422,7 @@ def update_state(self, batch_update: Optional[BatchUpdate]):
if batch_update:
# Process added requests.
needs_update |= bool(batch_update.added)
for index, params, output_tok_ids in batch_update.added:
for index, params, _, output_tok_ids in batch_update.added:
if (isinstance(params, SamplingParams)
and (min_tokens := params.min_tokens)
and len(output_tok_ids) < min_tokens):
Expand Down Expand Up @@ -493,8 +495,159 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor:
return logits


def init_builtin_logitsprocs(pin_memory_available: bool, max_num_reqs: int,
device: torch.device) -> LogitsProcessorManager:
class MaxThinkTokensLogitsProcessor(LogitsProcessor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems more appropriate to split this into separate files.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it’s good to separate the files, but I’m just concerned about the divergence of different kinds of logits processors at the moment, since some are declared in the ops directory (e.g., bad words, penalties, top-k, top-p), while the built-in logits processors are declared in this logits_processor.py file.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can probably create a logit_processors dir, then put diff logic processor there.

The default ones can just live under logit_processors/__init__.py, and others can have its own file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good, I’ll update it.

"""A logits processor that limits the maximum number of thinking tokens."""

def __init__(self, reasoning_config: ReasoningConfig, pin_memory: bool,
device: torch.device):
"""
Args:
reasoning_config: Configuration for reasoning, which includes
the token IDs for thinking start and end.
pin_memory (bool): Whether to use pinned memory for tensors.
device (torch.device): Device to use for tensor operations.
"""
super().__init__()
self.think_start_token_id = reasoning_config.think_start_token_id
self.think_end_token_id = reasoning_config.think_end_token_id
self.pin_memory = pin_memory
self.device = device
self._state: dict[int, dict[str, Any]] = {}

def _find_first_token_index(self, target_list: list[int], token_id: int) -> int:
"""
Find the last occurrence of a single token in the list of tokens.

Args:
target_list (list[int]): The list of token IDs.
token_id (int): The token ID to find.
"""
try:
return len(target_list) - target_list[::-1].index(token_id) - 1
except ValueError:
return -1

def _find_last_sequence_index(self, target_list: list[int], token_ids: list[int]) -> int:
"""
Find the last occurrence of the sequence of token_ids in tokens.

Args:
target_list (list[int]): The list of token IDs.
token_ids (list[int]): The sequence of token IDs to find.
"""
index = self._find_first_token_index(target_list, token_ids[0])
if index != -1:
i = 1
for token_id in token_ids[1:]:
if index + i >= len(target_list) or target_list[index + i] != token_id:
return -1
i += 1
index += 1

return index

def _init_state_entry(self, prompt_tok_ids, max_think_tokens):
last_start = self._find_last_sequence_index(
prompt_tok_ids, self.think_start_token_id)
last_end = self._find_last_sequence_index(
prompt_tok_ids, self.think_end_token_id)
in_think = last_start > last_end
think_count = len(prompt_tok_ids) - (last_start + 1) if in_think else 0

return {
"in_think": in_think,
"in_end": False,
"think_count": think_count,
"end_count": 0,
"prompt_tok_ids": prompt_tok_ids,
"output_tok_ids": [],
"max_think_tokens": max_think_tokens,
}

def _update_think_state(self, state):
output = state["output_tok_ids"]
if not output:
return

sliced_output1 = output[-1 + len(self.think_start_token_id):]
sliced_output2 = output[-1 + len(self.think_end_token_id):]

if self._find_last_sequence_index(sliced_output1, self.think_start_token_id) != -1:
state["in_think"] = True
state["think_count"] = 0
elif self._find_last_sequence_index(sliced_output2, self.think_end_token_id) != -1:
state["in_think"] = False
state["think_count"] = 0
else:
state["think_count"] += 1

if state["in_end"]:
state["end_count"] += 1
if state["end_count"] >= len(self.think_end_token_id):
state["in_end"] = False
state["end_count"] = 0
else:
if state["in_think"] and state["think_count"] >= state["max_think_tokens"]:
state["in_think"] = False
state["in_end"] = True
state["end_count"] = 0

def is_argmax_invariant(self) -> bool:
"""This logits processor can change the outcome of
greedy sampling by forcing that the thinking section
ends after a certain number of tokens."""
return False

def update_state(self, batch_update: Optional[BatchUpdate]):
if batch_update:
for (index, params, prompt_tok_ids, output_tok_ids) in batch_update.added:
max_think_tokens = (params.max_think_tokens if isinstance(
params, SamplingParams) else None)
if max_think_tokens is not None:
self._state[index] = self._init_state_entry(
prompt_tok_ids, max_think_tokens)
self._state[index]["output_tok_ids"] = output_tok_ids

for index in batch_update.removed:
self._state.pop(index, {})

for i1, i2, direction in batch_update.moved:
if direction == MoveDirectionality.SWAP:
self._state[i1], self._state[i2] = self._state[i2], self._state[i1]
else:
self._state[i2] = self._state.pop(i1, {})

for state in self._state.values():
self._update_think_state(state)

def apply(self, logits: torch.Tensor) -> torch.Tensor:
batch_size = logits.size(0)
if not self._state:
return logits

mask = torch.zeros(batch_size, dtype=torch.bool, device=logits.device)
end_token_id = self.think_end_token_id

for index in range(batch_size):
state = self._state.get(index)
if not state:
continue

force_end_token_id = None
if state["in_end"]:
force_end_token_id = self.think_end_token_id[state["end_count"]]
mask[index] = True

if mask.any():
logits[mask] = -float("inf")
logits[mask, force_end_token_id] = 0.0

return logits


def init_builtin_logitsprocs(
pin_memory_available: bool, max_num_reqs: int, device: torch.device,
reasoning_config: ReasoningConfig) -> LogitsProcessorManager:
"""Construct 'builtin' vLLM logitsprocs which the engine
loads by default.

Expand All @@ -516,10 +669,18 @@ def init_builtin_logitsprocs(pin_memory_available: bool, max_num_reqs: int,
device=device,
# +1 for temporary swap space
max_num_reqs=max_num_reqs + 1)

non_argmax_invariant = [min_tokens_logitproc, logit_bias_logitproc]

if reasoning_config is not None:
max_think_tokens_logitproc = MaxThinkTokensLogitsProcessor(
reasoning_config=reasoning_config,
pin_memory=pin_memory_available,
device=device,
)
non_argmax_invariant.append(max_think_tokens_logitproc)

return LogitsProcessorManager(
non_argmax_invariant=[
min_tokens_logitproc,
logit_bias_logitproc,
],
non_argmax_invariant=non_argmax_invariant,
argmax_invariant=[min_p_logitproc],
)
Loading