Skip to content

Commit 366cc0c

Browse files
committed
update states only in update_state method
Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
1 parent 9fd0b63 commit 366cc0c

File tree

1 file changed

+43
-53
lines changed

1 file changed

+43
-53
lines changed

vllm/v1/sample/logits_processor.py

Lines changed: 43 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from dataclasses import dataclass, field
77
from enum import Enum
88
from itertools import chain
9-
from typing import Optional, Union
9+
from typing import Any, Optional, Union
1010

1111
import torch
1212
from torch._prims_common import DeviceLikeType
@@ -502,9 +502,9 @@ def __init__(self, reasoning_config: ReasoningConfig, pin_memory: bool, device:
502502
self.think_end_token_id = reasoning_config.think_end_token_id
503503
self.pin_memory = pin_memory
504504
self.device = device
505-
self._state = {}
505+
self._state: dict[int, dict[str, Any]] = {}
506506

507-
def _find_last_token_index(self, tokens, token_id):
507+
def _find_last_token_index(self, tokens: list[int], token_id: int) -> int:
508508
try:
509509
return len(tokens) - tokens[::-1].index(token_id) - 1
510510
except ValueError:
@@ -516,71 +516,61 @@ def is_argmax_invariant(self) -> bool:
516516
return False
517517

518518
def update_state(self, batch_update: Optional[BatchUpdate]):
519-
if batch_update is None:
520-
return
521-
522-
for index, params, prompt_tok_ids, output_tok_ids in batch_update.added:
523-
max_think_tokens = params.max_think_tokens if isinstance(params, SamplingParams) else None
524-
525-
if max_think_tokens is None:
526-
continue
527-
528-
last_think_start_idx = self._find_last_token_index(prompt_tok_ids, self.think_start_token_id)
529-
last_think_end_idx = self._find_last_token_index(prompt_tok_ids, self.think_end_token_id)
530-
531-
in_think = False
532-
count = 0
519+
if batch_update:
520+
for index, params, prompt_tok_ids, output_tok_ids in batch_update.added:
521+
max_think_tokens = params.max_think_tokens if isinstance(params, SamplingParams) else None
522+
if max_think_tokens is not None:
523+
last_start = self._find_last_token_index(prompt_tok_ids, self.think_start_token_id)
524+
last_end = self._find_last_token_index(prompt_tok_ids, self.think_end_token_id)
525+
in_think = last_start > last_end
526+
count = len(prompt_tok_ids) - (last_start + 1) if in_think else 0
527+
528+
self._state[index] = {
529+
"in_think": in_think,
530+
"count": count,
531+
"prompt_tok_ids": prompt_tok_ids,
532+
"output_tok_ids": output_tok_ids,
533+
"max_think_tokens": max_think_tokens,
534+
}
533535

534-
if last_think_start_idx > last_think_end_idx:
535-
in_think = True
536-
count = len(prompt_tok_ids) - (last_think_start_idx + 1)
536+
for index in batch_update.removed:
537+
self._state.pop(index, None)
537538

538-
self._state[index] = {
539-
"in_think": in_think,
540-
"count": count,
541-
"prompt_tok_ids": prompt_tok_ids,
542-
"output_tok_ids": output_tok_ids,
543-
"max_think_tokens": max_think_tokens,
544-
}
539+
for i1, i2, direction in batch_update.moved:
540+
if direction == MoveDirectionality.SWAP:
541+
self._state[i1], self._state[i2] = self._state[i2], self._state[i1]
542+
else:
543+
self._state[i2] = self._state.pop(i1, None)
545544

546-
for index in batch_update.removed:
547-
self._state.pop(index, None)
545+
# Update in_think and count for all active requests
546+
for state in self._state.values():
547+
output = state["output_tok_ids"]
548+
if not output:
549+
continue
548550

549-
for i1, i2, direction in batch_update.moved:
550-
if direction == MoveDirectionality.SWAP:
551-
self._state[i1], self._state[i2] = self._state[i2], self._state[i1]
552-
else:
553-
self._state[i2] = self._state.pop(i1, None)
551+
last_tok = output[-1]
552+
if last_tok == self.think_start_token_id:
553+
state["in_think"] = True
554+
state["count"] = 0
555+
elif last_tok == self.think_end_token_id:
556+
state["in_think"] = False
557+
state["count"] = 0
558+
elif state["in_think"]:
559+
state["count"] += 1
554560

555561
def apply(self, logits: torch.Tensor) -> torch.Tensor:
556562
batch_size = logits.size(0)
557-
if batch_size == 0:
563+
if not self._state:
558564
return logits
559565

560566
mask = torch.zeros(batch_size, dtype=torch.bool, device=logits.device)
561567
end_token_id = self.think_end_token_id
562568

563569
for index in range(batch_size):
564-
state = self._state.get(index, None)
565-
if not state or not state.get("output_tok_ids"):
570+
state = self._state.get(index)
571+
if not state:
566572
continue
567573

568-
last_tok = state["output_tok_ids"][-1]
569-
in_think = state["in_think"]
570-
count = state["count"]
571-
572-
if last_tok == self.think_start_token_id:
573-
in_think = True
574-
count = 0
575-
elif last_tok == self.think_end_token_id:
576-
in_think = False
577-
count = 0
578-
elif in_think:
579-
count += 1
580-
581-
state["in_think"] = in_think
582-
state["count"] = count
583-
584574
if state["in_think"] and state["count"] >= state["max_think_tokens"]:
585575
mask[index] = True
586576

0 commit comments

Comments
 (0)