Skip to content

Commit f2e195a

Browse files
committed
update states only in update_state method
Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
1 parent 84aee5b commit f2e195a

File tree

1 file changed

+43
-53
lines changed

1 file changed

+43
-53
lines changed

vllm/v1/sample/logits_processor.py

Lines changed: 43 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from dataclasses import dataclass, field
77
from enum import Enum
88
from itertools import chain
9-
from typing import Optional, Union
9+
from typing import Any, Optional, Union
1010

1111
import torch
1212
from torch._prims_common import DeviceLikeType
@@ -510,9 +510,9 @@ def __init__(self, reasoning_config: ReasoningConfig, pin_memory: bool, device:
510510
self.think_end_token_id = reasoning_config.think_end_token_id
511511
self.pin_memory = pin_memory
512512
self.device = device
513-
self._state = {}
513+
self._state: dict[int, dict[str, Any]] = {}
514514

515-
def _find_last_token_index(self, tokens, token_id):
515+
def _find_last_token_index(self, tokens: list[int], token_id: int) -> int:
516516
try:
517517
return len(tokens) - tokens[::-1].index(token_id) - 1
518518
except ValueError:
@@ -524,71 +524,61 @@ def is_argmax_invariant(self) -> bool:
524524
return False
525525

526526
def update_state(self, batch_update: Optional[BatchUpdate]):
527-
if batch_update is None:
528-
return
529-
530-
for index, params, prompt_tok_ids, output_tok_ids in batch_update.added:
531-
max_think_tokens = params.max_think_tokens if isinstance(params, SamplingParams) else None
532-
533-
if max_think_tokens is None:
534-
continue
535-
536-
last_think_start_idx = self._find_last_token_index(prompt_tok_ids, self.think_start_token_id)
537-
last_think_end_idx = self._find_last_token_index(prompt_tok_ids, self.think_end_token_id)
538-
539-
in_think = False
540-
count = 0
527+
if batch_update:
528+
for index, params, prompt_tok_ids, output_tok_ids in batch_update.added:
529+
max_think_tokens = params.max_think_tokens if isinstance(params, SamplingParams) else None
530+
if max_think_tokens is not None:
531+
last_start = self._find_last_token_index(prompt_tok_ids, self.think_start_token_id)
532+
last_end = self._find_last_token_index(prompt_tok_ids, self.think_end_token_id)
533+
in_think = last_start > last_end
534+
count = len(prompt_tok_ids) - (last_start + 1) if in_think else 0
535+
536+
self._state[index] = {
537+
"in_think": in_think,
538+
"count": count,
539+
"prompt_tok_ids": prompt_tok_ids,
540+
"output_tok_ids": output_tok_ids,
541+
"max_think_tokens": max_think_tokens,
542+
}
541543

542-
if last_think_start_idx > last_think_end_idx:
543-
in_think = True
544-
count = len(prompt_tok_ids) - (last_think_start_idx + 1)
544+
for index in batch_update.removed:
545+
self._state.pop(index, None)
545546

546-
self._state[index] = {
547-
"in_think": in_think,
548-
"count": count,
549-
"prompt_tok_ids": prompt_tok_ids,
550-
"output_tok_ids": output_tok_ids,
551-
"max_think_tokens": max_think_tokens,
552-
}
547+
for i1, i2, direction in batch_update.moved:
548+
if direction == MoveDirectionality.SWAP:
549+
self._state[i1], self._state[i2] = self._state[i2], self._state[i1]
550+
else:
551+
self._state[i2] = self._state.pop(i1, None)
553552

554-
for index in batch_update.removed:
555-
self._state.pop(index, None)
553+
# Update in_think and count for all active requests
554+
for state in self._state.values():
555+
output = state["output_tok_ids"]
556+
if not output:
557+
continue
556558

557-
for i1, i2, direction in batch_update.moved:
558-
if direction == MoveDirectionality.SWAP:
559-
self._state[i1], self._state[i2] = self._state[i2], self._state[i1]
560-
else:
561-
self._state[i2] = self._state.pop(i1, None)
559+
last_tok = output[-1]
560+
if last_tok == self.think_start_token_id:
561+
state["in_think"] = True
562+
state["count"] = 0
563+
elif last_tok == self.think_end_token_id:
564+
state["in_think"] = False
565+
state["count"] = 0
566+
elif state["in_think"]:
567+
state["count"] += 1
562568

563569
def apply(self, logits: torch.Tensor) -> torch.Tensor:
564570
batch_size = logits.size(0)
565-
if batch_size == 0:
571+
if not self._state:
566572
return logits
567573

568574
mask = torch.zeros(batch_size, dtype=torch.bool, device=logits.device)
569575
end_token_id = self.think_end_token_id
570576

571577
for index in range(batch_size):
572-
state = self._state.get(index, None)
573-
if not state or not state.get("output_tok_ids"):
578+
state = self._state.get(index)
579+
if not state:
574580
continue
575581

576-
last_tok = state["output_tok_ids"][-1]
577-
in_think = state["in_think"]
578-
count = state["count"]
579-
580-
if last_tok == self.think_start_token_id:
581-
in_think = True
582-
count = 0
583-
elif last_tok == self.think_end_token_id:
584-
in_think = False
585-
count = 0
586-
elif in_think:
587-
count += 1
588-
589-
state["in_think"] = in_think
590-
state["count"] = count
591-
592582
if state["in_think"] and state["count"] >= state["max_think_tokens"]:
593583
mask[index] = True
594584

0 commit comments

Comments
 (0)