|
1 | 1 | # SPDX-License-Identifier: Apache-2.0
|
2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
3 | 3 |
|
| 4 | +import time |
| 5 | +import uuid |
| 6 | +from collections import defaultdict |
| 7 | +from typing import Optional |
| 8 | +from unittest.mock import patch |
| 9 | + |
| 10 | +import pytest |
| 11 | + |
| 12 | +try: |
| 13 | + from nixl._api import nixl_agent as NixlWrapper |
| 14 | +except ImportError: |
| 15 | + NixlWrapper = None |
| 16 | + |
4 | 17 | from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
|
5 |
| - NixlConnectorMetadata) |
| 18 | + KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata, |
| 19 | + NixlConnectorWorker) |
| 20 | +from vllm.forward_context import ForwardContext |
6 | 21 |
|
7 | 22 | from .utils import create_request, create_scheduler, create_vllm_config
|
8 | 23 |
|
@@ -72,3 +87,160 @@ def test_prompt_less_than_block_size():
|
72 | 87 |
|
73 | 88 | # This request should be scheduled regularly.
|
74 | 89 | assert len(scheduler_output.scheduled_new_reqs) == 1
|
| 90 | + |
| 91 | + |
| 92 | +class FakeNixlWrapper: |
| 93 | + """Mock implementation of NixlWrapper for testing. |
| 94 | + |
| 95 | + We don't inherit from NixlWrapper because NixlWrapper could be None. |
| 96 | + """ |
| 97 | + |
| 98 | + AGENT_METADATA = b"fake_agent_metadata" |
| 99 | + REMOTE_AGENT_NAME = "remote_agent" |
| 100 | + |
| 101 | + def __init__(self, agent_name: str, *args, **kwargs): |
| 102 | + self._cycles_before_xfer_done = 0 |
| 103 | + self._check_xfer_state_cycles: defaultdict[int, int] = defaultdict( |
| 104 | + lambda: 0) |
| 105 | + |
| 106 | + def get_reg_descs(self, caches_data, memory_type: str) -> list: |
| 107 | + return [str(uuid.uuid4()) for _ in caches_data] |
| 108 | + |
| 109 | + def register_memory(self, descs) -> None: |
| 110 | + pass |
| 111 | + |
| 112 | + def get_xfer_descs(self, blocks_data, memory_type: str) -> list: |
| 113 | + return [str(uuid.uuid4()) for _ in blocks_data] |
| 114 | + |
| 115 | + def prep_xfer_dlist(self, agent_name: str, descs: list) -> int: |
| 116 | + return uuid.uuid4().int |
| 117 | + |
| 118 | + def get_agent_metadata(self) -> bytes: |
| 119 | + return self.AGENT_METADATA |
| 120 | + |
| 121 | + def add_remote_agent(self, agent_metadata: bytes) -> str: |
| 122 | + return self.REMOTE_AGENT_NAME |
| 123 | + |
| 124 | + def get_new_notifs(self) -> dict[str, list[bytes]]: |
| 125 | + # Used to collect done_sending, which we don't test yet. |
| 126 | + return {} |
| 127 | + |
| 128 | + def check_xfer_state(self, handle: int) -> str: |
| 129 | + if self._check_xfer_state_cycles[ |
| 130 | + handle] >= self._cycles_before_xfer_done: |
| 131 | + return "DONE" |
| 132 | + self._check_xfer_state_cycles[handle] += 1 |
| 133 | + return "PROC" |
| 134 | + |
| 135 | + def release_xfer_handle(self, handle: int) -> None: |
| 136 | + pass |
| 137 | + |
| 138 | + def send_notif(self, agent_name: str, notif_msg: bytes) -> None: |
| 139 | + pass |
| 140 | + |
| 141 | + def make_prepped_xfer(self, |
| 142 | + xfer_type: str, |
| 143 | + local_xfer_side_handle: int, |
| 144 | + local_block_descs_ids: list[int], |
| 145 | + remote_xfer_side_handle: int, |
| 146 | + remote_block_descs_ids: list[int], |
| 147 | + notif_msg: Optional[bytes] = None) -> int: |
| 148 | + return uuid.uuid4().int |
| 149 | + |
| 150 | + def transfer(self, handle: int) -> str: |
| 151 | + return "PROC" |
| 152 | + |
| 153 | + ############################################################ |
| 154 | + # Follow are for changing the behavior during testing. |
| 155 | + ############################################################ |
| 156 | + |
| 157 | + def set_cycles_before_xfer_done(self, cycles: int): |
| 158 | + """Set the number of cycles before a transfer is considered done.""" |
| 159 | + self._cycles_before_xfer_done = cycles |
| 160 | + |
| 161 | + |
| 162 | +class FakeNixlConnectorWorker(NixlConnectorWorker): |
| 163 | + |
| 164 | + REMOTE_ENGINE_ID = "remote_engine" |
| 165 | + |
| 166 | + def __init__(self, *args, hand_shake_latency: float = 1.8, **kwargs): |
| 167 | + super().__init__(*args, **kwargs) |
| 168 | + self._hand_shake_latency = hand_shake_latency |
| 169 | + |
| 170 | + def _nixl_handshake(self, host: str, port: int): |
| 171 | + # Mimic slow _nixl_handshake, as well as bypass zmq communication. |
| 172 | + time.sleep(self._hand_shake_latency) |
| 173 | + # These should've been done in register_kv_caches(), called by |
| 174 | + # gpu_model_runner. Here we just hardcode some dummy values. |
| 175 | + self.slot_size_bytes = 4096 |
| 176 | + self.block_len = self.slot_size_bytes * self.block_size |
| 177 | + self.num_blocks = 1 |
| 178 | + self.dst_num_blocks[self.engine_id] = self.num_blocks |
| 179 | + |
| 180 | + self.add_remote_agent( |
| 181 | + NixlAgentMetadata( |
| 182 | + engine_id=self.REMOTE_ENGINE_ID, |
| 183 | + agent_metadata=FakeNixlWrapper.AGENT_METADATA, |
| 184 | + kv_caches_base_addr=[0], |
| 185 | + num_blocks=1, |
| 186 | + tp_size=1, |
| 187 | + block_len=self.block_len, |
| 188 | + attn_backend_name=self.backend_name, |
| 189 | + )) |
| 190 | + |
| 191 | + |
| 192 | +@pytest.mark.skipif(NixlWrapper is None, reason="nixl not installed") |
| 193 | +@patch( |
| 194 | + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", |
| 195 | + FakeNixlWrapper) |
| 196 | +def test_multi_xfer_one_engine( |
| 197 | + # dist_init is a fixture that initializes the distributed environment. |
| 198 | + dist_init): |
| 199 | + """Test case where multiple xfers are initiated to the same engine. |
| 200 | + |
| 201 | + This test triggers the connector to load remote KV for the same |
| 202 | + `request_id`. The transfer is not done immediately due to |
| 203 | + `set_cycles_before_xfer_done`, so there is a state where there are multiple |
| 204 | + transfer states for the same `request_id`, and `get_finished` should handle |
| 205 | + it correctly (wait for all transfers to be done). |
| 206 | + """ |
| 207 | + vllm_config = create_vllm_config() |
| 208 | + |
| 209 | + request_id = "req_id" |
| 210 | + |
| 211 | + # Test worker role in decode server. |
| 212 | + connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) |
| 213 | + connector.connector_worker = FakeNixlConnectorWorker(vllm_config, |
| 214 | + connector.engine_id, |
| 215 | + hand_shake_latency=0) |
| 216 | + assert isinstance(connector.connector_worker.nixl_wrapper, FakeNixlWrapper) |
| 217 | + connector.connector_worker.nixl_wrapper.set_cycles_before_xfer_done(3) |
| 218 | + for i in range(4): |
| 219 | + metadata = NixlConnectorMetadata() |
| 220 | + metadata.add_new_req(request_id=request_id, |
| 221 | + local_block_ids=[i + 1, i + 2, i + 3], |
| 222 | + kv_transfer_params={ |
| 223 | + "remote_block_ids": [i + 4, i + 5, i + 6], |
| 224 | + "remote_engine_id": |
| 225 | + FakeNixlConnectorWorker.REMOTE_ENGINE_ID, |
| 226 | + "remote_host": "localhost", |
| 227 | + "remote_port": 1234, |
| 228 | + }) |
| 229 | + connector.bind_connector_metadata(metadata) |
| 230 | + |
| 231 | + dummy_ctx = ForwardContext( |
| 232 | + no_compile_layers={}, |
| 233 | + attn_metadata={}, |
| 234 | + virtual_engine=0, |
| 235 | + ) |
| 236 | + _before_load = time.perf_counter() |
| 237 | + connector.start_load_kv(dummy_ctx) |
| 238 | + _after_load = time.perf_counter() |
| 239 | + assert _after_load - _before_load < 0.1, "start_load_kv took " \ |
| 240 | + f"{_after_load - _before_load} seconds" |
| 241 | + |
| 242 | + while True: |
| 243 | + _, done_recving = connector.get_finished(finished_req_ids=set()) |
| 244 | + if len(done_recving) > 0: |
| 245 | + assert request_id in done_recving |
| 246 | + break |
0 commit comments