@@ -27,7 +27,8 @@ class MoveDirectionality(Enum):
27
27
28
28
# (index, params, prompt_tok_ids, output_tok_ids) tuples for new
29
29
# requests added to the batch.
30
- AddedRequest = tuple [int , Union [SamplingParams , PoolingParams ], list [int ], list [int ]]
30
+ AddedRequest = tuple [int , Union [SamplingParams , PoolingParams ], list [int ],
31
+ list [int ]]
31
32
# (index 1, index 2, directionality) tuples representing
32
33
# one-way moves or two-way swaps of requests in batch
33
34
MovedRequest = tuple [int , int , MoveDirectionality ]
@@ -497,13 +498,14 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor:
497
498
class MaxThinkTokensLogitsProcessor (LogitsProcessor ):
498
499
"""A logits processor that limits the maximum number of thinking tokens."""
499
500
500
- def __init__ (self , reasoning_config : ReasoningConfig , pin_memory : bool , device : torch .device ):
501
+ def __init__ (self , reasoning_config : ReasoningConfig , pin_memory : bool ,
502
+ device : torch .device ):
501
503
"""
502
504
Args:
503
- think_start_token_id (int): Token ID for the start of thinking section.
504
- think_end_token_id (int): Token ID for the end of thinking section .
505
- pin_memory (bool): Whether to use pinned memory for tensors.
506
- device (torch.device): Device to use for tensor operations.
505
+ reasoning_config: Configuration for reasoning, which includes
506
+ the token IDs for thinking start and end .
507
+ pin_memory (bool): Whether to use pinned memory for tensors.
508
+ device (torch.device): Device to use for tensor operations.
507
509
"""
508
510
super ().__init__ ()
509
511
self .think_start_token_id = reasoning_config .think_start_token_id
@@ -519,19 +521,25 @@ def _find_last_token_index(self, tokens: list[int], token_id: int) -> int:
519
521
return - 1
520
522
521
523
def is_argmax_invariant (self ) -> bool :
522
- """This logits processor can change the outcome of greedy sampling
523
- by forcing that the thinking section ends after a certain number of tokens."""
524
+ """This logits processor can change the outcome of
525
+ greedy sampling by forcing that the thinking section
526
+ ends after a certain number of tokens."""
524
527
return False
525
528
526
529
def update_state (self , batch_update : Optional [BatchUpdate ]):
527
530
if batch_update :
528
- for index , params , prompt_tok_ids , output_tok_ids in batch_update .added :
529
- max_think_tokens = params .max_think_tokens if isinstance (params , SamplingParams ) else None
531
+ for (index , params , prompt_tok_ids ,
532
+ output_tok_ids ) in batch_update .added :
533
+ max_think_tokens = (params .max_think_tokens if isinstance (
534
+ params , SamplingParams ) else None )
530
535
if max_think_tokens is not None :
531
- last_start = self ._find_last_token_index (prompt_tok_ids , self .think_start_token_id )
532
- last_end = self ._find_last_token_index (prompt_tok_ids , self .think_end_token_id )
536
+ last_start = self ._find_last_token_index (
537
+ prompt_tok_ids , self .think_start_token_id )
538
+ last_end = self ._find_last_token_index (
539
+ prompt_tok_ids , self .think_end_token_id )
533
540
in_think = last_start > last_end
534
- count = len (prompt_tok_ids ) - (last_start + 1 ) if in_think else 0
541
+ count = len (prompt_tok_ids ) - (last_start +
542
+ 1 ) if in_think else 0
535
543
536
544
self ._state [index ] = {
537
545
"in_think" : in_think ,
@@ -542,13 +550,14 @@ def update_state(self, batch_update: Optional[BatchUpdate]):
542
550
}
543
551
544
552
for index in batch_update .removed :
545
- self ._state .pop (index , None )
553
+ self ._state .pop (index , {} )
546
554
547
555
for i1 , i2 , direction in batch_update .moved :
548
556
if direction == MoveDirectionality .SWAP :
549
- self ._state [i1 ], self ._state [i2 ] = self ._state [i2 ], self ._state [i1 ]
557
+ self ._state [i1 ], self ._state [i2 ] = self ._state [
558
+ i2 ], self ._state [i1 ]
550
559
else :
551
- self ._state [i2 ] = self ._state .pop (i1 , None )
560
+ self ._state [i2 ] = self ._state .pop (i1 , {} )
552
561
553
562
# Update in_think and count for all active requests
554
563
for state in self ._state .values ():
@@ -579,7 +588,8 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor:
579
588
if not state :
580
589
continue
581
590
582
- if state ["in_think" ] and state ["count" ] >= state ["max_think_tokens" ]:
591
+ if state ["in_think" ] and state ["count" ] >= state [
592
+ "max_think_tokens" ]:
583
593
mask [index ] = True
584
594
585
595
if mask .any ():
@@ -589,8 +599,9 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor:
589
599
return logits
590
600
591
601
592
- def init_builtin_logitsprocs (pin_memory_available : bool , max_num_reqs : int ,
593
- device : torch .device , reasoning_config : ReasoningConfig ) -> LogitsProcessorManager :
602
+ def init_builtin_logitsprocs (
603
+ pin_memory_available : bool , max_num_reqs : int , device : torch .device ,
604
+ reasoning_config : ReasoningConfig ) -> LogitsProcessorManager :
594
605
"""Construct 'builtin' vLLM logitsprocs which the engine
595
606
loads by default.
596
607
@@ -619,8 +630,7 @@ def init_builtin_logitsprocs(pin_memory_available: bool, max_num_reqs: int,
619
630
)
620
631
return LogitsProcessorManager (
621
632
non_argmax_invariant = [
622
- min_tokens_logitproc ,
623
- logit_bias_logitproc ,
633
+ min_tokens_logitproc , logit_bias_logitproc ,
624
634
max_think_tokens_logitproc
625
635
],
626
636
argmax_invariant = [min_p_logitproc ],
0 commit comments