diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index de72e60434a..c144141959f 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -383,6 +383,20 @@ def find_longest_cache_hit( hit_blocks = hit_blocks_other_attn + hit_blocks_full_attn return hit_blocks, hit_length + def calculate_optimal_block_size(kv_cache_spec: dict[str, KVCacheSpec]) -> int: + """Calculate optimal block_size using aggregate constraint.""" + attention_specs = [s for s in kv_cache_spec.values() if isinstance(s, AttentionSpec)] + mamba_specs = [s for s in kv_cache_spec.values() if isinstance(s, MambaSpec)] + + if not (attention_specs and mamba_specs): + return attention_specs[0].block_size if attention_specs else 16 + + max_mamba_state = max(s.page_size_bytes for s in mamba_specs) + num_attention_layers = len(attention_specs) + min_per_token_bytes = min(s.page_size_bytes / s.block_size for s in attention_specs) + + required = max_mamba_state / (min_per_token_bytes * num_attention_layers) + return max(16, int(math.ceil(required / 16) * 16)) # Align to 16 def get_kv_cache_coordinator( kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool, diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 6067a127e97..d5af33f2ce8 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """KV-Cache Utilities.""" +import copy import os from collections import defaultdict, deque from collections.abc import Iterable, Sequence @@ -16,6 +17,7 @@ KVCacheTensor, SlidingWindowSpec) from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request +from vllm.v1.core.kv_cache_coordinator import calculate_optimal_block_size logger = init_logger(__name__) @@ -35,9 +37,9 @@ class BlockHash(NamedTuple): class BlockHashWithGroupId(NamedTuple): - # The hash value for the contents (e.g., token_ids) of a block without group - # ID. The value is the same for blocks representing the same tokens but for - # different groups. + # The hash value for the contents (e.g., token_ids) of a block without + # group ID. The value is the same for blocks representing the same tokens + # but for different groups. block_hash: BlockHash # The KV cache group ID. group_id: int @@ -54,7 +56,7 @@ def get_hash_value(self) -> int: # a random seed if PYTHONHASHSEED is not set. # # The function `init_none_hash` initializes this variable globally. -NONE_HASH: int +NONE_HASH: int = 0 # Default value, will be overridden by init_none_hash def init_none_hash(hash_fn: Callable): @@ -76,8 +78,8 @@ class PrefixCachingMetrics: """Metrics for prefix caching with a hit rate of the max recent N requests. Args: - max_recent_requests: The number of the max recent requests to aggregate. - Defaults to 1000. + max_recent_requests: The number of the max recent requests to + aggregate. Defaults to 1000. """ def __init__(self, max_recent_requests: int = 1000): @@ -196,8 +198,8 @@ class FreeKVCacheBlockQueue: manipulating the linked list. Instead, this class manipulates the prev_free_block and next_free_block attributes of the given blocks. - The queue is ordered by block ID in the beginning. When a block is allocated - and then freed, it will be appended back with the eviction order: + The queue is ordered by block ID in the beginning. When a block is + allocated and then freed, it will be appended back with the eviction order: 1. The least recent used block is at the front (LRU). 2. If two blocks have the same last accessed time (allocated by the same sequence), the one with more hash tokens (the tail of a block @@ -902,10 +904,26 @@ def _get_kv_cache_config_uniform_page_size( return kv_cache_config +def _get_kv_cache_config_optimal_block_size(vllm_config, kv_cache_spec, available_memory): + """Use optimal block size for hybrid models.""" + optimal_block_size = calculate_optimal_block_size(kv_cache_spec) + + # Update specs with optimal size + updated_specs = {} + for name, spec in kv_cache_spec.items(): + # The optimal block size is applied to all specs to ensure uniformity. + new_spec = copy.deepcopy(spec) + new_spec.block_size = optimal_block_size + updated_specs[name] = new_spec + + # Use existing logic + return _get_kv_cache_config_uniform_page_size(vllm_config, updated_specs, available_memory) + def _get_kv_cache_config_attention_free() -> KVCacheConfig: return KVCacheConfig(num_blocks=1, kv_cache_tensors=[], kv_cache_groups=[]) + def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): """ This function tries to convert the KV cache specs to one type if the model @@ -991,8 +1009,10 @@ def get_kv_cache_config( kv_cache_spec, available_memory) - raise NotImplementedError - + else: + return _get_kv_cache_config_optimal_block_size(vllm_config, + kv_cache_spec, + available_memory) def unify_kv_cache_configs(kv_cache_configs: list[KVCacheConfig]): """