Skip to content

Commit f4274ab

Browse files
committed
feat: limit thinking tokens
Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
1 parent 66f6fbd commit f4274ab

File tree

8 files changed

+185
-27
lines changed

8 files changed

+185
-27
lines changed

vllm/config.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4404,6 +4404,17 @@ def set_splitting_ops_for_v1(self):
44044404
"vllm.unified_attention_with_output",
44054405
]
44064406

4407+
class ReasoningConfig:
4408+
"""Configuration for reasoning models."""
4409+
4410+
think_start_token_id: Optional[int] = None
4411+
"""Token ID that indicates the start of reasoning."""
4412+
think_end_token_id: Optional[int] = None
4413+
"""Token ID that indicates the end of reasoning."""
4414+
4415+
def __init__(self, think_start_token_id: Optional[int] = None, think_end_token_id: Optional[int] = None):
4416+
self.think_start_token_id = think_start_token_id
4417+
self.think_end_token_id = think_end_token_id
44074418

44084419
@config
44094420
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
@@ -4461,6 +4472,8 @@ class VllmConfig:
44614472
# some opaque config, only used to provide additional information
44624473
# for the hash computation, mainly used for testing, debugging or out of
44634474
# tree config registration.
4475+
reasoning_config: Optional[ReasoningConfig] = None
4476+
"""The configurations for reasoning model."""
44644477
additional_config: Union[dict, SupportsHash] = field(default_factory=dict)
44654478
"""Additional config for specified platform. Different platforms may
44664479
support different configs. Make sure the configs are valid for the platform

vllm/entrypoints/openai/protocol.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
404404
prompt_logprobs: Optional[int] = None
405405
allowed_token_ids: Optional[list[int]] = None
406406
bad_words: list[str] = Field(default_factory=list)
407+
max_think_tokens: Optional[int] = None
407408
# --8<-- [end:chat-completion-sampling-params]
408409

409410
# --8<-- [start:chat-completion-extra-params]
@@ -670,6 +671,7 @@ def to_sampling_params(
670671
guided_decoding=guided_decoding,
671672
logit_bias=self.logit_bias,
672673
bad_words= self.bad_words,
674+
max_think_tokens=self.max_think_tokens,
673675
allowed_token_ids=self.allowed_token_ids,
674676
extra_args=extra_args or None,
675677
)

vllm/reasoning/deepseek_r1_reasoning_parser.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
2323
text. This parser extracts the reasoning content from the model output.
2424
"""
2525

26-
start_token_id: int
27-
end_token_id: int
26+
think_start_token_id: int
27+
think_end_token_id: int
2828

2929
start_token: str = "<think>"
3030
end_token: str = "</think>"
@@ -37,24 +37,24 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase):
3737
"The model tokenizer must be passed to the ReasoningParser "
3838
"constructor during construction.")
3939

40-
self.start_token_id = self.vocab.get(self.start_token)
41-
self.end_token_id = self.vocab.get(self.end_token)
42-
if self.start_token_id is None or self.end_token_id is None:
40+
self.think_start_token_id = self.vocab.get(self.start_token)
41+
self.think_end_token_id = self.vocab.get(self.end_token)
42+
if self.think_start_token_id is None or self.think_end_token_id is None:
4343
raise RuntimeError(
4444
"DeepSeek R1 reasoning parser could not locate think start/end "
4545
"tokens in the tokenizer!")
4646

4747
def is_reasoning_end(self, input_ids: list[int]) -> bool:
48-
return self.end_token_id in input_ids
48+
return self.think_end_token_id in input_ids
4949

5050
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
5151
"""
5252
Extract the content after the end tokens
5353
"""
54-
if self.end_token_id not in input_ids[:-1]:
54+
if self.think_end_token_id not in input_ids[:-1]:
5555
return []
5656
else:
57-
return input_ids[input_ids.index(self.end_token_id) + 1:]
57+
return input_ids[input_ids.index(self.think_end_token_id) + 1:]
5858

