Skip to content

Draft: WIP NixlConnector allow configurable handshake backend +HTTP #19447

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 36 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
efbd791
initial noodling
wseaton Jun 2, 2025
9dd2c1c
attempt to background agent registration
wseaton Jun 10, 2025
a0c03cd
more simple immediate retry
wseaton Jun 11, 2025
cf616b5
fix bad merge
wseaton Jun 11, 2025
505f586
implement nicks suggestions
wseaton Jun 16, 2025
a1a0918
change retry logic and move to scheduler; simply background handshake…
wseaton Jun 16, 2025
518a59f
fixup by changing types at callsite
wseaton Jun 16, 2025
6426e3f
fixup popping requests
wseaton Jun 16, 2025
29885af
debug logging
wseaton Jun 16, 2025
504bba7
working checkpoint
wseaton Jun 17, 2025
dd49c96
fix unreachable
wseaton Jun 17, 2025
f65c1b3
remove uv.lock
wseaton Jun 17, 2025
efd655b
remove unused
wseaton Jun 17, 2025
dba3835
flip protocol; fix scheduling order bug
wseaton Jun 18, 2025
85855d1
fix bug in case of no kvconnectorgroup
wseaton Jun 18, 2025
1fc1af4
actually use handshake timeout; simplify route
wseaton Jun 18, 2025
9d8c15c
revert back to working het TP logic
wseaton Jun 19, 2025
fbf0630
Merge branch 'main' into kv-xfer-updates
wseaton Jun 26, 2025
69af91c
Resolve merge conflicts and remove deprecated pending handshake logic
wseaton Jun 26, 2025
91af1cd
Merge remote-tracking branch 'origin/main' into kv-xfer-updates
wseaton Jun 27, 2025
83ec83a
move nixl sidechannel to own entrypoint
wseaton Jun 27, 2025
9a86b37
allow configurable handshake strategy
wseaton Jun 30, 2025
cc887f5
pre-commit fixes
wseaton Jun 30, 2025
b960355
satisfy mypy
wseaton Jun 30, 2025
e0e88b2
add comments; retrigger fastcheck
wseaton Jun 30, 2025
34e3e55
add type hint in test
wseaton Jun 30, 2025
2325470
revert api change
wseaton Jun 30, 2025
430ad03
fix usage
wseaton Jun 30, 2025
4f30966
remove dead code
wseaton Jun 30, 2025
54e088e
fix type declaration
wseaton Jun 30, 2025
ead79d6
revert spurious changes to multi-connector
wseaton Jun 30, 2025
1a6ac2c
remove dead code for threading
wseaton Jul 1, 2025
f940b6d
remove dead engine code
wseaton Jul 1, 2025
7768f21
fix optional ordering
wseaton Jul 1, 2025
3568422
naming, fix engine default param issue; pr feedback
wseaton Jul 9, 2025
332723a
use built in util
wseaton Jul 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -202,3 +202,4 @@ shellcheck*/

# Ingore moe/marlin_moe gen code
csrc/moe/marlin_moe_wna16/kernel_*
uv.lock
11 changes: 11 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
import vllm.envs as envs
from vllm import version
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorHandshakeMetadata)
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS,
QuantizationMethods,
Expand Down Expand Up @@ -1511,6 +1513,10 @@ class CacheConfig:
num_cpu_blocks: Optional[int] = field(default=None, init=False)
"""The number of blocks to allocate for CPU memory."""

transfer_handshake_metadata: Optional[dict[int, dict[
int, KVConnectorHandshakeMetadata]]] = field(default=None, init=False)
"""Metadata for the KV connector handshake."""

def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
Expand Down Expand Up @@ -4504,6 +4510,11 @@ def __post_init__(self):
if self.kv_events_config is not None:
# Hybrid KV cache manager is not compatible with KV events.
self.scheduler_config.disable_hybrid_kv_cache_manager = True

if (self.kv_transfer_config is not None
and self.kv_transfer_config.is_kv_transfer_instance):
from collections import defaultdict
self.cache_config.transfer_handshake_metadata = defaultdict(dict)

