Skip to content

Commit 1db4f47

Browse files
authored
[BugFix] Fix multi async save in MultiConnector (vllm-project#18246)
Signed-off-by: Nick Hill <nhill@redhat.com>
1 parent d3d91b6 commit 1db4f47

File tree

1 file changed

+17
-6
lines changed

1 file changed

+17
-6
lines changed

vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
import copy
3+
from dataclasses import dataclass
34
from typing import TYPE_CHECKING, Any, Optional
45

56
import torch
@@ -21,9 +22,10 @@
2122
logger = init_logger(__name__)
2223

2324

24-
class MultiKVConnectorMetadata(tuple[KVConnectorMetadata, ...],
25-
KVConnectorMetadata):
26-
pass
25+
@dataclass
26+
class MultiKVConnectorMetadata(KVConnectorMetadata):
27+
metadata: tuple[KVConnectorMetadata, ...]
28+
extra_async_saves: Optional[dict[str, int]] = None
2729

2830

2931
class MultiConnector(KVConnectorBase_V1):
@@ -54,6 +56,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
5456
# Keeps track of *additional* remaining async saves (beyond 1) to be
5557
# finished per request. Not needed for async loads since we only allow
5658
# a single connector to load.
59+
# Propagated from scheduler to worker side via the connector metadata.
5760
self._extra_async_saves: dict[str, int] = {}
5861

5962
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
@@ -66,7 +69,10 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
6669
def bind_connector_metadata(
6770
self, connector_metadata: KVConnectorMetadata) -> None:
6871
assert isinstance(connector_metadata, MultiKVConnectorMetadata)
69-
for c, cm in zip(self._connectors, connector_metadata):
72+
if connector_metadata.extra_async_saves:
73+
self._extra_async_saves.update(
74+
connector_metadata.extra_async_saves)
75+
for c, cm in zip(self._connectors, connector_metadata.metadata):
7076
c.bind_connector_metadata(cm)
7177

7278
def clear_connector_metadata(self) -> None:
@@ -152,8 +158,13 @@ def update_state_after_alloc(self, request: "Request",
152158
def build_connector_meta(
153159
self,
154160
scheduler_output: SchedulerOutput) -> MultiKVConnectorMetadata:
155-
return MultiKVConnectorMetadata(
156-
c.build_connector_meta(scheduler_output) for c in self._connectors)
161+
metadata = MultiKVConnectorMetadata(metadata=tuple(
162+
c.build_connector_meta(scheduler_output)
163+
for c in self._connectors))
164+
if self._extra_async_saves:
165+
metadata.extra_async_saves = self._extra_async_saves
166+
self._extra_async_saves = {}
167+
return metadata
157168

158169
def request_finished(
159170
self,

0 commit comments

Comments
 (0)