diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 3b5281962b2..1fc92e6887b 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -272,7 +272,8 @@ class ChatCompletionRequest(OpenAIBaseModel): truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None prompt_logprobs: Optional[int] = None allowed_token_ids: Optional[list[int]] = None - bad_words: list[str] = Field(default_factory=list) + bad_words: list[str] = Field(default_factory=list), + thinking_budget: Optional[int] = -1 # --8<-- [end:chat-completion-sampling-params] # --8<-- [start:chat-completion-extra-params] @@ -563,6 +564,7 @@ def to_sampling_params( guided_decoding=guided_decoding, logit_bias=self.logit_bias, bad_words= self.bad_words, + thinking_budget=self.thinking_budget, allowed_token_ids=self.allowed_token_ids, extra_args=extra_args or None, ) diff --git a/vllm/reasoning/abs_reasoning_parsers.py b/vllm/reasoning/abs_reasoning_parsers.py index e827d381ca1..394d112894b 100644 --- a/vllm/reasoning/abs_reasoning_parsers.py +++ b/vllm/reasoning/abs_reasoning_parsers.py @@ -35,6 +35,16 @@ def vocab(self) -> dict[str, int]: # whereas all tokenizers have .get_vocab() return self.model_tokenizer.get_vocab() + + @property + def special_token_ids(self) -> dict: + """ + Returns a dictionary of special token IDs for model-agnostic access. + Example: {"start_token_id": int, "end_token_id": int} + """ + # Default: not implemented, override in subclasses if needed + return {} + @abstractmethod def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: """ @@ -121,6 +131,16 @@ def get_reasoning_parser(cls, name: str | None) -> type[ReasoningParser]: raise KeyError( f"reasoning helper: '{name}' not found in reasoning_parsers") + @staticmethod + def get_special_token_ids(parser) -> dict: + """ + Get special token IDs from a parser instance, if available. + Returns an empty dict if not present. + """ + if parser is not None and hasattr(parser, "special_token_ids"): + return parser.special_token_ids + return {} + @classmethod def _register_module( cls, diff --git a/vllm/reasoning/deepseek_r1_reasoning_parser.py b/vllm/reasoning/deepseek_r1_reasoning_parser.py index 1a5ca46a60f..8e2c772ab29 100644 --- a/vllm/reasoning/deepseek_r1_reasoning_parser.py +++ b/vllm/reasoning/deepseek_r1_reasoning_parser.py @@ -44,6 +44,17 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase): "DeepSeek R1 reasoning parser could not locate think start/end " "tokens in the tokenizer!") + @property + def special_token_ids(self) -> dict: + """ + Returns a dictionary of special token IDs for model-agnostic access. + Example: {"start_token_id": int, "end_token_id": int} + """ + return { + "start_token_id": self.start_token_id, + "end_token_id": self.end_token_id, + } + def is_reasoning_end(self, input_ids: list[int]) -> bool: return self.end_token_id in input_ids diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index a9a862384d1..af8cb50fa1c 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -43,7 +43,8 @@ class GuidedDecodingParams: disable_any_whitespace: bool = False disable_additional_properties: bool = False whitespace_pattern: Optional[str] = None - structural_tag: Optional[str] = None + structural_tag: Optional[str] = None, + thinking_budget: Optional[int] = -1 @staticmethod def from_optional( @@ -55,6 +56,7 @@ def from_optional( backend: Optional[str] = None, whitespace_pattern: Optional[str] = None, structural_tag: Optional[str] = None, + thinking_budget: Optional[int] = -1 ) -> Optional["GuidedDecodingParams"]: if all(arg is None for arg in (json, regex, choice, grammar, json_object, structural_tag)): @@ -71,6 +73,7 @@ def from_optional( backend=backend, whitespace_pattern=whitespace_pattern, structural_tag=structural_tag, + thinking_budget=thinking_budget ) def __post_init__(self): diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index efc5b3012ec..49cebe1536c 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -151,5 +151,11 @@ class SchedulerOutput: # the bitmask for the whole batch grammar_bitmask: Optional[npt.NDArray[np.int32]] + # The budget for the current thinking step. + requests_with_remaining_budget: dict[str, int] + + #thinking token id + end_thinking_token_id: int + # KV Cache Connector metadata. kv_connector_metadata: Optional[KVConnectorMetadata] = None diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 20a40d74f31..e5b6e16b282 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -48,6 +48,7 @@ def __init__( mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, include_finished_set: bool = False, log_stats: bool = False, + special_token_ids: Optional[dict] = None, # NEW ) -> None: self.vllm_config = vllm_config self.scheduler_config = vllm_config.scheduler_config @@ -57,7 +58,9 @@ def __init__( self.kv_events_config = vllm_config.kv_events_config self.log_stats = log_stats self.structured_output_manager = structured_output_manager - + self.special_token_ids = special_token_ids + self.start_thinking_token_id = self.special_token_ids.get("start_token_id") + self.end_thinking_token_id = self.special_token_ids.get("end_token_id") # include_finished_set controls whether a separate set of finished # request ids should be included in the EngineCoreOutputs returned # by update_from_outputs(). This is currently used in the multi-engine @@ -160,6 +163,17 @@ def __init__( enable_kv_cache_events=self.enable_kv_cache_events, ) + + def get_current_usage(self, output_tokens, start_thinking_token_id): + try: + start_thinking_token_index = output_tokens.index( + start_thinking_token_id) + current_usage = len(output_tokens) - start_thinking_token_index - 1 + return current_usage + except ValueError: + # If the start thinking token is not found, return None. + return None + def schedule(self) -> SchedulerOutput: # NOTE(woosuk) on the scheduling algorithm: # There's no "decoding phase" nor "prefill phase" in the scheduler. @@ -536,6 +550,25 @@ def schedule(self) -> SchedulerOutput: structured_output_request_ids, scheduled_spec_decode_tokens, ) + + # dictionary of all request that require forcing of thinking token + # # to be scheduled in this step. + requests_with_remaining_budget: dict[str, int] = {} + for request_id, request in self.requests.items(): + if request.sampling_params.thinking_budget >=0: + thinking_budget_used =self.get_current_usage\ + (request.output_token_ids, + self.start_thinking_token_id) + if thinking_budget_used is not None and request.sampling_params.thinking_budget_used\ + <= request.sampling_params.thinking_budget\ + and self.end_thinking_token_id not in request.output_token_ids\ + and len(request.output_token_ids) >0: + current_remaining_budget = request.sampling_params.thinking_budget\ + - thinking_budget_used + requests_with_remaining_budget[request_id] = current_remaining_budget + + + # Construct the scheduler output. new_reqs_data = [ NewRequestData.from_request(req, diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 453ed364dc8..840a199506d 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -105,6 +105,10 @@ def __init__(self, "compatibility may not be maintained.", vllm_config.scheduler_config.scheduler_cls) + from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager + parser = getattr(self.model_executor, "reasoning_parser", None) + special_token_ids = ReasoningParserManager.get_special_token_ids(parser) + self.scheduler: SchedulerInterface = Scheduler( vllm_config=vllm_config, kv_cache_config=kv_cache_config, @@ -112,6 +116,7 @@ def __init__(self, include_finished_set=vllm_config.parallel_config.data_parallel_size > 1, log_stats=self.log_stats, + special_token_ids=special_token_ids, ) # Setup MM Input Mapper. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 29d39de212f..b97ac94df28 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1153,6 +1153,33 @@ def apply_grammar_bitmask( indices=out_indices, ) + def _force_thinking(self, scheduler_output: "SchedulerOutput", + valid_sampled_token_ids: torch.Tensor): + ''' + Force thinking for reasoning models. + This is used to ensure that the model generates a specific number of tokens + The following function utilizes remaining thinking budget for each request + in order to decide if we need to enforece thinking on valid token ids + for the particular request. Eg: if remaining budget is 2, and + valid tokens for the request are [1, 2, 3, 4, 5], we will + remove the last 3 tokens and append end_thinking_token_id. + Resulting in [1,2,end_thinking_token_id] and the end token is appended + to the end of the valid tokens. + ''' + if scheduler_output.requests_with_remaining_budget: + for req_id, remaining_budget in \ + scheduler_output.requests_with_remaining_budget.items(): + req_index = self.input_batch.req_id_to_index[req_id] + sampled_tokens = valid_sampled_token_ids[req_index] + if len(sampled_tokens) > remaining_budget: + clear_indices = \ + len(sampled_tokens) - remaining_budget + if clear_indices > 0: + del sampled_tokens[-clear_indices:] + sampled_tokens.append(scheduler_output.end_thinking_token_id) + + + def sync_and_slice_intermediate_tensors( self, num_tokens: int, intermediate_tensors: IntermediateTensors, sync_self: bool) -> IntermediateTensors: @@ -1509,6 +1536,10 @@ def execute_model( for i in discard_sampled_tokens_req_indices: valid_sampled_token_ids[i].clear() + self._force_thinking( + scheduler_output, + valid_sampled_token_ids) + if not self.speculative_config: # Speculative decoding is not enabled. spec_token_ids = None