Skip to content

[V1] Large Block_size solution #21123

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions vllm/v1/core/block_hash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""Block hash types for KV cache prefix caching."""
from typing import Any, NamedTuple, Optional


class BlockHash(NamedTuple):
"""Hash value of a block (int), the token IDs in the block, and extra keys.
We keep a tuple of token IDs and extra keys to reduce the likelihood of
hash collisions when the hash value is the same. By using SHA256 however,
hash collisions are practically impossible.
"""
# Hash value of the block in an integer.
hash_value: int
# Token IDs in the block.
token_ids: tuple[int, ...]
# Extra keys for the block.
extra_keys: Optional[Any] = None


class BlockHashWithGroupId(NamedTuple):
"""Block hash with KV cache group ID for multi-group cache management."""
# The hash value for the contents (e.g., token_ids) of a block without
# group ID. The value is the same for blocks representing the same tokens
# but for different groups.
block_hash: BlockHash
# The KV cache group ID.
group_id: int

def get_hash_value(self) -> int:
"""Get the hash value of the block."""
return self.block_hash.hash_value
9 changes: 4 additions & 5 deletions vllm/v1/core/block_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@
from vllm.distributed.kv_events import (AllBlocksCleared, BlockRemoved,
BlockStored, KVCacheEvent)
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
FreeKVCacheBlockQueue, KVCacheBlock,
from vllm.v1.core.kv_cache_utils import (FreeKVCacheBlockQueue, KVCacheBlock,
generate_block_hash_extra_keys,
hash_block_tokens)
from vllm.v1.request import Request

from vllm.v1.core.block_hash import BlockHash, BlockHashWithGroupId
logger = init_logger(__name__)


Expand Down Expand Up @@ -115,8 +114,8 @@ def cache_full_blocks(
request: The request to cache the blocks.
blocks: All blocks in the request.
block_hashes: Block hashes of the blocks in the request. Note that
this list may be shorter than the blocks list. In this case the
missed block hash will be computed in this function.
this list may be shorter than the blocks list. In this case the
missed block hash will be computed in this function.
num_cached_blocks: The number of blocks that are already cached.
num_full_blocks: The number of blocks that are full and should
be cached after this function.
Expand Down
113 changes: 60 additions & 53 deletions vllm/v1/core/kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,51 +2,25 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""KV-Cache Utilities."""

import copy
import os
from collections import defaultdict, deque
from collections.abc import Iterable, Sequence
from dataclasses import dataclass
from typing import Any, Callable, NamedTuple, Optional
from typing import Any, Callable, Optional

from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils import GiB_bytes, cdiv, sha256_cbor_64bit
from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
FullAttentionSpec, KVCacheConfig,
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheSpec,
KVCacheTensor, SlidingWindowSpec)
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request

from vllm.v1.core.block_hash import BlockHash, BlockHashWithGroupId
logger = init_logger(__name__)


class BlockHash(NamedTuple):
"""Hash value of a block (int), the token IDs in the block, and extra keys.
We keep a tuple of token IDs and extra keys to reduce the likelihood of
hash collisions when the hash value is the same. By using SHA256 however,
hash collisions are practically impossible.
"""
# Hash value of the block in an integer.
hash_value: int
# Token IDs in the block.
token_ids: tuple[int, ...]
# Extra keys for the block.
extra_keys: Optional[Any] = None


class BlockHashWithGroupId(NamedTuple):
# The hash value for the contents (e.g., token_ids) of a block without group
# ID. The value is the same for blocks representing the same tokens but for
# different groups.
block_hash: BlockHash
# The KV cache group ID.
group_id: int

def get_hash_value(self) -> int:
return self.block_hash.hash_value


# The hash seed for the first block of any prefix block sequence.
#
# We use a random value to avoid hash collisions or PYTHONHASHSEED environment
Expand All @@ -55,7 +29,7 @@ def get_hash_value(self) -> int:
# a random seed if PYTHONHASHSEED is not set.
#
# The function `init_none_hash` initializes this variable globally.
NONE_HASH: int
NONE_HASH: int = 0 # Default value, will be overridden by init_none_hash


def init_none_hash(hash_fn: Callable):
Expand All @@ -77,8 +51,8 @@ class PrefixCachingMetrics:
"""Metrics for prefix caching with a hit rate of the max recent N requests.

Args:
max_recent_requests: The number of the max recent requests to aggregate.
Defaults to 1000.
max_recent_requests: The number of the max recent requests to
aggregate. Defaults to 1000.
"""

def __init__(self, max_recent_requests: int = 1000):
Expand Down Expand Up @@ -197,8 +171,8 @@ class FreeKVCacheBlockQueue:
manipulating the linked list. Instead, this class manipulates the
prev_free_block and next_free_block attributes of the given blocks.

