Skip to content

[Reasoning] Add thinking budget support #20949

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions vllm/v1/core/sched/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,5 +151,11 @@ class SchedulerOutput:
# the bitmask for the whole batch
grammar_bitmask: Optional[npt.NDArray[np.int32]]

# The budget for the current thinking step.
requests_with_remaining_budget: dict[str, int]

#thinking token id
end_thinking_token_id: int

# KV Cache Connector metadata.
kv_connector_metadata: Optional[KVConnectorMetadata] = None
39 changes: 38 additions & 1 deletion vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def __init__(
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
include_finished_set: bool = False,
log_stats: bool = False,
special_token_ids: Optional[dict] = None, # NEW
thinking_budget: Optional[int] = None, # NEW
) -> None:
self.vllm_config = vllm_config
self.scheduler_config = vllm_config.scheduler_config
Expand All @@ -57,7 +59,10 @@ def __init__(
self.kv_events_config = vllm_config.kv_events_config
self.log_stats = log_stats
self.structured_output_manager = structured_output_manager

self.special_token_ids = special_token_ids
self.thinking_budget = thinking_budget
self.start_thinking_token_id = self.special_token_ids.get("start_token_id")
self.end_thinking_token_id = self.special_token_ids.get("end_token_id")
# include_finished_set controls whether a separate set of finished
# request ids should be included in the EngineCoreOutputs returned
# by update_from_outputs(). This is currently used in the multi-engine
Expand Down Expand Up @@ -160,6 +165,17 @@ def __init__(
enable_kv_cache_events=self.enable_kv_cache_events,
)


def get_current_usage(self, output_tokens, start_thinking_token_id):
try:
start_thinking_token_index = output_tokens.index(
start_thinking_token_id)
current_usage = len(output_tokens) - start_thinking_token_index - 1
return current_usage
except ValueError:
# If the start thinking token is not found, return None.
return None

def schedule(self) -> SchedulerOutput:
# NOTE(woosuk) on the scheduling algorithm:
# There's no "decoding phase" nor "prefill phase" in the scheduler.
Expand Down Expand Up @@ -536,6 +552,27 @@ def schedule(self) -> SchedulerOutput:
structured_output_request_ids,
scheduled_spec_decode_tokens,
)

# COHERE START
# dictionary of all request that require forcing of thinking token
# # to be scheduled in this step.
requests_with_remaining_budget: dict[str, int] = {}
for request_id, request in self.requests.items():
if self.thinking_budget >=0:
thinking_budget_used =self.get_current_usage\
(request.output_token_ids,
self.start_thinking_token_id)
if thinking_budget_used is not None and self.thinking_budget_used\
<= self.thinking_budget\
and self.end_thinking_token_id not in request.output_token_ids\
and len(request.output_token_ids) >0:
current_remaining_budget = self.thinking_budget\
- thinking_budget_used
requests_with_remaining_budget[request_id] = current_remaining_budget

#COHERE END


# Construct the scheduler output.
new_reqs_data = [
NewRequestData.from_request(req,
Expand Down
7 changes: 7 additions & 0 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,20 @@ def __init__(self,
"compatibility may not be maintained.",
vllm_config.scheduler_config.scheduler_cls)

from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager
parser = getattr(self.model_executor, "reasoning_parser", None)
special_token_ids = ReasoningParserManager.get_special_token_ids(parser)
thinking_budget = getattr(vllm_config, "thinking_budget", None)

self.scheduler: SchedulerInterface = Scheduler(
vllm_config=vllm_config,
kv_cache_config=kv_cache_config,
structured_output_manager=self.structured_output_manager,
include_finished_set=vllm_config.parallel_config.data_parallel_size
> 1,
log_stats=self.log_stats,
special_token_ids=special_token_ids,
thinking_budget=thinking_budget,
)

# Setup MM Input Mapper.
Expand Down
30 changes: 30 additions & 0 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1153,6 +1153,32 @@ def apply_grammar_bitmask(
indices=out_indices,
)

def _force_thinking(self, scheduler_output: "SchedulerOutput",
valid_sampled_token_ids: torch.Tensor):
'''
Force thinking for Cohere models.
This is used to ensure that the model generates a specific number of tokens
The following function utilizes remaining thinking budget for each request
in order to decide if we need to enforece thinking on valid token ids
for the particular request. Eg: if remaining budget is 2, and
valid tokens for the request are [1, 2, 3, 4, 5], we will
remove the last 3 tokens and append end_thinking_token_id.
Resulting in [1,2,end_thinking_token_id]
'''
if scheduler_output.requests_with_remaining_budget:
for req_id, remaining_budget in \
scheduler_output.requests_with_remaining_budget.items():
req_index = self.input_batch.req_id_to_index[req_id]
sampled_tokens = valid_sampled_token_ids[req_index]
if len(sampled_tokens) > remaining_budget:
clear_indices = \
len(sampled_tokens) - remaining_budget
if clear_indices > 0:
del sampled_tokens[-clear_indices:]
sampled_tokens.append(scheduler_output.end_thinking_token_id)



def sync_and_slice_intermediate_tensors(
self, num_tokens: int, intermediate_tensors: IntermediateTensors,
sync_self: bool) -> IntermediateTensors:
Expand Down Expand Up @@ -1509,6 +1535,10 @@ def execute_model(
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear()

self._force_thinking(
scheduler_output,
valid_sampled_token_ids)

if not self.speculative_config:
# Speculative decoding is not enabled.
spec_token_ids = None
Expand Down