def update_sizes_for_sequence_parallelism(self,
possible_sizes: list) -> list:
Expand Down
156 changes: 122 additions & 34 deletions vllm/distributed/kv_transfer/kv_connector/v1/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
KVConnectorBase_V1 Class for Distributed KV Cache & Hidden State
communication in vLLM v1
Expand All @@ -8,15 +7,9 @@
Scheduler-side: runs in the scheduler, binds metadata, which
is used by the worker-side to load/save KV cache.
get_num_new_matched_tokens() - get number of new tokens
that exist in the remote KV cache. Might be called multiple
times for a given request and should be side-effect free.
that exist in the remote KV cache
update_state_after_alloc() - update KVConnector state after
temporary buffer alloc by the CacheManager.
request_finished() - called when a request is finished, with
the computed kv cache blocks for the request.
Returns whether KV cache should be freed now or will be
freed asynchronously and optionally returns KV transfer
params.

Worker-side: runs in each worker, loads/saves KV cache to/from
the Connector based on the metadata.
Expand All @@ -25,16 +18,16 @@

save_kv_layer() - starts saving KV for layer i (maybe async)
wait_for_save() - blocks until all saves are done

get_finished() - called with ids of finished requests, returns
ids of requests that have completed async sending/recving.
"""

import enum
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Optional

import msgspec
import torch
from pydantic_core import core_schema

from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
Expand All @@ -49,6 +42,64 @@
logger = init_logger(__name__)


@dataclass
class KVTransferFinishedResult:
"""Result of KV transfer get_finished operation."""

finished_sending: set[str]
finished_recving: set[str]
pending_handshake: set[str]

def has_any_finished(self) -> bool:
"""Check if any requests finished or are pending."""
return bool(self.finished_sending or self.finished_recving
or self.pending_handshake)

def is_empty(self) -> bool:
"""Check if all sets are empty."""
return not self.has_any_finished()

def get_all_finished_req_ids(self) -> set[str]:
"""Get all request IDs that have finished (sending or receiving)."""
return self.finished_sending.union(self.finished_recving)

def merge(self,
other: 'KVTransferFinishedResult') -> 'KVTransferFinishedResult':
"""Merge with another result, combining all sets."""
return KVTransferFinishedResult(
finished_sending=self.finished_sending.union(
other.finished_sending),
finished_recving=self.finished_recving.union(
other.finished_recving),
pending_handshake=self.pending_handshake.union(
other.pending_handshake))

@classmethod
def empty(cls) -> 'KVTransferFinishedResult':
"""Create an empty result."""
return cls(finished_sending=set(),
finished_recving=set(),
pending_handshake=set())

@classmethod
def from_tuple(
cls, result_tuple: tuple[set[str], set[str], set[str]]
) -> 'KVTransferFinishedResult':
"""Create from the old tuple format for backward compatibility."""
finished_sending, finished_recving, pending_handshake = result_tuple
return cls(finished_sending=finished_sending,
finished_recving=finished_recving,
pending_handshake=pending_handshake)

def to_tuple(self) -> tuple[set[str], set[str], set[str]]:
"""Convert to the old tuple format for backward compatibility."""
return (
self.finished_sending,
self.finished_recving,
self.pending_handshake,
)


class KVConnectorRole(enum.Enum):
# Connector running in the scheduler process
SCHEDULER = 0
Expand All @@ -65,6 +116,39 @@ class KVConnectorMetadata:
pass


class KVConnectorHandshakeMetadata(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
# required for @cached_property.
dict=True):
"""
Metadata optionally used for out of band connector handshake between
P/D workers.
"""
connector_type: str = "base"

@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: Callable[[Any],
core_schema.CoreSchema]
) -> core_schema.CoreSchema:
"""bridge msgspec.Struct with pydantic for schema generation"""
return core_schema.no_info_after_validator_function(
cls, core_schema.dict_schema())


class KVConnectorTransferMetadata(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
dict=True):
"""
Wrapper for transfer handshake metadata sent between engine and utils.
"""
tensor_parallel_rank: int
data_parallel_rank: int
content: Optional[dict]


class KVConnectorBase_V1(ABC):

def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
Expand All @@ -74,6 +158,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
self._connector_metadata = KVConnectorMetadata()
self._vllm_config = vllm_config
self._role = role
self._handshake_metadata: Optional[KVConnectorHandshakeMetadata] = None

@property
def role(self) -> KVConnectorRole:
Expand Down Expand Up @@ -104,7 +189,7 @@ def clear_connector_metadata(self) -> None:
"""
self._connector_metadata = KVConnectorMetadata()

