Skip to content

Commit 07ebf75

Browse files
committed
feat: limit thinking tokens
1 parent 0e3fe89 commit 07ebf75

File tree

8 files changed

+185
-27
lines changed

8 files changed

+185
-27
lines changed

vllm/config.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4289,6 +4289,17 @@ def set_splitting_ops_for_v1(self):
42894289
"vllm.unified_attention_with_output",
42904290
]
42914291

4292+
class ReasoningConfig:
4293+
"""Configuration for reasoning models."""
4294+
4295+
think_start_token_id: Optional[int] = None
4296+
"""Token ID that indicates the start of reasoning."""
4297+
think_end_token_id: Optional[int] = None
4298+
"""Token ID that indicates the end of reasoning."""
4299+
4300+
def __init__(self, think_start_token_id: Optional[int] = None, think_end_token_id: Optional[int] = None):
4301+
self.think_start_token_id = think_start_token_id
4302+
self.think_end_token_id = think_end_token_id
42924303

42934304
@config
42944305
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
@@ -4346,6 +4357,8 @@ class VllmConfig:
43464357
# some opaque config, only used to provide additional information
43474358
# for the hash computation, mainly used for testing, debugging or out of
43484359
# tree config registration.
4360+
reasoning_config: Optional[ReasoningConfig] = None
4361+
"""The configurations for reasoning model."""
43494362
additional_config: Union[dict, SupportsHash] = field(default_factory=dict)
43504363
"""Additional config for specified platform. Different platforms may
43514364
support different configs. Make sure the configs are valid for the platform

vllm/entrypoints/openai/protocol.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
272272
prompt_logprobs: Optional[int] = None
273273
allowed_token_ids: Optional[list[int]] = None
274274
bad_words: list[str] = Field(default_factory=list)
275+
max_think_tokens: Optional[int] = None
275276
# --8<-- [end:chat-completion-sampling-params]
276277

277278
# --8<-- [start:chat-completion-extra-params]
@@ -538,6 +539,7 @@ def to_sampling_params(
538539
guided_decoding=guided_decoding,
539540
logit_bias=self.logit_bias,
540541
bad_words= self.bad_words,
542+
max_think_tokens=self.max_think_tokens,
541543
allowed_token_ids=self.allowed_token_ids,
542544
extra_args=extra_args or None,
543545
)

vllm/reasoning/deepseek_r1_reasoning_parser.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
2323
text. This parser extracts the reasoning content from the model output.
2424
"""
2525

26-
start_token_id: int
27-
end_token_id: int
26+
think_start_token_id: int
27+
think_end_token_id: int
2828

2929
start_token: str = "<think>"
3030
end_token: str = "</think>"
@@ -37,24 +37,24 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase):
3737
"The model tokenizer must be passed to the ReasoningParser "
3838
"constructor during construction.")
3939

40-
self.start_token_id = self.vocab.get(self.start_token)
41-
self.end_token_id = self.vocab.get(self.end_token)
42-
if self.start_token_id is None or self.end_token_id is None:
40+
self.think_start_token_id = self.vocab.get(self.start_token)
41+
self.think_end_token_id = self.vocab.get(self.end_token)
42+
if self.think_start_token_id is None or self.think_end_token_id is None:
4343
raise RuntimeError(
4444
"DeepSeek R1 reasoning parser could not locate think start/end "
4545
"tokens in the tokenizer!")
4646

4747
def is_reasoning_end(self, input_ids: list[int]) -> bool:
48-
return self.end_token_id in input_ids
48+
return self.think_end_token_id in input_ids
4949

5050
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
5151
"""
5252
Extract the content after the end tokens
5353
"""
54-
if self.end_token_id not in input_ids[:-1]:
54+
if self.think_end_token_id not in input_ids[:-1]:
5555
return []
5656
else:
57-
return input_ids[input_ids.index(self.end_token_id) + 1:]
57+
return input_ids[input_ids.index(self.think_end_token_id) + 1:]
5858

