Skip to content

Commit 71d1d75

Browse files
authored
[PD][Nixl] Remote consumer READ timeout for clearing request blocks (#20139)
Signed-off-by: NickLucche <nlucches@redhat.com>
1 parent 72d14d0 commit 71d1d75

File tree

3 files changed

+115
-10
lines changed

3 files changed

+115
-10
lines changed

tests/v1/kv_connector/unit/test_nixl_connector.py

Lines changed: 74 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,13 @@
99

1010
import pytest
1111

12+
from vllm import LLM
13+
from vllm.config import KVTransferConfig
1214
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
1315
KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata,
1416
NixlConnectorWorker)
1517
from vllm.forward_context import ForwardContext
18+
from vllm.sampling_params import SamplingParams
1619

1720
from .utils import create_request, create_scheduler, create_vllm_config
1821

@@ -41,9 +44,9 @@ def test_basic_interface():
4144
assert kv_connector_metadata is not None
4245
assert isinstance(kv_connector_metadata, NixlConnectorMetadata)
4346

44-
assert len(kv_connector_metadata.requests) == 1
45-
assert request_id in kv_connector_metadata.requests
46-
req_meta = kv_connector_metadata.requests[request_id]
47+
assert len(kv_connector_metadata.reqs_to_recv) == 1
48+
assert request_id in kv_connector_metadata.reqs_to_recv
49+
req_meta = kv_connector_metadata.reqs_to_recv[request_id]
4750

