Skip to content

Commit 1b15df2

Browse files
njhillNickLucche
andauthored
[BugFix] Fix handling of num_computed_tokens with connector (vllm-project#18232)
Signed-off-by: Nick Hill <nhill@redhat.com> Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com>
1 parent 43b5f61 commit 1b15df2

File tree

2 files changed

+29
-16
lines changed

2 files changed

+29
-16
lines changed

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,17 @@ def get_num_new_matched_tokens(
209209
rounded_num_prompt_tokens = round_down(
210210
len(request.prompt_token_ids), self.block_size)
211211
count = max(rounded_num_prompt_tokens - num_computed_tokens, 0)
212-
return count, count > 0
212+
if count > 0:
213+
return count, True
214+
215+
# NOTE: if count is 0 here, we have less than block_size
216+
# tokens to pull after subtracting the local prefix cache hit.
217+
# The remote only sends fully computed blocks, so there is
218+
# nothing to transfer but we still need to notify the
219+
# prefill worker so that the remote blocks are freed.
220+
if all(p in params for p in ("remote_engine_id", "remote_host",
221+
"remote_port")):
222+
self._reqs_need_recv[request.request_id] = (request, [])
213223

214224
# No remote prefill for this request.
215225
return 0, False
@@ -225,10 +235,6 @@ def update_state_after_alloc(self, request: "Request",
225235
num_external_tokens, params)
226236

227237
if params is not None and params.get("do_remote_prefill"):
228-
# NOTE(rob): if prompt < block_size, no remote blocks
229-
# since the remote only sends fully computed blocks, so
230-
# skip recving for this request. num_external_tokens
231-
# should be 0 if there are no remote blocks.
232238
if params.get("remote_block_ids"):
233239
if all(p in params for p in ("remote_engine_id", "remote_host",
234240
"remote_port")):

vllm/v1/core/sched/scheduler.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -345,32 +345,38 @@ def schedule(self) -> SchedulerOutput:
345345
skipped_waiting_requests.appendleft(request)
346346
continue
347347

348+
num_external_computed_tokens = 0
349+
load_kv_async = False
350+
348351
# Get already-cached tokens.
349352
if num_prealloc_computed_tokens == 0:
350353
new_computed_blocks, num_native_computed_tokens = \
351354
self.kv_cache_manager.get_computed_blocks(
352355
request)
356+
357+
# Get externally-cached tokens if using a KVConnector.
358+
if self.connector is not None:
359+
num_external_computed_tokens, load_kv_async = (
360+
self.connector.get_num_new_matched_tokens(
361+
request, num_native_computed_tokens))
362+
363+
# Total computed tokens (local + external).
364+
num_computed_tokens = (num_native_computed_tokens +
365+
num_external_computed_tokens)
353366
else:
354367
# P/D: skip checking prefix cache if loaded from remote kvs.
355368
new_computed_blocks = KVCacheBlocks.create_empty()
356369
num_native_computed_tokens = 0
357370

358-
# Get externally-cached tokens if using a KVConnector.
359-
num_external_computed_tokens, load_kv_async = (
360-
(0, False) if self.connector is None else
361-
self.connector.get_num_new_matched_tokens(
362-
request, num_native_computed_tokens))
363-
364-
# Total computed tokens (local + external).
365-
num_computed_tokens = (num_native_computed_tokens +
366-
num_external_computed_tokens +
367-
num_prealloc_computed_tokens)
371+
# Total computed tokens (allocated in prior step).
372+
num_computed_tokens = num_prealloc_computed_tokens
368373

369374
encoder_inputs_to_schedule = None
370375
new_encoder_budget = encoder_budget
371376

372377
# P/D: loading remote KV, do not allocate for new work.
373378
if load_kv_async:
379+
assert num_external_computed_tokens > 0
374380
num_new_tokens = 0
375381
# Number of tokens to be scheduled.
376382
else:
@@ -411,7 +417,8 @@ def schedule(self) -> SchedulerOutput:
411417
# KVConnector: update internal state after allocation.
412418
# This information is used to determine if a load is
413419
# needed for this request.
414-
if self.connector is not None:
420+
if num_external_computed_tokens:
421+
assert self.connector is not None
415422
self.connector.update_state_after_alloc(
416423
request,
417424
new_computed_blocks + new_blocks,

0 commit comments

Comments
 (0)