Skip to content

Commit 7c12a76

Browse files
authored
[Misc] Simplify the prefix caching logic on draft tokens (#20701)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
1 parent cd587c9 commit 7c12a76

File tree

2 files changed

+10
-11
lines changed

2 files changed

+10
-11
lines changed

vllm/v1/core/kv_cache_manager.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,6 @@ def allocate_slots(
190190
num_new_tokens: int,
191191
num_new_computed_tokens: int = 0,
192192
new_computed_blocks: Optional[KVCacheBlocks] = None,
193-
num_draft_tokens: int = 0,
194193
num_lookahead_tokens: int = 0,
195194
delay_cache_blocks: bool = False,
196195
) -> Optional[KVCacheBlocks]:
@@ -286,12 +285,17 @@ def allocate_slots(
286285
if not self.enable_caching or delay_cache_blocks:
287286
return KVCacheBlocks(new_blocks)
288287

289-
# Speculated tokens might be rejected in the future, so we does
290-
# not cache any speculated tokens. We only cache blocks with
291-
# generated (accepted) tokens.
288+
# NOTE(woosuk): We want to commit (cache) up to num_computed_tokens +
289+
# num_new_tokens, but must exclude "non-committable" tokens (e.g.,
290+
# draft tokens that could be rejected). Therefore, we cap the number
291+
# at `request.num_tokens`, ensuring only "finalized" tokens are cached.
292+
num_tokens_to_cache = min(num_computed_tokens + num_new_tokens,
293+
request.num_tokens)
292294
self.coordinator.cache_blocks(
293-
request, self.req_to_block_hashes[request.request_id],
294-
num_computed_tokens + num_new_tokens - num_draft_tokens)
295+
request,
296+
self.req_to_block_hashes[request.request_id],
297+
num_tokens_to_cache,
298+
)
295299

296300
return KVCacheBlocks(new_blocks)
297301

vllm/v1/core/sched/scheduler.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -241,15 +241,10 @@ def schedule(self) -> SchedulerOutput:
241241
req_index += 1
242242
continue
243243

244-
num_draft_tokens = max(
245-
num_new_tokens + request.num_computed_tokens -
246-
request.num_tokens, 0)
247-
248244
while True:
249245
new_blocks = self.kv_cache_manager.allocate_slots(
250246
request,
251247
num_new_tokens,
252-
num_draft_tokens=num_draft_tokens,
253248
num_lookahead_tokens=self.num_lookahead_tokens)
254249
if new_blocks is None:
255250
# The request cannot be scheduled.

0 commit comments

Comments
 (0)