Skip to content

[NIXL] vllm v0 nixl integration #16677

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 37 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3091,6 +3091,9 @@ class KVTransferConfig(BaseModel):
# The KV connector for vLLM to transmit KV caches between vLLM instances.
kv_connector: Optional[str] = None

# Whether to use NIXL prepped xfer for KV cache transfer.
use_prepped_xfer: bool = False

# The device used by kv connector to buffer the KV cache.
# Currently only support 'cuda'.
kv_buffer_device: Optional[str] = "cuda"
Expand All @@ -3100,7 +3103,7 @@ class KVTransferConfig(BaseModel):
kv_buffer_size: float = 1e9

# Whether this vLLM instance produces, consumes KV cache, or both. Choices
# are 'kv_producer', 'kv_consumer', and 'both'.
# are 'kv_producer', 'kv_consumer', and 'kv_both'.
kv_role: Optional[str] = None

# The rank of this vLLM instance in the KV cache transfer. Typical value:
Expand All @@ -3121,6 +3124,13 @@ class KVTransferConfig(BaseModel):
# any extra config that the connector may need
kv_connector_extra_config: dict[str, Any] = {}

# This does not need to be set by the user. It is set by the connector.
kv_producers_parallel_size: Optional[int] = None
kv_producers_tensor_parallel_size: Optional[int] = None
kv_producers_pipeline_parallel_size: Optional[int] = None
kv_consumers_tensor_parallel_size: Optional[int] = None
kv_consumers_pipeline_parallel_size: Optional[int] = None

def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
Expand Down Expand Up @@ -3155,16 +3165,29 @@ def model_post_init(self, __context: Any) -> None:
f"Supported roles are `kv_producer`, `kv_consumer`, "
f"and `kv_both`")

if self.kv_connector is not None and self.kv_role is None:
if self.kv_connector is not None and self.kv_connector != "DynamoNixlConnector" and self.kv_role is None:
raise ValueError("Please specify kv_disagg_role when kv_connector "
"is set, supported roles are `kv_producer`, "
"`kv_consumer`, and `kv_both`")

if self.use_prepped_xfer is False:
logger.warning("`use_prepped_xfer` parameter is deprecated. All transfers will be done using prepped xfer.")
self.use_prepped_xfer = True


@property
def is_kv_transfer_instance(self) -> bool:
return self.kv_connector is not None and \
self.kv_role in ["kv_producer", "kv_consumer", "kv_both"]

@property
def need_kv_parallel_group(self) -> bool:
# for those database-based connector, vLLM does not need to create
# parallel group, and in that case the kv parallel size will be 1.
if self.kv_connector == "DynamoNixlConnector":
return False
return self.kv_connector is not None and self.kv_parallel_size > 1

@property
def is_kv_producer(self) -> bool:
return self.kv_connector is not None and \
Expand All @@ -3178,6 +3201,18 @@ def is_kv_consumer(self) -> bool:
def get_from_extra_config(self, key, default) -> Any:
return self.kv_connector_extra_config.get(key, default)

@property
def tensor_parallel_multiplier(self) -> int:
return self.kv_consumers_tensor_parallel_size // self.kv_producers_tensor_parallel_size

@property
def kv_consumers_parallel_size(self) -> int:
return self.kv_parallel_size - self.kv_producers_parallel_size

@property
def kv_world_size(self) -> int:
return self.kv_producers_parallel_size + self.kv_consumers_parallel_size * self.tensor_parallel_multiplier


class CompilationLevel:
# constants for the levels of the compilation process
Expand Down
10 changes: 9 additions & 1 deletion vllm/core/block/cpu_gpu_block_allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
DeviceAwareBlockAllocator)
from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator
from vllm.core.block.prefix_caching_block import PrefixCachingBlockAllocator
from vllm.core.event_manager import KVCacheEventManager
from vllm.platforms import current_platform
from vllm.utils import Device

