Skip to content

Commit f38a044

Browse files
committed
[Misc] Add BlockHash and BlockHashWithGroupId types
Introduces BlockHash and BlockHashWithGroupId NamedTuple classes for KV cache prefix caching, including support for token IDs, extra keys, and group IDs to facilitate multi‑group cache management and reduce hash collisions. Signed-off-by: WorldExplored <srreyansh.sethi@gmail.com>
1 parent dd572c0 commit f38a044

File tree

3 files changed

+95
-57
lines changed

3 files changed

+95
-57
lines changed

vllm/v1/core/block_hash.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
"""Block hash types for KV cache prefix caching."""
2+
from typing import Any, NamedTuple, Optional
3+
4+
5+
class BlockHash(NamedTuple):
6+
"""Hash value of a block (int), the token IDs in the block, and extra keys.
7+
8+
We keep a tuple of token IDs and extra keys to reduce the likelihood of
9+
hash collisions when the hash value is the same. By using SHA256 however,
10+
hash collisions are practically impossible.
11+
"""
12+
# Hash value of the block in an integer.
13+
hash_value: int
14+
# Token IDs in the block.
15+
token_ids: tuple[int, ...]
16+
# Extra keys for the block.
17+
extra_keys: Optional[Any] = None
18+
19+
20+
class BlockHashWithGroupId(NamedTuple):
21+
"""Block hash with KV cache group ID for multi-group cache management."""
22+
# The hash value for the contents (e.g., token_ids) of a block without
23+
# group ID. The value is the same for blocks representing the same tokens
24+
# but for different groups.
25+
block_hash: BlockHash
26+
# The KV cache group ID.
27+
group_id: int
28+
29+
def get_hash_value(self) -> int:
30+
"""Get the hash value of the block."""
31+
return self.block_hash.hash_value

vllm/v1/core/block_pool.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,11 @@
77
from vllm.distributed.kv_events import (AllBlocksCleared, BlockRemoved,
88
BlockStored, KVCacheEvent)
99
from vllm.logger import init_logger
10-
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
11-
FreeKVCacheBlockQueue, KVCacheBlock,
10+
from vllm.v1.core.kv_cache_utils import (FreeKVCacheBlockQueue, KVCacheBlock,
1211
generate_block_hash_extra_keys,
1312
hash_block_tokens)
1413
from vllm.v1.request import Request
15-
14+
from vllm.v1.core.block_hash import BlockHash, BlockHashWithGroupId
1615
logger = init_logger(__name__)
1716

1817

@@ -115,8 +114,8 @@ def cache_full_blocks(
115114
request: The request to cache the blocks.
116115
blocks: All blocks in the request.
117116
block_hashes: Block hashes of the blocks in the request. Note that
118-
this list may be shorter than the blocks list. In this case the
119-
missed block hash will be computed in this function.
117+
this list may be shorter than the blocks list. In this case the
118+
missed block hash will be computed in this function.
120119
num_cached_blocks: The number of blocks that are already cached.
121120
num_full_blocks: The number of blocks that are full and should
122121
be cached after this function.

vllm/v1/core/kv_cache_utils.py

Lines changed: 60 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
"""KV-Cache Utilities."""
44

5+
import copy
56
import os
67
from collections import defaultdict, deque
78
from collections.abc import Iterable, Sequence
@@ -11,42 +12,16 @@
1112
from vllm.config import VllmConfig
1213
from vllm.logger import init_logger
1314
from vllm.utils import GiB_bytes, cdiv, sha256_cbor_64bit
14-
from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
15-
FullAttentionSpec, KVCacheConfig,
15+
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
1616
KVCacheGroupSpec, KVCacheSpec,
1717
KVCacheTensor, SlidingWindowSpec)
1818
from vllm.v1.metrics.stats import PrefixCacheStats
1919
from vllm.v1.request import Request
20-
20+
from vllm.v1.core.kv_cache_coordinator import HybridKVCacheCoordinator
21+
from vllm.v1.core.block_hash import BlockHash, BlockHashWithGroupId
2122
logger = init_logger(__name__)
2223

