12
12
from torch ._prims_common import DeviceLikeType
13
13
14
14
from vllm import PoolingParams , SamplingParams
15
+ from vllm .config import ReasoningConfig
15
16
from vllm .logger import init_logger
16
17
17
18
logger = init_logger (__name__ )
@@ -24,9 +25,9 @@ class MoveDirectionality(Enum):
24
25
SWAP = 1
25
26
26
27
27
- # (index, params, output_tok_ids) tuples for new
28
+ # (index, params, prompt_tok_ids, output_tok_ids) tuples for new
28
29
# requests added to the batch.
29
- AddedRequest = tuple [int , Union [SamplingParams , PoolingParams ], list [int ]]
30
+ AddedRequest = tuple [int , Union [SamplingParams , PoolingParams ], list [int ], list [ int ] ]
30
31
# (index 1, index 2, directionality) tuples representing
31
32
# one-way moves or two-way swaps of requests in batch
32
33
MovedRequest = tuple [int , int , MoveDirectionality ]
@@ -43,9 +44,9 @@ class BatchUpdate:
43
44
# within the persistent batch.
44
45
#
45
46
# Note: each added request is represented as
46
- # (index, params, output_tok_ids)
47
- # Key assumption: output_tok_ids is a reference to the
48
- # request's running output tokens list; in this way
47
+ # (index, params, prompt_tok_ids, output_tok_ids)
48
+ # Key assumption: prompt_tok_ids, output_tok_ids is a reference to the
49
+ # request's prompt and running output tokens list; in this way
49
50
# the logits processors always see the latest list of
50
51
# generated tokens
51
52
removed : Sequence [RemovedRequest ]
@@ -254,7 +255,7 @@ def update_state(self, batch_update: Optional[BatchUpdate]):
254
255
255
256
needs_update = False
256
257
# Process added requests.
257
- for index , params , _ in batch_update .added :
258
+ for index , params , _ , _ in batch_update .added :
258
259
min_p = params .min_p if isinstance (params , SamplingParams ) else 0.0
259
260
if self .min_p_cpu [index ] != min_p :
260
261
needs_update = True
@@ -329,7 +330,7 @@ def update_state(self, batch_update: Optional[BatchUpdate]):
329
330
330
331
# Process added requests.
331
332
needs_update = bool (batch_update .added )
332
- for index , params , _ in batch_update .added :
333
+ for index , params , _ , _ in batch_update .added :
333
334
if isinstance (params , SamplingParams ) and (lb :=
334
335
params .logit_bias ):
335
336
self .biases [index ] = lb
@@ -412,7 +413,7 @@ def update_state(self, batch_update: Optional[BatchUpdate]):
412
413
if batch_update :
413
414
# Process added requests.
414
415
needs_update |= bool (batch_update .added )
415
- for index , params , output_tok_ids in batch_update .added :
416
+ for index , params , _ , output_tok_ids in batch_update .added :
416
417
if (isinstance (params , SamplingParams )
417
418
and (min_tokens := params .min_tokens )
418
419
and len (output_tok_ids ) < min_tokens ):
@@ -485,8 +486,113 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor:
485
486
return logits
486
487
487
488
489
+ class MaxThinkTokensLogitsProcessor (LogitsProcessor ):
490
+ """A logits processor that limits the maximum number of thinking tokens."""
491
+
492
+ def __init__ (self , reasoning_config : ReasoningConfig , pin_memory : bool , device : torch .device ):
493
+ """
494
+ Args:
495
+ think_start_token_id (int): Token ID for the start of thinking section.
496
+ think_end_token_id (int): Token ID for the end of thinking section.
497
+ pin_memory (bool): Whether to use pinned memory for tensors.
498
+ device (torch.device): Device to use for tensor operations.
499
+ """
500
+ super ().__init__ ()
501
+ self .think_start_token_id = reasoning_config .think_start_token_id
502
+ self .think_end_token_id = reasoning_config .think_end_token_id
503
+ self .pin_memory = pin_memory
504
+ self .device = device
505
+ self ._state = {}
506
+
507
+ def _find_last_token_index (self , tokens , token_id ):
508
+ try :
509
+ return len (tokens ) - tokens [::- 1 ].index (token_id ) - 1
510
+ except ValueError :
511
+ return - 1
512
+
513
+ def is_argmax_invariant (self ) -> bool :
514
+ """This logits processor can change the outcome of greedy sampling
515
+ by forcing that the thinking section ends after a certain number of tokens."""
516
+ return False
517
+
518
+ def update_state (self , batch_update : Optional [BatchUpdate ]):
519
+ if batch_update is None :
520
+ return
521
+
522
+ for index , params , prompt_tok_ids , output_tok_ids in batch_update .added :
523
+ max_think_tokens = params .max_think_tokens if isinstance (params , SamplingParams ) else None
524
+
525
+ if max_think_tokens is None :
526
+ continue
527
+
528
+ last_think_start_idx = self ._find_last_token_index (prompt_tok_ids , self .think_start_token_id )
529
+ last_think_end_idx = self ._find_last_token_index (prompt_tok_ids , self .think_end_token_id )
530
+
531
+ in_think = False
532
+ count = 0
533
+
534
+ if last_think_start_idx > last_think_end_idx :
535
+ in_think = True
536
+ count = len (prompt_tok_ids ) - (last_think_start_idx + 1 )
537
+
538
+ self ._state [index ] = {
539
+ "in_think" : in_think ,
540
+ "count" : count ,
541
+ "prompt_tok_ids" : prompt_tok_ids ,
542
+ "output_tok_ids" : output_tok_ids ,
543
+ "max_think_tokens" : max_think_tokens ,
544
+ }
545
+
546
+ for index in batch_update .removed :
547
+ self ._state .pop (index , None )
548
+
549
+ for i1 , i2 , direction in batch_update .moved :
550
+ if direction == MoveDirectionality .SWAP :
551
+ self ._state [i1 ], self ._state [i2 ] = self ._state [i2 ], self ._state [i1 ]
552
+ else :
553
+ self ._state [i2 ] = self ._state .pop (i1 , None )
554
+
555
+ def apply (self , logits : torch .Tensor ) -> torch .Tensor :
556
+ batch_size = logits .size (0 )
557
+ if batch_size == 0 :
558
+ return logits
559
+
560
+ mask = torch .zeros (batch_size , dtype = torch .bool , device = logits .device )
561
+ end_token_id = self .think_end_token_id
562
+
563
+ for index in range (batch_size ):
564
+ state = self ._state .get (index , None )
565
+ if not state or not state .get ("output_tok_ids" ):
566
+ continue
567
+
568
+ last_tok = state ["output_tok_ids" ][- 1 ]
569
+ in_think = state ["in_think" ]
570
+ count = state ["count" ]
571
+
572
+ if last_tok == self .think_start_token_id :
573
+ in_think = True
574
+ count = 0
575
+ elif last_tok == self .think_end_token_id :
576
+ in_think = False
577
+ count = 0
578
+ elif in_think :
579
+ count += 1
580
+
581
+ state ["in_think" ] = in_think
582
+ state ["count" ] = count
583
+
584
+ if state ["in_think" ] and state ["count" ] >= state ["max_think_tokens" ]:
585
+ mask [index ] = True
586
+
587
+ if mask .any ():
588
+ logits [mask ] = - float ("inf" )
589
+ logits [mask , end_token_id ] = 0.0
590
+
591
+ return logits
592
+
593
+
488
594
def init_builtin_logitsprocs (pin_memory_available : bool , max_num_reqs : int ,
489
- device : torch .device ) -> LogitsProcessorManager :
595
+ device : torch .device , reasoning_config : ReasoningConfig ) -> LogitsProcessorManager :
490
596
"""Construct 'builtin' vLLM logitsprocs which the engine
491
597
loads by default.
492
598
@@ -508,10 +614,16 @@ def init_builtin_logitsprocs(pin_memory_available: bool, max_num_reqs: int,
508
614
device = device ,
509
615
# +1 for temporary swap space
510
616
max_num_reqs = max_num_reqs + 1 )
617
+ max_think_tokens_logitproc = MaxThinkTokensLogitsProcessor (
618
+ reasoning_config = reasoning_config ,
619
+ pin_memory = pin_memory_available ,
620
+ device = device ,
621
+ )
511
622
return LogitsProcessorManager (
512
623
non_argmax_invariant = [
513
624
min_tokens_logitproc ,
514
625
logit_bias_logitproc ,
626
+ max_think_tokens_logitproc
515
627
],
516
628
argmax_invariant = [min_p_logitproc ],
517
629
)
0 commit comments