Expand All @@ -28,6 +29,7 @@ def create(
num_gpu_blocks: int,
num_cpu_blocks: int,
block_size: int,
event_manager: Optional[KVCacheEventManager] = None,
) -> DeviceAwareBlockAllocator:
"""Creates a CpuGpuBlockAllocator instance with the specified
configuration.
Expand Down Expand Up @@ -64,6 +66,7 @@ def create(
cpu_block_ids = block_ids[num_gpu_blocks:]

if allocator_type == "naive":
assert event_manager is None, "Event API not supported with naive allocator."
gpu_allocator: BlockAllocator = NaiveBlockAllocator(
create_block=NaiveBlock, # type: ignore
num_blocks=num_gpu_blocks,
Expand All @@ -82,23 +85,27 @@ def create(
num_blocks=num_gpu_blocks,
block_size=block_size,
block_ids=gpu_block_ids,
event_manager=event_manager,
)

cpu_allocator = PrefixCachingBlockAllocator(
num_blocks=num_cpu_blocks,
block_size=block_size,
block_ids=cpu_block_ids,
event_manager=event_manager,
)
else:
raise ValueError(f"Unknown allocator type {allocator_type=}")

return CpuGpuBlockAllocator(
cpu_block_allocator=cpu_allocator,
gpu_block_allocator=gpu_allocator,
event_manager=event_manager,
)

def __init__(self, cpu_block_allocator: BlockAllocator,
gpu_block_allocator: BlockAllocator):
gpu_block_allocator: BlockAllocator,
event_manager: Optional[KVCacheEventManager] = None,):
assert not (
cpu_block_allocator.all_block_ids
& gpu_block_allocator.all_block_ids
Expand All @@ -108,6 +115,7 @@ def __init__(self, cpu_block_allocator: BlockAllocator,
Device.CPU: cpu_block_allocator,
Device.GPU: gpu_block_allocator,
}
self.event_manager = event_manager

self._swap_mapping: Dict[int, int] = {}
self._null_block: Optional[Block] = None
Expand Down
9 changes: 5 additions & 4 deletions vllm/core/block/naive_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from collections import deque
from typing import Deque, FrozenSet, Iterable, List, Optional, Tuple, Union

import heapq
from vllm.core.block.common import (BlockPool, CopyOnWriteTracker, RefCounter,
get_all_blocks_recursively)
from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device
Expand Down Expand Up @@ -38,7 +38,7 @@ def __init__(
if block_ids is None:
block_ids = range(num_blocks)

self._free_block_indices: Deque[BlockId] = deque(block_ids)
self._free_block_indices: List[BlockId] = list(block_ids)
self._all_block_indices = frozenset(block_ids)
assert len(self._all_block_indices) == num_blocks

Expand Down Expand Up @@ -134,7 +134,8 @@ def _allocate_block_id(self) -> BlockId:
if not self._free_block_indices:
raise BlockAllocator.NoFreeBlocksError()

block_id = self._free_block_indices.popleft()
block_id = heapq.heappop(self._free_block_indices)
# TODO: figure out why sometime block_id is None
self._refcounter.incr(block_id)
return block_id

Expand All @@ -148,7 +149,7 @@ def _free_block_id(self, block: Union[Block, BlockId]) -> None:

refcount = self._refcounter.decr(block_id)
if refcount == 0:
self._free_block_indices.appendleft(block_id)
heapq.heappush(self._free_block_indices, block_id)

def free(self, block: Block, keep_block_object: bool = False) -> None:
# Release the physical block id
Expand Down
24 changes: 20 additions & 4 deletions vllm/core/block/prefix_caching_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from bisect import bisect_left
from os.path import commonprefix
from typing import (Callable, Dict, FrozenSet, Iterable, List, Optional, Set,
Tuple)
Tuple, TYPE_CHECKING)

from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker,
get_all_blocks_recursively)
Expand All @@ -23,6 +23,9 @@
# then we know this block hasn't been accessed yet.
_DEFAULT_LAST_ACCESSED_TIME = -1

if TYPE_CHECKING:
from vllm.core.event_manager import KVCacheEventManager

logger = init_logger(__name__)


Expand Down Expand Up @@ -80,6 +83,7 @@ def __init__(
block_size: int,
block_ids: Optional[Iterable[int]] = None,
eviction_policy: EvictionPolicy = EvictionPolicy.LRU,
event_manager: Optional["KVCacheEventManager"] = None,
):
if block_ids is None:
block_ids = range(num_blocks)
Expand Down Expand Up @@ -131,6 +135,9 @@ def __init__(

self.metric_data = CacheMetricData()

self.event_manager = event_manager

# Implements Block.Factory.
def _create_block(
self,
prev_block: Optional[Block],
Expand Down Expand Up @@ -337,6 +344,9 @@ def _maybe_allocate_evicted_block_id(self) -> Optional[BlockId]:
assert self._refcounter.get(_block_id) == 0
assert _block_id == block_id

if self.event_manager:
self.event_manager.enqueue_removed_event(content_hash_to_evict)

self._cached_blocks.pop(content_hash_to_evict)

self._refcounter.incr(block_id)
Expand Down Expand Up @@ -513,6 +523,10 @@ def promote_to_immutable_block(self, block: Block) -> BlockId:
# Mark this block as touched so that it can be marked as
# computed after the entire batch of sequences are scheduled.
self._touched_blocks.add(block.block_id)

if self.event_manager:
self.event_manager.enqueue_stored_event(block.prev_block, block)

return block.block_id

# Reuse the cached content hash
Expand Down Expand Up @@ -579,9 +593,11 @@ def mark_blocks_as_accessed(self, block_ids: List[int],

def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
# Mark all touched blocks as computed.
for block_id in self._touched_blocks:
self._block_tracker[block_id].computed = True
self._touched_blocks.clear()
for block_id in block_ids:
if block_id in self._touched_blocks:
logger.debug("Mark block as computed: %s", block_id)
self._block_tracker[block_id].computed = True
self._touched_blocks.remove(block_id)

def _track_block_id(self, block_id: Optional[BlockId],
computed: bool) -> None:
Expand Down
29 changes: 28 additions & 1 deletion vllm/core/block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker,
LastAccessBlocksTracker)
from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec
from vllm.core.event_manager import KVCacheEventManager
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.envs import (VLLM_KV_CAPI_PATH, VLLM_KV_COMPONENT, VLLM_KV_NAMESPACE,
VLLM_WORKER_ID)
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
from vllm.utils import Device

Expand Down Expand Up @@ -60,6 +63,7 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):

def __init__(
self,
model_name: str,
block_size: int,
num_gpu_blocks: int,
num_cpu_blocks: int,
Expand Down Expand Up @@ -91,11 +95,29 @@ def __init__(

self.watermark_blocks = int(watermark * num_gpu_blocks)

kv_event_manager_params = [
VLLM_WORKER_ID, VLLM_KV_CAPI_PATH, VLLM_KV_NAMESPACE,
VLLM_KV_COMPONENT
]
set_kv_event_manager_params = len(
[param for param in kv_event_manager_params if param is not None])

if set_kv_event_manager_params == len(kv_event_manager_params):
self.event_manager = KVCacheEventManager(
namespace=VLLM_KV_NAMESPACE,
component=VLLM_KV_COMPONENT,
worker_id=VLLM_WORKER_ID,
lib_path=VLLM_KV_CAPI_PATH,
kv_block_size=block_size)
else:
self.event_manager = None

self.block_allocator = CpuGpuBlockAllocator.create(
allocator_type="prefix_caching" if enable_caching else "naive",
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks,
block_size=block_size,
event_manager=self.event_manager,
)

self.block_tables: Dict[SeqId, BlockTable] = {}
Expand All @@ -108,7 +130,8 @@ def __init__(

def can_allocate(self,
seq_group: SequenceGroup,
num_lookahead_slots: int = 0) -> AllocStatus:
num_lookahead_slots: int = 0,
is_remote_decode: bool = False) -> AllocStatus:
# FIXME(woosuk): Here we assume that all sequences in the group share
# the same prompt. This may not be true for preempted sequences.

Expand All @@ -121,6 +144,10 @@ def can_allocate(self,
num_lookahead_slots=num_lookahead_slots,
)

# if remote decode, we need to allocate twice as many blocks for staging
if is_remote_decode:
num_required_blocks *= 2

if seq_group.is_encoder_decoder():
encoder_seq = seq_group.get_encoder_seq()
assert encoder_seq is not None
Expand Down
Loading