5959
def extract_reasoning_content_streaming(
6060
self,
@@ -75,14 +75,14 @@ def extract_reasoning_content_streaming(
7575
"""
7676
# Skip single special tokens
7777
if len(delta_token_ids) == 1 and (delta_token_ids[0] in [
78-
self.start_token_id, self.end_token_id
78+
self.think_start_token_id, self.think_end_token_id
7979
]):
8080
return None
8181

8282
# Check if <think> is present in previous or delta.
8383
# Keep compatibility with models that don't generate <think> tokens.
84-
if self.start_token_id in previous_token_ids:
85-
if self.end_token_id in delta_token_ids:
84+
if self.think_start_token_id in previous_token_ids:
85+
if self.think_end_token_id in delta_token_ids:
8686
# <think> in previous, </think> in delta,
8787
# extract reasoning content
8888
end_index = delta_text.find(self.end_token)
@@ -92,16 +92,16 @@ def extract_reasoning_content_streaming(
9292
reasoning_content=reasoning_content,
9393
content=content if content else None,
9494
)
95-
elif self.end_token_id in previous_token_ids:
95+
elif self.think_end_token_id in previous_token_ids:
9696
# <think> in previous, </think> in previous,
9797
# reasoning content continues
9898
return DeltaMessage(content=delta_text)
9999
else:
100100
# <think> in previous, no </think> in previous or delta,
101101
# reasoning content continues
102102
return DeltaMessage(reasoning_content=delta_text)
103-
elif self.start_token_id in delta_token_ids:
104-
if self.end_token_id in delta_token_ids:
103+
elif self.think_start_token_id in delta_token_ids:
104+
if self.think_end_token_id in delta_token_ids:
105105
# <think> in delta, </think> in delta, extract reasoning content
106106
start_index = delta_text.find(self.start_token)
107107
end_index = delta_text.find(self.end_token)
@@ -120,7 +120,7 @@ def extract_reasoning_content_streaming(
120120
# No <think> in previous or delta, also need to check for </think>.
121121
# Because the model may have generated </think> without <think>
122122
# Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
123-
if self.end_token_id in delta_token_ids:
123+
if self.think_end_token_id in delta_token_ids:
124124
# </think> in delta with more tokens,
125125
# extract reasoning content and content
126126
end_index = delta_text.find(self.end_token)
@@ -130,7 +130,7 @@ def extract_reasoning_content_streaming(
130130
reasoning_content=reasoning_content,
131131
content=content if content else None,
132132
)
133-
elif self.end_token_id in previous_token_ids:
133+
elif self.think_end_token_id in previous_token_ids:
134134
# </think> in previous, thinking content ends
135135
return DeltaMessage(content=delta_text)
136136
else:

vllm/sampling_params.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ class SamplingParams(
200200
extra_args: Arbitrary additional args, that can be used by custom
201201
sampling implementations, plugins, etc. Not used by any in-tree
202202
sampling implementations.
203+
max_think_tokens: Maximum number of tokens allowed for thinking
203204
"""
204205

205206
n: int = 1
@@ -248,6 +249,9 @@ class SamplingParams(
248249
bad_words: Optional[list[str]] = None
249250
_bad_words_token_ids: Optional[list[list[int]]] = None
250251

252+
# Maximum number of tokens allowed for thinking operations.
253+
max_think_tokens: Optional[int] = None
254+
251255
@staticmethod
252256
def from_optional(
253257
n: Optional[int] = 1,
@@ -263,6 +267,7 @@ def from_optional(
263267
stop: Optional[Union[str, list[str]]] = None,
264268
stop_token_ids: Optional[list[int]] = None,
265269
bad_words: Optional[list[str]] = None,
270+
max_think_tokens: Optional[int] = None,
266271
include_stop_str_in_output: bool = False,
267272
ignore_eos: bool = False,
268273
max_tokens: Optional[int] = 16,
@@ -306,6 +311,7 @@ def from_optional(
306311
stop=stop,
307312
stop_token_ids=stop_token_ids,
308313
bad_words=bad_words,
314+
max_think_tokens=max_think_tokens,
309315
include_stop_str_in_output=include_stop_str_in_output,
310316
ignore_eos=ignore_eos,
311317
max_tokens=max_tokens,
@@ -574,6 +580,7 @@ def __repr__(self) -> str:
574580
f"stop={self.stop}, "
575581
f"stop_token_ids={self.stop_token_ids}, "
576582
f"bad_words={self.bad_words}, "
583+
f"max_think_tokens={self.max_think_tokens}, "
577584
f"include_stop_str_in_output={self.include_stop_str_in_output}, "
578585
f"ignore_eos={self.ignore_eos}, "
579586
f"max_tokens={self.max_tokens}, "

vllm/v1/engine/core.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def __init__(self,
8686
self.collective_rpc("initialize_cache",
8787
args=(num_gpu_blocks, num_cpu_blocks))
8888

89+
# EngineCore holds StructuredOutputManager to handle and it has vllm config as an arg.
8990
self.structured_output_manager = StructuredOutputManager(vllm_config)
9091

9192
# Setup scheduler.

vllm/v1/sample/logits_processor.py

Lines changed: 121 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from torch._prims_common import DeviceLikeType
1313

1414
from vllm import PoolingParams, SamplingParams
15+
from vllm.config import ReasoningConfig
1516
from vllm.logger import init_logger
1617

1718
logger = init_logger(__name__)
@@ -24,9 +25,9 @@ class MoveDirectionality(Enum):
2425
SWAP = 1
2526

2627

27-
# (index, params, output_tok_ids) tuples for new
28+
# (index, params, prompt_tok_ids, output_tok_ids) tuples for new
2829
# requests added to the batch.
29-
AddedRequest = tuple[int, Union[SamplingParams, PoolingParams], list[int]]
30+
AddedRequest = tuple[int, Union[SamplingParams, PoolingParams], list[int], list[int]]
3031
# (index 1, index 2, directionality) tuples representing
3132
# one-way moves or two-way swaps of requests in batch
3233
MovedRequest = tuple[int, int, MoveDirectionality]
@@ -43,9 +44,9 @@ class BatchUpdate:
4344
# within the persistent batch.
4445
#
4546
# Note: each added request is represented as
46-
# (index, params, output_tok_ids)
47-
# Key assumption: output_tok_ids is a reference to the
48-
# request's running output tokens list; in this way
47+
# (index, params, prompt_tok_ids, output_tok_ids)
48+
# Key assumption: prompt_tok_ids, output_tok_ids is a reference to the
49+
# request's prompt and running output tokens list; in this way
4950
# the logits processors always see the latest list of
5051
# generated tokens
5152
removed: Sequence[RemovedRequest]
@@ -254,7 +255,7 @@ def update_state(self, batch_update: Optional[BatchUpdate]):
254255

255256
needs_update = False
256257
# Process added requests.
257-
for index, params, _ in batch_update.added:
258+
for index, params, _, _ in batch_update.added:
258259
min_p = params.min_p if isinstance(params, SamplingParams) else 0.0
259260
if self.min_p_cpu[index] != min_p:
260261
needs_update = True
@@ -329,7 +330,7 @@ def update_state(self, batch_update: Optional[BatchUpdate]):
329330

330331
# Process added requests.
331332
needs_update = bool(batch_update.added)
332-
for index, params, _ in batch_update.added:
333+
for index, params, _, _ in batch_update.added:
333334
if isinstance(params, SamplingParams) and (lb :=
334335
params.logit_bias):
335336
self.biases[index] = lb
@@ -412,7 +413,7 @@ def update_state(self, batch_update: Optional[BatchUpdate]):
412413
if batch_update:
413414
# Process added requests.
414415
needs_update |= bool(batch_update.added)
415-
for index, params, output_tok_ids in batch_update.added:
416+
for index, params, _, output_tok_ids in batch_update.added:
416417
if (isinstance(params, SamplingParams)
417418
and (min_tokens := params.min_tokens)
418419
and len(output_tok_ids) < min_tokens):
@@ -485,8 +486,113 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor:
485486
return logits
486487

487488

489+
class MaxThinkTokensLogitsProcessor(LogitsProcessor):
490+
"""A logits processor that limits the maximum number of thinking tokens."""
491+
492+
def __init__(self, reasoning_config: ReasoningConfig, pin_memory: bool, device: torch.device):
493+
"""
494+
Args:
495+
think_start_token_id (int): Token ID for the start of thinking section.
496+
think_end_token_id (int): Token ID for the end of thinking section.
497+
pin_memory (bool): Whether to use pinned memory for tensors.
498+
device (torch.device): Device to use for tensor operations.
499+
"""
500+
super().__init__()
501+
self.think_start_token_id = reasoning_config.think_start_token_id
502+
self.think_end_token_id = reasoning_config.think_end_token_id
503+
self.pin_memory = pin_memory
504+
self.device = device
505+
self._state = {}
506+
507+
def _find_last_token_index(self, tokens, token_id):
508+
try:
509+
return len(tokens) - tokens[::-1].index(token_id) - 1
510+
except ValueError:
511+
return -1
512+
513+
def is_argmax_invariant(self) -> bool:
514+
"""This logits processor can change the outcome of greedy sampling
515+
by forcing that the thinking section ends after a certain number of tokens."""
516+
return False
517+
518+
def update_state(self, batch_update: Optional[BatchUpdate]):
519+
if batch_update is None:
520+
return
521+
522+
for index, params, prompt_tok_ids, output_tok_ids in batch_update.added:
523+
max_think_tokens = params.max_think_tokens if isinstance(params, SamplingParams) else None
524+
525+
if max_think_tokens is None:
526+
continue
527+
528+
last_think_start_idx = self._find_last_token_index(prompt_tok_ids, self.think_start_token_id)
529+
last_think_end_idx = self._find_last_token_index(prompt_tok_ids, self.think_end_token_id)
530+
531+
in_think = False
532+
count = 0
533+
534+
if last_think_start_idx > last_think_end_idx:
535+
in_think = True
536+
count = len(prompt_tok_ids) - (last_think_start_idx + 1)
537+
538+
self._state[index] = {
539+
"in_think": in_think,
540+
"count": count,
541+
"prompt_tok_ids": prompt_tok_ids,
542+
"output_tok_ids": output_tok_ids,
543+
"max_think_tokens": max_think_tokens,
544+
}
545+
546+
for index in batch_update.removed:
547+
self._state.pop(index, None)
548+
549+
for i1, i2, direction in batch_update.moved:
550+
if direction == MoveDirectionality.SWAP:
551+
self._state[i1], self._state[i2] = self._state[i2], self._state[i1]
552+
else:
553+
self._state[i2] = self._state.pop(i1, None)
554+
555+
def apply(self, logits: torch.Tensor) -> torch.Tensor:
556+
batch_size = logits.size(0)
557+
if batch_size == 0:
558+
return logits
559+
560+
mask = torch.zeros(batch_size, dtype=torch.bool, device=logits.device)
561+
end_token_id = self.think_end_token_id
562+
563+
for index in range(batch_size):
564+
state = self._state.get(index, None)
565+
if not state or not state.get("output_tok_ids"):
566+
continue
567+
568+
last_tok = state["output_tok_ids"][-1]
569+
in_think = state["in_think"]
570+
count = state["count"]
571+
572+
if last_tok == self.think_start_token_id:
573+
in_think = True
574+
count = 0
575+
elif last_tok == self.think_end_token_id:
576+
in_think = False
577+
count = 0
578+
elif in_think:
579+
count += 1
580+
581+
state["in_think"] = in_think
582+
state["count"] = count
583+
584+
if state["in_think"] and state["count"] >= state["max_think_tokens"]:
585+
mask[index] = True
586+
587+
if mask.any():
588+
logits[mask] = -float("inf")
589+
logits[mask, end_token_id] = 0.0
590+
591+
return logits
592+
593+
488594
def init_builtin_logitsprocs(pin_memory_available: bool, max_num_reqs: int,
489-
device: torch.device) -> LogitsProcessorManager:
595+
device: torch.device, reasoning_config: ReasoningConfig) -> LogitsProcessorManager:
490596
"""Construct 'builtin' vLLM logitsprocs which the engine
491597
loads by default.
492598
@@ -508,10 +614,16 @@ def init_builtin_logitsprocs(pin_memory_available: bool, max_num_reqs: int,
508614
device=device,
509615
# +1 for temporary swap space
510616
max_num_reqs=max_num_reqs + 1)
617+
max_think_tokens_logitproc = MaxThinkTokensLogitsProcessor(
618+
reasoning_config=reasoning_config,
619+
pin_memory=pin_memory_available,
620+
device=device,
621+
)
511622
return LogitsProcessorManager(
512623
non_argmax_invariant=[
513624
min_tokens_logitproc,
514625
logit_bias_logitproc,
626+
max_think_tokens_logitproc
515627
],
516628
argmax_invariant=[min_p_logitproc],
517629
)

0 commit comments

Comments
 (0)