5959
def extract_reasoning_content_streaming(
6060
self,
@@ -75,14 +75,14 @@ def extract_reasoning_content_streaming(
7575
"""
7676
# Skip single special tokens
7777
if len(delta_token_ids) == 1 and (delta_token_ids[0] in [
78-
self.start_token_id, self.end_token_id
78+
self.think_start_token_id, self.think_end_token_id
7979
]):
8080
return None
8181

8282
# Check if <think> is present in previous or delta.
8383
# Keep compatibility with models that don't generate <think> tokens.
84-
if self.start_token_id in previous_token_ids:
85-
if self.end_token_id in delta_token_ids:
84+
if self.think_start_token_id in previous_token_ids:
85+
if self.think_end_token_id in delta_token_ids:
8686
# <think> in previous, </think> in delta,
8787
# extract reasoning content
8888
end_index = delta_text.find(self.end_token)
@@ -92,16 +92,16 @@ def extract_reasoning_content_streaming(
9292
reasoning_content=reasoning_content,
9393
content=content if content else None,
9494
)
95-
elif self.end_token_id in previous_token_ids:
95+
elif self.think_end_token_id in previous_token_ids:
9696
# <think> in previous, </think> in previous,
9797
# reasoning content continues
9898
return DeltaMessage(content=delta_text)
9999
else:
100100
# <think> in previous, no </think> in previous or delta,
101101
# reasoning content continues
102102
return DeltaMessage(reasoning_content=delta_text)
103-
elif self.start_token_id in delta_token_ids:
104-
if self.end_token_id in delta_token_ids:
103+
elif self.think_start_token_id in delta_token_ids:
104+
if self.think_end_token_id in delta_token_ids:
105105
# <think> in delta, </think> in delta, extract reasoning content
106106
start_index = delta_text.find(self.start_token)
107107
end_index = delta_text.find(self.end_token)
@@ -120,7 +120,7 @@ def extract_reasoning_content_streaming(
120120
# No <think> in previous or delta, also need to check for </think>.
121121
# Because the model may have generated </think> without <think>
122122
# Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
123-
if self.end_token_id in delta_token_ids:
123+
if self.think_end_token_id in delta_token_ids:
124124
# </think> in delta with more tokens,
125125
# extract reasoning content and content
126126
end_index = delta_text.find(self.end_token)
@@ -130,7 +130,7 @@ def extract_reasoning_content_streaming(
130130
reasoning_content=reasoning_content,
131131
content=content if content else None,
132132
)
133-
elif self.end_token_id in previous_token_ids:
133+
elif self.think_end_token_id in previous_token_ids:
134134
# </think> in previous, thinking content ends
135135
return DeltaMessage(content=delta_text)
136136
else:

vllm/sampling_params.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ class SamplingParams(
200200
extra_args: Arbitrary additional args, that can be used by custom
201201
sampling implementations, plugins, etc. Not used by any in-tree
202202
sampling implementations.
203+
max_think_tokens: Maximum number of tokens allowed for thinking
203204
"""
204205

205206
n: int = 1
@@ -248,6 +249,9 @@ class SamplingParams(
248249
bad_words: Optional[list[str]] = None
249250
_bad_words_token_ids: Optional[list[list[int]]] = None
250251

252+
# Maximum number of tokens allowed for thinking operations.
253+
max_think_tokens: Optional[int] = None
254+
251255
@staticmethod
252256
def from_optional(
253257
n: Optional[int] = 1,
@@ -263,6 +267,7 @@ def from_optional(
263267
stop: Optional[Union[str, list[str]]] = None,
264268
stop_token_ids: Optional[list[int]] = None,
265269
bad_words: Optional[list[str]] = None,
270+
max_think_tokens: Optional[int] = None,
266271
include_stop_str_in_output: bool = False,
267272
ignore_eos: bool = False,
268273
max_tokens: Optional[int] = 16,
@@ -306,6 +311,7 @@ def from_optional(
306311
stop=stop,
307312
stop_token_ids=stop_token_ids,
308313
bad_words=bad_words,
314+
max_think_tokens=max_think_tokens,
309315
include_stop_str_in_output=include_stop_str_in_output,
310316
ignore_eos=ignore_eos,
311317
max_tokens=max_tokens,
@@ -574,6 +580,7 @@ def __repr__(self) -> str:
574580
f"stop={self.stop}, "
575581
f"stop_token_ids={self.stop_token_ids}, "
576582
f"bad_words={self.bad_words}, "
583+
f"max_think_tokens={self.max_think_tokens}, "
577584
f"include_stop_str_in_output={self.include_stop_str_in_output}, "
578585
f"ignore_eos={self.ignore_eos}, "
579586
f"max_tokens={self.max_tokens}, "

vllm/v1/engine/core.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def __init__(self,
8686
self.collective_rpc("initialize_cache",
8787
args=(num_gpu_blocks, num_cpu_blocks))
8888

89+
# EngineCore holds StructuredOutputManager to handle and it has vllm config as an arg.
8990
self.structured_output_manager = StructuredOutputManager(vllm_config)
9091

9192
# Setup scheduler.

vllm/v1/sample/logits_processor.py

Lines changed: 121 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from torch._prims_common import DeviceLikeType
1313

1414
from vllm import PoolingParams, SamplingParams
15+
from vllm.config import ReasoningConfig
1516
from vllm.logger import init_logger
1617

1718
logger = init_logger(__name__)
@@ -24,9 +25,9 @@ class MoveDirectionality(Enum):
2425
SWAP = 1
2526

2627

27-
# (index, params, output_tok_ids) tuples for new
28+
# (index, params, prompt_tok_ids, output_tok_ids) tuples for new
2829
# requests added to the batch.
29-
AddedRequest = tuple[int, Union[SamplingParams, PoolingParams], list[int]]
30+
AddedRequest = tuple[int, Union[SamplingParams, PoolingParams], list[int], list[int]]
3031
# (index 1, index 2, directionality) tuples representing
3132
# one-way moves or two-way swaps of requests in batch
3233
MovedRequest = tuple[int, int, MoveDirectionality]
@@ -43,9 +44,9 @@ class BatchUpdate:
4344
# within the persistent batch.
4445
#
4546
# Note: each added request is represented as
46-
# (index, params, output_tok_ids)
47-
# Key assumption: output_tok_ids is a reference to the
48-
# request's running output tokens list; in this way
47+
# (index, params, prompt_tok_ids, output_tok_ids)
48+
# Key assumption: prompt_tok_ids, output_tok_ids is a reference to the
49+
# request's prompt and running output tokens list; in this way
4950
# the logits processors always see the latest list of
5051
# generated tokens
5152
removed: Sequence[RemovedRequest]
@@ -260,7 +261,7 @@ def update_state(self, batch_update: Optional[BatchUpdate]):
260261

261262
needs_update = False
262263
# Process added requests.
263-
for index, params, _ in batch_update.added:
264+
for index, params, _, _ in batch_update.added:
264265
min_p = params.min_p if isinstance(params, SamplingParams) else 0.0
265266
if self.min_p_cpu[index] != min_p:
266267
needs_update = True
@@ -337,7 +338,7 @@ def update_state(self, batch_update: Optional[BatchUpdate]):
337338

338339
# Process added requests.
339340
needs_update = bool(batch_update.added)
340-
for index, params, _ in batch_update.added:
341+
for index, params, _, _ in batch_update.added:
341342
if isinstance(params, SamplingParams) and (lb :=
342343
params.logit_bias):
343344
self.biases[index] = lb
@@ -420,7 +421,7 @@ def update_state(self, batch_update: Optional[BatchUpdate]):
420421
if batch_update:
421422
# Process added requests.
422423
needs_update |= bool(batch_update.added)
423-
for index, params, output_tok_ids in batch_update.added:
424+
for index, params, _, output_tok_ids in batch_update.added:
424425
if (isinstance(params, SamplingParams)
425426
and (min_tokens := params.min_tokens)
426427
and len(output_tok_ids) < min_tokens):
@@ -493,8 +494,113 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor:
493494
return logits
494495

495496

497+
class MaxThinkTokensLogitsProcessor(LogitsProcessor):
498+
"""A logits processor that limits the maximum number of thinking tokens."""
499+
500+
def __init__(self, reasoning_config: ReasoningConfig, pin_memory: bool, device: torch.device):
501+
"""
502+
Args:
503+
think_start_token_id (int): Token ID for the start of thinking section.
504+
think_end_token_id (int): Token ID for the end of thinking section.
505+
pin_memory (bool): Whether to use pinned memory for tensors.
506+
device (torch.device): Device to use for tensor operations.
507+
"""
508+
super().__init__()
509+
self.think_start_token_id = reasoning_config.think_start_token_id
510+
self.think_end_token_id = reasoning_config.think_end_token_id
511+
self.pin_memory = pin_memory
512+
self.device = device
513+
self._state = {}
514+
515+
def _find_last_token_index(self, tokens, token_id):
516+
try:
517+
return len(tokens) - tokens[::-1].index(token_id) - 1
518+
except ValueError:
519+
return -1
520+
521+
def is_argmax_invariant(self) -> bool:
522+
"""This logits processor can change the outcome of greedy sampling
523+
by forcing that the thinking section ends after a certain number of tokens."""
524+
return False
525+
526+
def update_state(self, batch_update: Optional[BatchUpdate]):
527+
if batch_update is None:
528+
return
529+
530+
for index, params, prompt_tok_ids, output_tok_ids in batch_update.added:
531+
max_think_tokens = params.max_think_tokens if isinstance(params, SamplingParams) else None
532+
533+
if max_think_tokens is None:
534+
continue
535+
536+
last_think_start_idx = self._find_last_token_index(prompt_tok_ids, self.think_start_token_id)
537+
last_think_end_idx = self._find_last_token_index(prompt_tok_ids, self.think_end_token_id)
538+
539+
in_think = False
540+
count = 0
541+
542+
if last_think_start_idx > last_think_end_idx:
543+
in_think = True
544+
count = len(prompt_tok_ids) - (last_think_start_idx + 1)
545+
546+
self._state[index] = {
547+
"in_think": in_think,
548+
"count": count,
549+
"prompt_tok_ids": prompt_tok_ids,
550+
"output_tok_ids": output_tok_ids,
551+
"max_think_tokens": max_think_tokens,
552+
}
553+
554+
for index in batch_update.removed:
555+
self._state.pop(index, None)
556+
557+
for i1, i2, direction in batch_update.moved:
558+
if direction == MoveDirectionality.SWAP:
559+
self._state[i1], self._state[i2] = self._state[i2], self._state[i1]
560+
else:
561+
self._state[i2] = self._state.pop(i1, None)
562+
563+
def apply(self, logits: torch.Tensor) -> torch.Tensor:
564+
batch_size = logits.size(0)
565+
if batch_size == 0:
566+
return logits
567+
568+
mask = torch.zeros(batch_size, dtype=torch.bool, device=logits.device)
569+
end_token_id = self.think_end_token_id
570+
571+
for index in range(batch_size):
572+
state = self._state.get(index, None)
573+
if not state or not state.get("output_tok_ids"):
574+
continue
575+
576+
last_tok = state["output_tok_ids"][-1]
577+
in_think = state["in_think"]
578+
count = state["count"]
579+
580+
if last_tok == self.think_start_token_id:
581+
in_think = True
582+
count = 0
583+
elif last_tok == self.think_end_token_id:
584+
in_think = False
585+
count = 0
586+
elif in_think:
587+
count += 1
588+
589+
state["in_think"] = in_think
590+
state["count"] = count
591+
592+
if state["in_think"] and state["count"] >= state["max_think_tokens"]:
593+
mask[index] = True
594+
595+
if mask.any():
596+
logits[mask] = -float("inf")
597+
logits[mask, end_token_id] = 0.0
598+
599+
return logits
600+
601+
496602
def init_builtin_logitsprocs(pin_memory_available: bool, max_num_reqs: int,
497-
device: torch.device) -> LogitsProcessorManager:
603+
device: torch.device, reasoning_config: ReasoningConfig) -> LogitsProcessorManager:
498604
"""Construct 'builtin' vLLM logitsprocs which the engine
499605
loads by default.
500606
@@ -516,10 +622,16 @@ def init_builtin_logitsprocs(pin_memory_available: bool, max_num_reqs: int,
516622
device=device,
517623
# +1 for temporary swap space
518624
max_num_reqs=max_num_reqs + 1)
625+
max_think_tokens_logitproc = MaxThinkTokensLogitsProcessor(
626+
reasoning_config=reasoning_config,
627+
pin_memory=pin_memory_available,
628+
device=device,
629+
)
519630
return LogitsProcessorManager(
520631
non_argmax_invariant=[
521632
min_tokens_logitproc,
522633
logit_bias_logitproc,
634+
max_think_tokens_logitproc
523635
],
524636
argmax_invariant=[min_p_logitproc],
525637
)

0 commit comments

Comments
 (0)