Skip to content

Commit 59dd311

Browse files
authored
[KVConnector] Keep KVTransferParams as a dict (#18033)
1 parent d066e52 commit 59dd311

File tree

7 files changed

+64
-157
lines changed

7 files changed

+64
-157
lines changed

tests/v1/kv_connector/unit/utils.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
# SPDX-License-Identifier: Apache-2.0
2-
from typing import Optional
2+
from typing import Any, Optional
33

44
import torch
55

66
from vllm import SamplingParams
77
from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig,
88
ModelConfig, SchedulerConfig, VllmConfig)
9-
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
10-
NixlKVTransferParams)
119
from vllm.v1.core.sched.scheduler import Scheduler
1210
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
1311
KVCacheGroupSpec)
@@ -124,20 +122,20 @@ def create_request(
124122
) -> Request:
125123
"""Make dummy request for testing."""
126124

125+
kv_transfer_params: Optional[dict[str, Any]] = None
126+
127127
if do_remote_decode:
128128
assert not do_remote_prefill
129-
kv_transfer_params = NixlKVTransferParams(do_remote_prefill=False,
130-
do_remote_decode=True)
129+
kv_transfer_params = dict(do_remote_prefill=False,
130+
do_remote_decode=True)
131131
elif do_remote_prefill:
132-
kv_transfer_params = NixlKVTransferParams(
133-
do_remote_prefill=True,
134-
do_remote_decode=False,
135-
remote_engine_id="my-engine-id",
136-
remote_block_ids=list(range(num_remote_blocks)),
137-
remote_host="my-host",
138-
remote_port=1234)
139-
else:
140-
kv_transfer_params = None
132+
kv_transfer_params = dict(do_remote_prefill=True,
133+
do_remote_decode=False,
134+
remote_engine_id="my-engine-id",
135+
remote_block_ids=list(
136+
range(num_remote_blocks)),
137+
remote_host="my-host",
138+
remote_port=1234)
141139

142140
max_tokens = 1 if do_remote_decode else max_tokens
143141
sampling_params = SamplingParams(max_tokens=max_tokens)
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# SPDX-License-Identifier: Apache-2.0
22
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
3-
KVConnectorBase_V1, KVConnectorRole, KVTransferParams)
3+
KVConnectorBase_V1, KVConnectorRole)
44

5-
__all__ = ["KVConnectorRole", "KVConnectorBase_V1", "KVTransferParams"]
5+
__all__ = ["KVConnectorRole", "KVConnectorBase_V1"]

vllm/distributed/kv_transfer/kv_connector/v1/base.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -48,23 +48,6 @@ class KVConnectorRole(enum.Enum):
4848
WORKER = 1
4949

5050

