@@ -526,10 +526,20 @@ def is_argmax_invariant(self) -> bool:
526
526
def update_state (self , batch_update : Optional [BatchUpdate ]):
527
527
if batch_update :
528
528
for index , params , prompt_tok_ids , output_tok_ids in batch_update .added :
529
- max_think_tokens = params .max_think_tokens if isinstance (params , SamplingParams ) else None
529
+ max_think_tokens = (
530
+ params .max_think_tokens
531
+ if isinstance (params , SamplingParams )
532
+ else None
533
+ )
530
534
if max_think_tokens is not None :
531
- last_start = self ._find_last_token_index (prompt_tok_ids , self .think_start_token_id )
532
- last_end = self ._find_last_token_index (prompt_tok_ids , self .think_end_token_id )
535
+ last_start = self ._find_last_token_index (
536
+ prompt_tok_ids ,
537
+ self .think_start_token_id
538
+ )
539
+ last_end = self ._find_last_token_index (
540
+ prompt_tok_ids ,
541
+ self .think_end_token_id
542
+ )
533
543
in_think = last_start > last_end
534
544
count = len (prompt_tok_ids ) - (last_start + 1 ) if in_think else 0
535
545
@@ -542,13 +552,13 @@ def update_state(self, batch_update: Optional[BatchUpdate]):
542
552
}
543
553
544
554
for index in batch_update .removed :
545
- self ._state .pop (index , None )
555
+ self ._state .pop (index , {} )
546
556
547
557
for i1 , i2 , direction in batch_update .moved :
548
558
if direction == MoveDirectionality .SWAP :
549
559
self ._state [i1 ], self ._state [i2 ] = self ._state [i2 ], self ._state [i1 ]
550
560
else :
551
- self ._state [i2 ] = self ._state .pop (i1 , None )
561
+ self ._state [i2 ] = self ._state .pop (i1 , {} )
552
562
553
563
# Update in_think and count for all active requests
554
564
for state in self ._state .values ():
@@ -579,7 +589,8 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor:
579
589
if not state :
580
590
continue
581
591
582
- if state ["in_think" ] and state ["count" ] >= state ["max_think_tokens" ]:
592
+ if state ["in_think" ] and state ["count" ] >= state [
593
+ "max_think_tokens" ]:
583
594
mask [index ] = True
584
595
585
596
if mask .any ():
@@ -589,8 +600,9 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor:
589
600
return logits
590
601
591
602
592
- def init_builtin_logitsprocs (pin_memory_available : bool , max_num_reqs : int ,
593
- device : torch .device , reasoning_config : ReasoningConfig ) -> LogitsProcessorManager :
603
+ def init_builtin_logitsprocs (
604
+ pin_memory_available : bool , max_num_reqs : int , device : torch .device ,
605
+ reasoning_config : ReasoningConfig ) -> LogitsProcessorManager :
594
606
"""Construct 'builtin' vLLM logitsprocs which the engine
595
607
loads by default.
596
608
@@ -619,8 +631,7 @@ def init_builtin_logitsprocs(pin_memory_available: bool, max_num_reqs: int,
619
631
)
620
632
return LogitsProcessorManager (
621
633
non_argmax_invariant = [
622
- min_tokens_logitproc ,
623
- logit_bias_logitproc ,
634
+ min_tokens_logitproc , logit_bias_logitproc ,
624
635
max_think_tokens_logitproc
625
636
],
626
637
argmax_invariant = [min_p_logitproc ],
0 commit comments