def _get_connector_metadata(self) -> KVConnectorMetadata:
def get_connector_metadata(self) -> KVConnectorMetadata:
"""Get the connector metadata.

This function should only be called inside the connector.
Expand Down Expand Up @@ -185,21 +270,37 @@ def wait_for_save(self):
"""
pass

def get_finished(
self, finished_req_ids: set[str]
) -> tuple[Optional[set[str]], Optional[set[str]]]:
def get_finished(self,
finished_req_ids: set[str]) -> KVTransferFinishedResult:
"""
Notifies worker-side connector ids of requests that have
finished generating tokens.

Returns:
ids of requests that have finished asynchronous transfer
(requests that previously returned True from request_finished()),
tuple of (sending/saving ids, recving/loading ids).
KVTransferFinishedResult containing sets of finished sending,
finished receiving, and pending handshake request IDs.
The finished saves/sends req ids must belong to a set provided in a
call to this method (this call or a prior one).
"""
return None, None
return KVTransferFinishedResult.empty()

def get_pending_handshake_req_ids(self) -> Optional[set[str]]:
"""
Get request IDs that are currently pending handshake completion.

Returns:
Set of request IDs waiting for handshake, or None if not applicable.
"""
return None

def get_handshake_metadata(self) -> Optional[KVConnectorHandshakeMetadata]:
"""
Get the handshake metadata for the connector.

Returns:
KVConnectorHandshakeMetadata: the handshake metadata.
"""
return self._handshake_metadata

# ==============================
# Scheduler-side methods
Expand All @@ -225,8 +326,7 @@ def get_num_new_matched_tokens(
- The number of tokens that can be loaded from the
external KV cache beyond what is already computed.
- `True` if external KV cache tokens will be loaded
asynchronously (between scheduler steps). Must be
'False' if the first element is 0.
asynchronously (between scheduler steps).
"""
pass

Expand All @@ -236,18 +336,6 @@ def update_state_after_alloc(self, request: "Request",
num_external_tokens: int):
"""
Update KVConnector state after block allocation.

If get_num_new_matched_tokens previously returned True for a
request, this function may be called twice for that same request -
first when blocks are allocated for the connector tokens to be
asynchronously loaded into, and second when any additional blocks
are allocated, after the load/transfer is complete.

Args:
request (Request): the request object.
blocks (KVCacheBlocks): the blocks allocated for the request.
num_external_tokens (int): the number of tokens that will be
loaded from the external KV cache.
"""
pass

Expand Down
36 changes: 27 additions & 9 deletions vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory)
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole,
KVTransferFinishedResult)
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.output import SchedulerOutput
Expand Down Expand Up @@ -102,21 +103,27 @@ def wait_for_save(self):
for c in self._connectors:
c.wait_for_save()

def get_finished(
self, finished_req_ids: set[str]
) -> tuple[Optional[set[str]], Optional[set[str]]]:
def get_finished(self,
finished_req_ids: set[str]) -> KVTransferFinishedResult:
finished_sending: set[str] = set()
finished_recving: set[str] = set()
pending_handshake: set[str] = set()

for c in self._connectors:
sending, recving = c.get_finished(finished_req_ids)
if not recving and not sending:
result = c.get_finished(finished_req_ids)
if result.is_empty():
continue

# Aggregate finished recving request ids.
finished_recving.update(recving or ())
finished_recving.update(result.finished_recving)

# Aggregate pending handshake request ids.
pending_handshake.update(result.pending_handshake)

# Aggregate finished sending request ids - only include
# once we've drained the "extra" count (for cases where
# more than one connector is async-saving the same request).
for req_id in sending or ():
for req_id in result.finished_sending:
extra_pending = self._extra_async_saves.get(req_id)
if extra_pending is None:
finished_sending.add(req_id)
Expand All @@ -127,7 +134,18 @@ def get_finished(
else:
self._extra_async_saves[req_id] = extra_pending - 1

return finished_sending or None, finished_recving or None
return KVTransferFinishedResult(finished_sending=finished_sending,
finished_recving=finished_recving,
pending_handshake=pending_handshake)

def get_pending_handshake_req_ids(self) -> Optional[set[str]]:
"""Get request IDs that are currently pending handshake completion."""
pending_handshake: set[str] = set()
for c in self._connectors:
connector_pending = c.get_pending_handshake_req_ids()
if connector_pending:
pending_handshake.update(connector_pending)
return pending_handshake or None

# ==============================
# Scheduler-side methods
Expand Down
Loading
Loading