Skip to content

[Reasoning] Add thinking budget support #20949

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
prompt_logprobs: Optional[int] = None
allowed_token_ids: Optional[list[int]] = None
bad_words: list[str] = Field(default_factory=list)
bad_words: list[str] = Field(default_factory=list),
thinking_budget: Optional[int] = -1
# --8<-- [end:chat-completion-sampling-params]

# --8<-- [start:chat-completion-extra-params]
Expand Down Expand Up @@ -563,6 +564,7 @@ def to_sampling_params(
guided_decoding=guided_decoding,
logit_bias=self.logit_bias,
bad_words= self.bad_words,
thinking_budget=self.thinking_budget,
allowed_token_ids=self.allowed_token_ids,
extra_args=extra_args or None,
)
Expand Down
20 changes: 20 additions & 0 deletions vllm/reasoning/abs_reasoning_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,16 @@ def vocab(self) -> dict[str, int]:
# whereas all tokenizers have .get_vocab()
return self.model_tokenizer.get_vocab()


@property
def special_token_ids(self) -> dict:
"""
Returns a dictionary of special token IDs for model-agnostic access.
Example: {"start_token_id": int, "end_token_id": int}
"""
# Default: not implemented, override in subclasses if needed
return {}

@abstractmethod
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
"""
Expand Down Expand Up @@ -121,6 +131,16 @@ def get_reasoning_parser(cls, name: str | None) -> type[ReasoningParser]:
raise KeyError(
f"reasoning helper: '{name}' not found in reasoning_parsers")

@staticmethod
def get_special_token_ids(parser) -> dict:
"""
Get special token IDs from a parser instance, if available.
Returns an empty dict if not present.
"""
if parser is not None and hasattr(parser, "special_token_ids"):
return parser.special_token_ids
return {}

@classmethod
def _register_module(
cls,
Expand Down
11 changes: 11 additions & 0 deletions vllm/reasoning/deepseek_r1_reasoning_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,17 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase):
"DeepSeek R1 reasoning parser could not locate think start/end "
"tokens in the tokenizer!")

@property
def special_token_ids(self) -> dict:
"""
Returns a dictionary of special token IDs for model-agnostic access.
Example: {"start_token_id": int, "end_token_id": int}
"""
return {
"start_token_id": self.start_token_id,
"end_token_id": self.end_token_id,
}

def is_reasoning_end(self, input_ids: list[int]) -> bool:
return self.end_token_id in input_ids

Expand Down
5 changes: 4 additions & 1 deletion vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ class GuidedDecodingParams:
disable_any_whitespace: bool = False
disable_additional_properties: bool = False
whitespace_pattern: Optional[str] = None
structural_tag: Optional[str] = None
structural_tag: Optional[str] = None,
thinking_budget: Optional[int] = -1

@staticmethod
def from_optional(
Expand All @@ -55,6 +56,7 @@ def from_optional(
backend: Optional[str] = None,
whitespace_pattern: Optional[str] = None,
structural_tag: Optional[str] = None,
thinking_budget: Optional[int] = -1
) -> Optional["GuidedDecodingParams"]:
if all(arg is None for arg in (json, regex, choice, grammar,
json_object, structural_tag)):
Expand All @@ -71,6 +73,7 @@ def from_optional(
backend=backend,
whitespace_pattern=whitespace_pattern,
structural_tag=structural_tag,
thinking_budget=thinking_budget
)

def __post_init__(self):
Expand Down
6 changes: 6 additions & 0 deletions vllm/v1/core/sched/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,5 +151,11 @@ class SchedulerOutput:
# the bitmask for the whole batch
grammar_bitmask: Optional[npt.NDArray[np.int32]]

# The budget for the current thinking step.
requests_with_remaining_budget: dict[str, int]

#thinking token id
end_thinking_token_id: int

# KV Cache Connector metadata.
kv_connector_metadata: Optional[KVConnectorMetadata] = None
35 changes: 34 additions & 1 deletion vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
include_finished_set: bool = False,
log_stats: bool = False,
special_token_ids: Optional[dict] = None, # NEW
) -> None:
self.vllm_config = vllm_config
self.scheduler_config = vllm_config.scheduler_config
Expand All @@ -57,7 +58,9 @@ def __init__(
self.kv_events_config = vllm_config.kv_events_config
self.log_stats = log_stats
self.structured_output_manager = structured_output_manager

self.special_token_ids = special_token_ids
self.start_thinking_token_id = self.special_token_ids.get("start_token_id")
self.end_thinking_token_id = self.special_token_ids.get("end_token_id")
# include_finished_set controls whether a separate set of finished
# request ids should be included in the EngineCoreOutputs returned
# by update_from_outputs(). This is currently used in the multi-engine
Expand Down Expand Up @@ -160,6 +163,17 @@ def __init__(
enable_kv_cache_events=self.enable_kv_cache_events,
)


def get_current_usage(self, output_tokens, start_thinking_token_id):
try:
start_thinking_token_index = output_tokens.index(
start_thinking_token_id)
current_usage = len(output_tokens) - start_thinking_token_index - 1
return current_usage
except ValueError:
# If the start thinking token is not found, return None.
return None

def schedule(self) -> SchedulerOutput:
# NOTE(woosuk) on the scheduling algorithm:
# There's no "decoding phase" nor "prefill phase" in the scheduler.
Expand Down Expand Up @@ -536,6 +550,25 @@ def schedule(self) -> SchedulerOutput:
structured_output_request_ids,
scheduled_spec_decode_tokens,
)

# dictionary of all request that require forcing of thinking token
# # to be scheduled in this step.
requests_with_remaining_budget: dict[str, int] = {}
for request_id, request in self.requests.items():
if request.sampling_params.thinking_budget >=0:
thinking_budget_used =self.get_current_usage\
(request.output_token_ids,
self.start_thinking_token_id)
if thinking_budget_used is not None and request.sampling_params.thinking_budget_used\
<= request.sampling_params.thinking_budget\
and self.end_thinking_token_id not in request.output_token_ids\
and len(request.output_token_ids) >0:
current_remaining_budget = request.sampling_params.thinking_budget\
- thinking_budget_used
requests_with_remaining_budget[request_id] = current_remaining_budget



# Construct the scheduler output.
new_reqs_data = [
NewRequestData.from_request(req,
Expand Down
5 changes: 5 additions & 0 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,18 @@ def __init__(self,
"compatibility may not be maintained.",
vllm_config.scheduler_config.scheduler_cls)

from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager
parser = getattr(self.model_executor, "reasoning_parser", None)
special_token_ids = ReasoningParserManager.get_special_token_ids(parser)

self.scheduler: SchedulerInterface = Scheduler(
vllm_config=vllm_config,
kv_cache_config=kv_cache_config,
structured_output_manager=self.structured_output_manager,
include_finished_set=vllm_config.parallel_config.data_parallel_size
> 1,
log_stats=self.log_stats,
special_token_ids=special_token_ids,
)

# Setup MM Input Mapper.
Expand Down
31 changes: 31 additions & 0 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1153,6 +1153,33 @@ def apply_grammar_bitmask(
indices=out_indices,
)

def _force_thinking(self, scheduler_output: "SchedulerOutput",
valid_sampled_token_ids: torch.Tensor):
'''
Force thinking for reasoning models.
This is used to ensure that the model generates a specific number of tokens
The following function utilizes remaining thinking budget for each request
in order to decide if we need to enforece thinking on valid token ids
for the particular request. Eg: if remaining budget is 2, and
valid tokens for the request are [1, 2, 3, 4, 5], we will
remove the last 3 tokens and append end_thinking_token_id.
Resulting in [1,2,end_thinking_token_id] and the end token is appended
to the end of the valid tokens.
'''
if scheduler_output.requests_with_remaining_budget:
for req_id, remaining_budget in \
scheduler_output.requests_with_remaining_budget.items():
req_index = self.input_batch.req_id_to_index[req_id]
sampled_tokens = valid_sampled_token_ids[req_index]
if len(sampled_tokens) > remaining_budget:
clear_indices = \
len(sampled_tokens) - remaining_budget
if clear_indices > 0:
del sampled_tokens[-clear_indices:]
sampled_tokens.append(scheduler_output.end_thinking_token_id)



def sync_and_slice_intermediate_tensors(
self, num_tokens: int, intermediate_tensors: IntermediateTensors,
sync_self: bool) -> IntermediateTensors:
Expand Down Expand Up @@ -1509,6 +1536,10 @@ def execute_model(
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear()

self._force_thinking(
scheduler_output,
valid_sampled_token_ids)

if not self.speculative_config:
# Speculative decoding is not enabled.
spec_token_ids = None
Expand Down