From f4274ab20979bcf4367ba429266a6e625aa25c28 Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Sat, 12 Jul 2025 09:14:00 +0000 Subject: [PATCH 1/8] feat: limit thinking tokens Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> --- vllm/config.py | 13 ++ vllm/entrypoints/openai/protocol.py | 2 + .../reasoning/deepseek_r1_reasoning_parser.py | 32 ++--- vllm/sampling_params.py | 7 + vllm/v1/engine/core.py | 1 + vllm/v1/sample/logits_processor.py | 130 ++++++++++++++++-- vllm/v1/worker/gpu_input_batch.py | 7 +- vllm/v1/worker/gpu_model_runner.py | 20 +++ 8 files changed, 185 insertions(+), 27 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index d9f356c5c60..9ef217bcc1d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4404,6 +4404,17 @@ def set_splitting_ops_for_v1(self): "vllm.unified_attention_with_output", ] +class ReasoningConfig: + """Configuration for reasoning models.""" + + think_start_token_id: Optional[int] = None + """Token ID that indicates the start of reasoning.""" + think_end_token_id: Optional[int] = None + """Token ID that indicates the end of reasoning.""" + + def __init__(self, think_start_token_id: Optional[int] = None, think_end_token_id: Optional[int] = None): + self.think_start_token_id = think_start_token_id + self.think_end_token_id = think_end_token_id @config @dataclass(config=ConfigDict(arbitrary_types_allowed=True)) @@ -4461,6 +4472,8 @@ class VllmConfig: # some opaque config, only used to provide additional information # for the hash computation, mainly used for testing, debugging or out of # tree config registration. + reasoning_config: Optional[ReasoningConfig] = None + """The configurations for reasoning model.""" additional_config: Union[dict, SupportsHash] = field(default_factory=dict) """Additional config for specified platform. Different platforms may support different configs. Make sure the configs are valid for the platform diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index fdac6ccd19e..e0c7bbf21db 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -404,6 +404,7 @@ class ChatCompletionRequest(OpenAIBaseModel): prompt_logprobs: Optional[int] = None allowed_token_ids: Optional[list[int]] = None bad_words: list[str] = Field(default_factory=list) + max_think_tokens: Optional[int] = None # --8<-- [end:chat-completion-sampling-params] # --8<-- [start:chat-completion-extra-params] @@ -670,6 +671,7 @@ def to_sampling_params( guided_decoding=guided_decoding, logit_bias=self.logit_bias, bad_words= self.bad_words, + max_think_tokens=self.max_think_tokens, allowed_token_ids=self.allowed_token_ids, extra_args=extra_args or None, ) diff --git a/vllm/reasoning/deepseek_r1_reasoning_parser.py b/vllm/reasoning/deepseek_r1_reasoning_parser.py index 1a5ca46a60f..96bb50f3817 100644 --- a/vllm/reasoning/deepseek_r1_reasoning_parser.py +++ b/vllm/reasoning/deepseek_r1_reasoning_parser.py @@ -23,8 +23,8 @@ class DeepSeekR1ReasoningParser(ReasoningParser): text. This parser extracts the reasoning content from the model output. """ - start_token_id: int - end_token_id: int + think_start_token_id: int + think_end_token_id: int start_token: str = "" end_token: str = "" @@ -37,24 +37,24 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase): "The model tokenizer must be passed to the ReasoningParser " "constructor during construction.") - self.start_token_id = self.vocab.get(self.start_token) - self.end_token_id = self.vocab.get(self.end_token) - if self.start_token_id is None or self.end_token_id is None: + self.think_start_token_id = self.vocab.get(self.start_token) + self.think_end_token_id = self.vocab.get(self.end_token) + if self.think_start_token_id is None or self.think_end_token_id is None: raise RuntimeError( "DeepSeek R1 reasoning parser could not locate think start/end " "tokens in the tokenizer!") def is_reasoning_end(self, input_ids: list[int]) -> bool: - return self.end_token_id in input_ids + return self.think_end_token_id in input_ids def extract_content_ids(self, input_ids: list[int]) -> list[int]: """ Extract the content after the end tokens """ - if self.end_token_id not in input_ids[:-1]: + if self.think_end_token_id not in input_ids[:-1]: return [] else: - return input_ids[input_ids.index(self.end_token_id) + 1:] + return input_ids[input_ids.index(self.think_end_token_id) + 1:] def extract_reasoning_content_streaming( self, @@ -75,14 +75,14 @@ def extract_reasoning_content_streaming( """ # Skip single special tokens if len(delta_token_ids) == 1 and (delta_token_ids[0] in [ - self.start_token_id, self.end_token_id + self.think_start_token_id, self.think_end_token_id ]): return None # Check if is present in previous or delta. # Keep compatibility with models that don't generate tokens. - if self.start_token_id in previous_token_ids: - if self.end_token_id in delta_token_ids: + if self.think_start_token_id in previous_token_ids: + if self.think_end_token_id in delta_token_ids: # in previous, in delta, # extract reasoning content end_index = delta_text.find(self.end_token) @@ -92,7 +92,7 @@ def extract_reasoning_content_streaming( reasoning_content=reasoning_content, content=content if content else None, ) - elif self.end_token_id in previous_token_ids: + elif self.think_end_token_id in previous_token_ids: # in previous, in previous, # reasoning content continues return DeltaMessage(content=delta_text) @@ -100,8 +100,8 @@ def extract_reasoning_content_streaming( # in previous, no in previous or delta, # reasoning content continues return DeltaMessage(reasoning_content=delta_text) - elif self.start_token_id in delta_token_ids: - if self.end_token_id in delta_token_ids: + elif self.think_start_token_id in delta_token_ids: + if self.think_end_token_id in delta_token_ids: # in delta, in delta, extract reasoning content start_index = delta_text.find(self.start_token) end_index = delta_text.find(self.end_token) @@ -120,7 +120,7 @@ def extract_reasoning_content_streaming( # No in previous or delta, also need to check for . # Because the model may have generated without # Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f - if self.end_token_id in delta_token_ids: + if self.think_end_token_id in delta_token_ids: # in delta with more tokens, # extract reasoning content and content end_index = delta_text.find(self.end_token) @@ -130,7 +130,7 @@ def extract_reasoning_content_streaming( reasoning_content=reasoning_content, content=content if content else None, ) - elif self.end_token_id in previous_token_ids: + elif self.think_end_token_id in previous_token_ids: # in previous, thinking content ends return DeltaMessage(content=delta_text) else: diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index a9a862384d1..49cdbdde45c 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -200,6 +200,7 @@ class SamplingParams( extra_args: Arbitrary additional args, that can be used by custom sampling implementations, plugins, etc. Not used by any in-tree sampling implementations. + max_think_tokens: Maximum number of tokens allowed for thinking """ n: int = 1 @@ -248,6 +249,9 @@ class SamplingParams( bad_words: Optional[list[str]] = None _bad_words_token_ids: Optional[list[list[int]]] = None + # Maximum number of tokens allowed for thinking operations. + max_think_tokens: Optional[int] = None + @staticmethod def from_optional( n: Optional[int] = 1, @@ -263,6 +267,7 @@ def from_optional( stop: Optional[Union[str, list[str]]] = None, stop_token_ids: Optional[list[int]] = None, bad_words: Optional[list[str]] = None, + max_think_tokens: Optional[int] = None, include_stop_str_in_output: bool = False, ignore_eos: bool = False, max_tokens: Optional[int] = 16, @@ -306,6 +311,7 @@ def from_optional( stop=stop, stop_token_ids=stop_token_ids, bad_words=bad_words, + max_think_tokens=max_think_tokens, include_stop_str_in_output=include_stop_str_in_output, ignore_eos=ignore_eos, max_tokens=max_tokens, @@ -574,6 +580,7 @@ def __repr__(self) -> str: f"stop={self.stop}, " f"stop_token_ids={self.stop_token_ids}, " f"bad_words={self.bad_words}, " + f"max_think_tokens={self.max_think_tokens}, " f"include_stop_str_in_output={self.include_stop_str_in_output}, " f"ignore_eos={self.ignore_eos}, " f"max_tokens={self.max_tokens}, " diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index e2fdf6f8a11..0eb0d8018ae 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -86,6 +86,7 @@ def __init__(self, self.collective_rpc("initialize_cache", args=(num_gpu_blocks, num_cpu_blocks)) + # EngineCore holds StructuredOutputManager to handle and it has vllm config as an arg. self.structured_output_manager = StructuredOutputManager(vllm_config) # Setup scheduler. diff --git a/vllm/v1/sample/logits_processor.py b/vllm/v1/sample/logits_processor.py index 3a4c25964e7..04912b358b3 100644 --- a/vllm/v1/sample/logits_processor.py +++ b/vllm/v1/sample/logits_processor.py @@ -12,6 +12,7 @@ from torch._prims_common import DeviceLikeType from vllm import PoolingParams, SamplingParams +from vllm.config import ReasoningConfig from vllm.logger import init_logger logger = init_logger(__name__) @@ -24,9 +25,9 @@ class MoveDirectionality(Enum): SWAP = 1 -# (index, params, output_tok_ids) tuples for new +# (index, params, prompt_tok_ids, output_tok_ids) tuples for new # requests added to the batch. -AddedRequest = tuple[int, Union[SamplingParams, PoolingParams], list[int]] +AddedRequest = tuple[int, Union[SamplingParams, PoolingParams], list[int], list[int]] # (index 1, index 2, directionality) tuples representing # one-way moves or two-way swaps of requests in batch MovedRequest = tuple[int, int, MoveDirectionality] @@ -43,9 +44,9 @@ class BatchUpdate: # within the persistent batch. # # Note: each added request is represented as - # (index, params, output_tok_ids) - # Key assumption: output_tok_ids is a reference to the - # request's running output tokens list; in this way + # (index, params, prompt_tok_ids, output_tok_ids) + # Key assumption: prompt_tok_ids, output_tok_ids is a reference to the + # request's prompt and running output tokens list; in this way # the logits processors always see the latest list of # generated tokens removed: Sequence[RemovedRequest] @@ -260,7 +261,7 @@ def update_state(self, batch_update: Optional[BatchUpdate]): needs_update = False # Process added requests. - for index, params, _ in batch_update.added: + for index, params, _, _ in batch_update.added: min_p = params.min_p if isinstance(params, SamplingParams) else 0.0 if self.min_p_cpu[index] != min_p: needs_update = True @@ -337,7 +338,7 @@ def update_state(self, batch_update: Optional[BatchUpdate]): # Process added requests. needs_update = bool(batch_update.added) - for index, params, _ in batch_update.added: + for index, params, _, _ in batch_update.added: if isinstance(params, SamplingParams) and (lb := params.logit_bias): self.biases[index] = lb @@ -420,7 +421,7 @@ def update_state(self, batch_update: Optional[BatchUpdate]): if batch_update: # Process added requests. needs_update |= bool(batch_update.added) - for index, params, output_tok_ids in batch_update.added: + for index, params, _, output_tok_ids in batch_update.added: if (isinstance(params, SamplingParams) and (min_tokens := params.min_tokens) and len(output_tok_ids) < min_tokens): @@ -493,8 +494,113 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor: return logits +class MaxThinkTokensLogitsProcessor(LogitsProcessor): + """A logits processor that limits the maximum number of thinking tokens.""" + + def __init__(self, reasoning_config: ReasoningConfig, pin_memory: bool, device: torch.device): + """ + Args: + think_start_token_id (int): Token ID for the start of thinking section. + think_end_token_id (int): Token ID for the end of thinking section. + pin_memory (bool): Whether to use pinned memory for tensors. + device (torch.device): Device to use for tensor operations. + """ + super().__init__() + self.think_start_token_id = reasoning_config.think_start_token_id + self.think_end_token_id = reasoning_config.think_end_token_id + self.pin_memory = pin_memory + self.device = device + self._state = {} + + def _find_last_token_index(self, tokens, token_id): + try: + return len(tokens) - tokens[::-1].index(token_id) - 1 + except ValueError: + return -1 + + def is_argmax_invariant(self) -> bool: + """This logits processor can change the outcome of greedy sampling + by forcing that the thinking section ends after a certain number of tokens.""" + return False + + def update_state(self, batch_update: Optional[BatchUpdate]): + if batch_update is None: + return + + for index, params, prompt_tok_ids, output_tok_ids in batch_update.added: + max_think_tokens = params.max_think_tokens if isinstance(params, SamplingParams) else None + + if max_think_tokens is None: + continue + + last_think_start_idx = self._find_last_token_index(prompt_tok_ids, self.think_start_token_id) + last_think_end_idx = self._find_last_token_index(prompt_tok_ids, self.think_end_token_id) + + in_think = False + count = 0 + + if last_think_start_idx > last_think_end_idx: + in_think = True + count = len(prompt_tok_ids) - (last_think_start_idx + 1) + + self._state[index] = { + "in_think": in_think, + "count": count, + "prompt_tok_ids": prompt_tok_ids, + "output_tok_ids": output_tok_ids, + "max_think_tokens": max_think_tokens, + } + + for index in batch_update.removed: + self._state.pop(index, None) + + for i1, i2, direction in batch_update.moved: + if direction == MoveDirectionality.SWAP: + self._state[i1], self._state[i2] = self._state[i2], self._state[i1] + else: + self._state[i2] = self._state.pop(i1, None) + + def apply(self, logits: torch.Tensor) -> torch.Tensor: + batch_size = logits.size(0) + if batch_size == 0: + return logits + + mask = torch.zeros(batch_size, dtype=torch.bool, device=logits.device) + end_token_id = self.think_end_token_id + + for index in range(batch_size): + state = self._state.get(index, None) + if not state or not state.get("output_tok_ids"): + continue + + last_tok = state["output_tok_ids"][-1] + in_think = state["in_think"] + count = state["count"] + + if last_tok == self.think_start_token_id: + in_think = True + count = 0 + elif last_tok == self.think_end_token_id: + in_think = False + count = 0 + elif in_think: + count += 1 + + state["in_think"] = in_think + state["count"] = count + + if state["in_think"] and state["count"] >= state["max_think_tokens"]: + mask[index] = True + + if mask.any(): + logits[mask] = -float("inf") + logits[mask, end_token_id] = 0.0 + + return logits + + def init_builtin_logitsprocs(pin_memory_available: bool, max_num_reqs: int, - device: torch.device) -> LogitsProcessorManager: + device: torch.device, reasoning_config: ReasoningConfig) -> LogitsProcessorManager: """Construct 'builtin' vLLM logitsprocs which the engine loads by default. @@ -516,10 +622,16 @@ def init_builtin_logitsprocs(pin_memory_available: bool, max_num_reqs: int, device=device, # +1 for temporary swap space max_num_reqs=max_num_reqs + 1) + max_think_tokens_logitproc = MaxThinkTokensLogitsProcessor( + reasoning_config=reasoning_config, + pin_memory=pin_memory_available, + device=device, + ) return LogitsProcessorManager( non_argmax_invariant=[ min_tokens_logitproc, logit_bias_logitproc, + max_think_tokens_logitproc ], argmax_invariant=[min_p_logitproc], ) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 1a79d72be0a..daf339bdd5b 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -8,6 +8,7 @@ import numpy as np import torch +from vllm.config import ReasoningConfig from vllm.lora.request import LoRARequest from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.pooling_params import PoolingParams @@ -71,6 +72,7 @@ def __init__( block_sizes: list[int], # The block_size of each kv cache group is_spec_decode: bool = False, logits_processing_needs_token_ids: bool = False, + reasoning_config: ReasoningConfig = None, ): self.is_spec_decode = is_spec_decode self.max_num_reqs = max_num_reqs @@ -221,7 +223,8 @@ def __init__( self.logitsprocs = init_builtin_logitsprocs( pin_memory_available=pin_memory, max_num_reqs=max_num_reqs + 1, - device=device) + device=device, + reasoning_config=reasoning_config) # TODO convert this to LogitsProcessor self.has_allowed_token_ids: set[str] = set() @@ -260,7 +263,7 @@ def _register_add_request(self, request: "CachedRequestState") -> int: params = (request.sampling_params if request.sampling_params else request.pooling_params) self.batch_update_builder.added.append( - (req_index, params, request.output_token_ids)) + (req_index, params, request.prompt_token_ids, request.output_token_ids)) return req_index def add_request( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4551cb2df98..8a37a30f2e9 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -71,6 +71,10 @@ from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing, sanity_check_mm_encoder_outputs, scatter_mm_placeholders) +from vllm.config import ReasoningConfig +from vllm.reasoning import ReasoningParserManager +from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs + if TYPE_CHECKING: import xgrammar as xgr import xgrammar.kernels.apply_token_bitmask_inplace_torch_compile as xgr_torch_compile # noqa: E501 @@ -105,6 +109,20 @@ def __init__( self.prompt_adapter_config = vllm_config.prompt_adapter_config self.observability_config = vllm_config.observability_config + if self.vllm_config.decoding_config.reasoning_backend in ('deepseek_r1', 'qwen'): + tokenizer = init_tokenizer_from_configs( + model_config=self.vllm_config.model_config, + scheduler_config=self.vllm_config.scheduler_config, + lora_config=self.vllm_config.lora_config, + ).get_lora_tokenizer(None) + reasoning_backend = \ + self.vllm_config.decoding_config.reasoning_backend + reasoner_cls = ReasoningParserManager.get_reasoning_parser( + reasoning_backend) + reasoning_parser = reasoner_cls(tokenizer=tokenizer) + self.vllm_config.reasoning_config = ReasoningConfig(think_start_token_id=reasoning_parser.think_start_token_id, + think_end_token_id=reasoning_parser.think_end_token_id) + from vllm.model_executor.models.utils import set_cpu_offload_max_bytes set_cpu_offload_max_bytes( int(self.cache_config.cpu_offload_gb * 1024**3)) @@ -212,6 +230,7 @@ def __init__( vocab_size=self.model_config.get_vocab_size(), block_sizes=[self.cache_config.block_size], is_spec_decode=bool(self.vllm_config.speculative_config), + reasoning_config=self.vllm_config.reasoning_config, ) self.use_cuda_graph = ( @@ -2384,6 +2403,7 @@ def may_reinitialize_input_batch(self, vocab_size=self.model_config.get_vocab_size(), block_sizes=block_sizes, is_spec_decode=bool(self.vllm_config.speculative_config), + reasoning_config=self.vllm_config.reasoning_config, ) def _allocate_kv_cache_tensors( From 84aee5bda0ac43df5f3c749aa31157b79874f368 Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Sat, 12 Jul 2025 09:24:57 +0000 Subject: [PATCH 2/8] remove comment Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> --- vllm/v1/engine/core.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 0eb0d8018ae..e2fdf6f8a11 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -86,7 +86,6 @@ def __init__(self, self.collective_rpc("initialize_cache", args=(num_gpu_blocks, num_cpu_blocks)) - # EngineCore holds StructuredOutputManager to handle and it has vllm config as an arg. self.structured_output_manager = StructuredOutputManager(vllm_config) # Setup scheduler. From f2e195a46168cc23c84156d1334b16ab5c117a0f Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Mon, 14 Jul 2025 04:06:23 +0000 Subject: [PATCH 3/8] update states only in update_state method Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> --- vllm/v1/sample/logits_processor.py | 96 +++++++++++++----------------- 1 file changed, 43 insertions(+), 53 deletions(-) diff --git a/vllm/v1/sample/logits_processor.py b/vllm/v1/sample/logits_processor.py index 04912b358b3..2b4e518e1a8 100644 --- a/vllm/v1/sample/logits_processor.py +++ b/vllm/v1/sample/logits_processor.py @@ -6,7 +6,7 @@ from dataclasses import dataclass, field from enum import Enum from itertools import chain -from typing import Optional, Union +from typing import Any, Optional, Union import torch from torch._prims_common import DeviceLikeType @@ -510,9 +510,9 @@ def __init__(self, reasoning_config: ReasoningConfig, pin_memory: bool, device: self.think_end_token_id = reasoning_config.think_end_token_id self.pin_memory = pin_memory self.device = device - self._state = {} + self._state: dict[int, dict[str, Any]] = {} - def _find_last_token_index(self, tokens, token_id): + def _find_last_token_index(self, tokens: list[int], token_id: int) -> int: try: return len(tokens) - tokens[::-1].index(token_id) - 1 except ValueError: @@ -524,71 +524,61 @@ def is_argmax_invariant(self) -> bool: return False def update_state(self, batch_update: Optional[BatchUpdate]): - if batch_update is None: - return - - for index, params, prompt_tok_ids, output_tok_ids in batch_update.added: - max_think_tokens = params.max_think_tokens if isinstance(params, SamplingParams) else None - - if max_think_tokens is None: - continue - - last_think_start_idx = self._find_last_token_index(prompt_tok_ids, self.think_start_token_id) - last_think_end_idx = self._find_last_token_index(prompt_tok_ids, self.think_end_token_id) - - in_think = False - count = 0 + if batch_update: + for index, params, prompt_tok_ids, output_tok_ids in batch_update.added: + max_think_tokens = params.max_think_tokens if isinstance(params, SamplingParams) else None + if max_think_tokens is not None: + last_start = self._find_last_token_index(prompt_tok_ids, self.think_start_token_id) + last_end = self._find_last_token_index(prompt_tok_ids, self.think_end_token_id) + in_think = last_start > last_end + count = len(prompt_tok_ids) - (last_start + 1) if in_think else 0 + + self._state[index] = { + "in_think": in_think, + "count": count, + "prompt_tok_ids": prompt_tok_ids, + "output_tok_ids": output_tok_ids, + "max_think_tokens": max_think_tokens, + } - if last_think_start_idx > last_think_end_idx: - in_think = True - count = len(prompt_tok_ids) - (last_think_start_idx + 1) + for index in batch_update.removed: + self._state.pop(index, None) - self._state[index] = { - "in_think": in_think, - "count": count, - "prompt_tok_ids": prompt_tok_ids, - "output_tok_ids": output_tok_ids, - "max_think_tokens": max_think_tokens, - } + for i1, i2, direction in batch_update.moved: + if direction == MoveDirectionality.SWAP: + self._state[i1], self._state[i2] = self._state[i2], self._state[i1] + else: + self._state[i2] = self._state.pop(i1, None) - for index in batch_update.removed: - self._state.pop(index, None) + # Update in_think and count for all active requests + for state in self._state.values(): + output = state["output_tok_ids"] + if not output: + continue - for i1, i2, direction in batch_update.moved: - if direction == MoveDirectionality.SWAP: - self._state[i1], self._state[i2] = self._state[i2], self._state[i1] - else: - self._state[i2] = self._state.pop(i1, None) + last_tok = output[-1] + if last_tok == self.think_start_token_id: + state["in_think"] = True + state["count"] = 0 + elif last_tok == self.think_end_token_id: + state["in_think"] = False + state["count"] = 0 + elif state["in_think"]: + state["count"] += 1 def apply(self, logits: torch.Tensor) -> torch.Tensor: batch_size = logits.size(0) - if batch_size == 0: + if not self._state: return logits mask = torch.zeros(batch_size, dtype=torch.bool, device=logits.device) end_token_id = self.think_end_token_id for index in range(batch_size): - state = self._state.get(index, None) - if not state or not state.get("output_tok_ids"): + state = self._state.get(index) + if not state: continue - last_tok = state["output_tok_ids"][-1] - in_think = state["in_think"] - count = state["count"] - - if last_tok == self.think_start_token_id: - in_think = True - count = 0 - elif last_tok == self.think_end_token_id: - in_think = False - count = 0 - elif in_think: - count += 1 - - state["in_think"] = in_think - state["count"] = count - if state["in_think"] and state["count"] >= state["max_think_tokens"]: mask[index] = True From 4c4251d2b3fa64f0368a75d10fdd969286be7243 Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Mon, 14 Jul 2025 04:40:44 +0000 Subject: [PATCH 4/8] make precommit and lint Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> --- vllm/config.py | 6 +++- vllm/v1/sample/logits_processor.py | 52 ++++++++++++++++++------------ vllm/v1/worker/gpu_input_batch.py | 3 +- vllm/v1/worker/gpu_model_runner.py | 16 ++++----- 4 files changed, 46 insertions(+), 31 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 9ef217bcc1d..160b7fe5d15 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4404,6 +4404,7 @@ def set_splitting_ops_for_v1(self): "vllm.unified_attention_with_output", ] + class ReasoningConfig: """Configuration for reasoning models.""" @@ -4412,10 +4413,13 @@ class ReasoningConfig: think_end_token_id: Optional[int] = None """Token ID that indicates the end of reasoning.""" - def __init__(self, think_start_token_id: Optional[int] = None, think_end_token_id: Optional[int] = None): + def __init__(self, + think_start_token_id: Optional[int] = None, + think_end_token_id: Optional[int] = None): self.think_start_token_id = think_start_token_id self.think_end_token_id = think_end_token_id + @config @dataclass(config=ConfigDict(arbitrary_types_allowed=True)) class VllmConfig: diff --git a/vllm/v1/sample/logits_processor.py b/vllm/v1/sample/logits_processor.py index 2b4e518e1a8..62f615bbd68 100644 --- a/vllm/v1/sample/logits_processor.py +++ b/vllm/v1/sample/logits_processor.py @@ -27,7 +27,8 @@ class MoveDirectionality(Enum): # (index, params, prompt_tok_ids, output_tok_ids) tuples for new # requests added to the batch. -AddedRequest = tuple[int, Union[SamplingParams, PoolingParams], list[int], list[int]] +AddedRequest = tuple[int, Union[SamplingParams, PoolingParams], list[int], + list[int]] # (index 1, index 2, directionality) tuples representing # one-way moves or two-way swaps of requests in batch MovedRequest = tuple[int, int, MoveDirectionality] @@ -497,13 +498,14 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor: class MaxThinkTokensLogitsProcessor(LogitsProcessor): """A logits processor that limits the maximum number of thinking tokens.""" - def __init__(self, reasoning_config: ReasoningConfig, pin_memory: bool, device: torch.device): + def __init__(self, reasoning_config: ReasoningConfig, pin_memory: bool, + device: torch.device): """ Args: - think_start_token_id (int): Token ID for the start of thinking section. - think_end_token_id (int): Token ID for the end of thinking section. - pin_memory (bool): Whether to use pinned memory for tensors. - device (torch.device): Device to use for tensor operations. + reasoning_config: Configuration for reasoning, which includes + the token IDs for thinking start and end. + pin_memory (bool): Whether to use pinned memory for tensors. + device (torch.device): Device to use for tensor operations. """ super().__init__() self.think_start_token_id = reasoning_config.think_start_token_id @@ -519,19 +521,25 @@ def _find_last_token_index(self, tokens: list[int], token_id: int) -> int: return -1 def is_argmax_invariant(self) -> bool: - """This logits processor can change the outcome of greedy sampling - by forcing that the thinking section ends after a certain number of tokens.""" + """This logits processor can change the outcome of + greedy sampling by forcing that the thinking section + ends after a certain number of tokens.""" return False def update_state(self, batch_update: Optional[BatchUpdate]): if batch_update: - for index, params, prompt_tok_ids, output_tok_ids in batch_update.added: - max_think_tokens = params.max_think_tokens if isinstance(params, SamplingParams) else None + for (index, params, prompt_tok_ids, + output_tok_ids) in batch_update.added: + max_think_tokens = (params.max_think_tokens if isinstance( + params, SamplingParams) else None) if max_think_tokens is not None: - last_start = self._find_last_token_index(prompt_tok_ids, self.think_start_token_id) - last_end = self._find_last_token_index(prompt_tok_ids, self.think_end_token_id) + last_start = self._find_last_token_index( + prompt_tok_ids, self.think_start_token_id) + last_end = self._find_last_token_index( + prompt_tok_ids, self.think_end_token_id) in_think = last_start > last_end - count = len(prompt_tok_ids) - (last_start + 1) if in_think else 0 + count = len(prompt_tok_ids) - (last_start + + 1) if in_think else 0 self._state[index] = { "in_think": in_think, @@ -542,13 +550,14 @@ def update_state(self, batch_update: Optional[BatchUpdate]): } for index in batch_update.removed: - self._state.pop(index, None) + self._state.pop(index, {}) for i1, i2, direction in batch_update.moved: if direction == MoveDirectionality.SWAP: - self._state[i1], self._state[i2] = self._state[i2], self._state[i1] + self._state[i1], self._state[i2] = self._state[ + i2], self._state[i1] else: - self._state[i2] = self._state.pop(i1, None) + self._state[i2] = self._state.pop(i1, {}) # Update in_think and count for all active requests for state in self._state.values(): @@ -579,7 +588,8 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor: if not state: continue - if state["in_think"] and state["count"] >= state["max_think_tokens"]: + if state["in_think"] and state["count"] >= state[ + "max_think_tokens"]: mask[index] = True if mask.any(): @@ -589,8 +599,9 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor: return logits -def init_builtin_logitsprocs(pin_memory_available: bool, max_num_reqs: int, - device: torch.device, reasoning_config: ReasoningConfig) -> LogitsProcessorManager: +def init_builtin_logitsprocs( + pin_memory_available: bool, max_num_reqs: int, device: torch.device, + reasoning_config: ReasoningConfig) -> LogitsProcessorManager: """Construct 'builtin' vLLM logitsprocs which the engine loads by default. @@ -619,8 +630,7 @@ def init_builtin_logitsprocs(pin_memory_available: bool, max_num_reqs: int, ) return LogitsProcessorManager( non_argmax_invariant=[ - min_tokens_logitproc, - logit_bias_logitproc, + min_tokens_logitproc, logit_bias_logitproc, max_think_tokens_logitproc ], argmax_invariant=[min_p_logitproc], diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index daf339bdd5b..58b59181f3b 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -263,7 +263,8 @@ def _register_add_request(self, request: "CachedRequestState") -> int: params = (request.sampling_params if request.sampling_params else request.pooling_params) self.batch_update_builder.added.append( - (req_index, params, request.prompt_token_ids, request.output_token_ids)) + (req_index, params, request.prompt_token_ids, + request.output_token_ids)) return req_index def add_request( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 8a37a30f2e9..61c9fb57149 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -18,7 +18,7 @@ from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.layer import Attention from vllm.compilation.counter import compilation_counter -from vllm.config import (CompilationLevel, VllmConfig, +from vllm.config import (CompilationLevel, ReasoningConfig, VllmConfig, get_layers_from_vllm_config, update_config) from vllm.distributed.eplb.eplb_state import EplbState from vllm.distributed.kv_transfer import (get_kv_transfer_group, @@ -39,8 +39,10 @@ from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.utils import group_mm_inputs_by_modality from vllm.pooling_params import PoolingParams +from vllm.reasoning import ReasoningParserManager from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, GiB_bytes, LazyLoader, async_tensor_h2d, cdiv, check_use_alibi, get_dtype_size, @@ -71,10 +73,6 @@ from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing, sanity_check_mm_encoder_outputs, scatter_mm_placeholders) -from vllm.config import ReasoningConfig -from vllm.reasoning import ReasoningParserManager -from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs - if TYPE_CHECKING: import xgrammar as xgr import xgrammar.kernels.apply_token_bitmask_inplace_torch_compile as xgr_torch_compile # noqa: E501 @@ -109,7 +107,8 @@ def __init__( self.prompt_adapter_config = vllm_config.prompt_adapter_config self.observability_config = vllm_config.observability_config - if self.vllm_config.decoding_config.reasoning_backend in ('deepseek_r1', 'qwen'): + if self.vllm_config.decoding_config.reasoning_backend in ( + 'deepseek_r1', 'qwen'): tokenizer = init_tokenizer_from_configs( model_config=self.vllm_config.model_config, scheduler_config=self.vllm_config.scheduler_config, @@ -120,8 +119,9 @@ def __init__( reasoner_cls = ReasoningParserManager.get_reasoning_parser( reasoning_backend) reasoning_parser = reasoner_cls(tokenizer=tokenizer) - self.vllm_config.reasoning_config = ReasoningConfig(think_start_token_id=reasoning_parser.think_start_token_id, - think_end_token_id=reasoning_parser.think_end_token_id) + self.vllm_config.reasoning_config = ReasoningConfig( + think_start_token_id=reasoning_parser.think_start_token_id, + think_end_token_id=reasoning_parser.think_end_token_id) from vllm.model_executor.models.utils import set_cpu_offload_max_bytes set_cpu_offload_max_bytes( From 398471146508c22c240d8335728c81db9c0d93a7 Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Mon, 14 Jul 2025 10:11:40 +0000 Subject: [PATCH 5/8] disable max think tokens logits processor if reasoning parser info is missing Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> --- vllm/v1/sample/logits_processor.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/vllm/v1/sample/logits_processor.py b/vllm/v1/sample/logits_processor.py index 62f615bbd68..6f4f2bcf0f2 100644 --- a/vllm/v1/sample/logits_processor.py +++ b/vllm/v1/sample/logits_processor.py @@ -623,15 +623,18 @@ def init_builtin_logitsprocs( device=device, # +1 for temporary swap space max_num_reqs=max_num_reqs + 1) - max_think_tokens_logitproc = MaxThinkTokensLogitsProcessor( - reasoning_config=reasoning_config, - pin_memory=pin_memory_available, - device=device, - ) + + non_argmax_invariant = [min_tokens_logitproc, logit_bias_logitproc] + + if reasoning_config is not None: + max_think_tokens_logitproc = MaxThinkTokensLogitsProcessor( + reasoning_config=reasoning_config, + pin_memory=pin_memory_available, + device=device, + ) + non_argmax_invariant.append(max_think_tokens_logitproc) + return LogitsProcessorManager( - non_argmax_invariant=[ - min_tokens_logitproc, logit_bias_logitproc, - max_think_tokens_logitproc - ], + non_argmax_invariant=non_argmax_invariant, argmax_invariant=[min_p_logitproc], ) From 5636a129cba8f266fb09ef48d7a4c6cdbfc828f4 Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Tue, 15 Jul 2025 12:49:18 +0000 Subject: [PATCH 6/8] revert change of deepseek reasoning parser Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> --- .../reasoning/deepseek_r1_reasoning_parser.py | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/vllm/reasoning/deepseek_r1_reasoning_parser.py b/vllm/reasoning/deepseek_r1_reasoning_parser.py index 96bb50f3817..1a5ca46a60f 100644 --- a/vllm/reasoning/deepseek_r1_reasoning_parser.py +++ b/vllm/reasoning/deepseek_r1_reasoning_parser.py @@ -23,8 +23,8 @@ class DeepSeekR1ReasoningParser(ReasoningParser): text. This parser extracts the reasoning content from the model output. """ - think_start_token_id: int - think_end_token_id: int + start_token_id: int + end_token_id: int start_token: str = "" end_token: str = "" @@ -37,24 +37,24 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase): "The model tokenizer must be passed to the ReasoningParser " "constructor during construction.") - self.think_start_token_id = self.vocab.get(self.start_token) - self.think_end_token_id = self.vocab.get(self.end_token) - if self.think_start_token_id is None or self.think_end_token_id is None: + self.start_token_id = self.vocab.get(self.start_token) + self.end_token_id = self.vocab.get(self.end_token) + if self.start_token_id is None or self.end_token_id is None: raise RuntimeError( "DeepSeek R1 reasoning parser could not locate think start/end " "tokens in the tokenizer!") def is_reasoning_end(self, input_ids: list[int]) -> bool: - return self.think_end_token_id in input_ids + return self.end_token_id in input_ids def extract_content_ids(self, input_ids: list[int]) -> list[int]: """ Extract the content after the end tokens """ - if self.think_end_token_id not in input_ids[:-1]: + if self.end_token_id not in input_ids[:-1]: return [] else: - return input_ids[input_ids.index(self.think_end_token_id) + 1:] + return input_ids[input_ids.index(self.end_token_id) + 1:] def extract_reasoning_content_streaming( self, @@ -75,14 +75,14 @@ def extract_reasoning_content_streaming( """ # Skip single special tokens if len(delta_token_ids) == 1 and (delta_token_ids[0] in [ - self.think_start_token_id, self.think_end_token_id + self.start_token_id, self.end_token_id ]): return None # Check if is present in previous or delta. # Keep compatibility with models that don't generate tokens. - if self.think_start_token_id in previous_token_ids: - if self.think_end_token_id in delta_token_ids: + if self.start_token_id in previous_token_ids: + if self.end_token_id in delta_token_ids: # in previous, in delta, # extract reasoning content end_index = delta_text.find(self.end_token) @@ -92,7 +92,7 @@ def extract_reasoning_content_streaming( reasoning_content=reasoning_content, content=content if content else None, ) - elif self.think_end_token_id in previous_token_ids: + elif self.end_token_id in previous_token_ids: # in previous, in previous, # reasoning content continues return DeltaMessage(content=delta_text) @@ -100,8 +100,8 @@ def extract_reasoning_content_streaming( # in previous, no in previous or delta, # reasoning content continues return DeltaMessage(reasoning_content=delta_text) - elif self.think_start_token_id in delta_token_ids: - if self.think_end_token_id in delta_token_ids: + elif self.start_token_id in delta_token_ids: + if self.end_token_id in delta_token_ids: # in delta, in delta, extract reasoning content start_index = delta_text.find(self.start_token) end_index = delta_text.find(self.end_token) @@ -120,7 +120,7 @@ def extract_reasoning_content_streaming( # No in previous or delta, also need to check for . # Because the model may have generated without # Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f - if self.think_end_token_id in delta_token_ids: + if self.end_token_id in delta_token_ids: # in delta with more tokens, # extract reasoning content and content end_index = delta_text.find(self.end_token) @@ -130,7 +130,7 @@ def extract_reasoning_content_streaming( reasoning_content=reasoning_content, content=content if content else None, ) - elif self.think_end_token_id in previous_token_ids: + elif self.end_token_id in previous_token_ids: # in previous, thinking content ends return DeltaMessage(content=delta_text) else: From 6b424adbb1fff5743371b17b03b25ffe4bd06ce5 Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Wed, 16 Jul 2025 06:33:40 +0000 Subject: [PATCH 7/8] support think start/end as token sequences Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> --- vllm/config.py | 20 +++-- vllm/engine/arg_utils.py | 14 +++- vllm/v1/sample/logits_processor.py | 122 ++++++++++++++++++++--------- vllm/v1/worker/gpu_model_runner.py | 16 ++-- 4 files changed, 117 insertions(+), 55 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 160b7fe5d15..d11d1034238 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4408,16 +4408,24 @@ def set_splitting_ops_for_v1(self): class ReasoningConfig: """Configuration for reasoning models.""" - think_start_token_id: Optional[int] = None + think_start_str: Optional[str] = None + """String that indicates the start of reasoning.""" + think_end_str: Optional[str] = None + """String that indicates the end of reasoning.""" + think_start_token_ids: Optional[int] = None """Token ID that indicates the start of reasoning.""" - think_end_token_id: Optional[int] = None + think_end_token_ids: Optional[int] = None """Token ID that indicates the end of reasoning.""" def __init__(self, - think_start_token_id: Optional[int] = None, - think_end_token_id: Optional[int] = None): - self.think_start_token_id = think_start_token_id - self.think_end_token_id = think_end_token_id + think_start_str: Optional[str] = None, + think_end_str: Optional[str] = None, + think_start_token_ids: Optional[int] = None, + think_end_token_ids: Optional[int] = None): + self.think_start_str = think_start_str + self.think_end_str = think_end_str + self.think_start_token_ids = think_start_token_ids + self.think_end_token_ids = think_end_token_ids @config diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f47499309d8..65133c300e4 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -31,9 +31,9 @@ ModelConfig, ModelDType, ModelImpl, MultiModalConfig, ObservabilityConfig, ParallelConfig, PoolerConfig, PrefixCachingHashAlgo, PromptAdapterConfig, - SchedulerConfig, SchedulerPolicy, SpeculativeConfig, - TaskOption, TokenizerMode, TokenizerPoolConfig, - VllmConfig, get_attr_docs, get_field) + ReasoningConfig, SchedulerConfig, SchedulerPolicy, + SpeculativeConfig, TaskOption, TokenizerMode, + TokenizerPoolConfig, VllmConfig, get_attr_docs, get_field) from vllm.logger import init_logger from vllm.platforms import CpuArchEnum, current_platform from vllm.plugins import load_general_plugins @@ -464,6 +464,8 @@ class EngineArgs: kv_transfer_config: Optional[KVTransferConfig] = None kv_events_config: Optional[KVEventsConfig] = None + reasoning_config: Optional[ReasoningConfig] = None + generation_config: str = ModelConfig.generation_config enable_sleep_mode: bool = ModelConfig.enable_sleep_mode override_generation_config: dict[str, Any] = \ @@ -934,6 +936,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: **vllm_kwargs["kv_events_config"]) vllm_group.add_argument("--compilation-config", "-O", **vllm_kwargs["compilation_config"]) + vllm_group.add_argument("--reasoning-config", + **vllm_kwargs["reasoning_config"]) vllm_group.add_argument("--additional-config", **vllm_kwargs["additional_config"]) @@ -1332,6 +1336,9 @@ def create_engine_config( collect_detailed_traces=self.collect_detailed_traces, ) + reasoning_config_dict = json.loads(self.reasoning_config) + reasoning_config = ReasoningConfig(**reasoning_config_dict) + config = VllmConfig( model_config=model_config, cache_config=cache_config, @@ -1347,6 +1354,7 @@ def create_engine_config( compilation_config=self.compilation_config, kv_transfer_config=self.kv_transfer_config, kv_events_config=self.kv_events_config, + reasoning_config=reasoning_config, additional_config=self.additional_config, ) diff --git a/vllm/v1/sample/logits_processor.py b/vllm/v1/sample/logits_processor.py index 6f4f2bcf0f2..b846de0481e 100644 --- a/vllm/v1/sample/logits_processor.py +++ b/vllm/v1/sample/logits_processor.py @@ -514,12 +514,84 @@ def __init__(self, reasoning_config: ReasoningConfig, pin_memory: bool, self.device = device self._state: dict[int, dict[str, Any]] = {} - def _find_last_token_index(self, tokens: list[int], token_id: int) -> int: + def _find_first_token_index(self, target_list: list[int], token_id: int) -> int: + """ + Find the last occurrence of a single token in the list of tokens. + + Args: + target_list (list[int]): The list of token IDs. + token_id (int): The token ID to find. + """ try: - return len(tokens) - tokens[::-1].index(token_id) - 1 + return len(target_list) - target_list[::-1].index(token_id) - 1 except ValueError: return -1 + def _find_last_sequence_index(self, target_list: list[int], token_ids: list[int]) -> int: + """ + Find the last occurrence of the sequence of token_ids in tokens. + + Args: + target_list (list[int]): The list of token IDs. + token_ids (list[int]): The sequence of token IDs to find. + """ + index = self._find_first_token_index(target_list, token_ids[0]) + if index != -1: + i = 1 + for token_id in token_ids[1:]: + if index + i >= len(target_list) or target_list[index + i] != token_id: + return -1 + i += 1 + index += 1 + + return index + + def _init_state_entry(self, prompt_tok_ids, max_think_tokens): + last_start = self._find_last_sequence_index( + prompt_tok_ids, self.think_start_token_id) + last_end = self._find_last_sequence_index( + prompt_tok_ids, self.think_end_token_id) + in_think = last_start > last_end + think_count = len(prompt_tok_ids) - (last_start + 1) if in_think else 0 + + return { + "in_think": in_think, + "in_end": False, + "think_count": think_count, + "end_count": 0, + "prompt_tok_ids": prompt_tok_ids, + "output_tok_ids": [], + "max_think_tokens": max_think_tokens, + } + + def _update_think_state(self, state): + output = state["output_tok_ids"] + if not output: + return + + sliced_output1 = output[-1 + len(self.think_start_token_id):] + sliced_output2 = output[-1 + len(self.think_end_token_id):] + + if self._find_last_sequence_index(sliced_output1, self.think_start_token_id) != -1: + state["in_think"] = True + state["think_count"] = 0 + elif self._find_last_sequence_index(sliced_output2, self.think_end_token_id) != -1: + state["in_think"] = False + state["think_count"] = 0 + else: + state["think_count"] += 1 + + if state["in_end"]: + state["end_count"] += 1 + if state["end_count"] >= len(self.think_end_token_id): + state["in_end"] = False + state["end_count"] = 0 + else: + if state["in_think"] and state["think_count"] >= state["max_think_tokens"]: + state["in_think"] = False + state["in_end"] = True + state["end_count"] = 0 + def is_argmax_invariant(self) -> bool: """This logits processor can change the outcome of greedy sampling by forcing that the thinking section @@ -528,52 +600,25 @@ def is_argmax_invariant(self) -> bool: def update_state(self, batch_update: Optional[BatchUpdate]): if batch_update: - for (index, params, prompt_tok_ids, - output_tok_ids) in batch_update.added: + for (index, params, prompt_tok_ids, output_tok_ids) in batch_update.added: max_think_tokens = (params.max_think_tokens if isinstance( params, SamplingParams) else None) if max_think_tokens is not None: - last_start = self._find_last_token_index( - prompt_tok_ids, self.think_start_token_id) - last_end = self._find_last_token_index( - prompt_tok_ids, self.think_end_token_id) - in_think = last_start > last_end - count = len(prompt_tok_ids) - (last_start + - 1) if in_think else 0 - - self._state[index] = { - "in_think": in_think, - "count": count, - "prompt_tok_ids": prompt_tok_ids, - "output_tok_ids": output_tok_ids, - "max_think_tokens": max_think_tokens, - } + self._state[index] = self._init_state_entry( + prompt_tok_ids, max_think_tokens) + self._state[index]["output_tok_ids"] = output_tok_ids for index in batch_update.removed: self._state.pop(index, {}) for i1, i2, direction in batch_update.moved: if direction == MoveDirectionality.SWAP: - self._state[i1], self._state[i2] = self._state[ - i2], self._state[i1] + self._state[i1], self._state[i2] = self._state[i2], self._state[i1] else: self._state[i2] = self._state.pop(i1, {}) - # Update in_think and count for all active requests for state in self._state.values(): - output = state["output_tok_ids"] - if not output: - continue - - last_tok = output[-1] - if last_tok == self.think_start_token_id: - state["in_think"] = True - state["count"] = 0 - elif last_tok == self.think_end_token_id: - state["in_think"] = False - state["count"] = 0 - elif state["in_think"]: - state["count"] += 1 + self._update_think_state(state) def apply(self, logits: torch.Tensor) -> torch.Tensor: batch_size = logits.size(0) @@ -588,13 +633,14 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor: if not state: continue - if state["in_think"] and state["count"] >= state[ - "max_think_tokens"]: + force_end_token_id = None + if state["in_end"]: + force_end_token_id = self.think_end_token_id[state["end_count"]] mask[index] = True if mask.any(): logits[mask] = -float("inf") - logits[mask, end_token_id] = 0.0 + logits[mask, force_end_token_id] = 0.0 return logits diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 61c9fb57149..ad48cc5151d 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -114,14 +114,14 @@ def __init__( scheduler_config=self.vllm_config.scheduler_config, lora_config=self.vllm_config.lora_config, ).get_lora_tokenizer(None) - reasoning_backend = \ - self.vllm_config.decoding_config.reasoning_backend - reasoner_cls = ReasoningParserManager.get_reasoning_parser( - reasoning_backend) - reasoning_parser = reasoner_cls(tokenizer=tokenizer) - self.vllm_config.reasoning_config = ReasoningConfig( - think_start_token_id=reasoning_parser.think_start_token_id, - think_end_token_id=reasoning_parser.think_end_token_id) + reasoning_config = self.vllm_config.reasoning_config + if reasoning_config is not None: + reasoning_config.think_start_token_id = \ + tokenizer.convert_tokens_to_ids( + tokenizer.tokenize(reasoning_config.think_start_str)) + reasoning_config.think_end_token_id = \ + tokenizer.convert_tokens_to_ids( + tokenizer.tokenize(reasoning_config.think_end_str)) from vllm.model_executor.models.utils import set_cpu_offload_max_bytes set_cpu_offload_max_bytes( From 2e6cff6fe6e3611810a8e6f07029fb22026b7908 Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Thu, 17 Jul 2025 13:47:22 +0000 Subject: [PATCH 8/8] refactor and change logic faster Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Signed-off-by: Sungjae Lee --- vllm/v1/sample/logits_processor.py | 90 ++++++++++++------------------ vllm/v1/worker/gpu_model_runner.py | 4 +- 2 files changed, 39 insertions(+), 55 deletions(-) diff --git a/vllm/v1/sample/logits_processor.py b/vllm/v1/sample/logits_processor.py index b846de0481e..e6c429bb8d2 100644 --- a/vllm/v1/sample/logits_processor.py +++ b/vllm/v1/sample/logits_processor.py @@ -496,7 +496,7 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor: class MaxThinkTokensLogitsProcessor(LogitsProcessor): - """A logits processor that limits the maximum number of thinking tokens.""" + """Limits the number of tokens allowed inside a 'thinking' section.""" def __init__(self, reasoning_config: ReasoningConfig, pin_memory: bool, device: torch.device): @@ -508,82 +508,68 @@ def __init__(self, reasoning_config: ReasoningConfig, pin_memory: bool, device (torch.device): Device to use for tensor operations. """ super().__init__() - self.think_start_token_id = reasoning_config.think_start_token_id - self.think_end_token_id = reasoning_config.think_end_token_id + self.think_start_token_ids = reasoning_config.think_start_token_ids + self.think_end_token_ids = reasoning_config.think_end_token_ids self.pin_memory = pin_memory self.device = device self._state: dict[int, dict[str, Any]] = {} - def _find_first_token_index(self, target_list: list[int], token_id: int) -> int: + @staticmethod + def _find_last_sequence_index(target_list: list[int], token_ids: list[int]) -> int: """ - Find the last occurrence of a single token in the list of tokens. + Returns the index of the last occurrence of token_ids in target_list. Args: target_list (list[int]): The list of token IDs. - token_id (int): The token ID to find. + token_ids (list[int]): The sequence of token IDs to find. """ - try: - return len(target_list) - target_list[::-1].index(token_id) - 1 - except ValueError: + if not token_ids: return -1 - def _find_last_sequence_index(self, target_list: list[int], token_ids: list[int]) -> int: - """ - Find the last occurrence of the sequence of token_ids in tokens. + for i in range(len(target_list) - len(token_ids), -1, -1): + if target_list[i:i + len(token_ids)] == token_ids: + return i + return -1 - Args: - target_list (list[int]): The list of token IDs. - token_ids (list[int]): The sequence of token IDs to find. - """ - index = self._find_first_token_index(target_list, token_ids[0]) - if index != -1: - i = 1 - for token_id in token_ids[1:]: - if index + i >= len(target_list) or target_list[index + i] != token_id: - return -1 - i += 1 - index += 1 - - return index - - def _init_state_entry(self, prompt_tok_ids, max_think_tokens): + def _init_state_entry(self, prompt_tok_ids: list[int], max_think_tokens: int) -> dict[str, Any]: + """Initializes the tracking state for a given sequence index.""" last_start = self._find_last_sequence_index( - prompt_tok_ids, self.think_start_token_id) + prompt_tok_ids, self.think_start_token_ids) last_end = self._find_last_sequence_index( - prompt_tok_ids, self.think_end_token_id) + prompt_tok_ids, self.think_end_token_ids) in_think = last_start > last_end think_count = len(prompt_tok_ids) - (last_start + 1) if in_think else 0 return { - "in_think": in_think, - "in_end": False, - "think_count": think_count, - "end_count": 0, + "in_think": in_think, # Currently in thinking mode + "in_end": False, # Currently forcing end tokens + "think_count": think_count, # Number of tokens in thinking section + "end_count": 0, # Number of end tokens forced so far "prompt_tok_ids": prompt_tok_ids, "output_tok_ids": [], "max_think_tokens": max_think_tokens, } - def _update_think_state(self, state): + def _update_think_state(self, state: dict[str, Any]): + """Updates the state based on generated output tokens.""" output = state["output_tok_ids"] if not output: return - sliced_output1 = output[-1 + len(self.think_start_token_id):] - sliced_output2 = output[-1 + len(self.think_end_token_id):] - - if self._find_last_sequence_index(sliced_output1, self.think_start_token_id) != -1: + # Check if recent output matches start or end sequences + if output[-len(self.think_start_token_ids):] == self.think_start_token_ids: state["in_think"] = True state["think_count"] = 0 - elif self._find_last_sequence_index(sliced_output2, self.think_end_token_id) != -1: + elif output[-len(self.think_end_token_ids):] == self.think_end_token_ids: state["in_think"] = False state["think_count"] = 0 - else: + elif state["in_think"]: state["think_count"] += 1 + # Transition into end mode if thinking token limit exceeded if state["in_end"]: state["end_count"] += 1 - if state["end_count"] >= len(self.think_end_token_id): + if state["end_count"] >= len(self.think_end_token_ids): state["in_end"] = False state["end_count"] = 0 else: @@ -626,21 +612,19 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor: return logits mask = torch.zeros(batch_size, dtype=torch.bool, device=logits.device) - end_token_id = self.think_end_token_id - - for index in range(batch_size): - state = self._state.get(index) - if not state: - continue + force_token_ids = torch.full((batch_size,), -1, dtype=torch.long, device=logits.device) - force_end_token_id = None - if state["in_end"]: - force_end_token_id = self.think_end_token_id[state["end_count"]] - mask[index] = True + for i in range(batch_size): + state = self._state.get(i) + if state and state["in_end"]: + mask[i] = True + force_token_ids[i] = self.think_end_token_ids[state["end_count"]] if mask.any(): logits[mask] = -float("inf") - logits[mask, force_end_token_id] = 0.0 + row_indices = torch.arange(batch_size, device=logits.device)[mask] + col_indices = force_token_ids[mask] + logits[row_indices, col_indices] = 0.0 return logits diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ad48cc5151d..1d16abed80a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -116,10 +116,10 @@ def __init__( ).get_lora_tokenizer(None) reasoning_config = self.vllm_config.reasoning_config if reasoning_config is not None: - reasoning_config.think_start_token_id = \ + reasoning_config.think_start_token_ids = \ tokenizer.convert_tokens_to_ids( tokenizer.tokenize(reasoning_config.think_start_str)) - reasoning_config.think_end_token_id = \ + reasoning_config.think_end_token_ids = \ tokenizer.convert_tokens_to_ids( tokenizer.tokenize(reasoning_config.think_end_str))