Skip to content

Commit 1121cf6

Browse files
fixed large block_size in Hybrid Attn Models
Added calculate_optimal_block_size() to compute minimal block size using an aggregate constraint across all attention layers. And Updated get_kv_cache_config() to invoke the new optimization for heterogeneous (SSM+attention) models while leaving uniform models unchanged. Signed-off-by: WorldExplored srreyansh.sethi@gmail.com Signed-off-by: nadathurv work.vnadathur@gmail.com Co-Authored-By: Srreyansh Sethi <107075589+WorldExplored@users.noreply.github.com>
1 parent f148c44 commit 1121cf6

File tree

2 files changed

+46
-9
lines changed

2 files changed

+46
-9
lines changed

vllm/v1/core/kv_cache_coordinator.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,20 @@ def find_longest_cache_hit(
383383
hit_blocks = hit_blocks_other_attn + hit_blocks_full_attn
384384
return hit_blocks, hit_length
385385

386+
def calculate_optimal_block_size(kv_cache_spec: dict[str, KVCacheSpec]) -> int:
387+
"""Calculate optimal block_size using aggregate constraint."""
388+
attention_specs = [s for s in kv_cache_spec.values() if isinstance(s, AttentionSpec)]
389+
mamba_specs = [s for s in kv_cache_spec.values() if isinstance(s, MambaSpec)]
390+
391+
if not (attention_specs and mamba_specs):
392+
return attention_specs[0].block_size if attention_specs else 16
393+
394+
max_mamba_state = max(s.state_size_bytes for s in mamba_specs)
395+
num_attention_layers = len(attention_specs)
396+
min_per_token_bytes = min(s.page_size_bytes / s.block_size for s in attention_specs)
397+
398+
required = max_mamba_state / (min_per_token_bytes * num_attention_layers)
399+
return max(16, int(math.ceil(required / 16) * 16)) # Align to 16
386400

387401
def get_kv_cache_coordinator(
388402
kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool,

vllm/v1/core/kv_cache_utils.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
"""KV-Cache Utilities."""
44

5+
import copy
56
import os
67
from collections import defaultdict, deque
78
from collections.abc import Iterable, Sequence
@@ -16,6 +17,7 @@
1617
KVCacheTensor, SlidingWindowSpec)
1718
from vllm.v1.metrics.stats import PrefixCacheStats
1819
from vllm.v1.request import Request
20+
from vllm.v1.core.kv_cache_coordinator import calculate_optimal_block_size
1921

2022
logger = init_logger(__name__)
2123

@@ -35,9 +37,9 @@ class BlockHash(NamedTuple):
3537

3638

3739
class BlockHashWithGroupId(NamedTuple):
38-
# The hash value for the contents (e.g., token_ids) of a block without group
39-
# ID. The value is the same for blocks representing the same tokens but for
40-
# different groups.
40+
# The hash value for the contents (e.g., token_ids) of a block without
41+
# group ID. The value is the same for blocks representing the same tokens
42+
# but for different groups.
4143
block_hash: BlockHash
4244
# The KV cache group ID.
4345
group_id: int
@@ -54,7 +56,7 @@ def get_hash_value(self) -> int:
5456
# a random seed if PYTHONHASHSEED is not set.
5557
#
5658
# The function `init_none_hash` initializes this variable globally.
57-
NONE_HASH: int
59+
NONE_HASH: int = 0 # Default value, will be overridden by init_none_hash
5860

5961

6062
def init_none_hash(hash_fn: Callable):
@@ -76,8 +78,8 @@ class PrefixCachingMetrics:
7678
"""Metrics for prefix caching with a hit rate of the max recent N requests.
7779
7880
Args:
79-
max_recent_requests: The number of the max recent requests to aggregate.
80-
Defaults to 1000.
81+
max_recent_requests: The number of the max recent requests to
82+
aggregate. Defaults to 1000.
8183
"""
8284

8385
def __init__(self, max_recent_requests: int = 1000):
@@ -196,8 +198,8 @@ class FreeKVCacheBlockQueue:
196198
manipulating the linked list. Instead, this class manipulates the
197199
prev_free_block and next_free_block attributes of the given blocks.
198200
199-
The queue is ordered by block ID in the beginning. When a block is allocated
200-
and then freed, it will be appended back with the eviction order:
201+
The queue is ordered by block ID in the beginning. When a block is
202+
allocated and then freed, it will be appended back with the eviction order:
201203
1. The least recent used block is at the front (LRU).
202204
2. If two blocks have the same last accessed time (allocated by the
203205
same sequence), the one with more hash tokens (the tail of a block
@@ -891,6 +893,24 @@ def _get_kv_cache_config_uniform_page_size(
891893
return kv_cache_config
892894

893895

896+
def _get_kv_cache_config_optimal_block_size(vllm_config, kv_cache_spec, available_memory):
897+
"""Use optimal block size for hybrid models."""
898+
optimal_block_size = calculate_optimal_block_size(kv_cache_spec)
899+
900+
# Update specs with optimal size
901+
updated_specs = {}
902+
for name, spec in kv_cache_spec.items():
903+
if hasattr(spec, 'block_size'): # AttentionSpec
904+
new_spec = copy.deepcopy(spec)
905+
new_spec.block_size = optimal_block_size
906+
updated_specs[name] = new_spec
907+
else:
908+
updated_specs[name] = spec
909+
910+
# Use existing logic
911+
return _get_kv_cache_config_uniform_page_size(vllm_config, updated_specs, available_memory)
912+
913+
894914
def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
895915
"""
896916
This function tries to convert the KV cache specs to one type if the model
@@ -973,7 +993,10 @@ def get_kv_cache_config(
973993
available_memory)
974994

975995
raise NotImplementedError
976-
996+
else:
997+
return _get_kv_cache_config_optimal_block_size(vllm_config,
998+
kv_cache_spec,
999+
available_memory)
9771000

9781001
def unify_kv_cache_configs(kv_cache_configs: list[KVCacheConfig]):
9791002
"""

0 commit comments

Comments
 (0)