6
6
from dataclasses import dataclass , field
7
7
from enum import Enum
8
8
from itertools import chain
9
- from typing import Optional , Union
9
+ from typing import Any , Optional , Union
10
10
11
11
import torch
12
12
from torch ._prims_common import DeviceLikeType
@@ -502,9 +502,9 @@ def __init__(self, reasoning_config: ReasoningConfig, pin_memory: bool, device:
502
502
self .think_end_token_id = reasoning_config .think_end_token_id
503
503
self .pin_memory = pin_memory
504
504
self .device = device
505
- self ._state = {}
505
+ self ._state : dict [ int , dict [ str , Any ]] = {}
506
506
507
- def _find_last_token_index (self , tokens , token_id ) :
507
+ def _find_last_token_index (self , tokens : list [ int ] , token_id : int ) -> int :
508
508
try :
509
509
return len (tokens ) - tokens [::- 1 ].index (token_id ) - 1
510
510
except ValueError :
@@ -516,71 +516,61 @@ def is_argmax_invariant(self) -> bool:
516
516
return False
517
517
518
518
def update_state (self , batch_update : Optional [BatchUpdate ]):
519
- if batch_update is None :
520
- return
521
-
522
- for index , params , prompt_tok_ids , output_tok_ids in batch_update .added :
523
- max_think_tokens = params .max_think_tokens if isinstance (params , SamplingParams ) else None
524
-
525
- if max_think_tokens is None :
526
- continue
527
-
528
- last_think_start_idx = self ._find_last_token_index (prompt_tok_ids , self .think_start_token_id )
529
- last_think_end_idx = self ._find_last_token_index (prompt_tok_ids , self .think_end_token_id )
530
-
531
- in_think = False
532
- count = 0
519
+ if batch_update :
520
+ for index , params , prompt_tok_ids , output_tok_ids in batch_update .added :
521
+ max_think_tokens = params .max_think_tokens if isinstance (params , SamplingParams ) else None
522
+ if max_think_tokens is not None :
523
+ last_start = self ._find_last_token_index (prompt_tok_ids , self .think_start_token_id )
524
+ last_end = self ._find_last_token_index (prompt_tok_ids , self .think_end_token_id )
525
+ in_think = last_start > last_end
526
+ count = len (prompt_tok_ids ) - (last_start + 1 ) if in_think else 0
527
+
528
+ self ._state [index ] = {
529
+ "in_think" : in_think ,
530
+ "count" : count ,
531
+ "prompt_tok_ids" : prompt_tok_ids ,
532
+ "output_tok_ids" : output_tok_ids ,
533
+ "max_think_tokens" : max_think_tokens ,
534
+ }
533
535
534
- if last_think_start_idx > last_think_end_idx :
535
- in_think = True
536
- count = len (prompt_tok_ids ) - (last_think_start_idx + 1 )
536
+ for index in batch_update .removed :
537
+ self ._state .pop (index , None )
537
538
538
- self ._state [index ] = {
539
- "in_think" : in_think ,
540
- "count" : count ,
541
- "prompt_tok_ids" : prompt_tok_ids ,
542
- "output_tok_ids" : output_tok_ids ,
543
- "max_think_tokens" : max_think_tokens ,
544
- }
539
+ for i1 , i2 , direction in batch_update .moved :
540
+ if direction == MoveDirectionality .SWAP :
541
+ self ._state [i1 ], self ._state [i2 ] = self ._state [i2 ], self ._state [i1 ]
542
+ else :
543
+ self ._state [i2 ] = self ._state .pop (i1 , None )
545
544
546
- for index in batch_update .removed :
547
- self ._state .pop (index , None )
545
+ # Update in_think and count for all active requests
546
+ for state in self ._state .values ():
547
+ output = state ["output_tok_ids" ]
548
+ if not output :
549
+ continue
548
550
549
- for i1 , i2 , direction in batch_update .moved :
550
- if direction == MoveDirectionality .SWAP :
551
- self ._state [i1 ], self ._state [i2 ] = self ._state [i2 ], self ._state [i1 ]
552
- else :
553
- self ._state [i2 ] = self ._state .pop (i1 , None )
551
+ last_tok = output [- 1 ]
552
+ if last_tok == self .think_start_token_id :
553
+ state ["in_think" ] = True
554
+ state ["count" ] = 0
555
+ elif last_tok == self .think_end_token_id :
556
+ state ["in_think" ] = False
557
+ state ["count" ] = 0
558
+ elif state ["in_think" ]:
559
+ state ["count" ] += 1
554
560
555
561
def apply (self , logits : torch .Tensor ) -> torch .Tensor :
556
562
batch_size = logits .size (0 )
557
- if batch_size == 0 :
563
+ if not self . _state :
558
564
return logits
559
565
560
566
mask = torch .zeros (batch_size , dtype = torch .bool , device = logits .device )
561
567
end_token_id = self .think_end_token_id
562
568
563
569
for index in range (batch_size ):
564
- state = self ._state .get (index , None )
565
- if not state or not state . get ( "output_tok_ids" ) :
570
+ state = self ._state .get (index )
571
+ if not state :
566
572
continue
567
573
568
- last_tok = state ["output_tok_ids" ][- 1 ]
569
- in_think = state ["in_think" ]
570
- count = state ["count" ]
571
-
572
- if last_tok == self .think_start_token_id :
573
- in_think = True
574
- count = 0
575
- elif last_tok == self .think_end_token_id :
576
- in_think = False
577
- count = 0
578
- elif in_think :
579
- count += 1
580
-
581
- state ["in_think" ] = in_think
582
- state ["count" ] = count
583
-
584
574
if state ["in_think" ] and state ["count" ] >= state ["max_think_tokens" ]:
585
575
mask [index ] = True
586
576
0 commit comments