-
-
Notifications
You must be signed in to change notification settings - Fork 8.8k
[Feature] limit thinking tokens #20859
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
f4274ab
84aee5b
f2e195a
4c4251d
3984711
5636a12
6b424ad
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -404,6 +404,7 @@ class ChatCompletionRequest(OpenAIBaseModel): | |
prompt_logprobs: Optional[int] = None | ||
allowed_token_ids: Optional[list[int]] = None | ||
bad_words: list[str] = Field(default_factory=list) | ||
max_think_tokens: Optional[int] = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we introduce some heuristic with
Then we can also have this as additional There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds reasonable. So the user should only provide There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Two scenarios:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also this should be included in the |
||
# --8<-- [end:chat-completion-sampling-params] | ||
|
||
# --8<-- [start:chat-completion-extra-params] | ||
|
@@ -670,6 +671,7 @@ def to_sampling_params( | |
guided_decoding=guided_decoding, | ||
logit_bias=self.logit_bias, | ||
bad_words= self.bad_words, | ||
max_think_tokens=self.max_think_tokens, | ||
allowed_token_ids=self.allowed_token_ids, | ||
extra_args=extra_args or None, | ||
) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -200,6 +200,7 @@ class SamplingParams( | |
extra_args: Arbitrary additional args, that can be used by custom | ||
sampling implementations, plugins, etc. Not used by any in-tree | ||
sampling implementations. | ||
max_think_tokens: Maximum number of tokens allowed for thinking | ||
""" | ||
|
||
n: int = 1 | ||
|
@@ -248,6 +249,9 @@ class SamplingParams( | |
bad_words: Optional[list[str]] = None | ||
_bad_words_token_ids: Optional[list[list[int]]] = None | ||
|
||
# Maximum number of tokens allowed for thinking operations. | ||
max_think_tokens: Optional[int] = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we rename this to thinking_budget, would help provide consistency in naming since the max thinking here would refer to the thinking budget provided by the user. |
||
|
||
@staticmethod | ||
def from_optional( | ||
n: Optional[int] = 1, | ||
|
@@ -263,6 +267,7 @@ def from_optional( | |
stop: Optional[Union[str, list[str]]] = None, | ||
stop_token_ids: Optional[list[int]] = None, | ||
bad_words: Optional[list[str]] = None, | ||
max_think_tokens: Optional[int] = None, | ||
include_stop_str_in_output: bool = False, | ||
ignore_eos: bool = False, | ||
max_tokens: Optional[int] = 16, | ||
|
@@ -306,6 +311,7 @@ def from_optional( | |
stop=stop, | ||
stop_token_ids=stop_token_ids, | ||
bad_words=bad_words, | ||
max_think_tokens=max_think_tokens, | ||
include_stop_str_in_output=include_stop_str_in_output, | ||
ignore_eos=ignore_eos, | ||
max_tokens=max_tokens, | ||
|
@@ -574,6 +580,7 @@ def __repr__(self) -> str: | |
f"stop={self.stop}, " | ||
f"stop_token_ids={self.stop_token_ids}, " | ||
f"bad_words={self.bad_words}, " | ||
f"max_think_tokens={self.max_think_tokens}, " | ||
f"include_stop_str_in_output={self.include_stop_str_in_output}, " | ||
f"ignore_eos={self.ignore_eos}, " | ||
f"max_tokens={self.max_tokens}, " | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,12 +6,13 @@ | |
from dataclasses import dataclass, field | ||
from enum import Enum | ||
from itertools import chain | ||
from typing import Optional, Union | ||
from typing import Any, Optional, Union | ||
|
||
import torch | ||
from torch._prims_common import DeviceLikeType | ||
|
||
from vllm import PoolingParams, SamplingParams | ||
from vllm.config import ReasoningConfig | ||
from vllm.logger import init_logger | ||
|
||
logger = init_logger(__name__) | ||
|
@@ -24,9 +25,10 @@ class MoveDirectionality(Enum): | |
SWAP = 1 | ||
|
||
|
||
# (index, params, output_tok_ids) tuples for new | ||
# (index, params, prompt_tok_ids, output_tok_ids) tuples for new | ||
# requests added to the batch. | ||
AddedRequest = tuple[int, Union[SamplingParams, PoolingParams], list[int]] | ||
AddedRequest = tuple[int, Union[SamplingParams, PoolingParams], list[int], | ||
list[int]] | ||
# (index 1, index 2, directionality) tuples representing | ||
# one-way moves or two-way swaps of requests in batch | ||
MovedRequest = tuple[int, int, MoveDirectionality] | ||
|
@@ -43,9 +45,9 @@ class BatchUpdate: | |
# within the persistent batch. | ||
# | ||
# Note: each added request is represented as | ||
# (index, params, output_tok_ids) | ||
# Key assumption: output_tok_ids is a reference to the | ||
# request's running output tokens list; in this way | ||
# (index, params, prompt_tok_ids, output_tok_ids) | ||
# Key assumption: prompt_tok_ids, output_tok_ids is a reference to the | ||
# request's prompt and running output tokens list; in this way | ||
# the logits processors always see the latest list of | ||
# generated tokens | ||
removed: Sequence[RemovedRequest] | ||
|
@@ -260,7 +262,7 @@ def update_state(self, batch_update: Optional[BatchUpdate]): | |
|
||
needs_update = False | ||
# Process added requests. | ||
for index, params, _ in batch_update.added: | ||
for index, params, _, _ in batch_update.added: | ||
min_p = params.min_p if isinstance(params, SamplingParams) else 0.0 | ||
if self.min_p_cpu[index] != min_p: | ||
needs_update = True | ||
|
@@ -337,7 +339,7 @@ def update_state(self, batch_update: Optional[BatchUpdate]): | |
|
||
# Process added requests. | ||
needs_update = bool(batch_update.added) | ||
for index, params, _ in batch_update.added: | ||
for index, params, _, _ in batch_update.added: | ||
if isinstance(params, SamplingParams) and (lb := | ||
params.logit_bias): | ||
self.biases[index] = lb | ||
|
@@ -420,7 +422,7 @@ def update_state(self, batch_update: Optional[BatchUpdate]): | |
if batch_update: | ||
# Process added requests. | ||
needs_update |= bool(batch_update.added) | ||
for index, params, output_tok_ids in batch_update.added: | ||
for index, params, _, output_tok_ids in batch_update.added: | ||
if (isinstance(params, SamplingParams) | ||
and (min_tokens := params.min_tokens) | ||
and len(output_tok_ids) < min_tokens): | ||
|
@@ -493,8 +495,159 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor: | |
return logits | ||
|
||
|
||
def init_builtin_logitsprocs(pin_memory_available: bool, max_num_reqs: int, | ||
device: torch.device) -> LogitsProcessorManager: | ||
class MaxThinkTokensLogitsProcessor(LogitsProcessor): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems more appropriate to split this into separate files. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it’s good to separate the files, but I’m just concerned about the divergence of different kinds of logits processors at the moment, since some are declared in the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can probably create a logit_processors dir, then put diff logic processor there. The default ones can just live under There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good, I’ll update it. |
||
"""A logits processor that limits the maximum number of thinking tokens.""" | ||
|
||
def __init__(self, reasoning_config: ReasoningConfig, pin_memory: bool, | ||
device: torch.device): | ||
""" | ||
Args: | ||
reasoning_config: Configuration for reasoning, which includes | ||
the token IDs for thinking start and end. | ||
pin_memory (bool): Whether to use pinned memory for tensors. | ||
device (torch.device): Device to use for tensor operations. | ||
""" | ||
super().__init__() | ||
self.think_start_token_id = reasoning_config.think_start_token_id | ||
self.think_end_token_id = reasoning_config.think_end_token_id | ||
self.pin_memory = pin_memory | ||
self.device = device | ||
self._state: dict[int, dict[str, Any]] = {} | ||
|
||
def _find_first_token_index(self, target_list: list[int], token_id: int) -> int: | ||
""" | ||
Find the last occurrence of a single token in the list of tokens. | ||
|
||
Args: | ||
target_list (list[int]): The list of token IDs. | ||
token_id (int): The token ID to find. | ||
""" | ||
try: | ||
return len(target_list) - target_list[::-1].index(token_id) - 1 | ||
except ValueError: | ||
return -1 | ||
|
||
def _find_last_sequence_index(self, target_list: list[int], token_ids: list[int]) -> int: | ||
""" | ||
Find the last occurrence of the sequence of token_ids in tokens. | ||
|
||
Args: | ||
target_list (list[int]): The list of token IDs. | ||
token_ids (list[int]): The sequence of token IDs to find. | ||
""" | ||
index = self._find_first_token_index(target_list, token_ids[0]) | ||
if index != -1: | ||
i = 1 | ||
for token_id in token_ids[1:]: | ||
if index + i >= len(target_list) or target_list[index + i] != token_id: | ||
return -1 | ||
i += 1 | ||
index += 1 | ||
|
||
return index | ||
|
||
def _init_state_entry(self, prompt_tok_ids, max_think_tokens): | ||
last_start = self._find_last_sequence_index( | ||
prompt_tok_ids, self.think_start_token_id) | ||
last_end = self._find_last_sequence_index( | ||
prompt_tok_ids, self.think_end_token_id) | ||
in_think = last_start > last_end | ||
think_count = len(prompt_tok_ids) - (last_start + 1) if in_think else 0 | ||
|
||
return { | ||
"in_think": in_think, | ||
"in_end": False, | ||
"think_count": think_count, | ||
"end_count": 0, | ||
"prompt_tok_ids": prompt_tok_ids, | ||
"output_tok_ids": [], | ||
"max_think_tokens": max_think_tokens, | ||
} | ||
|
||
def _update_think_state(self, state): | ||
output = state["output_tok_ids"] | ||
if not output: | ||
return | ||
|
||
sliced_output1 = output[-1 + len(self.think_start_token_id):] | ||
sliced_output2 = output[-1 + len(self.think_end_token_id):] | ||
|
||
if self._find_last_sequence_index(sliced_output1, self.think_start_token_id) != -1: | ||
state["in_think"] = True | ||
state["think_count"] = 0 | ||
elif self._find_last_sequence_index(sliced_output2, self.think_end_token_id) != -1: | ||
state["in_think"] = False | ||
state["think_count"] = 0 | ||
else: | ||
state["think_count"] += 1 | ||
|
||
if state["in_end"]: | ||
state["end_count"] += 1 | ||
if state["end_count"] >= len(self.think_end_token_id): | ||
state["in_end"] = False | ||
state["end_count"] = 0 | ||
else: | ||
if state["in_think"] and state["think_count"] >= state["max_think_tokens"]: | ||
state["in_think"] = False | ||
state["in_end"] = True | ||
state["end_count"] = 0 | ||
|
||
def is_argmax_invariant(self) -> bool: | ||
"""This logits processor can change the outcome of | ||
greedy sampling by forcing that the thinking section | ||
ends after a certain number of tokens.""" | ||
return False | ||
|
||
def update_state(self, batch_update: Optional[BatchUpdate]): | ||
if batch_update: | ||
for (index, params, prompt_tok_ids, output_tok_ids) in batch_update.added: | ||
max_think_tokens = (params.max_think_tokens if isinstance( | ||
params, SamplingParams) else None) | ||
if max_think_tokens is not None: | ||
self._state[index] = self._init_state_entry( | ||
prompt_tok_ids, max_think_tokens) | ||
self._state[index]["output_tok_ids"] = output_tok_ids | ||
|
||
for index in batch_update.removed: | ||
self._state.pop(index, {}) | ||
|
||
for i1, i2, direction in batch_update.moved: | ||
if direction == MoveDirectionality.SWAP: | ||
self._state[i1], self._state[i2] = self._state[i2], self._state[i1] | ||
else: | ||
self._state[i2] = self._state.pop(i1, {}) | ||
|
||
for state in self._state.values(): | ||
self._update_think_state(state) | ||
|
||
def apply(self, logits: torch.Tensor) -> torch.Tensor: | ||
batch_size = logits.size(0) | ||
if not self._state: | ||
return logits | ||
|
||
mask = torch.zeros(batch_size, dtype=torch.bool, device=logits.device) | ||
end_token_id = self.think_end_token_id | ||
|
||
for index in range(batch_size): | ||
state = self._state.get(index) | ||
if not state: | ||
continue | ||
|
||
force_end_token_id = None | ||
if state["in_end"]: | ||
force_end_token_id = self.think_end_token_id[state["end_count"]] | ||
mask[index] = True | ||
|
||
if mask.any(): | ||
logits[mask] = -float("inf") | ||
logits[mask, force_end_token_id] = 0.0 | ||
|
||
return logits | ||
|
||
|
||
def init_builtin_logitsprocs( | ||
pin_memory_available: bool, max_num_reqs: int, device: torch.device, | ||
reasoning_config: ReasoningConfig) -> LogitsProcessorManager: | ||
"""Construct 'builtin' vLLM logitsprocs which the engine | ||
loads by default. | ||
|
||
|
@@ -516,10 +669,18 @@ def init_builtin_logitsprocs(pin_memory_available: bool, max_num_reqs: int, | |
device=device, | ||
# +1 for temporary swap space | ||
max_num_reqs=max_num_reqs + 1) | ||
|
||
non_argmax_invariant = [min_tokens_logitproc, logit_bias_logitproc] | ||
|
||
if reasoning_config is not None: | ||
max_think_tokens_logitproc = MaxThinkTokensLogitsProcessor( | ||
reasoning_config=reasoning_config, | ||
pin_memory=pin_memory_available, | ||
device=device, | ||
) | ||
non_argmax_invariant.append(max_think_tokens_logitproc) | ||
|
||
return LogitsProcessorManager( | ||
non_argmax_invariant=[ | ||
min_tokens_logitproc, | ||
logit_bias_logitproc, | ||
], | ||
non_argmax_invariant=non_argmax_invariant, | ||
argmax_invariant=[min_p_logitproc], | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto