From f4274ab20979bcf4367ba429266a6e625aa25c28 Mon Sep 17 00:00:00 2001
From: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
Date: Sat, 12 Jul 2025 09:14:00 +0000
Subject: [PATCH 1/8] feat: limit thinking tokens
Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
---
vllm/config.py | 13 ++
vllm/entrypoints/openai/protocol.py | 2 +
.../reasoning/deepseek_r1_reasoning_parser.py | 32 ++---
vllm/sampling_params.py | 7 +
vllm/v1/engine/core.py | 1 +
vllm/v1/sample/logits_processor.py | 130 ++++++++++++++++--
vllm/v1/worker/gpu_input_batch.py | 7 +-
vllm/v1/worker/gpu_model_runner.py | 20 +++
8 files changed, 185 insertions(+), 27 deletions(-)
diff --git a/vllm/config.py b/vllm/config.py
index d9f356c5c60..9ef217bcc1d 100644
--- a/vllm/config.py
+++ b/vllm/config.py
@@ -4404,6 +4404,17 @@ def set_splitting_ops_for_v1(self):
"vllm.unified_attention_with_output",
]
+class ReasoningConfig:
+ """Configuration for reasoning models."""
+
+ think_start_token_id: Optional[int] = None
+ """Token ID that indicates the start of reasoning."""
+ think_end_token_id: Optional[int] = None
+ """Token ID that indicates the end of reasoning."""
+
+ def __init__(self, think_start_token_id: Optional[int] = None, think_end_token_id: Optional[int] = None):
+ self.think_start_token_id = think_start_token_id
+ self.think_end_token_id = think_end_token_id
@config
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
@@ -4461,6 +4472,8 @@ class VllmConfig:
# some opaque config, only used to provide additional information
# for the hash computation, mainly used for testing, debugging or out of
# tree config registration.
+ reasoning_config: Optional[ReasoningConfig] = None
+ """The configurations for reasoning model."""
additional_config: Union[dict, SupportsHash] = field(default_factory=dict)
"""Additional config for specified platform. Different platforms may
support different configs. Make sure the configs are valid for the platform
diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py
index fdac6ccd19e..e0c7bbf21db 100644
--- a/vllm/entrypoints/openai/protocol.py
+++ b/vllm/entrypoints/openai/protocol.py
@@ -404,6 +404,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
prompt_logprobs: Optional[int] = None
allowed_token_ids: Optional[list[int]] = None
bad_words: list[str] = Field(default_factory=list)
+ max_think_tokens: Optional[int] = None
# --8<-- [end:chat-completion-sampling-params]
# --8<-- [start:chat-completion-extra-params]
@@ -670,6 +671,7 @@ def to_sampling_params(
guided_decoding=guided_decoding,
logit_bias=self.logit_bias,
bad_words= self.bad_words,
+ max_think_tokens=self.max_think_tokens,
allowed_token_ids=self.allowed_token_ids,
extra_args=extra_args or None,
)
diff --git a/vllm/reasoning/deepseek_r1_reasoning_parser.py b/vllm/reasoning/deepseek_r1_reasoning_parser.py
index 1a5ca46a60f..96bb50f3817 100644
--- a/vllm/reasoning/deepseek_r1_reasoning_parser.py
+++ b/vllm/reasoning/deepseek_r1_reasoning_parser.py
@@ -23,8 +23,8 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
text. This parser extracts the reasoning content from the model output.
"""
- start_token_id: int
- end_token_id: int
+ think_start_token_id: int
+ think_end_token_id: int
start_token: str = ""
end_token: str = ""
@@ -37,24 +37,24 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase):
"The model tokenizer must be passed to the ReasoningParser "
"constructor during construction.")
- self.start_token_id = self.vocab.get(self.start_token)
- self.end_token_id = self.vocab.get(self.end_token)
- if self.start_token_id is None or self.end_token_id is None:
+ self.think_start_token_id = self.vocab.get(self.start_token)
+ self.think_end_token_id = self.vocab.get(self.end_token)
+ if self.think_start_token_id is None or self.think_end_token_id is None:
raise RuntimeError(
"DeepSeek R1 reasoning parser could not locate think start/end "
"tokens in the tokenizer!")
def is_reasoning_end(self, input_ids: list[int]) -> bool:
- return self.end_token_id in input_ids
+ return self.think_end_token_id in input_ids
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
"""
Extract the content after the end tokens
"""
- if self.end_token_id not in input_ids[:-1]:
+ if self.think_end_token_id not in input_ids[:-1]:
return []
else:
- return input_ids[input_ids.index(self.end_token_id) + 1:]
+ return input_ids[input_ids.index(self.think_end_token_id) + 1:]
def extract_reasoning_content_streaming(
self,
@@ -75,14 +75,14 @@ def extract_reasoning_content_streaming(
"""
# Skip single special tokens
if len(delta_token_ids) == 1 and (delta_token_ids[0] in [
- self.start_token_id, self.end_token_id
+ self.think_start_token_id, self.think_end_token_id
]):
return None
# Check if is present in previous or delta.
# Keep compatibility with models that don't generate tokens.
- if self.start_token_id in previous_token_ids:
- if self.end_token_id in delta_token_ids:
+ if self.think_start_token_id in previous_token_ids:
+ if self.think_end_token_id in delta_token_ids:
# in previous, in delta,
# extract reasoning content
end_index = delta_text.find(self.end_token)
@@ -92,7 +92,7 @@ def extract_reasoning_content_streaming(
reasoning_content=reasoning_content,
content=content if content else None,
)
- elif self.end_token_id in previous_token_ids:
+ elif self.think_end_token_id in previous_token_ids:
# in previous, in previous,
# reasoning content continues
return DeltaMessage(content=delta_text)
@@ -100,8 +100,8 @@ def extract_reasoning_content_streaming(
# in previous, no in previous or delta,
# reasoning content continues
return DeltaMessage(reasoning_content=delta_text)
- elif self.start_token_id in delta_token_ids:
- if self.end_token_id in delta_token_ids:
+ elif self.think_start_token_id in delta_token_ids:
+ if self.think_end_token_id in delta_token_ids:
# in delta, in delta, extract reasoning content
start_index = delta_text.find(self.start_token)
end_index = delta_text.find(self.end_token)
@@ -120,7 +120,7 @@ def extract_reasoning_content_streaming(
# No in previous or delta, also need to check for .
# Because the model may have generated without
# Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
- if self.end_token_id in delta_token_ids:
+ if self.think_end_token_id in delta_token_ids:
# in delta with more tokens,
# extract reasoning content and content
end_index = delta_text.find(self.end_token)
@@ -130,7 +130,7 @@ def extract_reasoning_content_streaming(
reasoning_content=reasoning_content,
content=content if content else None,
)
- elif self.end_token_id in previous_token_ids:
+ elif self.think_end_token_id in previous_token_ids:
# in previous, thinking content ends
return DeltaMessage(content=delta_text)
else:
diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py
index a9a862384d1..49cdbdde45c 100644
--- a/vllm/sampling_params.py
+++ b/vllm/sampling_params.py
@@ -200,6 +200,7 @@ class SamplingParams(
extra_args: Arbitrary additional args, that can be used by custom
sampling implementations, plugins, etc. Not used by any in-tree
sampling implementations.
+ max_think_tokens: Maximum number of tokens allowed for thinking
"""
n: int = 1
@@ -248,6 +249,9 @@ class SamplingParams(
bad_words: Optional[list[str]] = None
_bad_words_token_ids: Optional[list[list[int]]] = None
+ # Maximum number of tokens allowed for thinking operations.
+ max_think_tokens: Optional[int] = None
+
@staticmethod
def from_optional(
n: Optional[int] = 1,
@@ -263,6 +267,7 @@ def from_optional(
stop: Optional[Union[str, list[str]]] = None,
stop_token_ids: Optional[list[int]] = None,
bad_words: Optional[list[str]] = None,
+ max_think_tokens: Optional[int] = None,
include_stop_str_in_output: bool = False,
ignore_eos: bool = False,
max_tokens: Optional[int] = 16,
@@ -306,6 +311,7 @@ def from_optional(
stop=stop,
stop_token_ids=stop_token_ids,
bad_words=bad_words,
+ max_think_tokens=max_think_tokens,
include_stop_str_in_output=include_stop_str_in_output,
ignore_eos=ignore_eos,
max_tokens=max_tokens,
@@ -574,6 +580,7 @@ def __repr__(self) -> str:
f"stop={self.stop}, "
f"stop_token_ids={self.stop_token_ids}, "
f"bad_words={self.bad_words}, "
+ f"max_think_tokens={self.max_think_tokens}, "
f"include_stop_str_in_output={self.include_stop_str_in_output}, "
f"ignore_eos={self.ignore_eos}, "
f"max_tokens={self.max_tokens}, "
diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py
index e2fdf6f8a11..0eb0d8018ae 100644
--- a/vllm/v1/engine/core.py
+++ b/vllm/v1/engine/core.py
@@ -86,6 +86,7 @@ def __init__(self,
self.collective_rpc("initialize_cache",
args=(num_gpu_blocks, num_cpu_blocks))
+ # EngineCore holds StructuredOutputManager to handle and it has vllm config as an arg.
self.structured_output_manager = StructuredOutputManager(vllm_config)
# Setup scheduler.
diff --git a/vllm/v1/sample/logits_processor.py b/vllm/v1/sample/logits_processor.py
index 3a4c25964e7..04912b358b3 100644
--- a/vllm/v1/sample/logits_processor.py
+++ b/vllm/v1/sample/logits_processor.py
@@ -12,6 +12,7 @@
from torch._prims_common import DeviceLikeType
from vllm import PoolingParams, SamplingParams
+from vllm.config import ReasoningConfig
from vllm.logger import init_logger
logger = init_logger(__name__)
@@ -24,9 +25,9 @@ class MoveDirectionality(Enum):
SWAP = 1
-# (index, params, output_tok_ids) tuples for new
+# (index, params, prompt_tok_ids, output_tok_ids) tuples for new
# requests added to the batch.
-AddedRequest = tuple[int, Union[SamplingParams, PoolingParams], list[int]]
+AddedRequest = tuple[int, Union[SamplingParams, PoolingParams], list[int], list[int]]
# (index 1, index 2, directionality) tuples representing
# one-way moves or two-way swaps of requests in batch
MovedRequest = tuple[int, int, MoveDirectionality]
@@ -43,9 +44,9 @@ class BatchUpdate:
# within the persistent batch.
#
# Note: each added request is represented as
- # (index, params, output_tok_ids)
- # Key assumption: output_tok_ids is a reference to the
- # request's running output tokens list; in this way
+ # (index, params, prompt_tok_ids, output_tok_ids)
+ # Key assumption: prompt_tok_ids, output_tok_ids is a reference to the
+ # request's prompt and running output tokens list; in this way
# the logits processors always see the latest list of
# generated tokens
removed: Sequence[RemovedRequest]
@@ -260,7 +261,7 @@ def update_state(self, batch_update: Optional[BatchUpdate]):
needs_update = False
# Process added requests.
- for index, params, _ in batch_update.added:
+ for index, params, _, _ in batch_update.added:
min_p = params.min_p if isinstance(params, SamplingParams) else 0.0
if self.min_p_cpu[index] != min_p:
needs_update = True
@@ -337,7 +338,7 @@ def update_state(self, batch_update: Optional[BatchUpdate]):
# Process added requests.
needs_update = bool(batch_update.added)
- for index, params, _ in batch_update.added:
+ for index, params, _, _ in batch_update.added:
if isinstance(params, SamplingParams) and (lb :=
params.logit_bias):
self.biases[index] = lb
@@ -420,7 +421,7 @@ def update_state(self, batch_update: Optional[BatchUpdate]):
if batch_update:
# Process added requests.
needs_update |= bool(batch_update.added)
- for index, params, output_tok_ids in batch_update.added:
+ for index, params, _, output_tok_ids in batch_update.added:
if (isinstance(params, SamplingParams)
and (min_tokens := params.min_tokens)
and len(output_tok_ids) < min_tokens):
@@ -493,8 +494,113 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor:
return logits
+class MaxThinkTokensLogitsProcessor(LogitsProcessor):
+ """A logits processor that limits the maximum number of thinking tokens."""
+
+ def __init__(self, reasoning_config: ReasoningConfig, pin_memory: bool, device: torch.device):
+ """
+ Args:
+ think_start_token_id (int): Token ID for the start of thinking section.
+ think_end_token_id (int): Token ID for the end of thinking section.
+ pin_memory (bool): Whether to use pinned memory for tensors.
+ device (torch.device): Device to use for tensor operations.
+ """
+ super().__init__()
+ self.think_start_token_id = reasoning_config.think_start_token_id
+ self.think_end_token_id = reasoning_config.think_end_token_id
+ self.pin_memory = pin_memory
+ self.device = device
+ self._state = {}
+
+ def _find_last_token_index(self, tokens, token_id):
+ try:
+ return len(tokens) - tokens[::-1].index(token_id) - 1
+ except ValueError:
+ return -1
+
+ def is_argmax_invariant(self) -> bool:
+ """This logits processor can change the outcome of greedy sampling
+ by forcing that the thinking section ends after a certain number of tokens."""
+ return False
+
+ def update_state(self, batch_update: Optional[BatchUpdate]):
+ if batch_update is None:
+ return
+
+ for index, params, prompt_tok_ids, output_tok_ids in batch_update.added:
+ max_think_tokens = params.max_think_tokens if isinstance(params, SamplingParams) else None
+
+ if max_think_tokens is None:
+ continue
+
+ last_think_start_idx = self._find_last_token_index(prompt_tok_ids, self.think_start_token_id)
+ last_think_end_idx = self._find_last_token_index(prompt_tok_ids, self.think_end_token_id)
+
+ in_think = False
+ count = 0
+
+ if last_think_start_idx > last_think_end_idx:
+ in_think = True
+ count = len(prompt_tok_ids) - (last_think_start_idx + 1)
+
+ self._state[index] = {
+ "in_think": in_think,
+ "count": count,
+ "prompt_tok_ids": prompt_tok_ids,
+ "output_tok_ids": output_tok_ids,
+ "max_think_tokens": max_think_tokens,
+ }
+
+ for index in batch_update.removed:
+ self._state.pop(index, None)
+
+ for i1, i2, direction in batch_update.moved:
+ if direction == MoveDirectionality.SWAP:
+ self._state[i1], self._state[i2] = self._state[i2], self._state[i1]
+ else:
+ self._state[i2] = self._state.pop(i1, None)
+
+ def apply(self, logits: torch.Tensor) -> torch.Tensor:
+ batch_size = logits.size(0)
+ if batch_size == 0:
+ return logits
+
+ mask = torch.zeros(batch_size, dtype=torch.bool, device=logits.device)
+ end_token_id = self.think_end_token_id
+
+ for index in range(batch_size):
+ state = self._state.get(index, None)
+ if not state or not state.get("output_tok_ids"):
+ continue
+
+ last_tok = state["output_tok_ids"][-1]
+ in_think = state["in_think"]
+ count = state["count"]
+
+ if last_tok == self.think_start_token_id:
+ in_think = True
+ count = 0
+ elif last_tok == self.think_end_token_id:
+ in_think = False
+ count = 0
+ elif in_think:
+ count += 1
+
+ state["in_think"] = in_think
+ state["count"] = count
+
+ if state["in_think"] and state["count"] >= state["max_think_tokens"]:
+ mask[index] = True
+
+ if mask.any():
+ logits[mask] = -float("inf")
+ logits[mask, end_token_id] = 0.0
+
+ return logits
+
+
def init_builtin_logitsprocs(pin_memory_available: bool, max_num_reqs: int,
- device: torch.device) -> LogitsProcessorManager:
+ device: torch.device, reasoning_config: ReasoningConfig) -> LogitsProcessorManager:
"""Construct 'builtin' vLLM logitsprocs which the engine
loads by default.
@@ -516,10 +622,16 @@ def init_builtin_logitsprocs(pin_memory_available: bool, max_num_reqs: int,
device=device,
# +1 for temporary swap space
max_num_reqs=max_num_reqs + 1)
+ max_think_tokens_logitproc = MaxThinkTokensLogitsProcessor(
+ reasoning_config=reasoning_config,
+ pin_memory=pin_memory_available,
+ device=device,
+ )
return LogitsProcessorManager(
non_argmax_invariant=[
min_tokens_logitproc,
logit_bias_logitproc,
+ max_think_tokens_logitproc
],
argmax_invariant=[min_p_logitproc],
)
diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py
index 1a79d72be0a..daf339bdd5b 100644
--- a/vllm/v1/worker/gpu_input_batch.py
+++ b/vllm/v1/worker/gpu_input_batch.py
@@ -8,6 +8,7 @@
import numpy as np
import torch
+from vllm.config import ReasoningConfig
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.pooling_params import PoolingParams
@@ -71,6 +72,7 @@ def __init__(
block_sizes: list[int], # The block_size of each kv cache group
is_spec_decode: bool = False,
logits_processing_needs_token_ids: bool = False,
+ reasoning_config: ReasoningConfig = None,
):
self.is_spec_decode = is_spec_decode
self.max_num_reqs = max_num_reqs
@@ -221,7 +223,8 @@ def __init__(
self.logitsprocs = init_builtin_logitsprocs(
pin_memory_available=pin_memory,
max_num_reqs=max_num_reqs + 1,
- device=device)
+ device=device,
+ reasoning_config=reasoning_config)
# TODO convert this to LogitsProcessor
self.has_allowed_token_ids: set[str] = set()
@@ -260,7 +263,7 @@ def _register_add_request(self, request: "CachedRequestState") -> int:
params = (request.sampling_params
if request.sampling_params else request.pooling_params)
self.batch_update_builder.added.append(
- (req_index, params, request.output_token_ids))
+ (req_index, params, request.prompt_token_ids, request.output_token_ids))
return req_index
def add_request(
diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py
index 4551cb2df98..8a37a30f2e9 100644
--- a/vllm/v1/worker/gpu_model_runner.py
+++ b/vllm/v1/worker/gpu_model_runner.py
@@ -71,6 +71,10 @@
from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing,
sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
+from vllm.config import ReasoningConfig
+from vllm.reasoning import ReasoningParserManager
+from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
+
if TYPE_CHECKING:
import xgrammar as xgr
import xgrammar.kernels.apply_token_bitmask_inplace_torch_compile as xgr_torch_compile # noqa: E501
@@ -105,6 +109,20 @@ def __init__(
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config
+ if self.vllm_config.decoding_config.reasoning_backend in ('deepseek_r1', 'qwen'):
+ tokenizer = init_tokenizer_from_configs(
+ model_config=self.vllm_config.model_config,
+ scheduler_config=self.vllm_config.scheduler_config,
+ lora_config=self.vllm_config.lora_config,
+ ).get_lora_tokenizer(None)
+ reasoning_backend = \
+ self.vllm_config.decoding_config.reasoning_backend
+ reasoner_cls = ReasoningParserManager.get_reasoning_parser(
+ reasoning_backend)
+ reasoning_parser = reasoner_cls(tokenizer=tokenizer)
+ self.vllm_config.reasoning_config = ReasoningConfig(think_start_token_id=reasoning_parser.think_start_token_id,
+ think_end_token_id=reasoning_parser.think_end_token_id)
+
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
set_cpu_offload_max_bytes(
int(self.cache_config.cpu_offload_gb * 1024**3))
@@ -212,6 +230,7 @@ def __init__(
vocab_size=self.model_config.get_vocab_size(),
block_sizes=[self.cache_config.block_size],
is_spec_decode=bool(self.vllm_config.speculative_config),
+ reasoning_config=self.vllm_config.reasoning_config,
)
self.use_cuda_graph = (
@@ -2384,6 +2403,7 @@ def may_reinitialize_input_batch(self,
vocab_size=self.model_config.get_vocab_size(),
block_sizes=block_sizes,
is_spec_decode=bool(self.vllm_config.speculative_config),
+ reasoning_config=self.vllm_config.reasoning_config,
)
def _allocate_kv_cache_tensors(
From 84aee5bda0ac43df5f3c749aa31157b79874f368 Mon Sep 17 00:00:00 2001
From: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
Date: Sat, 12 Jul 2025 09:24:57 +0000
Subject: [PATCH 2/8] remove comment
Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
---
vllm/v1/engine/core.py | 1 -
1 file changed, 1 deletion(-)
diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py
index 0eb0d8018ae..e2fdf6f8a11 100644
--- a/vllm/v1/engine/core.py
+++ b/vllm/v1/engine/core.py
@@ -86,7 +86,6 @@ def __init__(self,
self.collective_rpc("initialize_cache",
args=(num_gpu_blocks, num_cpu_blocks))
- # EngineCore holds StructuredOutputManager to handle and it has vllm config as an arg.
self.structured_output_manager = StructuredOutputManager(vllm_config)
# Setup scheduler.
From f2e195a46168cc23c84156d1334b16ab5c117a0f Mon Sep 17 00:00:00 2001
From: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
Date: Mon, 14 Jul 2025 04:06:23 +0000
Subject: [PATCH 3/8] update states only in update_state method
Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
---
vllm/v1/sample/logits_processor.py | 96 +++++++++++++-----------------
1 file changed, 43 insertions(+), 53 deletions(-)
diff --git a/vllm/v1/sample/logits_processor.py b/vllm/v1/sample/logits_processor.py
index 04912b358b3..2b4e518e1a8 100644
--- a/vllm/v1/sample/logits_processor.py
+++ b/vllm/v1/sample/logits_processor.py
@@ -6,7 +6,7 @@
from dataclasses import dataclass, field
from enum import Enum
from itertools import chain
-from typing import Optional, Union
+from typing import Any, Optional, Union
import torch
from torch._prims_common import DeviceLikeType
@@ -510,9 +510,9 @@ def __init__(self, reasoning_config: ReasoningConfig, pin_memory: bool, device:
self.think_end_token_id = reasoning_config.think_end_token_id
self.pin_memory = pin_memory
self.device = device
- self._state = {}
+ self._state: dict[int, dict[str, Any]] = {}
- def _find_last_token_index(self, tokens, token_id):
+ def _find_last_token_index(self, tokens: list[int], token_id: int) -> int:
try:
return len(tokens) - tokens[::-1].index(token_id) - 1
except ValueError:
@@ -524,71 +524,61 @@ def is_argmax_invariant(self) -> bool:
return False
def update_state(self, batch_update: Optional[BatchUpdate]):
- if batch_update is None:
- return
-
- for index, params, prompt_tok_ids, output_tok_ids in batch_update.added:
- max_think_tokens = params.max_think_tokens if isinstance(params, SamplingParams) else None
-
- if max_think_tokens is None:
- continue
-
- last_think_start_idx = self._find_last_token_index(prompt_tok_ids, self.think_start_token_id)
- last_think_end_idx = self._find_last_token_index(prompt_tok_ids, self.think_end_token_id)
-
- in_think = False
- count = 0
+ if batch_update:
+ for index, params, prompt_tok_ids, output_tok_ids in batch_update.added:
+ max_think_tokens = params.max_think_tokens if isinstance(params, SamplingParams) else None
+ if max_think_tokens is not None:
+ last_start = self._find_last_token_index(prompt_tok_ids, self.think_start_token_id)
+ last_end = self._find_last_token_index(prompt_tok_ids, self.think_end_token_id)
+ in_think = last_start > last_end
+ count = len(prompt_tok_ids) - (last_start + 1) if in_think else 0
+
+ self._state[index] = {
+ "in_think": in_think,
+ "count": count,
+ "prompt_tok_ids": prompt_tok_ids,
+ "output_tok_ids": output_tok_ids,
+ "max_think_tokens": max_think_tokens,
+ }
- if last_think_start_idx > last_think_end_idx:
- in_think = True
- count = len(prompt_tok_ids) - (last_think_start_idx + 1)
+ for index in batch_update.removed:
+ self._state.pop(index, None)
- self._state[index] = {
- "in_think": in_think,
- "count": count,
- "prompt_tok_ids": prompt_tok_ids,
- "output_tok_ids": output_tok_ids,
- "max_think_tokens": max_think_tokens,
- }
+ for i1, i2, direction in batch_update.moved:
+ if direction == MoveDirectionality.SWAP:
+ self._state[i1], self._state[i2] = self._state[i2], self._state[i1]
+ else:
+ self._state[i2] = self._state.pop(i1, None)
- for index in batch_update.removed:
- self._state.pop(index, None)
+ # Update in_think and count for all active requests
+ for state in self._state.values():
+ output = state["output_tok_ids"]
+ if not output:
+ continue
- for i1, i2, direction in batch_update.moved:
- if direction == MoveDirectionality.SWAP:
- self._state[i1], self._state[i2] = self._state[i2], self._state[i1]
- else:
- self._state[i2] = self._state.pop(i1, None)
+ last_tok = output[-1]
+ if last_tok == self.think_start_token_id:
+ state["in_think"] = True
+ state["count"] = 0
+ elif last_tok == self.think_end_token_id:
+ state["in_think"] = False
+ state["count"] = 0
+ elif state["in_think"]:
+ state["count"] += 1
def apply(self, logits: torch.Tensor) -> torch.Tensor:
batch_size = logits.size(0)
- if batch_size == 0:
+ if not self._state:
return logits
mask = torch.zeros(batch_size, dtype=torch.bool, device=logits.device)
end_token_id = self.think_end_token_id
for index in range(batch_size):
- state = self._state.get(index, None)
- if not state or not state.get("output_tok_ids"):
+ state = self._state.get(index)
+ if not state:
continue
- last_tok = state["output_tok_ids"][-1]
- in_think = state["in_think"]
- count = state["count"]
-
- if last_tok == self.think_start_token_id:
- in_think = True
- count = 0
- elif last_tok == self.think_end_token_id:
- in_think = False
- count = 0
- elif in_think:
- count += 1
-
- state["in_think"] = in_think
- state["count"] = count
-
if state["in_think"] and state["count"] >= state["max_think_tokens"]:
mask[index] = True
From 4c4251d2b3fa64f0368a75d10fdd969286be7243 Mon Sep 17 00:00:00 2001
From: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
Date: Mon, 14 Jul 2025 04:40:44 +0000
Subject: [PATCH 4/8] make precommit and lint
Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
---
vllm/config.py | 6 +++-
vllm/v1/sample/logits_processor.py | 52 ++++++++++++++++++------------
vllm/v1/worker/gpu_input_batch.py | 3 +-
vllm/v1/worker/gpu_model_runner.py | 16 ++++-----
4 files changed, 46 insertions(+), 31 deletions(-)
diff --git a/vllm/config.py b/vllm/config.py
index 9ef217bcc1d..160b7fe5d15 100644
--- a/vllm/config.py
+++ b/vllm/config.py
@@ -4404,6 +4404,7 @@ def set_splitting_ops_for_v1(self):
"vllm.unified_attention_with_output",
]
+
class ReasoningConfig:
"""Configuration for reasoning models."""
@@ -4412,10 +4413,13 @@ class ReasoningConfig:
think_end_token_id: Optional[int] = None
"""Token ID that indicates the end of reasoning."""
- def __init__(self, think_start_token_id: Optional[int] = None, think_end_token_id: Optional[int] = None):
+ def __init__(self,
+ think_start_token_id: Optional[int] = None,
+ think_end_token_id: Optional[int] = None):
self.think_start_token_id = think_start_token_id
self.think_end_token_id = think_end_token_id
+
@config
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class VllmConfig:
diff --git a/vllm/v1/sample/logits_processor.py b/vllm/v1/sample/logits_processor.py
index 2b4e518e1a8..62f615bbd68 100644
--- a/vllm/v1/sample/logits_processor.py
+++ b/vllm/v1/sample/logits_processor.py
@@ -27,7 +27,8 @@ class MoveDirectionality(Enum):
# (index, params, prompt_tok_ids, output_tok_ids) tuples for new
# requests added to the batch.
-AddedRequest = tuple[int, Union[SamplingParams, PoolingParams], list[int], list[int]]
+AddedRequest = tuple[int, Union[SamplingParams, PoolingParams], list[int],
+ list[int]]
# (index 1, index 2, directionality) tuples representing
# one-way moves or two-way swaps of requests in batch
MovedRequest = tuple[int, int, MoveDirectionality]
@@ -497,13 +498,14 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor:
class MaxThinkTokensLogitsProcessor(LogitsProcessor):
"""A logits processor that limits the maximum number of thinking tokens."""
- def __init__(self, reasoning_config: ReasoningConfig, pin_memory: bool, device: torch.device):
+ def __init__(self, reasoning_config: ReasoningConfig, pin_memory: bool,
+ device: torch.device):
"""
Args:
- think_start_token_id (int): Token ID for the start of thinking section.
- think_end_token_id (int): Token ID for the end of thinking section.
- pin_memory (bool): Whether to use pinned memory for tensors.
- device (torch.device): Device to use for tensor operations.
+ reasoning_config: Configuration for reasoning, which includes
+ the token IDs for thinking start and end.
+ pin_memory (bool): Whether to use pinned memory for tensors.
+ device (torch.device): Device to use for tensor operations.
"""
super().__init__()
self.think_start_token_id = reasoning_config.think_start_token_id
@@ -519,19 +521,25 @@ def _find_last_token_index(self, tokens: list[int], token_id: int) -> int:
return -1
def is_argmax_invariant(self) -> bool:
- """This logits processor can change the outcome of greedy sampling
- by forcing that the thinking section ends after a certain number of tokens."""
+ """This logits processor can change the outcome of
+ greedy sampling by forcing that the thinking section
+ ends after a certain number of tokens."""
return False
def update_state(self, batch_update: Optional[BatchUpdate]):
if batch_update:
- for index, params, prompt_tok_ids, output_tok_ids in batch_update.added:
- max_think_tokens = params.max_think_tokens if isinstance(params, SamplingParams) else None
+ for (index, params, prompt_tok_ids,
+ output_tok_ids) in batch_update.added:
+ max_think_tokens = (params.max_think_tokens if isinstance(
+ params, SamplingParams) else None)
if max_think_tokens is not None:
- last_start = self._find_last_token_index(prompt_tok_ids, self.think_start_token_id)
- last_end = self._find_last_token_index(prompt_tok_ids, self.think_end_token_id)
+ last_start = self._find_last_token_index(
+ prompt_tok_ids, self.think_start_token_id)
+ last_end = self._find_last_token_index(
+ prompt_tok_ids, self.think_end_token_id)
in_think = last_start > last_end
- count = len(prompt_tok_ids) - (last_start + 1) if in_think else 0
+ count = len(prompt_tok_ids) - (last_start +
+ 1) if in_think else 0
self._state[index] = {
"in_think": in_think,
@@ -542,13 +550,14 @@ def update_state(self, batch_update: Optional[BatchUpdate]):
}
for index in batch_update.removed:
- self._state.pop(index, None)
+ self._state.pop(index, {})
for i1, i2, direction in batch_update.moved:
if direction == MoveDirectionality.SWAP:
- self._state[i1], self._state[i2] = self._state[i2], self._state[i1]
+ self._state[i1], self._state[i2] = self._state[
+ i2], self._state[i1]
else:
- self._state[i2] = self._state.pop(i1, None)
+ self._state[i2] = self._state.pop(i1, {})
# Update in_think and count for all active requests
for state in self._state.values():
@@ -579,7 +588,8 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor:
if not state:
continue
- if state["in_think"] and state["count"] >= state["max_think_tokens"]:
+ if state["in_think"] and state["count"] >= state[
+ "max_think_tokens"]:
mask[index] = True
if mask.any():
@@ -589,8 +599,9 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor:
return logits
-def init_builtin_logitsprocs(pin_memory_available: bool, max_num_reqs: int,
- device: torch.device, reasoning_config: ReasoningConfig) -> LogitsProcessorManager:
+def init_builtin_logitsprocs(
+ pin_memory_available: bool, max_num_reqs: int, device: torch.device,
+ reasoning_config: ReasoningConfig) -> LogitsProcessorManager:
"""Construct 'builtin' vLLM logitsprocs which the engine
loads by default.
@@ -619,8 +630,7 @@ def init_builtin_logitsprocs(pin_memory_available: bool, max_num_reqs: int,
)
return LogitsProcessorManager(
non_argmax_invariant=[
- min_tokens_logitproc,
- logit_bias_logitproc,
+ min_tokens_logitproc, logit_bias_logitproc,
max_think_tokens_logitproc
],
argmax_invariant=[min_p_logitproc],
diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py
index daf339bdd5b..58b59181f3b 100644
--- a/vllm/v1/worker/gpu_input_batch.py
+++ b/vllm/v1/worker/gpu_input_batch.py
@@ -263,7 +263,8 @@ def _register_add_request(self, request: "CachedRequestState") -> int:
params = (request.sampling_params
if request.sampling_params else request.pooling_params)
self.batch_update_builder.added.append(
- (req_index, params, request.prompt_token_ids, request.output_token_ids))
+ (req_index, params, request.prompt_token_ids,
+ request.output_token_ids))
return req_index
def add_request(
diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py
index 8a37a30f2e9..61c9fb57149 100644
--- a/vllm/v1/worker/gpu_model_runner.py
+++ b/vllm/v1/worker/gpu_model_runner.py
@@ -18,7 +18,7 @@
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.layer import Attention
from vllm.compilation.counter import compilation_counter
-from vllm.config import (CompilationLevel, VllmConfig,
+from vllm.config import (CompilationLevel, ReasoningConfig, VllmConfig,
get_layers_from_vllm_config, update_config)
from vllm.distributed.eplb.eplb_state import EplbState
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
@@ -39,8 +39,10 @@
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.pooling_params import PoolingParams
+from vllm.reasoning import ReasoningParserManager
from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors
+from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
GiB_bytes, LazyLoader, async_tensor_h2d, cdiv,
check_use_alibi, get_dtype_size,
@@ -71,10 +73,6 @@
from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing,
sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
-from vllm.config import ReasoningConfig
-from vllm.reasoning import ReasoningParserManager
-from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
-
if TYPE_CHECKING:
import xgrammar as xgr
import xgrammar.kernels.apply_token_bitmask_inplace_torch_compile as xgr_torch_compile # noqa: E501
@@ -109,7 +107,8 @@ def __init__(
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config
- if self.vllm_config.decoding_config.reasoning_backend in ('deepseek_r1', 'qwen'):
+ if self.vllm_config.decoding_config.reasoning_backend in (
+ 'deepseek_r1', 'qwen'):
tokenizer = init_tokenizer_from_configs(
model_config=self.vllm_config.model_config,
scheduler_config=self.vllm_config.scheduler_config,
@@ -120,8 +119,9 @@ def __init__(
reasoner_cls = ReasoningParserManager.get_reasoning_parser(
reasoning_backend)
reasoning_parser = reasoner_cls(tokenizer=tokenizer)
- self.vllm_config.reasoning_config = ReasoningConfig(think_start_token_id=reasoning_parser.think_start_token_id,
- think_end_token_id=reasoning_parser.think_end_token_id)
+ self.vllm_config.reasoning_config = ReasoningConfig(
+ think_start_token_id=reasoning_parser.think_start_token_id,
+ think_end_token_id=reasoning_parser.think_end_token_id)
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
set_cpu_offload_max_bytes(
From 398471146508c22c240d8335728c81db9c0d93a7 Mon Sep 17 00:00:00 2001
From: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
Date: Mon, 14 Jul 2025 10:11:40 +0000
Subject: [PATCH 5/8] disable max think tokens logits processor if reasoning
parser info is missing
Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
---
vllm/v1/sample/logits_processor.py | 21 ++++++++++++---------
1 file changed, 12 insertions(+), 9 deletions(-)
diff --git a/vllm/v1/sample/logits_processor.py b/vllm/v1/sample/logits_processor.py
index 62f615bbd68..6f4f2bcf0f2 100644
--- a/vllm/v1/sample/logits_processor.py
+++ b/vllm/v1/sample/logits_processor.py
@@ -623,15 +623,18 @@ def init_builtin_logitsprocs(
device=device,
# +1 for temporary swap space
max_num_reqs=max_num_reqs + 1)
- max_think_tokens_logitproc = MaxThinkTokensLogitsProcessor(
- reasoning_config=reasoning_config,
- pin_memory=pin_memory_available,
- device=device,
- )
+
+ non_argmax_invariant = [min_tokens_logitproc, logit_bias_logitproc]
+
+ if reasoning_config is not None:
+ max_think_tokens_logitproc = MaxThinkTokensLogitsProcessor(
+ reasoning_config=reasoning_config,
+ pin_memory=pin_memory_available,
+ device=device,
+ )
+ non_argmax_invariant.append(max_think_tokens_logitproc)
+
return LogitsProcessorManager(
- non_argmax_invariant=[
- min_tokens_logitproc, logit_bias_logitproc,
- max_think_tokens_logitproc
- ],
+ non_argmax_invariant=non_argmax_invariant,
argmax_invariant=[min_p_logitproc],
)
From 5636a129cba8f266fb09ef48d7a4c6cdbfc828f4 Mon Sep 17 00:00:00 2001
From: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
Date: Tue, 15 Jul 2025 12:49:18 +0000
Subject: [PATCH 6/8] revert change of deepseek reasoning parser
Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
---
.../reasoning/deepseek_r1_reasoning_parser.py | 32 +++++++++----------
1 file changed, 16 insertions(+), 16 deletions(-)
diff --git a/vllm/reasoning/deepseek_r1_reasoning_parser.py b/vllm/reasoning/deepseek_r1_reasoning_parser.py
index 96bb50f3817..1a5ca46a60f 100644
--- a/vllm/reasoning/deepseek_r1_reasoning_parser.py
+++ b/vllm/reasoning/deepseek_r1_reasoning_parser.py
@@ -23,8 +23,8 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
text. This parser extracts the reasoning content from the model output.
"""
- think_start_token_id: int
- think_end_token_id: int
+ start_token_id: int
+ end_token_id: int
start_token: str = ""
end_token: str = ""
@@ -37,24 +37,24 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase):
"The model tokenizer must be passed to the ReasoningParser "
"constructor during construction.")
- self.think_start_token_id = self.vocab.get(self.start_token)
- self.think_end_token_id = self.vocab.get(self.end_token)
- if self.think_start_token_id is None or self.think_end_token_id is None:
+ self.start_token_id = self.vocab.get(self.start_token)
+ self.end_token_id = self.vocab.get(self.end_token)
+ if self.start_token_id is None or self.end_token_id is None:
raise RuntimeError(
"DeepSeek R1 reasoning parser could not locate think start/end "
"tokens in the tokenizer!")
def is_reasoning_end(self, input_ids: list[int]) -> bool:
- return self.think_end_token_id in input_ids
+ return self.end_token_id in input_ids
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
"""
Extract the content after the end tokens
"""
- if self.think_end_token_id not in input_ids[:-1]:
+ if self.end_token_id not in input_ids[:-1]:
return []
else:
- return input_ids[input_ids.index(self.think_end_token_id) + 1:]
+ return input_ids[input_ids.index(self.end_token_id) + 1:]
def extract_reasoning_content_streaming(
self,
@@ -75,14 +75,14 @@ def extract_reasoning_content_streaming(
"""
# Skip single special tokens
if len(delta_token_ids) == 1 and (delta_token_ids[0] in [
- self.think_start_token_id, self.think_end_token_id
+ self.start_token_id, self.end_token_id
]):
return None
# Check if is present in previous or delta.
# Keep compatibility with models that don't generate tokens.
- if self.think_start_token_id in previous_token_ids:
- if self.think_end_token_id in delta_token_ids:
+ if self.start_token_id in previous_token_ids:
+ if self.end_token_id in delta_token_ids:
# in previous, in delta,
# extract reasoning content
end_index = delta_text.find(self.end_token)
@@ -92,7 +92,7 @@ def extract_reasoning_content_streaming(
reasoning_content=reasoning_content,
content=content if content else None,
)
- elif self.think_end_token_id in previous_token_ids:
+ elif self.end_token_id in previous_token_ids:
# in previous, in previous,
# reasoning content continues
return DeltaMessage(content=delta_text)
@@ -100,8 +100,8 @@ def extract_reasoning_content_streaming(
# in previous, no in previous or delta,
# reasoning content continues
return DeltaMessage(reasoning_content=delta_text)
- elif self.think_start_token_id in delta_token_ids:
- if self.think_end_token_id in delta_token_ids:
+ elif self.start_token_id in delta_token_ids:
+ if self.end_token_id in delta_token_ids:
# in delta, in delta, extract reasoning content
start_index = delta_text.find(self.start_token)
end_index = delta_text.find(self.end_token)
@@ -120,7 +120,7 @@ def extract_reasoning_content_streaming(
# No in previous or delta, also need to check for .
# Because the model may have generated without
# Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
- if self.think_end_token_id in delta_token_ids:
+ if self.end_token_id in delta_token_ids:
# in delta with more tokens,
# extract reasoning content and content
end_index = delta_text.find(self.end_token)
@@ -130,7 +130,7 @@ def extract_reasoning_content_streaming(
reasoning_content=reasoning_content,
content=content if content else None,
)
- elif self.think_end_token_id in previous_token_ids:
+ elif self.end_token_id in previous_token_ids:
# in previous, thinking content ends
return DeltaMessage(content=delta_text)
else:
From 6b424adbb1fff5743371b17b03b25ffe4bd06ce5 Mon Sep 17 00:00:00 2001
From: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
Date: Wed, 16 Jul 2025 06:33:40 +0000
Subject: [PATCH 7/8] support think start/end as token sequences
Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
---
vllm/config.py | 20 +++--
vllm/engine/arg_utils.py | 14 +++-
vllm/v1/sample/logits_processor.py | 122 ++++++++++++++++++++---------
vllm/v1/worker/gpu_model_runner.py | 16 ++--
4 files changed, 117 insertions(+), 55 deletions(-)
diff --git a/vllm/config.py b/vllm/config.py
index 160b7fe5d15..d11d1034238 100644
--- a/vllm/config.py
+++ b/vllm/config.py
@@ -4408,16 +4408,24 @@ def set_splitting_ops_for_v1(self):
class ReasoningConfig:
"""Configuration for reasoning models."""
- think_start_token_id: Optional[int] = None
+ think_start_str: Optional[str] = None
+ """String that indicates the start of reasoning."""
+ think_end_str: Optional[str] = None
+ """String that indicates the end of reasoning."""
+ think_start_token_ids: Optional[int] = None
"""Token ID that indicates the start of reasoning."""
- think_end_token_id: Optional[int] = None
+ think_end_token_ids: Optional[int] = None
"""Token ID that indicates the end of reasoning."""
def __init__(self,
- think_start_token_id: Optional[int] = None,
- think_end_token_id: Optional[int] = None):
- self.think_start_token_id = think_start_token_id
- self.think_end_token_id = think_end_token_id
+ think_start_str: Optional[str] = None,
+ think_end_str: Optional[str] = None,
+ think_start_token_ids: Optional[int] = None,
+ think_end_token_ids: Optional[int] = None):
+ self.think_start_str = think_start_str
+ self.think_end_str = think_end_str
+ self.think_start_token_ids = think_start_token_ids
+ self.think_end_token_ids = think_end_token_ids
@config
diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py
index f47499309d8..65133c300e4 100644
--- a/vllm/engine/arg_utils.py
+++ b/vllm/engine/arg_utils.py
@@ -31,9 +31,9 @@
ModelConfig, ModelDType, ModelImpl, MultiModalConfig,
ObservabilityConfig, ParallelConfig, PoolerConfig,
PrefixCachingHashAlgo, PromptAdapterConfig,
- SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
- TaskOption, TokenizerMode, TokenizerPoolConfig,
- VllmConfig, get_attr_docs, get_field)
+ ReasoningConfig, SchedulerConfig, SchedulerPolicy,
+ SpeculativeConfig, TaskOption, TokenizerMode,
+ TokenizerPoolConfig, VllmConfig, get_attr_docs, get_field)
from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform
from vllm.plugins import load_general_plugins
@@ -464,6 +464,8 @@ class EngineArgs:
kv_transfer_config: Optional[KVTransferConfig] = None
kv_events_config: Optional[KVEventsConfig] = None
+ reasoning_config: Optional[ReasoningConfig] = None
+
generation_config: str = ModelConfig.generation_config
enable_sleep_mode: bool = ModelConfig.enable_sleep_mode
override_generation_config: dict[str, Any] = \
@@ -934,6 +936,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
**vllm_kwargs["kv_events_config"])
vllm_group.add_argument("--compilation-config", "-O",
**vllm_kwargs["compilation_config"])
+ vllm_group.add_argument("--reasoning-config",
+ **vllm_kwargs["reasoning_config"])
vllm_group.add_argument("--additional-config",
**vllm_kwargs["additional_config"])
@@ -1332,6 +1336,9 @@ def create_engine_config(
collect_detailed_traces=self.collect_detailed_traces,
)
+ reasoning_config_dict = json.loads(self.reasoning_config)
+ reasoning_config = ReasoningConfig(**reasoning_config_dict)
+
config = VllmConfig(
model_config=model_config,
cache_config=cache_config,
@@ -1347,6 +1354,7 @@ def create_engine_config(
compilation_config=self.compilation_config,
kv_transfer_config=self.kv_transfer_config,
kv_events_config=self.kv_events_config,
+ reasoning_config=reasoning_config,
additional_config=self.additional_config,
)
diff --git a/vllm/v1/sample/logits_processor.py b/vllm/v1/sample/logits_processor.py
index 6f4f2bcf0f2..b846de0481e 100644
--- a/vllm/v1/sample/logits_processor.py
+++ b/vllm/v1/sample/logits_processor.py
@@ -514,12 +514,84 @@ def __init__(self, reasoning_config: ReasoningConfig, pin_memory: bool,
self.device = device
self._state: dict[int, dict[str, Any]] = {}
- def _find_last_token_index(self, tokens: list[int], token_id: int) -> int:
+ def _find_first_token_index(self, target_list: list[int], token_id: int) -> int:
+ """
+ Find the last occurrence of a single token in the list of tokens.
+
+ Args:
+ target_list (list[int]): The list of token IDs.
+ token_id (int): The token ID to find.
+ """
try:
- return len(tokens) - tokens[::-1].index(token_id) - 1
+ return len(target_list) - target_list[::-1].index(token_id) - 1
except ValueError:
return -1
+ def _find_last_sequence_index(self, target_list: list[int], token_ids: list[int]) -> int:
+ """
+ Find the last occurrence of the sequence of token_ids in tokens.
+
+ Args:
+ target_list (list[int]): The list of token IDs.
+ token_ids (list[int]): The sequence of token IDs to find.
+ """
+ index = self._find_first_token_index(target_list, token_ids[0])
+ if index != -1:
+ i = 1
+ for token_id in token_ids[1:]:
+ if index + i >= len(target_list) or target_list[index + i] != token_id:
+ return -1
+ i += 1
+ index += 1
+
+ return index
+
+ def _init_state_entry(self, prompt_tok_ids, max_think_tokens):
+ last_start = self._find_last_sequence_index(
+ prompt_tok_ids, self.think_start_token_id)
+ last_end = self._find_last_sequence_index(
+ prompt_tok_ids, self.think_end_token_id)
+ in_think = last_start > last_end
+ think_count = len(prompt_tok_ids) - (last_start + 1) if in_think else 0
+
+ return {
+ "in_think": in_think,
+ "in_end": False,
+ "think_count": think_count,
+ "end_count": 0,
+ "prompt_tok_ids": prompt_tok_ids,
+ "output_tok_ids": [],
+ "max_think_tokens": max_think_tokens,
+ }
+
+ def _update_think_state(self, state):
+ output = state["output_tok_ids"]
+ if not output:
+ return
+
+ sliced_output1 = output[-1 + len(self.think_start_token_id):]
+ sliced_output2 = output[-1 + len(self.think_end_token_id):]
+
+ if self._find_last_sequence_index(sliced_output1, self.think_start_token_id) != -1:
+ state["in_think"] = True
+ state["think_count"] = 0
+ elif self._find_last_sequence_index(sliced_output2, self.think_end_token_id) != -1:
+ state["in_think"] = False
+ state["think_count"] = 0
+ else:
+ state["think_count"] += 1
+
+ if state["in_end"]:
+ state["end_count"] += 1
+ if state["end_count"] >= len(self.think_end_token_id):
+ state["in_end"] = False
+ state["end_count"] = 0
+ else:
+ if state["in_think"] and state["think_count"] >= state["max_think_tokens"]:
+ state["in_think"] = False
+ state["in_end"] = True
+ state["end_count"] = 0
+
def is_argmax_invariant(self) -> bool:
"""This logits processor can change the outcome of
greedy sampling by forcing that the thinking section
@@ -528,52 +600,25 @@ def is_argmax_invariant(self) -> bool:
def update_state(self, batch_update: Optional[BatchUpdate]):
if batch_update:
- for (index, params, prompt_tok_ids,
- output_tok_ids) in batch_update.added:
+ for (index, params, prompt_tok_ids, output_tok_ids) in batch_update.added:
max_think_tokens = (params.max_think_tokens if isinstance(
params, SamplingParams) else None)
if max_think_tokens is not None:
- last_start = self._find_last_token_index(
- prompt_tok_ids, self.think_start_token_id)
- last_end = self._find_last_token_index(
- prompt_tok_ids, self.think_end_token_id)
- in_think = last_start > last_end
- count = len(prompt_tok_ids) - (last_start +
- 1) if in_think else 0
-
- self._state[index] = {
- "in_think": in_think,
- "count": count,
- "prompt_tok_ids": prompt_tok_ids,
- "output_tok_ids": output_tok_ids,
- "max_think_tokens": max_think_tokens,
- }
+ self._state[index] = self._init_state_entry(
+ prompt_tok_ids, max_think_tokens)
+ self._state[index]["output_tok_ids"] = output_tok_ids
for index in batch_update.removed:
self._state.pop(index, {})
for i1, i2, direction in batch_update.moved:
if direction == MoveDirectionality.SWAP:
- self._state[i1], self._state[i2] = self._state[
- i2], self._state[i1]
+ self._state[i1], self._state[i2] = self._state[i2], self._state[i1]
else:
self._state[i2] = self._state.pop(i1, {})
- # Update in_think and count for all active requests
for state in self._state.values():
- output = state["output_tok_ids"]
- if not output:
- continue
-
- last_tok = output[-1]
- if last_tok == self.think_start_token_id:
- state["in_think"] = True
- state["count"] = 0
- elif last_tok == self.think_end_token_id:
- state["in_think"] = False
- state["count"] = 0
- elif state["in_think"]:
- state["count"] += 1
+ self._update_think_state(state)
def apply(self, logits: torch.Tensor) -> torch.Tensor:
batch_size = logits.size(0)
@@ -588,13 +633,14 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor:
if not state:
continue
- if state["in_think"] and state["count"] >= state[
- "max_think_tokens"]:
+ force_end_token_id = None
+ if state["in_end"]:
+ force_end_token_id = self.think_end_token_id[state["end_count"]]
mask[index] = True
if mask.any():
logits[mask] = -float("inf")
- logits[mask, end_token_id] = 0.0
+ logits[mask, force_end_token_id] = 0.0
return logits
diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py
index 61c9fb57149..ad48cc5151d 100644
--- a/vllm/v1/worker/gpu_model_runner.py
+++ b/vllm/v1/worker/gpu_model_runner.py
@@ -114,14 +114,14 @@ def __init__(
scheduler_config=self.vllm_config.scheduler_config,
lora_config=self.vllm_config.lora_config,
).get_lora_tokenizer(None)
- reasoning_backend = \
- self.vllm_config.decoding_config.reasoning_backend
- reasoner_cls = ReasoningParserManager.get_reasoning_parser(
- reasoning_backend)
- reasoning_parser = reasoner_cls(tokenizer=tokenizer)
- self.vllm_config.reasoning_config = ReasoningConfig(
- think_start_token_id=reasoning_parser.think_start_token_id,
- think_end_token_id=reasoning_parser.think_end_token_id)
+ reasoning_config = self.vllm_config.reasoning_config
+ if reasoning_config is not None:
+ reasoning_config.think_start_token_id = \
+ tokenizer.convert_tokens_to_ids(
+ tokenizer.tokenize(reasoning_config.think_start_str))
+ reasoning_config.think_end_token_id = \
+ tokenizer.convert_tokens_to_ids(
+ tokenizer.tokenize(reasoning_config.think_end_str))
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
set_cpu_offload_max_bytes(
From 2e6cff6fe6e3611810a8e6f07029fb22026b7908 Mon Sep 17 00:00:00 2001
From: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
Date: Thu, 17 Jul 2025 13:47:22 +0000
Subject: [PATCH 8/8] refactor and change logic faster
Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
Signed-off-by: Sungjae Lee
---
vllm/v1/sample/logits_processor.py | 90 ++++++++++++------------------
vllm/v1/worker/gpu_model_runner.py | 4 +-
2 files changed, 39 insertions(+), 55 deletions(-)
diff --git a/vllm/v1/sample/logits_processor.py b/vllm/v1/sample/logits_processor.py
index b846de0481e..e6c429bb8d2 100644
--- a/vllm/v1/sample/logits_processor.py
+++ b/vllm/v1/sample/logits_processor.py
@@ -496,7 +496,7 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor:
class MaxThinkTokensLogitsProcessor(LogitsProcessor):
- """A logits processor that limits the maximum number of thinking tokens."""
+ """Limits the number of tokens allowed inside a 'thinking' section."""
def __init__(self, reasoning_config: ReasoningConfig, pin_memory: bool,
device: torch.device):
@@ -508,82 +508,68 @@ def __init__(self, reasoning_config: ReasoningConfig, pin_memory: bool,
device (torch.device): Device to use for tensor operations.
"""
super().__init__()
- self.think_start_token_id = reasoning_config.think_start_token_id
- self.think_end_token_id = reasoning_config.think_end_token_id
+ self.think_start_token_ids = reasoning_config.think_start_token_ids
+ self.think_end_token_ids = reasoning_config.think_end_token_ids
self.pin_memory = pin_memory
self.device = device
self._state: dict[int, dict[str, Any]] = {}
- def _find_first_token_index(self, target_list: list[int], token_id: int) -> int:
+ @staticmethod
+ def _find_last_sequence_index(target_list: list[int], token_ids: list[int]) -> int:
"""
- Find the last occurrence of a single token in the list of tokens.
+ Returns the index of the last occurrence of token_ids in target_list.
Args:
target_list (list[int]): The list of token IDs.
- token_id (int): The token ID to find.
+ token_ids (list[int]): The sequence of token IDs to find.
"""
- try:
- return len(target_list) - target_list[::-1].index(token_id) - 1
- except ValueError:
+ if not token_ids:
return -1
- def _find_last_sequence_index(self, target_list: list[int], token_ids: list[int]) -> int:
- """
- Find the last occurrence of the sequence of token_ids in tokens.
+ for i in range(len(target_list) - len(token_ids), -1, -1):
+ if target_list[i:i + len(token_ids)] == token_ids:
+ return i
+ return -1
- Args:
- target_list (list[int]): The list of token IDs.
- token_ids (list[int]): The sequence of token IDs to find.
- """
- index = self._find_first_token_index(target_list, token_ids[0])
- if index != -1:
- i = 1
- for token_id in token_ids[1:]:
- if index + i >= len(target_list) or target_list[index + i] != token_id:
- return -1
- i += 1
- index += 1
-
- return index
-
- def _init_state_entry(self, prompt_tok_ids, max_think_tokens):
+ def _init_state_entry(self, prompt_tok_ids: list[int], max_think_tokens: int) -> dict[str, Any]:
+ """Initializes the tracking state for a given sequence index."""
last_start = self._find_last_sequence_index(
- prompt_tok_ids, self.think_start_token_id)
+ prompt_tok_ids, self.think_start_token_ids)
last_end = self._find_last_sequence_index(
- prompt_tok_ids, self.think_end_token_id)
+ prompt_tok_ids, self.think_end_token_ids)
in_think = last_start > last_end
think_count = len(prompt_tok_ids) - (last_start + 1) if in_think else 0
return {
- "in_think": in_think,
- "in_end": False,
- "think_count": think_count,
- "end_count": 0,
+ "in_think": in_think, # Currently in thinking mode
+ "in_end": False, # Currently forcing end tokens
+ "think_count": think_count, # Number of tokens in thinking section
+ "end_count": 0, # Number of end tokens forced so far
"prompt_tok_ids": prompt_tok_ids,
"output_tok_ids": [],
"max_think_tokens": max_think_tokens,
}
- def _update_think_state(self, state):
+ def _update_think_state(self, state: dict[str, Any]):
+ """Updates the state based on generated output tokens."""
output = state["output_tok_ids"]
if not output:
return
- sliced_output1 = output[-1 + len(self.think_start_token_id):]
- sliced_output2 = output[-1 + len(self.think_end_token_id):]
-
- if self._find_last_sequence_index(sliced_output1, self.think_start_token_id) != -1:
+ # Check if recent output matches start or end sequences
+ if output[-len(self.think_start_token_ids):] == self.think_start_token_ids:
state["in_think"] = True
state["think_count"] = 0
- elif self._find_last_sequence_index(sliced_output2, self.think_end_token_id) != -1:
+ elif output[-len(self.think_end_token_ids):] == self.think_end_token_ids:
state["in_think"] = False
state["think_count"] = 0
- else:
+ elif state["in_think"]:
state["think_count"] += 1
+ # Transition into end mode if thinking token limit exceeded
if state["in_end"]:
state["end_count"] += 1
- if state["end_count"] >= len(self.think_end_token_id):
+ if state["end_count"] >= len(self.think_end_token_ids):
state["in_end"] = False
state["end_count"] = 0
else:
@@ -626,21 +612,19 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor:
return logits
mask = torch.zeros(batch_size, dtype=torch.bool, device=logits.device)
- end_token_id = self.think_end_token_id
-
- for index in range(batch_size):
- state = self._state.get(index)
- if not state:
- continue
+ force_token_ids = torch.full((batch_size,), -1, dtype=torch.long, device=logits.device)
- force_end_token_id = None
- if state["in_end"]:
- force_end_token_id = self.think_end_token_id[state["end_count"]]
- mask[index] = True
+ for i in range(batch_size):
+ state = self._state.get(i)
+ if state and state["in_end"]:
+ mask[i] = True
+ force_token_ids[i] = self.think_end_token_ids[state["end_count"]]
if mask.any():
logits[mask] = -float("inf")
- logits[mask, force_end_token_id] = 0.0
+ row_indices = torch.arange(batch_size, device=logits.device)[mask]
+ col_indices = force_token_ids[mask]
+ logits[row_indices, col_indices] = 0.0
return logits
diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py
index ad48cc5151d..1d16abed80a 100644
--- a/vllm/v1/worker/gpu_model_runner.py
+++ b/vllm/v1/worker/gpu_model_runner.py
@@ -116,10 +116,10 @@ def __init__(
).get_lora_tokenizer(None)
reasoning_config = self.vllm_config.reasoning_config
if reasoning_config is not None:
- reasoning_config.think_start_token_id = \
+ reasoning_config.think_start_token_ids = \
tokenizer.convert_tokens_to_ids(
tokenizer.tokenize(reasoning_config.think_start_str))
- reasoning_config.think_end_token_id = \
+ reasoning_config.think_end_token_ids = \
tokenizer.convert_tokens_to_ids(
tokenizer.tokenize(reasoning_config.think_end_str))