2324

24-
class BlockHash(NamedTuple):
25-
"""Hash value of a block (int), the token IDs in the block, and extra keys.
26-
We keep a tuple of token IDs and extra keys to reduce the likelihood of
27-
hash collisions when the hash value is the same. By using SHA256 however,
28-
hash collisions are practically impossible.
29-
"""
30-
# Hash value of the block in an integer.
31-
hash_value: int
32-
# Token IDs in the block.
33-
token_ids: tuple[int, ...]
34-
# Extra keys for the block.
35-
extra_keys: Optional[Any] = None
36-
37-
38-
class BlockHashWithGroupId(NamedTuple):
39-
# The hash value for the contents (e.g., token_ids) of a block without group
40-
# ID. The value is the same for blocks representing the same tokens but for
41-
# different groups.
42-
block_hash: BlockHash
43-
# The KV cache group ID.
44-
group_id: int
45-
46-
def get_hash_value(self) -> int:
47-
return self.block_hash.hash_value
48-
49-
5025
# The hash seed for the first block of any prefix block sequence.
5126
#
5227
# We use a random value to avoid hash collisions or PYTHONHASHSEED environment
@@ -55,7 +30,7 @@ def get_hash_value(self) -> int:
5530
# a random seed if PYTHONHASHSEED is not set.
5631
#
5732
# The function `init_none_hash` initializes this variable globally.
58-
NONE_HASH: int
33+
NONE_HASH: int = 0 # Default value, will be overridden by init_none_hash
5934

6035

6136
def init_none_hash(hash_fn: Callable):
@@ -77,8 +52,8 @@ class PrefixCachingMetrics:
7752
"""Metrics for prefix caching with a hit rate of the max recent N requests.
7853
7954
Args:
80-
max_recent_requests: The number of the max recent requests to aggregate.
81-
Defaults to 1000.
55+
max_recent_requests: The number of the max recent requests to
56+
aggregate. Defaults to 1000.
8257
"""
8358

