6
6
from dataclasses import dataclass , field
7
7
from enum import Enum
8
8
from itertools import chain
9
- from typing import Optional , Union
9
+ from typing import Any , Optional , Union
10
10
11
11
import torch
12
12
from torch ._prims_common import DeviceLikeType
@@ -510,9 +510,9 @@ def __init__(self, reasoning_config: ReasoningConfig, pin_memory: bool, device:
510
510
self .think_end_token_id = reasoning_config .think_end_token_id
511
511
self .pin_memory = pin_memory
512
512
self .device = device
513
- self ._state = {}
513
+ self ._state : dict [ int , dict [ str , Any ]] = {}
514
514
515
- def _find_last_token_index (self , tokens , token_id ) :
515
+ def _find_last_token_index (self , tokens : list [ int ] , token_id : int ) -> int :
516
516
try :
517
517
return len (tokens ) - tokens [::- 1 ].index (token_id ) - 1
518
518
except ValueError :
@@ -524,71 +524,61 @@ def is_argmax_invariant(self) -> bool:
524
524
return False
525
525
526
526
def update_state (self , batch_update : Optional [BatchUpdate ]):
527
- if batch_update is None :
528
- return
529
-
530
- for index , params , prompt_tok_ids , output_tok_ids in batch_update .added :
531
- max_think_tokens = params .max_think_tokens if isinstance (params , SamplingParams ) else None
532
-
533
- if max_think_tokens is None :
534
- continue
535
-
536
- last_think_start_idx = self ._find_last_token_index (prompt_tok_ids , self .think_start_token_id )
537
- last_think_end_idx = self ._find_last_token_index (prompt_tok_ids , self .think_end_token_id )
538
-
539
- in_think = False
540
- count = 0
527
+ if batch_update :
528
+ for index , params , prompt_tok_ids , output_tok_ids in batch_update .added :
529
+ max_think_tokens = params .max_think_tokens if isinstance (params , SamplingParams ) else None
530
+ if max_think_tokens is not None :
531
+ last_start = self ._find_last_token_index (prompt_tok_ids , self .think_start_token_id )
532
+ last_end = self ._find_last_token_index (prompt_tok_ids , self .think_end_token_id )
533
+ in_think = last_start > last_end
534
+ count = len (prompt_tok_ids ) - (last_start + 1 ) if in_think else 0
535
+
536
+ self ._state [index ] = {
537
+ "in_think" : in_think ,
538
+ "count" : count ,
539
+ "prompt_tok_ids" : prompt_tok_ids ,
540
+ "output_tok_ids" : output_tok_ids ,
541
+ "max_think_tokens" : max_think_tokens ,
542
+ }
541
543
542
- if last_think_start_idx > last_think_end_idx :
543
- in_think = True
544
- count = len (prompt_tok_ids ) - (last_think_start_idx + 1 )
544
+ for index in batch_update .removed :
545
+ self ._state .pop (index , None )
545
546
546
- self ._state [index ] = {
547
- "in_think" : in_think ,
548
- "count" : count ,
549
- "prompt_tok_ids" : prompt_tok_ids ,
550
- "output_tok_ids" : output_tok_ids ,
551
- "max_think_tokens" : max_think_tokens ,
552
- }
547
+ for i1 , i2 , direction in batch_update .moved :
548
+ if direction == MoveDirectionality .SWAP :
549
+ self ._state [i1 ], self ._state [i2 ] = self ._state [i2 ], self ._state [i1 ]
550
+ else :
551
+ self ._state [i2 ] = self ._state .pop (i1 , None )
553
552
554
- for index in batch_update .removed :
555
- self ._state .pop (index , None )
553
+ # Update in_think and count for all active requests
554
+ for state in self ._state .values ():
555
+ output = state ["output_tok_ids" ]
556
+ if not output :
557
+ continue
556
558
557
- for i1 , i2 , direction in batch_update .moved :
558
- if direction == MoveDirectionality .SWAP :
559
- self ._state [i1 ], self ._state [i2 ] = self ._state [i2 ], self ._state [i1 ]
560
- else :
561
- self ._state [i2 ] = self ._state .pop (i1 , None )
559
+ last_tok = output [- 1 ]
560
+ if last_tok == self .think_start_token_id :
561
+ state ["in_think" ] = True
562
+ state ["count" ] = 0
563
+ elif last_tok == self .think_end_token_id :
564
+ state ["in_think" ] = False
565
+ state ["count" ] = 0
566
+ elif state ["in_think" ]:
567
+ state ["count" ] += 1
562
568
563
569
def apply (self , logits : torch .Tensor ) -> torch .Tensor :
564
570
batch_size = logits .size (0 )
565
- if batch_size == 0 :
571
+ if not self . _state :
566
572
return logits
567
573
568
574
mask = torch .zeros (batch_size , dtype = torch .bool , device = logits .device )
569
575
end_token_id = self .think_end_token_id
570
576
571
577
for index in range (batch_size ):
572
- state = self ._state .get (index , None )
573
- if not state or not state . get ( "output_tok_ids" ) :
578
+ state = self ._state .get (index )
579
+ if not state :
574
580
continue
575
581
576
- last_tok = state ["output_tok_ids" ][- 1 ]
577
- in_think = state ["in_think" ]
578
- count = state ["count" ]
579
-
580
- if last_tok == self .think_start_token_id :
581
- in_think = True
582
- count = 0
583
- elif last_tok == self .think_end_token_id :
584
- in_think = False
585
- count = 0
586
- elif in_think :
587
- count += 1
588
-
589
- state ["in_think" ] = in_think
590
- state ["count" ] = count
591
-
592
582
if state ["in_think" ] and state ["count" ] >= state ["max_think_tokens" ]:
593
583
mask [index ] = True
594
584
0 commit comments