51-
class KVTransferParams:
52-
"""
53-
Abstract KVTransferParams used to send KVTransfer
54-
parameters between instances of vLLM.
55-
56-
Specific instances of KVConnector customize this
57-
method for serializing / deserializing msgs sent
58-
via the HTTP protocol.
59-
"""
60-
61-
@staticmethod
62-
def from_raw_dict(
63-
raw_dict: Optional[dict[str,
64-
Any]]) -> Optional["KVTransferParams"]:
65-
return None
66-
67-
6851
@dataclass
6952
class KVConnectorMetadata:
7053
"""
@@ -75,7 +58,6 @@ class KVConnectorMetadata:
7558

7659

7760
class KVConnectorBase_V1(ABC):
78-
_KVTransferParams = KVTransferParams
7961

8062
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
8163
logger.warning(
@@ -213,13 +195,6 @@ def get_finished(
213195
# Scheduler-side methods
214196
# ==============================
215197

216-
def set_kv_transfer_params(self, request: "Request"):
217-
"""Parse raw KV Transfer params."""
218-
assert request.kv_transfer_params is None
219-
kv_transfer_params = self._KVTransferParams.from_raw_dict(
220-
request.raw_kv_transfer_params)
221-
request.kv_transfer_params = kv_transfer_params
222-
223198
@abstractmethod
224199
def get_num_new_matched_tokens(
225200
self,

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 34 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from vllm import envs
1717
from vllm.config import VllmConfig
1818
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
19-
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole, KVTransferParams)
19+
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
2020
from vllm.distributed.parallel_state import (
2121
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
2222
get_tp_group)
@@ -44,56 +44,6 @@
4444
NixlWrapper = None
4545

4646

47-
@dataclass
48-
class NixlKVTransferParams(KVTransferParams):
49-
50-
def __init__(
51-
self,
52-
do_remote_prefill: bool,
53-
do_remote_decode: bool,
54-
remote_block_ids: Optional[list[int]] = None,
55-
remote_host: Optional[str] = None,
56-
remote_port: Optional[int] = None,
57-
remote_engine_id: Optional[str] = None,
58-
):
59-
self.do_remote_prefill = do_remote_prefill
60-
self.do_remote_decode = do_remote_decode
61-
self.remote_block_ids = remote_block_ids
62-
self.remote_host = remote_host
63-
self.remote_port = remote_port
64-
self.remote_engine_id = remote_engine_id
65-
66-
@staticmethod
67-
def from_raw_dict(
68-
raw_dict: Optional[dict[str,
69-
Any]]) -> Optional["NixlKVTransferParams"]:
70-
71-
# If no raw transfer params passed, return None.
72-
if raw_dict is None:
73-
return None
74-
75-
# Validate the request is formatted properly.
76-
if (("do_remote_prefill" not in raw_dict)
77-
or ("do_remote_decode" not in raw_dict)
78-
or ("remote_block_ids" not in raw_dict)
79-
or ("remote_host" not in raw_dict)
80-
or ("remote_port" not in raw_dict)
81-
or ("remote_engine_id" not in raw_dict)):
82-
logger.warning(
83-
"Got invalid KVTransferParams: %s. This "
84-
"request will not utilize KVTransfer", raw_dict)
85-
return None
86-
87-
return NixlKVTransferParams(
88-
do_remote_prefill=raw_dict["do_remote_prefill"],
89-
do_remote_decode=raw_dict["do_remote_decode"],
90-
remote_block_ids=raw_dict["remote_block_ids"],
91-
remote_host=raw_dict["remote_host"],
92-
remote_port=raw_dict["remote_port"],
93-
remote_engine_id=raw_dict["remote_engine_id"],
94-
)
95-
96-
9747
class NixlAgentMetadata(
9848
msgspec.Struct,
9949
omit_defaults=True, # type: ignore[call-arg]
@@ -123,25 +73,18 @@ def add_new_req(
12373
self,
12474
request_id: str,
12575
local_block_ids: list[int],
126-
kv_transfer_params: NixlKVTransferParams,
76+
kv_transfer_params: dict[str, Any],
12777
):
128-
assert request_id not in self.requests
129-
assert kv_transfer_params.remote_block_ids is not None
130-
assert kv_transfer_params.remote_engine_id is not None
131-
assert kv_transfer_params.remote_host is not None
132-
assert kv_transfer_params.remote_port is not None
133-
13478
self.requests[request_id] = ReqMeta(
13579
local_block_ids=local_block_ids,
136-
remote_block_ids=kv_transfer_params.remote_block_ids,
137-
remote_engine_id=kv_transfer_params.remote_engine_id,
138-
remote_host=kv_transfer_params.remote_host,
139-
remote_port=kv_transfer_params.remote_port,
80+
remote_block_ids=kv_transfer_params["remote_block_ids"],
81+
remote_engine_id=kv_transfer_params["remote_engine_id"],
82+
remote_host=kv_transfer_params["remote_host"],
83+
remote_port=kv_transfer_params["remote_port"],
14084
)
14185

14286

14387
class NixlConnector(KVConnectorBase_V1):
144-
_KVTransferParams: type[NixlKVTransferParams] = NixlKVTransferParams
14588

14689
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
14790
assert vllm_config.kv_transfer_config is not None
@@ -253,52 +196,52 @@ def get_num_new_matched_tokens(
253196
asynchronously (between scheduler steps).
254197
"""
255198

199+
params = request.kv_transfer_params
256200
logger.debug(
257201
"NIXLConnector get_num_new_matched_tokens: "
258202
"num_computed_tokens=%s, kv_transfer_params=%s",
259-
num_computed_tokens, request.kv_transfer_params)
260-
261-
# No KVTransfer for this request.
262-
if request.kv_transfer_params is None:
263-
return 0, False
264-
assert isinstance(request.kv_transfer_params, NixlKVTransferParams)
203+
num_computed_tokens, params)
265204

