From c183156e930fb49adf47e7245ad04f3937937baa Mon Sep 17 00:00:00 2001 From: rishitdholakia13 Date: Mon, 14 Jul 2025 22:19:56 +0000 Subject: [PATCH 1/6] Add thinking budget support for speculative and non speculative decode methods --- vllm/v1/core/sched/output.py | 6 +++++ vllm/v1/core/sched/scheduler.py | 39 +++++++++++++++++++++++++++++- vllm/v1/engine/core.py | 7 ++++++ vllm/v1/worker/gpu_model_runner.py | 30 +++++++++++++++++++++++ 4 files changed, 81 insertions(+), 1 deletion(-) diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index efc5b3012ec..49cebe1536c 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -151,5 +151,11 @@ class SchedulerOutput: # the bitmask for the whole batch grammar_bitmask: Optional[npt.NDArray[np.int32]] + # The budget for the current thinking step. + requests_with_remaining_budget: dict[str, int] + + #thinking token id + end_thinking_token_id: int + # KV Cache Connector metadata. kv_connector_metadata: Optional[KVConnectorMetadata] = None diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 20a40d74f31..13d999952c9 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -48,6 +48,8 @@ def __init__( mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, include_finished_set: bool = False, log_stats: bool = False, + special_token_ids: Optional[dict] = None, # NEW + thinking_budget: Optional[int] = None, # NEW ) -> None: self.vllm_config = vllm_config self.scheduler_config = vllm_config.scheduler_config @@ -57,7 +59,10 @@ def __init__( self.kv_events_config = vllm_config.kv_events_config self.log_stats = log_stats self.structured_output_manager = structured_output_manager - + self.special_token_ids = special_token_ids + self.thinking_budget = thinking_budget + self.start_thinking_token_id = self.special_token_ids.get("start_token_id") + self.end_thinking_token_id = self.special_token_ids.get("end_token_id") # include_finished_set controls whether a separate set of finished # request ids should be included in the EngineCoreOutputs returned # by update_from_outputs(). This is currently used in the multi-engine @@ -160,6 +165,17 @@ def __init__( enable_kv_cache_events=self.enable_kv_cache_events, ) + + def get_current_usage(self, output_tokens, start_thinking_token_id): + try: + start_thinking_token_index = output_tokens.index( + start_thinking_token_id) + current_usage = len(output_tokens) - start_thinking_token_index - 1 + return current_usage + except ValueError: + # If the start thinking token is not found, return None. + return None + def schedule(self) -> SchedulerOutput: # NOTE(woosuk) on the scheduling algorithm: # There's no "decoding phase" nor "prefill phase" in the scheduler. @@ -536,6 +552,27 @@ def schedule(self) -> SchedulerOutput: structured_output_request_ids, scheduled_spec_decode_tokens, ) + + # COHERE START + # dictionary of all request that require forcing of thinking token + # # to be scheduled in this step. + requests_with_remaining_budget: dict[str, int] = {} + for request_id, request in self.requests.items(): + if self.thinking_budget >=0: + thinking_budget_used =self.get_current_usage\ + (request.output_token_ids, + self.start_thinking_token_id) + if thinking_budget_used is not None and self.thinking_budget_used\ + <= self.thinking_budget\ + and self.end_thinking_token_id not in request.output_token_ids\ + and len(request.output_token_ids) >0: + current_remaining_budget = self.thinking_budget\ + - thinking_budget_used + requests_with_remaining_budget[request_id] = current_remaining_budget + + #COHERE END + + # Construct the scheduler output. new_reqs_data = [ NewRequestData.from_request(req, diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 453ed364dc8..8ff76f8e75f 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -105,6 +105,11 @@ def __init__(self, "compatibility may not be maintained.", vllm_config.scheduler_config.scheduler_cls) + from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager + parser = getattr(self.model_executor, "reasoning_parser", None) + special_token_ids = ReasoningParserManager.get_special_token_ids(parser) + thinking_budget = getattr(vllm_config, "thinking_budget", None) + self.scheduler: SchedulerInterface = Scheduler( vllm_config=vllm_config, kv_cache_config=kv_cache_config, @@ -112,6 +117,8 @@ def __init__(self, include_finished_set=vllm_config.parallel_config.data_parallel_size > 1, log_stats=self.log_stats, + special_token_ids=special_token_ids, + thinking_budget=thinking_budget, ) # Setup MM Input Mapper. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 29d39de212f..24d5c983bb4 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1153,6 +1153,32 @@ def apply_grammar_bitmask( indices=out_indices, ) + def _force_thinking(self, scheduler_output: "SchedulerOutput", + valid_sampled_token_ids: torch.Tensor): + ''' + Force thinking for Cohere models. + This is used to ensure that the model generates a specific number of tokens + The following function utilizes remaining thinking budget for each request + in order to decide if we need to enforece thinking on valid token ids + for the particular request. Eg: if remaining budget is 2, and + valid tokens for the request are [1, 2, 3, 4, 5], we will + remove the last 3 tokens and append end_thinking_token_id. + Resulting in [1,2,end_thinking_token_id] + ''' + if scheduler_output.requests_with_remaining_budget: + for req_id, remaining_budget in \ + scheduler_output.requests_with_remaining_budget.items(): + req_index = self.input_batch.req_id_to_index[req_id] + sampled_tokens = valid_sampled_token_ids[req_index] + if len(sampled_tokens) > remaining_budget: + clear_indices = \ + len(sampled_tokens) - remaining_budget + if clear_indices > 0: + del sampled_tokens[-clear_indices:] + sampled_tokens.append(scheduler_output.end_thinking_token_id) + + + def sync_and_slice_intermediate_tensors( self, num_tokens: int, intermediate_tensors: IntermediateTensors, sync_self: bool) -> IntermediateTensors: @@ -1509,6 +1535,10 @@ def execute_model( for i in discard_sampled_tokens_req_indices: valid_sampled_token_ids[i].clear() + self._force_thinking( + scheduler_output, + valid_sampled_token_ids) + if not self.speculative_config: # Speculative decoding is not enabled. spec_token_ids = None From f9a47924142456df3b3a22fd4c2bad8f153cdace Mon Sep 17 00:00:00 2001 From: rishitdholakia13 Date: Mon, 14 Jul 2025 22:31:27 +0000 Subject: [PATCH 2/6] Add special tokens function in reasoning parser --- vllm/reasoning/abs_reasoning_parsers.py | 20 +++++++++++++++++++ .../reasoning/deepseek_r1_reasoning_parser.py | 11 ++++++++++ 2 files changed, 31 insertions(+) diff --git a/vllm/reasoning/abs_reasoning_parsers.py b/vllm/reasoning/abs_reasoning_parsers.py index e827d381ca1..394d112894b 100644 --- a/vllm/reasoning/abs_reasoning_parsers.py +++ b/vllm/reasoning/abs_reasoning_parsers.py @@ -35,6 +35,16 @@ def vocab(self) -> dict[str, int]: # whereas all tokenizers have .get_vocab() return self.model_tokenizer.get_vocab() + + @property + def special_token_ids(self) -> dict: + """ + Returns a dictionary of special token IDs for model-agnostic access. + Example: {"start_token_id": int, "end_token_id": int} + """ + # Default: not implemented, override in subclasses if needed + return {} + @abstractmethod def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: """ @@ -121,6 +131,16 @@ def get_reasoning_parser(cls, name: str | None) -> type[ReasoningParser]: raise KeyError( f"reasoning helper: '{name}' not found in reasoning_parsers") + @staticmethod + def get_special_token_ids(parser) -> dict: + """ + Get special token IDs from a parser instance, if available. + Returns an empty dict if not present. + """ + if parser is not None and hasattr(parser, "special_token_ids"): + return parser.special_token_ids + return {} + @classmethod def _register_module( cls, diff --git a/vllm/reasoning/deepseek_r1_reasoning_parser.py b/vllm/reasoning/deepseek_r1_reasoning_parser.py index 1a5ca46a60f..8e2c772ab29 100644 --- a/vllm/reasoning/deepseek_r1_reasoning_parser.py +++ b/vllm/reasoning/deepseek_r1_reasoning_parser.py @@ -44,6 +44,17 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase): "DeepSeek R1 reasoning parser could not locate think start/end " "tokens in the tokenizer!") + @property + def special_token_ids(self) -> dict: + """ + Returns a dictionary of special token IDs for model-agnostic access. + Example: {"start_token_id": int, "end_token_id": int} + """ + return { + "start_token_id": self.start_token_id, + "end_token_id": self.end_token_id, + } + def is_reasoning_end(self, input_ids: list[int]) -> bool: return self.end_token_id in input_ids From 35485daf986d6965368f0bf11ee67caa5dba694e Mon Sep 17 00:00:00 2001 From: rishitdholakia13 Date: Mon, 14 Jul 2025 22:48:55 +0000 Subject: [PATCH 3/6] Move thinking budget to sampling parameters --- vllm/entrypoints/openai/protocol.py | 4 +++- vllm/sampling_params.py | 5 ++++- vllm/v1/core/sched/scheduler.py | 8 ++++---- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 3b5281962b2..1fc92e6887b 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -272,7 +272,8 @@ class ChatCompletionRequest(OpenAIBaseModel): truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None prompt_logprobs: Optional[int] = None allowed_token_ids: Optional[list[int]] = None - bad_words: list[str] = Field(default_factory=list) + bad_words: list[str] = Field(default_factory=list), + thinking_budget: Optional[int] = -1 # --8<-- [end:chat-completion-sampling-params] # --8<-- [start:chat-completion-extra-params] @@ -563,6 +564,7 @@ def to_sampling_params( guided_decoding=guided_decoding, logit_bias=self.logit_bias, bad_words= self.bad_words, + thinking_budget=self.thinking_budget, allowed_token_ids=self.allowed_token_ids, extra_args=extra_args or None, ) diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index a9a862384d1..af8cb50fa1c 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -43,7 +43,8 @@ class GuidedDecodingParams: disable_any_whitespace: bool = False disable_additional_properties: bool = False whitespace_pattern: Optional[str] = None - structural_tag: Optional[str] = None + structural_tag: Optional[str] = None, + thinking_budget: Optional[int] = -1 @staticmethod def from_optional( @@ -55,6 +56,7 @@ def from_optional( backend: Optional[str] = None, whitespace_pattern: Optional[str] = None, structural_tag: Optional[str] = None, + thinking_budget: Optional[int] = -1 ) -> Optional["GuidedDecodingParams"]: if all(arg is None for arg in (json, regex, choice, grammar, json_object, structural_tag)): @@ -71,6 +73,7 @@ def from_optional( backend=backend, whitespace_pattern=whitespace_pattern, structural_tag=structural_tag, + thinking_budget=thinking_budget ) def __post_init__(self): diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 13d999952c9..cba1fda7c43 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -558,15 +558,15 @@ def schedule(self) -> SchedulerOutput: # # to be scheduled in this step. requests_with_remaining_budget: dict[str, int] = {} for request_id, request in self.requests.items(): - if self.thinking_budget >=0: + if request.sampling_params.thinking_budget >=0: thinking_budget_used =self.get_current_usage\ (request.output_token_ids, self.start_thinking_token_id) - if thinking_budget_used is not None and self.thinking_budget_used\ - <= self.thinking_budget\ + if thinking_budget_used is not None and request.sampling_params.thinking_budget_used\ + <= request.sampling_params.thinking_budget\ and self.end_thinking_token_id not in request.output_token_ids\ and len(request.output_token_ids) >0: - current_remaining_budget = self.thinking_budget\ + current_remaining_budget = request.sampling_params.thinking_budget\ - thinking_budget_used requests_with_remaining_budget[request_id] = current_remaining_budget From 98e0b70973b8a51c281f0ada51b68b3bff9116d8 Mon Sep 17 00:00:00 2001 From: rishitdholakia13 Date: Mon, 14 Jul 2025 23:12:05 +0000 Subject: [PATCH 4/6] Cover a corner case in remaining budget --- vllm/v1/core/sched/scheduler.py | 2 -- vllm/v1/engine/core.py | 2 -- vllm/v1/worker/gpu_model_runner.py | 11 +++++++---- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index cba1fda7c43..ebfb8a6328e 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -49,7 +49,6 @@ def __init__( include_finished_set: bool = False, log_stats: bool = False, special_token_ids: Optional[dict] = None, # NEW - thinking_budget: Optional[int] = None, # NEW ) -> None: self.vllm_config = vllm_config self.scheduler_config = vllm_config.scheduler_config @@ -60,7 +59,6 @@ def __init__( self.log_stats = log_stats self.structured_output_manager = structured_output_manager self.special_token_ids = special_token_ids - self.thinking_budget = thinking_budget self.start_thinking_token_id = self.special_token_ids.get("start_token_id") self.end_thinking_token_id = self.special_token_ids.get("end_token_id") # include_finished_set controls whether a separate set of finished diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 8ff76f8e75f..840a199506d 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -108,7 +108,6 @@ def __init__(self, from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager parser = getattr(self.model_executor, "reasoning_parser", None) special_token_ids = ReasoningParserManager.get_special_token_ids(parser) - thinking_budget = getattr(vllm_config, "thinking_budget", None) self.scheduler: SchedulerInterface = Scheduler( vllm_config=vllm_config, @@ -118,7 +117,6 @@ def __init__(self, > 1, log_stats=self.log_stats, special_token_ids=special_token_ids, - thinking_budget=thinking_budget, ) # Setup MM Input Mapper. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 24d5c983bb4..5fbb9f62bf5 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1156,26 +1156,29 @@ def apply_grammar_bitmask( def _force_thinking(self, scheduler_output: "SchedulerOutput", valid_sampled_token_ids: torch.Tensor): ''' - Force thinking for Cohere models. + Force thinking for reasoning models. This is used to ensure that the model generates a specific number of tokens The following function utilizes remaining thinking budget for each request in order to decide if we need to enforece thinking on valid token ids for the particular request. Eg: if remaining budget is 2, and valid tokens for the request are [1, 2, 3, 4, 5], we will remove the last 3 tokens and append end_thinking_token_id. - Resulting in [1,2,end_thinking_token_id] - ''' + Resulting in [1,2,end_thinking_token_id] and the end token is appended + to the end of the valid tokens. + ''' if scheduler_output.requests_with_remaining_budget: for req_id, remaining_budget in \ scheduler_output.requests_with_remaining_budget.items(): req_index = self.input_batch.req_id_to_index[req_id] sampled_tokens = valid_sampled_token_ids[req_index] - if len(sampled_tokens) > remaining_budget: + if len(sampled_tokens) >= remaining_budget: clear_indices = \ len(sampled_tokens) - remaining_budget if clear_indices > 0: del sampled_tokens[-clear_indices:] sampled_tokens.append(scheduler_output.end_thinking_token_id) + elif remaining_budget == 0: + sampled_tokens.append(scheduler_output.end_thinking_token_id) From f80f8c4c9777e9274f4e26686eef19dabf3071ba Mon Sep 17 00:00:00 2001 From: rishitdholakia13 Date: Mon, 14 Jul 2025 23:22:09 +0000 Subject: [PATCH 5/6] Simplify the corner case --- vllm/v1/worker/gpu_model_runner.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5fbb9f62bf5..b97ac94df28 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1171,14 +1171,12 @@ def _force_thinking(self, scheduler_output: "SchedulerOutput", scheduler_output.requests_with_remaining_budget.items(): req_index = self.input_batch.req_id_to_index[req_id] sampled_tokens = valid_sampled_token_ids[req_index] - if len(sampled_tokens) >= remaining_budget: + if len(sampled_tokens) > remaining_budget: clear_indices = \ len(sampled_tokens) - remaining_budget if clear_indices > 0: del sampled_tokens[-clear_indices:] sampled_tokens.append(scheduler_output.end_thinking_token_id) - elif remaining_budget == 0: - sampled_tokens.append(scheduler_output.end_thinking_token_id) From b64aa6ef16c152202f42be69fbecca7d9f53dab9 Mon Sep 17 00:00:00 2001 From: rishitdholakia13 Date: Tue, 15 Jul 2025 00:38:53 +0000 Subject: [PATCH 6/6] Remove comments --- vllm/v1/core/sched/scheduler.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index ebfb8a6328e..e5b6e16b282 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -551,7 +551,6 @@ def schedule(self) -> SchedulerOutput: scheduled_spec_decode_tokens, ) - # COHERE START # dictionary of all request that require forcing of thinking token # # to be scheduled in this step. requests_with_remaining_budget: dict[str, int] = {} @@ -567,8 +566,7 @@ def schedule(self) -> SchedulerOutput: current_remaining_budget = request.sampling_params.thinking_budget\ - thinking_budget_used requests_with_remaining_budget[request_id] = current_remaining_budget - - #COHERE END + # Construct the scheduler output.