@@ -518,10 +518,20 @@ def is_argmax_invariant(self) -> bool:
518
518
def update_state (self , batch_update : Optional [BatchUpdate ]):
519
519
if batch_update :
520
520
for index , params , prompt_tok_ids , output_tok_ids in batch_update .added :
521
- max_think_tokens = params .max_think_tokens if isinstance (params , SamplingParams ) else None
521
+ max_think_tokens = (
522
+ params .max_think_tokens
523
+ if isinstance (params , SamplingParams )
524
+ else None
525
+ )
522
526
if max_think_tokens is not None :
523
- last_start = self ._find_last_token_index (prompt_tok_ids , self .think_start_token_id )
524
- last_end = self ._find_last_token_index (prompt_tok_ids , self .think_end_token_id )
527
+ last_start = self ._find_last_token_index (
528
+ prompt_tok_ids ,
529
+ self .think_start_token_id
530
+ )
531
+ last_end = self ._find_last_token_index (
532
+ prompt_tok_ids ,
533
+ self .think_end_token_id
534
+ )
525
535
in_think = last_start > last_end
526
536
count = len (prompt_tok_ids ) - (last_start + 1 ) if in_think else 0
527
537
@@ -534,13 +544,13 @@ def update_state(self, batch_update: Optional[BatchUpdate]):
534
544
}
535
545
536
546
for index in batch_update .removed :
537
- self ._state .pop (index , None )
547
+ self ._state .pop (index , {} )
538
548
539
549
for i1 , i2 , direction in batch_update .moved :
540
550
if direction == MoveDirectionality .SWAP :
541
551
self ._state [i1 ], self ._state [i2 ] = self ._state [i2 ], self ._state [i1 ]
542
552
else :
543
- self ._state [i2 ] = self ._state .pop (i1 , None )
553
+ self ._state [i2 ] = self ._state .pop (i1 , {} )
544
554
545
555
# Update in_think and count for all active requests
546
556
for state in self ._state .values ():
@@ -571,7 +581,8 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor:
571
581
if not state :
572
582
continue
573
583
574
- if state ["in_think" ] and state ["count" ] >= state ["max_think_tokens" ]:
584
+ if state ["in_think" ] and state ["count" ] >= state [
585
+ "max_think_tokens" ]:
575
586
mask [index ] = True
576
587
577
588
if mask .any ():
@@ -581,8 +592,9 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor:
581
592
return logits
582
593
583
594
584
- def init_builtin_logitsprocs (pin_memory_available : bool , max_num_reqs : int ,
585
- device : torch .device , reasoning_config : ReasoningConfig ) -> LogitsProcessorManager :
595
+ def init_builtin_logitsprocs (
596
+ pin_memory_available : bool , max_num_reqs : int , device : torch .device ,
597
+ reasoning_config : ReasoningConfig ) -> LogitsProcessorManager :
586
598
"""Construct 'builtin' vLLM logitsprocs which the engine
587
599
loads by default.
588
600
@@ -611,8 +623,7 @@ def init_builtin_logitsprocs(pin_memory_available: bool, max_num_reqs: int,
611
623
)
612
624
return LogitsProcessorManager (
613
625
non_argmax_invariant = [
614
- min_tokens_logitproc ,
615
- logit_bias_logitproc ,
626
+ min_tokens_logitproc , logit_bias_logitproc ,
616
627
max_think_tokens_logitproc
617
628
],
618
629
argmax_invariant = [min_p_logitproc ],
0 commit comments