266-
# Remote prefill: get all prompt blocks from remote.
267-
if request.kv_transfer_params.do_remote_prefill:
205+
if params is not None and params.get("do_remote_prefill"):
206+
# Remote prefill: get all prompt blocks from remote.
268207
assert num_computed_tokens % self.block_size == 0
269208
rounded_num_prompt_tokens = round_down(
270209
len(request.prompt_token_ids), self.block_size)
271210
count = max(rounded_num_prompt_tokens - num_computed_tokens, 0)
272211
return count, count > 0
273212

213+
# No remote prefill for this request.
274214
return 0, False
275215

276216
def update_state_after_alloc(self, request: "Request",
277217
blocks: "KVCacheBlocks",
278218
num_external_tokens: int):
279219

220+
params = request.kv_transfer_params
280221
logger.debug(
281222
"NIXLConnector update_state_after_alloc: "
282223
"num_external_tokens=%s, kv_transfer_params=%s",
283-
num_external_tokens, request.kv_transfer_params)
224+
num_external_tokens, params)
284225

285-
if request.kv_transfer_params is None:
286-
return
287-
288-
assert isinstance(request.kv_transfer_params, NixlKVTransferParams)
289-
if request.kv_transfer_params.do_remote_prefill:
226+
if params is not None and params.get("do_remote_prefill"):
290227
# NOTE(rob): if prompt < block_size, no remote blocks
291228
# since the remote only sends fully computed blocks, so
292229
# skip recving for this request. num_external_tokens
293230
# should be 0 if there are no remote blocks.
294-
if request.kv_transfer_params.remote_block_ids:
295-
# Get unhashed blocks to pull from remote.
296-
self._reqs_need_recv[request.request_id] = (
297-
request, blocks.get_unhashed_block_ids())
231+
if params.get("remote_block_ids"):
232+
if all(p in params for p in ("remote_engine_id", "remote_host",
233+
"remote_port")):
234+
# Get unhashed blocks to pull from remote.
235+
self._reqs_need_recv[request.request_id] = (
236+
request, blocks.get_unhashed_block_ids())
237+
else:
238+
logger.warning(
239+
"Got invalid KVTransferParams: %s. This "
240+
"request will not utilize KVTransfer", params)
298241
else:
299242
assert num_external_tokens == 0
300243
# Only trigger 1 KV transfer per request.
301-
request.kv_transfer_params.do_remote_prefill = False
244+
params["do_remote_prefill"] = False
302245

