1
1
# SPDX-License-Identifier: Apache-2.0
2
2
import copy
3
+ from dataclasses import dataclass
3
4
from typing import TYPE_CHECKING , Any , Optional
4
5
5
6
import torch
21
22
logger = init_logger (__name__ )
22
23
23
24
24
- class MultiKVConnectorMetadata (tuple [KVConnectorMetadata , ...],
25
- KVConnectorMetadata ):
26
- pass
25
+ @dataclass
26
+ class MultiKVConnectorMetadata (KVConnectorMetadata ):
27
+ metadata : tuple [KVConnectorMetadata , ...]
28
+ extra_async_saves : Optional [dict [str , int ]] = None
27
29
28
30
29
31
class MultiConnector (KVConnectorBase_V1 ):
@@ -54,6 +56,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
54
56
# Keeps track of *additional* remaining async saves (beyond 1) to be
55
57
# finished per request. Not needed for async loads since we only allow
56
58
# a single connector to load.
59
+ # Propagated from scheduler to worker side via the connector metadata.
57
60
self ._extra_async_saves : dict [str , int ] = {}
58
61
59
62
def register_kv_caches (self , kv_caches : dict [str , torch .Tensor ]):
@@ -66,7 +69,10 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
66
69
def bind_connector_metadata (
67
70
self , connector_metadata : KVConnectorMetadata ) -> None :
68
71
assert isinstance (connector_metadata , MultiKVConnectorMetadata )
69
- for c , cm in zip (self ._connectors , connector_metadata ):
72
+ if connector_metadata .extra_async_saves :
73
+ self ._extra_async_saves .update (
74
+ connector_metadata .extra_async_saves )
75
+ for c , cm in zip (self ._connectors , connector_metadata .metadata ):
70
76
c .bind_connector_metadata (cm )
71
77
72
78
def clear_connector_metadata (self ) -> None :
@@ -152,8 +158,13 @@ def update_state_after_alloc(self, request: "Request",
152
158
def build_connector_meta (
153
159
self ,
154
160
scheduler_output : SchedulerOutput ) -> MultiKVConnectorMetadata :
155
- return MultiKVConnectorMetadata (
156
- c .build_connector_meta (scheduler_output ) for c in self ._connectors )
161
+ metadata = MultiKVConnectorMetadata (metadata = tuple (
162
+ c .build_connector_meta (scheduler_output )
163
+ for c in self ._connectors ))
164
+ if self ._extra_async_saves :
165
+ metadata .extra_async_saves = self ._extra_async_saves
166
+ self ._extra_async_saves = {}
167
+ return metadata
157
168
158
169
def request_finished (
159
170
self ,
0 commit comments