diff --git a/vllm/config.py b/vllm/config.py index d9f356c5c60..d11d1034238 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4405,6 +4405,29 @@ def set_splitting_ops_for_v1(self): ] +class ReasoningConfig: + """Configuration for reasoning models.""" + + think_start_str: Optional[str] = None + """String that indicates the start of reasoning.""" + think_end_str: Optional[str] = None + """String that indicates the end of reasoning.""" + think_start_token_ids: Optional[int] = None + """Token ID that indicates the start of reasoning.""" + think_end_token_ids: Optional[int] = None + """Token ID that indicates the end of reasoning.""" + + def __init__(self, + think_start_str: Optional[str] = None, + think_end_str: Optional[str] = None, + think_start_token_ids: Optional[int] = None, + think_end_token_ids: Optional[int] = None): + self.think_start_str = think_start_str + self.think_end_str = think_end_str + self.think_start_token_ids = think_start_token_ids + self.think_end_token_ids = think_end_token_ids + + @config @dataclass(config=ConfigDict(arbitrary_types_allowed=True)) class VllmConfig: @@ -4461,6 +4484,8 @@ class VllmConfig: # some opaque config, only used to provide additional information # for the hash computation, mainly used for testing, debugging or out of # tree config registration. + reasoning_config: Optional[ReasoningConfig] = None + """The configurations for reasoning model.""" additional_config: Union[dict, SupportsHash] = field(default_factory=dict) """Additional config for specified platform. Different platforms may support different configs. Make sure the configs are valid for the platform diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f47499309d8..65133c300e4 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -31,9 +31,9 @@ ModelConfig, ModelDType, ModelImpl, MultiModalConfig, ObservabilityConfig, ParallelConfig, PoolerConfig, PrefixCachingHashAlgo, PromptAdapterConfig, - SchedulerConfig, SchedulerPolicy, SpeculativeConfig, - TaskOption, TokenizerMode, TokenizerPoolConfig, - VllmConfig, get_attr_docs, get_field) + ReasoningConfig, SchedulerConfig, SchedulerPolicy, + SpeculativeConfig, TaskOption, TokenizerMode, + TokenizerPoolConfig, VllmConfig, get_attr_docs, get_field) from vllm.logger import init_logger from vllm.platforms import CpuArchEnum, current_platform from vllm.plugins import load_general_plugins @@ -464,6 +464,8 @@ class EngineArgs: kv_transfer_config: Optional[KVTransferConfig] = None kv_events_config: Optional[KVEventsConfig] = None + reasoning_config: Optional[ReasoningConfig] = None + generation_config: str = ModelConfig.generation_config enable_sleep_mode: bool = ModelConfig.enable_sleep_mode override_generation_config: dict[str, Any] = \ @@ -934,6 +936,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: **vllm_kwargs["kv_events_config"]) vllm_group.add_argument("--compilation-config", "-O", **vllm_kwargs["compilation_config"]) + vllm_group.add_argument("--reasoning-config", + **vllm_kwargs["reasoning_config"]) vllm_group.add_argument("--additional-config", **vllm_kwargs["additional_config"]) @@ -1332,6 +1336,9 @@ def create_engine_config( collect_detailed_traces=self.collect_detailed_traces, ) + reasoning_config_dict = json.loads(self.reasoning_config) + reasoning_config = ReasoningConfig(**reasoning_config_dict) + config = VllmConfig( model_config=model_config, cache_config=cache_config, @@ -1347,6 +1354,7 @@ def create_engine_config( compilation_config=self.compilation_config, kv_transfer_config=self.kv_transfer_config, kv_events_config=self.kv_events_config, + reasoning_config=reasoning_config, additional_config=self.additional_config, ) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index fdac6ccd19e..e0c7bbf21db 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -404,6 +404,7 @@ class ChatCompletionRequest(OpenAIBaseModel): prompt_logprobs: Optional[int] = None allowed_token_ids: Optional[list[int]] = None bad_words: list[str] = Field(default_factory=list) + max_think_tokens: Optional[int] = None # --8<-- [end:chat-completion-sampling-params] # --8<-- [start:chat-completion-extra-params] @@ -670,6 +671,7 @@ def to_sampling_params( guided_decoding=guided_decoding, logit_bias=self.logit_bias, bad_words= self.bad_words, + max_think_tokens=self.max_think_tokens, allowed_token_ids=self.allowed_token_ids, extra_args=extra_args or None, ) diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index a9a862384d1..49cdbdde45c 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -200,6 +200,7 @@ class SamplingParams( extra_args: Arbitrary additional args, that can be used by custom sampling implementations, plugins, etc. Not used by any in-tree sampling implementations. + max_think_tokens: Maximum number of tokens allowed for thinking """ n: int = 1 @@ -248,6 +249,9 @@ class SamplingParams( bad_words: Optional[list[str]] = None _bad_words_token_ids: Optional[list[list[int]]] = None + # Maximum number of tokens allowed for thinking operations. + max_think_tokens: Optional[int] = None + @staticmethod def from_optional( n: Optional[int] = 1, @@ -263,6 +267,7 @@ def from_optional( stop: Optional[Union[str, list[str]]] = None, stop_token_ids: Optional[list[int]] = None, bad_words: Optional[list[str]] = None, + max_think_tokens: Optional[int] = None, include_stop_str_in_output: bool = False, ignore_eos: bool = False, max_tokens: Optional[int] = 16, @@ -306,6 +311,7 @@ def from_optional( stop=stop, stop_token_ids=stop_token_ids, bad_words=bad_words, + max_think_tokens=max_think_tokens, include_stop_str_in_output=include_stop_str_in_output, ignore_eos=ignore_eos, max_tokens=max_tokens, @@ -574,6 +580,7 @@ def __repr__(self) -> str: f"stop={self.stop}, " f"stop_token_ids={self.stop_token_ids}, " f"bad_words={self.bad_words}, " + f"max_think_tokens={self.max_think_tokens}, " f"include_stop_str_in_output={self.include_stop_str_in_output}, " f"ignore_eos={self.ignore_eos}, " f"max_tokens={self.max_tokens}, " diff --git a/vllm/v1/sample/logits_processor.py b/vllm/v1/sample/logits_processor.py index 3a4c25964e7..e6c429bb8d2 100644 --- a/vllm/v1/sample/logits_processor.py +++ b/vllm/v1/sample/logits_processor.py @@ -6,12 +6,13 @@ from dataclasses import dataclass, field from enum import Enum from itertools import chain -from typing import Optional, Union +from typing import Any, Optional, Union import torch from torch._prims_common import DeviceLikeType from vllm import PoolingParams, SamplingParams +from vllm.config import ReasoningConfig from vllm.logger import init_logger logger = init_logger(__name__) @@ -24,9 +25,10 @@ class MoveDirectionality(Enum): SWAP = 1 -# (index, params, output_tok_ids) tuples for new +# (index, params, prompt_tok_ids, output_tok_ids) tuples for new # requests added to the batch. -AddedRequest = tuple[int, Union[SamplingParams, PoolingParams], list[int]] +AddedRequest = tuple[int, Union[SamplingParams, PoolingParams], list[int], + list[int]] # (index 1, index 2, directionality) tuples representing # one-way moves or two-way swaps of requests in batch MovedRequest = tuple[int, int, MoveDirectionality] @@ -43,9 +45,9 @@ class BatchUpdate: # within the persistent batch. # # Note: each added request is represented as - # (index, params, output_tok_ids) - # Key assumption: output_tok_ids is a reference to the - # request's running output tokens list; in this way + # (index, params, prompt_tok_ids, output_tok_ids) + # Key assumption: prompt_tok_ids, output_tok_ids is a reference to the + # request's prompt and running output tokens list; in this way # the logits processors always see the latest list of # generated tokens removed: Sequence[RemovedRequest] @@ -260,7 +262,7 @@ def update_state(self, batch_update: Optional[BatchUpdate]): needs_update = False # Process added requests. - for index, params, _ in batch_update.added: + for index, params, _, _ in batch_update.added: min_p = params.min_p if isinstance(params, SamplingParams) else 0.0 if self.min_p_cpu[index] != min_p: needs_update = True @@ -337,7 +339,7 @@ def update_state(self, batch_update: Optional[BatchUpdate]): # Process added requests. needs_update = bool(batch_update.added) - for index, params, _ in batch_update.added: + for index, params, _, _ in batch_update.added: if isinstance(params, SamplingParams) and (lb := params.logit_bias): self.biases[index] = lb @@ -420,7 +422,7 @@ def update_state(self, batch_update: Optional[BatchUpdate]): if batch_update: # Process added requests. needs_update |= bool(batch_update.added) - for index, params, output_tok_ids in batch_update.added: + for index, params, _, output_tok_ids in batch_update.added: if (isinstance(params, SamplingParams) and (min_tokens := params.min_tokens) and len(output_tok_ids) < min_tokens): @@ -493,8 +495,143 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor: return logits -def init_builtin_logitsprocs(pin_memory_available: bool, max_num_reqs: int, - device: torch.device) -> LogitsProcessorManager: +class MaxThinkTokensLogitsProcessor(LogitsProcessor): + """Limits the number of tokens allowed inside a 'thinking' section.""" + + def __init__(self, reasoning_config: ReasoningConfig, pin_memory: bool, + device: torch.device): + """ + Args: + reasoning_config: Configuration for reasoning, which includes + the token IDs for thinking start and end. + pin_memory (bool): Whether to use pinned memory for tensors. + device (torch.device): Device to use for tensor operations. + """ + super().__init__() + self.think_start_token_ids = reasoning_config.think_start_token_ids + self.think_end_token_ids = reasoning_config.think_end_token_ids + self.pin_memory = pin_memory + self.device = device + self._state: dict[int, dict[str, Any]] = {} + + @staticmethod + def _find_last_sequence_index(target_list: list[int], token_ids: list[int]) -> int: + """ + Returns the index of the last occurrence of token_ids in target_list. + + Args: + target_list (list[int]): The list of token IDs. + token_ids (list[int]): The sequence of token IDs to find. + """ + if not token_ids: + return -1 + + for i in range(len(target_list) - len(token_ids), -1, -1): + if target_list[i:i + len(token_ids)] == token_ids: + return i + return -1 + + def _init_state_entry(self, prompt_tok_ids: list[int], max_think_tokens: int) -> dict[str, Any]: + """Initializes the tracking state for a given sequence index.""" + last_start = self._find_last_sequence_index( + prompt_tok_ids, self.think_start_token_ids) + last_end = self._find_last_sequence_index( + prompt_tok_ids, self.think_end_token_ids) + in_think = last_start > last_end + think_count = len(prompt_tok_ids) - (last_start + 1) if in_think else 0 + + return { + "in_think": in_think, # Currently in thinking mode + "in_end": False, # Currently forcing end tokens + "think_count": think_count, # Number of tokens in thinking section + "end_count": 0, # Number of end tokens forced so far + "prompt_tok_ids": prompt_tok_ids, + "output_tok_ids": [], + "max_think_tokens": max_think_tokens, + } + + def _update_think_state(self, state: dict[str, Any]): + """Updates the state based on generated output tokens.""" + output = state["output_tok_ids"] + if not output: + return + + # Check if recent output matches start or end sequences + if output[-len(self.think_start_token_ids):] == self.think_start_token_ids: + state["in_think"] = True + state["think_count"] = 0 + elif output[-len(self.think_end_token_ids):] == self.think_end_token_ids: + state["in_think"] = False + state["think_count"] = 0 + elif state["in_think"]: + state["think_count"] += 1 + + # Transition into end mode if thinking token limit exceeded + if state["in_end"]: + state["end_count"] += 1 + if state["end_count"] >= len(self.think_end_token_ids): + state["in_end"] = False + state["end_count"] = 0 + else: + if state["in_think"] and state["think_count"] >= state["max_think_tokens"]: + state["in_think"] = False + state["in_end"] = True + state["end_count"] = 0 + + def is_argmax_invariant(self) -> bool: + """This logits processor can change the outcome of + greedy sampling by forcing that the thinking section + ends after a certain number of tokens.""" + return False + + def update_state(self, batch_update: Optional[BatchUpdate]): + if batch_update: + for (index, params, prompt_tok_ids, output_tok_ids) in batch_update.added: + max_think_tokens = (params.max_think_tokens if isinstance( + params, SamplingParams) else None) + if max_think_tokens is not None: + self._state[index] = self._init_state_entry( + prompt_tok_ids, max_think_tokens) + self._state[index]["output_tok_ids"] = output_tok_ids + + for index in batch_update.removed: + self._state.pop(index, {}) + + for i1, i2, direction in batch_update.moved: + if direction == MoveDirectionality.SWAP: + self._state[i1], self._state[i2] = self._state[i2], self._state[i1] + else: + self._state[i2] = self._state.pop(i1, {}) + + for state in self._state.values(): + self._update_think_state(state) + + def apply(self, logits: torch.Tensor) -> torch.Tensor: + batch_size = logits.size(0) + if not self._state: + return logits + + mask = torch.zeros(batch_size, dtype=torch.bool, device=logits.device) + force_token_ids = torch.full((batch_size,), -1, dtype=torch.long, device=logits.device) + + for i in range(batch_size): + state = self._state.get(i) + if state and state["in_end"]: + mask[i] = True + force_token_ids[i] = self.think_end_token_ids[state["end_count"]] + + if mask.any(): + logits[mask] = -float("inf") + row_indices = torch.arange(batch_size, device=logits.device)[mask] + col_indices = force_token_ids[mask] + logits[row_indices, col_indices] = 0.0 + + return logits + + +def init_builtin_logitsprocs( + pin_memory_available: bool, max_num_reqs: int, device: torch.device, + reasoning_config: ReasoningConfig) -> LogitsProcessorManager: """Construct 'builtin' vLLM logitsprocs which the engine loads by default. @@ -516,10 +653,18 @@ def init_builtin_logitsprocs(pin_memory_available: bool, max_num_reqs: int, device=device, # +1 for temporary swap space max_num_reqs=max_num_reqs + 1) + + non_argmax_invariant = [min_tokens_logitproc, logit_bias_logitproc] + + if reasoning_config is not None: + max_think_tokens_logitproc = MaxThinkTokensLogitsProcessor( + reasoning_config=reasoning_config, + pin_memory=pin_memory_available, + device=device, + ) + non_argmax_invariant.append(max_think_tokens_logitproc) + return LogitsProcessorManager( - non_argmax_invariant=[ - min_tokens_logitproc, - logit_bias_logitproc, - ], + non_argmax_invariant=non_argmax_invariant, argmax_invariant=[min_p_logitproc], ) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 1a79d72be0a..58b59181f3b 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -8,6 +8,7 @@ import numpy as np import torch +from vllm.config import ReasoningConfig from vllm.lora.request import LoRARequest from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.pooling_params import PoolingParams @@ -71,6 +72,7 @@ def __init__( block_sizes: list[int], # The block_size of each kv cache group is_spec_decode: bool = False, logits_processing_needs_token_ids: bool = False, + reasoning_config: ReasoningConfig = None, ): self.is_spec_decode = is_spec_decode self.max_num_reqs = max_num_reqs @@ -221,7 +223,8 @@ def __init__( self.logitsprocs = init_builtin_logitsprocs( pin_memory_available=pin_memory, max_num_reqs=max_num_reqs + 1, - device=device) + device=device, + reasoning_config=reasoning_config) # TODO convert this to LogitsProcessor self.has_allowed_token_ids: set[str] = set() @@ -260,7 +263,8 @@ def _register_add_request(self, request: "CachedRequestState") -> int: params = (request.sampling_params if request.sampling_params else request.pooling_params) self.batch_update_builder.added.append( - (req_index, params, request.output_token_ids)) + (req_index, params, request.prompt_token_ids, + request.output_token_ids)) return req_index def add_request( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4551cb2df98..1d16abed80a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -18,7 +18,7 @@ from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.layer import Attention from vllm.compilation.counter import compilation_counter -from vllm.config import (CompilationLevel, VllmConfig, +from vllm.config import (CompilationLevel, ReasoningConfig, VllmConfig, get_layers_from_vllm_config, update_config) from vllm.distributed.eplb.eplb_state import EplbState from vllm.distributed.kv_transfer import (get_kv_transfer_group, @@ -39,8 +39,10 @@ from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.utils import group_mm_inputs_by_modality from vllm.pooling_params import PoolingParams +from vllm.reasoning import ReasoningParserManager from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, GiB_bytes, LazyLoader, async_tensor_h2d, cdiv, check_use_alibi, get_dtype_size, @@ -105,6 +107,22 @@ def __init__( self.prompt_adapter_config = vllm_config.prompt_adapter_config self.observability_config = vllm_config.observability_config + if self.vllm_config.decoding_config.reasoning_backend in ( + 'deepseek_r1', 'qwen'): + tokenizer = init_tokenizer_from_configs( + model_config=self.vllm_config.model_config, + scheduler_config=self.vllm_config.scheduler_config, + lora_config=self.vllm_config.lora_config, + ).get_lora_tokenizer(None) + reasoning_config = self.vllm_config.reasoning_config + if reasoning_config is not None: + reasoning_config.think_start_token_ids = \ + tokenizer.convert_tokens_to_ids( + tokenizer.tokenize(reasoning_config.think_start_str)) + reasoning_config.think_end_token_ids = \ + tokenizer.convert_tokens_to_ids( + tokenizer.tokenize(reasoning_config.think_end_str)) + from vllm.model_executor.models.utils import set_cpu_offload_max_bytes set_cpu_offload_max_bytes( int(self.cache_config.cpu_offload_gb * 1024**3)) @@ -212,6 +230,7 @@ def __init__( vocab_size=self.model_config.get_vocab_size(), block_sizes=[self.cache_config.block_size], is_spec_decode=bool(self.vllm_config.speculative_config), + reasoning_config=self.vllm_config.reasoning_config, ) self.use_cuda_graph = ( @@ -2384,6 +2403,7 @@ def may_reinitialize_input_batch(self, vocab_size=self.model_config.get_vocab_size(), block_sizes=block_sizes, is_spec_decode=bool(self.vllm_config.speculative_config), + reasoning_config=self.vllm_config.reasoning_config, ) def _allocate_kv_cache_tensors(