303246
def build_connector_meta(
304247
self,
@@ -308,7 +251,7 @@ def build_connector_meta(
308251

309252
# Loop through scheduled reqs and convert to ReqMeta.
310253
for req_id, (req, block_ids) in self._reqs_need_recv.items():
311-
assert isinstance(req.kv_transfer_params, NixlKVTransferParams)
254+
assert req.kv_transfer_params is not None
312255
meta.add_new_req(
313256
request_id=req_id,
314257
local_block_ids=block_ids,
@@ -330,34 +273,30 @@ def request_finished(
330273
should be freed now or will be sent asynchronously and freed later.
331274
"""
332275

276+
params = request.kv_transfer_params
333277
logger.debug(
334-
"NIXLConnector request_finished, "
335-
"request_status=%s, kv_transfer_params=%s", request.status,
336-
request.kv_transfer_params)
337-
338-
if request.kv_transfer_params is None:
339-
return False, None
340-
assert isinstance(request.kv_transfer_params, NixlKVTransferParams)
278+
"NIXLConnector request_finished, request_status=%s, "
279+
"kv_transfer_params=%s", request.status, params)
341280

342-
if ((not request.kv_transfer_params.do_remote_decode)
343-
or (request.status != RequestStatus.FINISHED_LENGTH_CAPPED)):
281+
if (params is None or not params.get("do_remote_decode")
282+
or request.status != RequestStatus.FINISHED_LENGTH_CAPPED):
344283
return False, None
345284

346285
# Get computed blocks.
347286
all_full = request.num_computed_tokens % self.block_size == 0
348-
computed_block_ids = (block_ids if all_full else block_ids[:-1])
287+
computed_block_ids = block_ids if all_full else block_ids[:-1]
349288

350289
# If prompt < block_size, no xfer so free blocks immediately.
351290
delay_free_blocks = len(computed_block_ids) > 0
352291

353-
return delay_free_blocks, NixlKVTransferParams(
292+
return delay_free_blocks, dict(
354293
do_remote_prefill=True,
355294
do_remote_decode=False,
356295
remote_block_ids=computed_block_ids,
357296
remote_engine_id=self.engine_id,
358297
remote_host=envs.VLLM_NIXL_SIDE_CHANNEL_HOST,
359298
remote_port=envs.VLLM_NIXL_SIDE_CHANNEL_PORT,
360-
).__dict__
299+
)
361300

362301

363302
class NixlConnectorWorker:

vllm/v1/core/sched/scheduler.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212
from vllm.distributed.kv_transfer.kv_connector.factory import (
1313
KVConnectorFactory)
1414
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
15-
KVConnectorRole,
16-
KVTransferParams)
15+
KVConnectorRole)
1716
from vllm.logger import init_logger
1817
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
1918
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
@@ -931,8 +930,13 @@ def get_kv_connector(self) -> Optional[KVConnectorBase_V1]:
931930
return self.connector
932931

933932
def _connector_finished(
934-
self, request: Request) -> tuple[bool, Optional[KVTransferParams]]:
935-
"""Invoke the KV connector request_finished() method if applicable."""
933+
self, request: Request) -> tuple[bool, Optional[dict[str, Any]]]:
934+
"""
935+
Invoke the KV connector request_finished() method if applicable.
936+
937+
Returns optional kv transfer parameters to be included with the
938+
request outputs.
939+
"""
936940
if self.connector is None:
937941
return False, None
938942
block_ids = self.kv_cache_manager.get_block_ids(request.request_id)

vllm/v1/engine/core.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -182,14 +182,10 @@ def add_request(self, request: EngineCoreRequest):
182182
# Start grammar compilation asynchronously
183183
self.structured_output_manager.grammar_init(req)
184184

185-
if req.raw_kv_transfer_params is not None:
186-
if (kv_connector := self.scheduler.get_kv_connector()):
187-
# Parse raw KV transfer params via connector.
188-
kv_connector.set_kv_transfer_params(req)
189-
else:
190-
logger.warning(
191-
"Got KVTransferParams, but no KVConnector found. "
192-
"Disabling KVTransfer for this request.")
185+
if req.kv_transfer_params is not None and (
186+
not self.scheduler.get_kv_connector()):
187+
logger.warning("Got kv_transfer_params, but no KVConnector found. "
188+
"Disabling KVTransfer for this request.")
193189

194190
self.scheduler.add_request(req)
195191

vllm/v1/request.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import enum
44
from typing import TYPE_CHECKING, Any, Optional, Union
55

6-
from vllm.distributed.kv_transfer.kv_connector.v1 import KVTransferParams
76
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
87
from vllm.sampling_params import SamplingParams
98
from vllm.utils import is_list_of
@@ -62,14 +61,10 @@ def __init__(
6261
self.num_encoder_inputs = len(self.mm_inputs)
6362
self.has_encoder_inputs = self.num_encoder_inputs > 0
6463

65-
# P/D: KV transfer parameters (raw and parsed).
66-
raw_params = (None if sampling_params.extra_args is None
67-
else sampling_params.extra_args.get(
68-
"kv_transfer_params", None))
69-
self.raw_kv_transfer_params: Optional[dict[str, Any]] = raw_params
70-
# Each connector parses the raw dictionary and sets this
71-
# attr the first time that the request is processed.
72-
self.kv_transfer_params: Optional[KVTransferParams] = None
64+
# P/D: Connector-specific KV transfer parameters.
65+
kv_params = (None if sampling_params.extra_args is None else
66+
sampling_params.extra_args.get("kv_transfer_params"))
67+
self.kv_transfer_params: Optional[dict[str, Any]] = kv_params
7368

7469
# Sanity check
7570
assert len(self.mm_inputs) == len(self.mm_positions)

0 commit comments

Comments
 (0)