@@ -518,10 +518,20 @@ def is_argmax_invariant(self) -> bool:
518
518
def update_state (self , batch_update : Optional [BatchUpdate ]):
519
519
if batch_update :
520
520
for index , params , prompt_tok_ids , output_tok_ids in batch_update .added :
521
- max_think_tokens = params .max_think_tokens if isinstance (params , SamplingParams ) else None
521
+ max_think_tokens = (
522
+ params .max_think_tokens
523
+ if isinstance (params , SamplingParams )
524
+ else None
525
+ )
522
526
if max_think_tokens is not None :
523
- last_start = self ._find_last_token_index (prompt_tok_ids , self .think_start_token_id )
524
- last_end = self ._find_last_token_index (prompt_tok_ids , self .think_end_token_id )
527
+ last_start = self ._find_last_token_index (
528
+ prompt_tok_ids ,
529
+ self .think_start_token_id
530
+ )
531
+ last_end = self ._find_last_token_index (
532
+ prompt_tok_ids ,
533
+ self .think_end_token_id
534
+ )
525
535
in_think = last_start > last_end
526
536
count = len (prompt_tok_ids ) - (last_start + 1 ) if in_think else 0
527
537
@@ -534,13 +544,13 @@ def update_state(self, batch_update: Optional[BatchUpdate]):
534
544
}
535
545
536
546
for index in batch_update .removed :
537
- self ._state .pop (index , None )
547
+ self ._state .pop (index , {} )
538
548
539
549
for i1 , i2 , direction in batch_update .moved :
540
550
if direction == MoveDirectionality .SWAP :
541
551
self ._state [i1 ], self ._state [i2 ] = self ._state [i2 ], self ._state [i1 ]
542
552
else :
543
- self ._state [i2 ] = self ._state .pop (i1 , None )
553
+ self ._state [i2 ] = self ._state .pop (i1 , {} )
544
554
545
555
# Update in_think and count for all active requests
546
556
for state in self ._state .values ():
0 commit comments