Skip to content

Commit 1bcd15e

Browse files
authored
[BugFix][P/D] Fix for cases where _recving_transfers can be cleaned up when *all* transfer done (#19874)
Signed-off-by: Linkun Chen <github@lkchen.net>
1 parent 2ebff5b commit 1bcd15e

File tree

2 files changed

+179
-4
lines changed

2 files changed

+179
-4
lines changed

tests/v1/kv_connector/unit/test_nixl_connector.py

Lines changed: 173 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,23 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
import time
5+
import uuid
6+
from collections import defaultdict
7+
from typing import Optional
8+
from unittest.mock import patch
9+
10+
import pytest
11+
12+
try:
13+
from nixl._api import nixl_agent as NixlWrapper
14+
except ImportError:
15+
NixlWrapper = None
16+
417
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
5-
NixlConnectorMetadata)
18+
KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata,
19+
NixlConnectorWorker)
20+
from vllm.forward_context import ForwardContext
621

722
from .utils import create_request, create_scheduler, create_vllm_config
823

@@ -72,3 +87,160 @@ def test_prompt_less_than_block_size():
7287

7388
# This request should be scheduled regularly.
7489
assert len(scheduler_output.scheduled_new_reqs) == 1
90+
91+
92+
class FakeNixlWrapper:
93+
"""Mock implementation of NixlWrapper for testing.
94+
95+
We don't inherit from NixlWrapper because NixlWrapper could be None.
96+
"""
97+
98+
AGENT_METADATA = b"fake_agent_metadata"
99+
REMOTE_AGENT_NAME = "remote_agent"
100+
101+
def __init__(self, agent_name: str, *args, **kwargs):
102+
self._cycles_before_xfer_done = 0
103+
self._check_xfer_state_cycles: defaultdict[int, int] = defaultdict(
104+
lambda: 0)
105+
106+
def get_reg_descs(self, caches_data, memory_type: str) -> list:
107+
return [str(uuid.uuid4()) for _ in caches_data]
108+
109+
def register_memory(self, descs) -> None:
110+
pass
111+
112+
def get_xfer_descs(self, blocks_data, memory_type: str) -> list:
113+
return [str(uuid.uuid4()) for _ in blocks_data]
114+
115+
def prep_xfer_dlist(self, agent_name: str, descs: list) -> int:
116+
return uuid.uuid4().int
117+
118+
def get_agent_metadata(self) -> bytes:
119+
return self.AGENT_METADATA
120+
121+
def add_remote_agent(self, agent_metadata: bytes) -> str:
122+
return self.REMOTE_AGENT_NAME
123+
124+
def get_new_notifs(self) -> dict[str, list[bytes]]:
125+
# Used to collect done_sending, which we don't test yet.
126+
return {}
127+
128+
def check_xfer_state(self, handle: int) -> str:
129+
if self._check_xfer_state_cycles[
130+
handle] >= self._cycles_before_xfer_done:
131+
return "DONE"
132+
self._check_xfer_state_cycles[handle] += 1
133+
return "PROC"
134+
135+
def release_xfer_handle(self, handle: int) -> None:
136+
pass
137+
138+
def send_notif(self, agent_name: str, notif_msg: bytes) -> None:
139+
pass
140+
141+
def make_prepped_xfer(self,
142+
xfer_type: str,
143+
local_xfer_side_handle: int,
144+
local_block_descs_ids: list[int],
145+
remote_xfer_side_handle: int,
146+
remote_block_descs_ids: list[int],
147+
notif_msg: Optional[bytes] = None) -> int:
148+
return uuid.uuid4().int
149+
150+
def transfer(self, handle: int) -> str:
151+
return "PROC"
152+
153+
############################################################
154+
# Follow are for changing the behavior during testing.
155+
############################################################
156+
157+
def set_cycles_before_xfer_done(self, cycles: int):
158+
"""Set the number of cycles before a transfer is considered done."""
159+
self._cycles_before_xfer_done = cycles
160+
161+
162+
class FakeNixlConnectorWorker(NixlConnectorWorker):
163+
164+
REMOTE_ENGINE_ID = "remote_engine"
165+
166+
def __init__(self, *args, hand_shake_latency: float = 1.8, **kwargs):
167+
super().__init__(*args, **kwargs)
168+
self._hand_shake_latency = hand_shake_latency
169+
170+
def _nixl_handshake(self, host: str, port: int):
171+
# Mimic slow _nixl_handshake, as well as bypass zmq communication.
172+
time.sleep(self._hand_shake_latency)
173+
# These should've been done in register_kv_caches(), called by
174+
# gpu_model_runner. Here we just hardcode some dummy values.
175+
self.slot_size_bytes = 4096
176+
self.block_len = self.slot_size_bytes * self.block_size
177+
self.num_blocks = 1
178+
self.dst_num_blocks[self.engine_id] = self.num_blocks
179+
180+
self.add_remote_agent(
181+
NixlAgentMetadata(
182+
engine_id=self.REMOTE_ENGINE_ID,
183+
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
184+
kv_caches_base_addr=[0],
185+
num_blocks=1,
186+
tp_size=1,
187+
block_len=self.block_len,
188+
attn_backend_name=self.backend_name,
189+
))
190+
191+
192+
@pytest.mark.skipif(NixlWrapper is None, reason="nixl not installed")
193+
@patch(
194+
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
195+
FakeNixlWrapper)
196+
def test_multi_xfer_one_engine(
197+
# dist_init is a fixture that initializes the distributed environment.
198+
dist_init):
199+
"""Test case where multiple xfers are initiated to the same engine.
200+
201+
This test triggers the connector to load remote KV for the same
202+
`request_id`. The transfer is not done immediately due to
203+
`set_cycles_before_xfer_done`, so there is a state where there are multiple
204+
transfer states for the same `request_id`, and `get_finished` should handle
205+
it correctly (wait for all transfers to be done).
206+
"""
207+
vllm_config = create_vllm_config()
208+
209+
request_id = "req_id"
210+
211+
# Test worker role in decode server.
212+
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
213+
connector.connector_worker = FakeNixlConnectorWorker(vllm_config,
214+
connector.engine_id,
215+
hand_shake_latency=0)
216+
assert isinstance(connector.connector_worker.nixl_wrapper, FakeNixlWrapper)
217+
connector.connector_worker.nixl_wrapper.set_cycles_before_xfer_done(3)
218+
for i in range(4):
219+
metadata = NixlConnectorMetadata()
220+
metadata.add_new_req(request_id=request_id,
221+
local_block_ids=[i + 1, i + 2, i + 3],
222+
kv_transfer_params={
223+
"remote_block_ids": [i + 4, i + 5, i + 6],
224+
"remote_engine_id":
225+
FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
226+
"remote_host": "localhost",
227+
"remote_port": 1234,
228+
})
229+
connector.bind_connector_metadata(metadata)
230+
231+
dummy_ctx = ForwardContext(
232+
no_compile_layers={},
233+
attn_metadata={},
234+
virtual_engine=0,
235+
)
236+
_before_load = time.perf_counter()
237+
connector.start_load_kv(dummy_ctx)
238+
_after_load = time.perf_counter()
239+
assert _after_load - _before_load < 0.1, "start_load_kv took " \
240+
f"{_after_load - _before_load} seconds"
241+
242+
while True:
243+
_, done_recving = connector.get_finished(finished_req_ids=set())
244+
if len(done_recving) > 0:
245+
assert request_id in done_recving
246+
break

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -841,17 +841,20 @@ def _pop_done_transfers(
841841
"""
842842
done_req_ids: set[str] = set()
843843
for req_id, handles in list(transfers.items()):
844-
for handle, xfer_stime in handles:
844+
in_progress = False
845+
for handle, _xfer_stime in handles:
845846
xfer_state = self.nixl_wrapper.check_xfer_state(handle)
846847
if xfer_state == "DONE":
847848
self.nixl_wrapper.release_xfer_handle(handle)
848-
done_req_ids.add(req_id)
849-
del transfers[req_id]
850849
elif xfer_state == "PROC":
850+
in_progress = True
851851
continue
852852
else:
853853
raise RuntimeError("Transfer failed with state %s",
854854
xfer_state)
855+
if not in_progress:
856+
done_req_ids.add(req_id)
857+
del transfers[req_id]
855858
return done_req_ids
856859

857860
def start_load_kv(self, metadata: NixlConnectorMetadata):

0 commit comments

Comments
 (0)