The queue is ordered by block ID in the beginning. When a block is allocated
and then freed, it will be appended back with the eviction order:
The queue is ordered by block ID in the beginning. When a block is
allocated and then freed, it will be appended back with the eviction order:
1. The least recent used block is at the front (LRU).
2. If two blocks have the same last accessed time (allocated by the
same sequence), the one with more hash tokens (the tail of a block
Expand Down Expand Up @@ -747,7 +721,7 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
Returns:
The generated KVCacheConfig
"""

page_size = get_uniform_page_size(kv_cache_spec)
num_blocks = get_num_blocks(vllm_config, len(kv_cache_spec),
available_memory, page_size)
Expand All @@ -762,7 +736,7 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
KVCacheTensor(size=per_layer_size, shared_by=[layer_name])
for layer_name in kv_cache_spec
]

kv_cache_config = KVCacheConfig(
num_blocks=num_blocks,
kv_cache_tensors=kv_cache_tensors,
Expand Down Expand Up @@ -949,6 +923,49 @@ def _get_kv_cache_config_attention_free() -> KVCacheConfig:
return KVCacheConfig(num_blocks=1, kv_cache_tensors=[], kv_cache_groups=[])


def _get_kv_cache_config_optimal_block_size(
vllm_config: VllmConfig,
kv_cache_spec: dict[str, KVCacheSpec],
available_memory: int) -> KVCacheConfig:
"""Use optimal block size for hybrid models.

Args:
vllm_config: The vLLM configuration.
kv_cache_spec: KV cache specifications for each cache type.
available_memory: Available memory in bytes.

Returns:
KV cache configuration with optimal block size.
"""
try:
# Import here to avoid circular dependency
from vllm.v1.core.kv_cache_coordinator import (
HybridKVCacheCoordinator)

optimal_block_size = HybridKVCacheCoordinator.calculate_optimal_block_size(
kv_cache_spec)

# Update specs with optimal size.
updated_specs = {}
for name, spec in kv_cache_spec.items():
# The optimal block size is applied to all specs to ensure uniformity.
new_spec = copy.deepcopy(spec)
new_spec.block_size = optimal_block_size
updated_specs[name] = new_spec

# Use existing logic.
return _get_kv_cache_config_uniform_page_size(vllm_config, updated_specs,
available_memory)
except Exception as e:
logger.warning(
"Failed to calculate optimal block size: %s. "
"Falling back to uniform page size logic.",
e,
exc_info=True)
return _get_kv_cache_config_uniform_page_size(vllm_config, kv_cache_spec,
available_memory)


def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
"""
This function tries to convert the KV cache specs to one type if the model
Expand Down Expand Up @@ -977,11 +994,7 @@ def is_hybrid(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
isinstance(spec, FullAttentionSpec) for spec in kv_cache_spec.values())
has_sliding_window = any(
isinstance(spec, SlidingWindowSpec) for spec in kv_cache_spec.values())
has_chunked_local_attention = any(
isinstance(spec, ChunkedLocalAttentionSpec)
for spec in kv_cache_spec.values())
if has_full_attention and (has_sliding_window
or has_chunked_local_attention):
if has_full_attention and has_sliding_window:
for layer_name, spec in kv_cache_spec.items():
if isinstance(spec, SlidingWindowSpec):
kv_cache_spec[layer_name] = FullAttentionSpec(
Expand All @@ -992,15 +1005,6 @@ def is_hybrid(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
use_mla=spec.use_mla,
sliding_window=spec.sliding_window,
)
elif isinstance(spec, ChunkedLocalAttentionSpec):
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=spec.block_size,
num_kv_heads=spec.num_kv_heads,
head_size=spec.head_size,
dtype=spec.dtype,
use_mla=spec.use_mla,
attention_chunk_size=spec.attention_chunk_size,
)

if is_hybrid(kv_cache_spec):
raise ValueError("Hybrid KV cache manager is disabled but failed to "
Expand All @@ -1024,6 +1028,7 @@ def get_kv_cache_config(
The generated KVCacheConfigs
"""
check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory)

if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager:
unify_hybrid_kv_cache_specs(kv_cache_spec)

Expand All @@ -1046,8 +1051,10 @@ def get_kv_cache_config(
kv_cache_spec,
available_memory)

raise NotImplementedError

else:
return _get_kv_cache_config_optimal_block_size(vllm_config,
kv_cache_spec,
available_memory)

def unify_kv_cache_configs(kv_cache_configs: list[KVCacheConfig]):
"""
Expand Down