@@ -513,10 +513,10 @@ def __init__(self, reasoning_config: ReasoningConfig, pin_memory: bool,
513
513
"medium" : 2048 ,
514
514
"high" : 8192 ,
515
515
}
516
- self .think_start_token_ids = getattr (
517
- reasoning_config , "think_start_token_ids" , [])
518
- self .think_end_token_ids = getattr (
519
- reasoning_config , "think_end_token_ids" , [])
516
+ self .think_start_token_ids = getattr (reasoning_config ,
517
+ "think_start_token_ids" , [])
518
+ self .think_end_token_ids = getattr (reasoning_config ,
519
+ "think_end_token_ids" , [])
520
520
self .reasoning_effort_to_token_budget ['low' ] = getattr (
521
521
reasoning_config , "low_effort_token_budget" ,
522
522
self .reasoning_effort_to_token_budget ['low' ])
@@ -532,8 +532,8 @@ def __init__(self, reasoning_config: ReasoningConfig, pin_memory: bool,
532
532
self ._state : dict [int , dict [str , Any ]] = {}
533
533
534
534
@staticmethod
535
- def _find_last_sequence_index (
536
- target_list : list [ int ], token_ids : list [int ]) -> int :
535
+ def _find_last_sequence_index (target_list : list [ int ],
536
+ token_ids : list [int ]) -> int :
537
537
"""
538
538
Returns the index of the last occurrence of token_ids in target_list.
539
539
@@ -550,8 +550,8 @@ def _find_last_sequence_index(
550
550
return - 1
551
551
552
552
def _resolve_thinking_token_budget (
553
- self , reasoning_effort : Optional [str ],
554
- thinking_token_budget : Optional [int ]) -> int :
553
+ self , reasoning_effort : Optional [str ],
554
+ thinking_token_budget : Optional [int ]) -> int :
555
555
"""
556
556
Determines the final thinking token budget.
557
557
Priority:
@@ -562,30 +562,30 @@ def _resolve_thinking_token_budget(
562
562
return thinking_token_budget
563
563
564
564
if reasoning_effort is not None :
565
- budget = self .reasoning_effort_to_token_budget .get (reasoning_effort )
565
+ budget = self .reasoning_effort_to_token_budget .get (
566
+ reasoning_effort )
566
567
if budget is None :
567
568
raise ValueError (
568
569
f"Unknown reasoning_effort: { reasoning_effort } " )
569
570
return budget
570
571
571
572
return None
572
573
573
- def _init_state_entry (
574
- self , prompt_tok_ids : list [int ],
575
- thinking_token_budget : int ) -> dict [str , Any ]:
574
+ def _init_state_entry (self , prompt_tok_ids : list [int ],
575
+ thinking_token_budget : int ) -> dict [str , Any ]:
576
576
"""Initializes the tracking state for a given sequence index."""
577
- last_start = self ._find_last_sequence_index (
578
- prompt_tok_ids , self .think_start_token_ids )
579
- last_end = self ._find_last_sequence_index (
580
- prompt_tok_ids , self .think_end_token_ids )
577
+ last_start = self ._find_last_sequence_index (prompt_tok_ids ,
578
+ self .think_start_token_ids )
579
+ last_end = self ._find_last_sequence_index (prompt_tok_ids ,
580
+ self .think_end_token_ids )
581
581
in_think = last_start > last_end
582
582
think_count = len (prompt_tok_ids ) - (last_start + 1 ) if in_think else 0
583
583
584
584
return {
585
- "in_think" : in_think , # Currently in thinking mode
586
- "in_end" : False , # Currently forcing end tokens
587
- "think_count" : think_count , # Number of tokens in thinking section
588
- "end_count" : 0 , # Number of end tokens forced so far
585
+ "in_think" : in_think , # Currently in thinking mode
586
+ "in_end" : False , # Currently forcing end tokens
587
+ "think_count" : think_count , # Number of tokens in thinking section
588
+ "end_count" : 0 , # Number of end tokens forced so far
589
589
"prompt_tok_ids" : prompt_tok_ids ,
590
590
"output_tok_ids" : [],
591
591
"thinking_token_budget" : thinking_token_budget ,
@@ -635,8 +635,8 @@ def update_state(self, batch_update: Optional[BatchUpdate]):
635
635
reasoning_effort = (params .reasoning_effort if isinstance (
636
636
params , SamplingParams ) else None )
637
637
thinking_token_budget = (params .thinking_token_budget
638
- if isinstance (
639
- params , SamplingParams ) else None )
638
+ if isinstance (params , SamplingParams )
639
+ else None )
640
640
resolved_thinking_token_budget = \
641
641
self ._resolve_thinking_token_budget (
642
642
reasoning_effort , thinking_token_budget )
@@ -664,8 +664,10 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor:
664
664
return logits
665
665
666
666
mask = torch .zeros (batch_size , dtype = torch .bool , device = logits .device )
667
- force_token_ids = torch .full ((batch_size ,), - 1 ,
668
- dtype = torch .long , device = logits .device )
667
+ force_token_ids = torch .full ((batch_size ,),
668
+ - 1 ,
669
+ dtype = torch .long ,
670
+ device = logits .device )
669
671
670
672
for i in range (batch_size ):
671
673
state = self ._state .get (i )
0 commit comments