8459
def __init__(self, max_recent_requests: int = 1000):
@@ -197,8 +172,8 @@ class FreeKVCacheBlockQueue:
197172
manipulating the linked list. Instead, this class manipulates the
198173
prev_free_block and next_free_block attributes of the given blocks.
199174
200-
The queue is ordered by block ID in the beginning. When a block is allocated
201-
and then freed, it will be appended back with the eviction order:
175+
The queue is ordered by block ID in the beginning. When a block is
176+
allocated and then freed, it will be appended back with the eviction order:
202177
1. The least recent used block is at the front (LRU).
203178
2. If two blocks have the same last accessed time (allocated by the
204179
same sequence), the one with more hash tokens (the tail of a block
@@ -747,7 +722,7 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
747722
Returns:
748723
The generated KVCacheConfig
749724
"""
750-
725+
751726
page_size = get_uniform_page_size(kv_cache_spec)
752727
num_blocks = get_num_blocks(vllm_config, len(kv_cache_spec),
753728
available_memory, page_size)
@@ -762,7 +737,7 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
762737
KVCacheTensor(size=per_layer_size, shared_by=[layer_name])
763738
for layer_name in kv_cache_spec
764739
]
765-
740+
766741
kv_cache_config = KVCacheConfig(
767742
num_blocks=num_blocks,
768743
kv_cache_tensors=kv_cache_tensors,
@@ -949,6 +924,49 @@ def _get_kv_cache_config_attention_free() -> KVCacheConfig:
949924
return KVCacheConfig(num_blocks=1, kv_cache_tensors=[], kv_cache_groups=[])
950925

951926

927+
def _get_kv_cache_config_optimal_block_size(
928+
vllm_config: VllmConfig,
929+
kv_cache_spec: dict[str, KVCacheSpec],
930+
available_memory: int) -> KVCacheConfig:
931+
"""Use optimal block size for hybrid models.
932+
933+
Args:
934+
vllm_config: The vLLM configuration.
935+
kv_cache_spec: KV cache specifications for each cache type.
936+
available_memory: Available memory in bytes.
937+
938+
Returns:
939+
KV cache configuration with optimal block size.
940+
"""
941+
try:
942+
# Import here to avoid circular dependency
943+
from vllm.v1.core.kv_cache_coordinator import (
944+
HybridKVCacheCoordinator)
945+
946+
optimal_block_size = HybridKVCacheCoordinator.calculate_optimal_block_size(
947+
kv_cache_spec)
948+
949+
# Update specs with optimal size.
950+
updated_specs = {}
951+
for name, spec in kv_cache_spec.items():
952+
# The optimal block size is applied to all specs to ensure uniformity.
953+
new_spec = copy.deepcopy(spec)
954+
new_spec.block_size = optimal_block_size
955+
updated_specs[name] = new_spec
956+
957+
# Use existing logic.
958+
return _get_kv_cache_config_uniform_page_size(vllm_config, updated_specs,
959+
available_memory)
960+
except Exception as e:
961+
logger.warning(
962+
"Failed to calculate optimal block size: %s. "
963+
"Falling back to uniform page size logic.",
964+
e,
965+
exc_info=True)
966+
return _get_kv_cache_config_uniform_page_size(vllm_config, kv_cache_spec,
967+
available_memory)
968+
969+
952970
def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
953971
"""
954972
This function tries to convert the KV cache specs to one type if the model
@@ -977,11 +995,7 @@ def is_hybrid(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
977995
isinstance(spec, FullAttentionSpec) for spec in kv_cache_spec.values())
978996
has_sliding_window = any(
979997
isinstance(spec, SlidingWindowSpec) for spec in kv_cache_spec.values())
980-
has_chunked_local_attention = any(
981-
isinstance(spec, ChunkedLocalAttentionSpec)
982-
for spec in kv_cache_spec.values())
983-
if has_full_attention and (has_sliding_window
984-
or has_chunked_local_attention):
998+
if has_full_attention and has_sliding_window:
985999
for layer_name, spec in kv_cache_spec.items():
9861000
if isinstance(spec, SlidingWindowSpec):
9871001
kv_cache_spec[layer_name] = FullAttentionSpec(
@@ -992,15 +1006,6 @@ def is_hybrid(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
9921006
use_mla=spec.use_mla,
9931007
sliding_window=spec.sliding_window,
9941008
)
995-
elif isinstance(spec, ChunkedLocalAttentionSpec):
996-
kv_cache_spec[layer_name] = FullAttentionSpec(
997-
block_size=spec.block_size,
998-
num_kv_heads=spec.num_kv_heads,
999-
head_size=spec.head_size,
1000-
dtype=spec.dtype,
1001-
use_mla=spec.use_mla,
1002-
attention_chunk_size=spec.attention_chunk_size,
1003-
)
10041009

10051010
if is_hybrid(kv_cache_spec):
10061011
raise ValueError("Hybrid KV cache manager is disabled but failed to "
@@ -1024,6 +1029,7 @@ def get_kv_cache_config(
10241029
The generated KVCacheConfigs
10251030
"""
10261031
check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory)
1032+
10271033
if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager:
10281034
unify_hybrid_kv_cache_specs(kv_cache_spec)
10291035

@@ -1046,8 +1052,10 @@ def get_kv_cache_config(
10461052
kv_cache_spec,
10471053
available_memory)
10481054

1049-
raise NotImplementedError
1050-
1055+
else:
1056+
return _get_kv_cache_config_optimal_block_size(vllm_config,
1057+
kv_cache_spec,
1058+
available_memory)
10511059

10521060
def unify_kv_cache_configs(kv_cache_configs: list[KVCacheConfig]):
10531061
"""

0 commit comments

Comments
 (0)