Skip to content

Commit efbd791

Browse files
committed
initial noodling
Signed-off-by: Will Eaton <weaton@redhat.com>
1 parent 29fa5ca commit efbd791

File tree

13 files changed

+342
-141
lines changed

13 files changed

+342
-141
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,3 +202,4 @@ shellcheck*/
202202

203203
# Ingore moe/marlin_moe gen code
204204
csrc/moe/marlin_moe_wna16/kernel_*
205+
uv.lock

vllm/config.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import textwrap
1111
import uuid
1212
import warnings
13-
from collections import Counter
13+
from collections import Counter, defaultdict
1414
from contextlib import contextmanager
1515
from dataclasses import (MISSING, Field, asdict, field, fields, is_dataclass,
1616
replace)
@@ -33,6 +33,8 @@
3333
import vllm.envs as envs
3434
from vllm import version
3535
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
36+
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
37+
KVConnectorHandshakeMetadata)
3638
from vllm.logger import init_logger
3739
from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS,
3840
QuantizationMethods,
@@ -1511,6 +1513,12 @@ class CacheConfig:
15111513
num_cpu_blocks: Optional[int] = field(default=None, init=False)
15121514
"""The number of blocks to allocate for CPU memory."""
15131515

1516+
transfer_handshake_metadata: dict[int, dict[int,
1517+
KVConnectorHandshakeMetadata]] = field(
1518+
default_factory=lambda: defaultdict(dict),
1519+
init=False)
1520+
"""Metadata for the KV connector handshake."""
1521+
15141522
def compute_hash(self) -> str:
15151523
"""
15161524
WARNING: Whenever a new field is added to this config,

vllm/distributed/kv_transfer/kv_connector/v1/base.py

Lines changed: 71 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,11 @@
3232

3333
import enum
3434
from abc import ABC, abstractmethod
35-
from typing import TYPE_CHECKING, Any, Optional
35+
from typing import TYPE_CHECKING, Any, Callable, Optional
3636

3737
import torch
38+
import msgspec
39+
from pydantic_core import core_schema
3840

3941
from vllm.logger import init_logger
4042
from vllm.v1.core.sched.output import SchedulerOutput
@@ -62,18 +64,58 @@ class KVConnectorMetadata:
6264
Abstract Metadata used to communicate between the
6365
Scheduler KVConnector and Worker KVConnector.
6466
"""
65-
pass
67+
68+
def __init__(self):
69+
pass
70+
6671

72+
class KVConnectorHandshakeMetadata(
73+
msgspec.Struct,
74+
omit_defaults=True, # type: ignore[call-arg]
75+
# required for @cached_property.
76+
dict=True):
77+
"""
78+
Metadata optionally used for out of band connector handshake between P/D workers.
79+
"""
80+
connector_type: str = "base"
81+
82+
@classmethod
83+
def __get_pydantic_core_schema__(
84+
cls,
85+
_source_type: Any,
86+
_handler: Callable[[Any], core_schema.CoreSchema]
87+
) -> core_schema.CoreSchema:
88+
"""bridge msgspec.Struct with pydantic for schema generation"""
89+
return core_schema.no_info_after_validator_function(
90+
cls,
91+
core_schema.dict_schema()
92+
)
93+
94+
class KVConnectorTransferMetadata(
95+
msgspec.Struct,
96+
omit_defaults=True, # type: ignore[call-arg]
97+
dict=True):
98+
"""
99+
Wrapper for transfer handshake metadata sent between engine and utils.
100+
"""
101+
tensor_parallel_rank: int
102+
data_parallel_rank: int
103+
content: Optional[dict]
104+
67105

68106
class KVConnectorBase_V1(ABC):
69107

70-
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
108+
def __init__(self,
109+
vllm_config: "VllmConfig",
110+
role: KVConnectorRole):
71111
logger.warning(
72112
"Initializing KVConnectorBase_V1. This API is experimental and "
73113
"subject to change in the future as we iterate the design.")
74114
self._connector_metadata = KVConnectorMetadata()
75115
self._vllm_config = vllm_config
76116
self._role = role
117+
self._handshake_metadata: Optional[KVConnectorHandshakeMetadata] = None
118+
77119

78120
@property
79121
def role(self) -> KVConnectorRole:
@@ -104,7 +146,7 @@ def clear_connector_metadata(self) -> None:
104146
"""
105147
self._connector_metadata = KVConnectorMetadata()
106148

107-
def _get_connector_metadata(self) -> KVConnectorMetadata:
149+
def get_connector_metadata(self) -> KVConnectorMetadata:
108150
"""Get the connector metadata.
109151
110152
This function should only be called inside the connector.
@@ -201,6 +243,31 @@ def get_finished(
201243
"""
202244
return None, None
203245

246+
def set_handshake_metadata(
247+
self, handshake_metadata: KVConnectorHandshakeMetadata) -> None:
248+
"""
249+
Set the handshake metadata for the connector.
250+
251+
This metadata is used for out-of-band connector handshake
252+
between P/D workers.
253+
254+
Args:
255+
handshake_metadata (KVConnectorHandshakeMetadata): the handshake
256+
metadata.
257+
"""
258+
self._handshake_metadata = handshake_metadata
259+
260+
261+
def get_handshake_metadata(
262+
self) -> Optional[KVConnectorHandshakeMetadata]:
263+
"""
264+
Get the handshake metadata for the connector.
265+
266+
Returns:
267+
KVConnectorHandshakeMetadata: the handshake metadata.
268+
"""
269+
return self._handshake_metadata
270+
204271
# ==============================
205272
# Scheduler-side methods
206273
# ==============================

0 commit comments

Comments
 (0)