4851
for block_id, block in zip(
4952
req_meta.local_block_ids, scheduler.kv_cache_manager.coordinator.
@@ -78,7 +81,7 @@ def test_prompt_less_than_block_size():
7881
kv_connector_metadata = scheduler_output.kv_connector_metadata
7982
assert kv_connector_metadata is not None
8083
assert isinstance(kv_connector_metadata, NixlConnectorMetadata)
81-
assert len(kv_connector_metadata.requests) == 0
84+
assert len(kv_connector_metadata.reqs_to_recv) == 0
8285

8386
# This request should be scheduled regularly.
8487
assert len(scheduler_output.scheduled_new_reqs) == 1
@@ -371,3 +374,70 @@ def test_concurrent_load_kv(
371374
if cnt_finished_reqs == total_reqs:
372375
return
373376
raise TimeoutError("Took too long to complete async handshake.")
377+
378+
379+
@patch(
380+
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
381+
FakeNixlWrapper)
382+
def test_abort_timeout_on_prefiller(monkeypatch):
383+
"""
384+
Test lifecycle of an aborted Remote Prefill request hitting the timeout.
385+
-----> P
386+
| {process request}
387+
<-\--- | {result is NOT delivered, eg proxy is down}
388+
|
389+
|
390+
| {eventually free blocks}
391+
"""
392+
model_name = "Qwen/Qwen3-0.6B"
393+
kv_transfer_config = KVTransferConfig(
394+
kv_connector="NixlConnector",
395+
kv_role="kv_both",
396+
)
397+
timeout = 6
398+
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
399+
monkeypatch.setenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", str(timeout))
400+
llm = LLM(
401+
model=model_name,
402+
enforce_eager=True,
403+
gpu_memory_utilization=0.5,
404+
kv_transfer_config=kv_transfer_config,
405+
)
406+
remote_prefill_opts = {
407+
"do_remote_decode": True,
408+
"do_remote_prefill": False,
409+
"remote_engine_id": None,
410+
"remote_block_ids": None,
411+
"remote_host": None,
412+
"remote_port": None,
413+
}
414+
# Simulate sidecar request
415+
sampling_params = SamplingParams(
416+
temperature=0.0,
417+
max_tokens=1,
418+
extra_args={"kv_transfer_params": remote_prefill_opts})
419+
scheduler = llm.llm_engine.engine_core.engine_core.scheduler
420+
req_to_blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
421+
0].req_to_blocks
422+
423+
padding = "Just making this request a little longer so that we're sure "
424+
"we're not hitting the small-request lower bound beneath which we don't "
425+
"actually trigger the whole kv transfer, but rather just recompute the "
426+
"blocks on D."
427+
_ = llm.generate([f"What is the capital of Japan? {padding}"],
428+
sampling_params)
429+
430+
# Request finished but not freed
431+
assert '0' in scheduler.finished_req_ids and '0' in req_to_blocks
432+
# Some other request, 0 still not freed
433+
_ = llm.generate([f"What is the capital of Italy? {padding}"],
434+
sampling_params)
435+
assert '0' in req_to_blocks
436+
assert '1' in scheduler.finished_req_ids and '1' in req_to_blocks
437+
438+
# Wait for timeout and trigger another scheduler loop
439+
time.sleep(timeout)
440+
_ = llm.generate([f"What is the capital of France? {padding}"],
441+
sampling_params)
442+
# Request-0 times out and is cleared!
443+
assert '0' not in req_to_blocks

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,15 +79,16 @@ class ReqMeta:
7979
class NixlConnectorMetadata(KVConnectorMetadata):
8080

8181
def __init__(self):
82-
self.requests: dict[ReqId, ReqMeta] = {}
82+
self.reqs_to_recv: dict[ReqId, ReqMeta] = {}
83+
self.reqs_to_send: dict[ReqId, float] = {}
8384

8485
def add_new_req(
8586
self,
8687
request_id: ReqId,
8788
local_block_ids: list[int],
8889
kv_transfer_params: dict[str, Any],
8990
):
90-
self.requests[request_id] = ReqMeta(
91+
self.reqs_to_recv[request_id] = ReqMeta(
9192
local_block_ids=local_block_ids,
9293
remote_block_ids=kv_transfer_params["remote_block_ids"],
9394
remote_engine_id=kv_transfer_params["remote_engine_id"],
@@ -194,10 +195,12 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
194195
vllm_config.parallel_config.tensor_parallel_size)
195196
logger.info("Initializing NIXL Scheduler %s", engine_id)
196197

197-
# Requests that need to start recv.
198+
# Requests that need to start recv/send.
198199
# New requests are added by update_state_after_alloc in
199200
# the scheduler. Used to make metadata passed to Worker.
200201
self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {}
202+
# Reqs to send and their expiration time
203+
self._reqs_need_send: dict[ReqId, float] = {}
201204

202205
def get_num_new_matched_tokens(
203206
self, request: "Request",
@@ -284,6 +287,9 @@ def build_connector_meta(
284287
# Clear the list once workers start the transfers
285288
self._reqs_need_recv.clear()
286289

290+
meta.reqs_to_send = self._reqs_need_send
291+
self._reqs_need_send = {}
292+
287293
return meta
288294

289295
def request_finished(
@@ -325,6 +331,11 @@ def request_finished(
325331
# If prompt < block_size, no xfer so free blocks immediately.
326332
delay_free_blocks = len(computed_block_ids) > 0
327333

334+
if delay_free_blocks:
335+
# Prefill request on remote. It will be read from D upon completion
336+
self._reqs_need_send[request.request_id] = time.perf_counter(
337+
) + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT
338+
328339
return delay_free_blocks, dict(
329340
do_remote_prefill=True,
330341
do_remote_decode=False,
@@ -394,6 +405,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
394405
# In progress transfers.
395406
# [req_id -> list[handle]]
396407
self._recving_transfers = defaultdict[ReqId, list[Transfer]](list)
408+
# Track the expiration time of requests that are waiting to be sent.
409+
self._reqs_to_send: dict[ReqId, float] = {}
397410

398411
# Complete transfer tracker. Used by the rank 0 to track finished
399412
# transactions on ranks 1 to N-1.
@@ -826,6 +839,16 @@ def get_finished(self) -> tuple[set[str], set[str]]:
826839
"and %s requests done recving", self.tp_rank,
827840
len(done_sending), len(done_recving))
828841

842+
# Handle timeout to avoid stranding blocks on remote.
843+
now = time.perf_counter()
844+
while self._reqs_to_send:
845+
req_id, expires = next(iter(self._reqs_to_send.items()))
846+
# Sorted dict, oldest requests are put first so we can exit early.
847+
if now < expires:
848+
break
849+
del self._reqs_to_send[req_id]
850+
done_sending.add(req_id)
851+
829852
if self.world_size == 1:
830853
return done_sending, done_recving
831854

@@ -857,7 +880,7 @@ def get_finished(self) -> tuple[set[str], set[str]]:
857880

858881
all_done_sending: set[str] = set()
859882
for req_id in list(self._done_sending_count.keys()):
860-
if self._done_sending_count[req_id] == self.world_size:
883+
if self._done_sending_count[req_id] >= self.world_size:
861884
del self._done_sending_count[req_id]
862885
all_done_sending.add(req_id)
863886

@@ -887,6 +910,7 @@ def _get_new_notifs(self) -> set[str]:
887910
tp_ratio):
888911
notified_req_ids.add(req_id)
889912
del self.consumer_notification_counts_by_req[req_id]
913+
del self._reqs_to_send[req_id]
890914
return notified_req_ids
891915

892916
def _pop_done_transfers(
@@ -921,7 +945,7 @@ def start_load_kv(self, metadata: NixlConnectorMetadata):
921945
Start loading by triggering non-blocking nixl_xfer.
922946
We check for these trnxs to complete in each step().
923947
"""
924-
for req_id, meta in metadata.requests.items():
948+
for req_id, meta in metadata.reqs_to_recv.items():
925949
remote_engine_id = meta.remote_engine_id
926950
logger.debug(
927951
"start_load_kv for request %s from remote engine %s. "
@@ -943,6 +967,9 @@ def start_load_kv(self, metadata: NixlConnectorMetadata):
943967
while not self._ready_requests.empty():
944968
self._read_blocks_for_req(*self._ready_requests.get_nowait())
945969

970+
# Add to requests that are waiting to be read and track expiration.
971+
self._reqs_to_send.update(metadata.reqs_to_send)
972+
946973
def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
947974
logger.debug(
948975
"Remote agent %s available, calling _read_blocks for req %s",

vllm/envs.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@
138138
VLLM_ROCM_QUICK_REDUCE_QUANTIZATION: str = "NONE"
139139
VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True
140140
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None
141+
VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 120
141142

142143

143144
def get_default_cache_root():
@@ -953,7 +954,14 @@ def get_vllm_port() -> Optional[int]:
953954
# generations on machines < 100 for compressed-tensors
954955
# models
955956
"VLLM_USE_NVFP4_CT_EMULATIONS":
956-
lambda: bool(int(os.getenv("VLLM_USE_NVFP4_CT_EMULATIONS", "0")))
957+
lambda: bool(int(os.getenv("VLLM_USE_NVFP4_CT_EMULATIONS", "0"))),
958+
959+
# Time (in seconds) after which the KV cache on the producer side is
960+
# automatically cleared if no READ notification is received from the
961+
# consumer. This is only applicable when using NixlConnector in a
962+
# disaggregated decode-prefill setup.
963+
"VLLM_NIXL_ABORT_REQUEST_TIMEOUT":
964+
lambda: int(os.getenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", "120"))
957965
}
958966

959967
# --8<-- [end:env-vars-definition]

0 commit comments

Comments
 (0)