12
12
from torch ._prims_common import DeviceLikeType
13
13
14
14
from vllm import PoolingParams , SamplingParams
15
+ from vllm .config import ReasoningConfig
15
16
from vllm .logger import init_logger
16
17
17
18
logger = init_logger (__name__ )
@@ -24,9 +25,9 @@ class MoveDirectionality(Enum):
24
25
SWAP = 1
25
26
26
27
27
- # (index, params, output_tok_ids) tuples for new
28
+ # (index, params, prompt_tok_ids, output_tok_ids) tuples for new
28
29
# requests added to the batch.
29
- AddedRequest = tuple [int , Union [SamplingParams , PoolingParams ], list [int ]]
30
+ AddedRequest = tuple [int , Union [SamplingParams , PoolingParams ], list [int ], list [ int ] ]
30
31
# (index 1, index 2, directionality) tuples representing
31
32
# one-way moves or two-way swaps of requests in batch
32
33
MovedRequest = tuple [int , int , MoveDirectionality ]
@@ -43,9 +44,9 @@ class BatchUpdate:
43
44
# within the persistent batch.
44
45
#
45
46
# Note: each added request is represented as
46
- # (index, params, output_tok_ids)
47
- # Key assumption: output_tok_ids is a reference to the
48
- # request's running output tokens list; in this way
47
+ # (index, params, prompt_tok_ids, output_tok_ids)
48
+ # Key assumption: prompt_tok_ids, output_tok_ids is a reference to the
49
+ # request's prompt and running output tokens list; in this way
49
50
# the logits processors always see the latest list of
50
51
# generated tokens
51
52
removed : Sequence [RemovedRequest ]
@@ -260,7 +261,7 @@ def update_state(self, batch_update: Optional[BatchUpdate]):
260
261
261
262
needs_update = False
262
263
# Process added requests.
263
- for index , params , _ in batch_update .added :
264
+ for index , params , _ , _ in batch_update .added :
264
265
min_p = params .min_p if isinstance (params , SamplingParams ) else 0.0
265
266
if self .min_p_cpu [index ] != min_p :
266
267
needs_update = True
@@ -337,7 +338,7 @@ def update_state(self, batch_update: Optional[BatchUpdate]):
337
338
338
339
# Process added requests.
339
340
needs_update = bool (batch_update .added )
340
- for index , params , _ in batch_update .added :
341
+ for index , params , _ , _ in batch_update .added :
341
342
if isinstance (params , SamplingParams ) and (lb :=
342
343
params .logit_bias ):
343
344
self .biases [index ] = lb
@@ -420,7 +421,7 @@ def update_state(self, batch_update: Optional[BatchUpdate]):
420
421
if batch_update :
421
422
# Process added requests.
422
423
needs_update |= bool (batch_update .added )
423
- for index , params , output_tok_ids in batch_update .added :
424
+ for index , params , _ , output_tok_ids in batch_update .added :
424
425
if (isinstance (params , SamplingParams )
425
426
and (min_tokens := params .min_tokens )
426
427
and len (output_tok_ids ) < min_tokens ):
@@ -493,8 +494,113 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor:
493
494
return logits
494
495
495
496
497
+ class MaxThinkTokensLogitsProcessor (LogitsProcessor ):
498
+ """A logits processor that limits the maximum number of thinking tokens."""
499
+
500
+ def __init__ (self , reasoning_config : ReasoningConfig , pin_memory : bool , device : torch .device ):
501
+ """
502
+ Args:
503
+ think_start_token_id (int): Token ID for the start of thinking section.
504
+ think_end_token_id (int): Token ID for the end of thinking section.
505
+ pin_memory (bool): Whether to use pinned memory for tensors.
506
+ device (torch.device): Device to use for tensor operations.
507
+ """
508
+ super ().__init__ ()
509
+ self .think_start_token_id = reasoning_config .think_start_token_id
510
+ self .think_end_token_id = reasoning_config .think_end_token_id
511
+ self .pin_memory = pin_memory
512
+ self .device = device
513
+ self ._state = {}
514
+
515
+ def _find_last_token_index (self , tokens , token_id ):
516
+ try :
517
+ return len (tokens ) - tokens [::- 1 ].index (token_id ) - 1
518
+ except ValueError :
519
+ return - 1
520
+
521
+ def is_argmax_invariant (self ) -> bool :
522
+ """This logits processor can change the outcome of greedy sampling
523
+ by forcing that the thinking section ends after a certain number of tokens."""
524
+ return False
525
+
526
+ def update_state (self , batch_update : Optional [BatchUpdate ]):
527
+ if batch_update is None :
528
+ return
529
+
530
+ for index , params , prompt_tok_ids , output_tok_ids in batch_update .added :
531
+ max_think_tokens = params .max_think_tokens if isinstance (params , SamplingParams ) else None
532
+
533
+ if max_think_tokens is None :
534
+ continue
535
+
536
+ last_think_start_idx = self ._find_last_token_index (prompt_tok_ids , self .think_start_token_id )
537
+ last_think_end_idx = self ._find_last_token_index (prompt_tok_ids , self .think_end_token_id )
538
+
539
+ in_think = False
540
+ count = 0
541
+
542
+ if last_think_start_idx > last_think_end_idx :
543
+ in_think = True
544
+ count = len (prompt_tok_ids ) - (last_think_start_idx + 1 )
545
+
546
+ self ._state [index ] = {
547
+ "in_think" : in_think ,
548
+ "count" : count ,
549
+ "prompt_tok_ids" : prompt_tok_ids ,
550
+ "output_tok_ids" : output_tok_ids ,
551
+ "max_think_tokens" : max_think_tokens ,
552
+ }
553
+
554
+ for index in batch_update .removed :
555
+ self ._state .pop (index , None )
556
+
557
+ for i1 , i2 , direction in batch_update .moved :
558
+ if direction == MoveDirectionality .SWAP :
559
+ self ._state [i1 ], self ._state [i2 ] = self ._state [i2 ], self ._state [i1 ]
560
+ else :
561
+ self ._state [i2 ] = self ._state .pop (i1 , None )
562
+
563
+ def apply (self , logits : torch .Tensor ) -> torch .Tensor :
564
+ batch_size = logits .size (0 )
565
+ if batch_size == 0 :
566
+ return logits
567
+
568
+ mask = torch .zeros (batch_size , dtype = torch .bool , device = logits .device )
569
+ end_token_id = self .think_end_token_id
570
+
571
+ for index in range (batch_size ):
572
+ state = self ._state .get (index , None )
573
+ if not state or not state .get ("output_tok_ids" ):
574
+ continue
575
+
576
+ last_tok = state ["output_tok_ids" ][- 1 ]
577
+ in_think = state ["in_think" ]
578
+ count = state ["count" ]
579
+
580
+ if last_tok == self .think_start_token_id :
581
+ in_think = True
582
+ count = 0
583
+ elif last_tok == self .think_end_token_id :
584
+ in_think = False
585
+ count = 0
586
+ elif in_think :
587
+ count += 1
588
+
589
+ state ["in_think" ] = in_think
590
+ state ["count" ] = count
591
+
592
+ if state ["in_think" ] and state ["count" ] >= state ["max_think_tokens" ]:
593
+ mask [index ] = True
594
+
595
+ if mask .any ():
596
+ logits [mask ] = - float ("inf" )
597
+ logits [mask , end_token_id ] = 0.0
598
+
599
+ return logits
600
+
601
+
496
602
def init_builtin_logitsprocs (pin_memory_available : bool , max_num_reqs : int ,
497
- device : torch .device ) -> LogitsProcessorManager :
603
+ device : torch .device , reasoning_config : ReasoningConfig ) -> LogitsProcessorManager :
498
604
"""Construct 'builtin' vLLM logitsprocs which the engine
499
605
loads by default.
500
606
@@ -516,10 +622,16 @@ def init_builtin_logitsprocs(pin_memory_available: bool, max_num_reqs: int,
516
622
device = device ,
517
623
# +1 for temporary swap space
518
624
max_num_reqs = max_num_reqs + 1 )
625
+ max_think_tokens_logitproc = MaxThinkTokensLogitsProcessor (
626
+ reasoning_config = reasoning_config ,
627
+ pin_memory = pin_memory_available ,
628
+ device = device ,
629
+ )
519
630
return LogitsProcessorManager (
520
631
non_argmax_invariant = [
521
632
min_tokens_logitproc ,
522
633
logit_bias_logitproc ,
634
+ max_think_tokens_logitproc
523
635
],
524
636
argmax_invariant = [min_p_logitproc ],
525
637
)
0 commit comments