diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 38de00625e3..de72e60434a 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -171,6 +171,35 @@ def find_longest_cache_hit( pass +class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator): + """ + KV cache coordinator to use if prefix caching is disabled or unsupported. + In contrast to UnitaryKVCacheCoordinator and HybridKVCacheCoordinator, + supports arbitrary numbers of KV cache groups (including 0 groups). + Does not implement any features related to prefix caching. + """ + + def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, + use_eagle: bool, caching_hash_fn: Callable, + enable_kv_cache_events: bool): + super().__init__(kv_cache_config, max_model_len, use_eagle, False, + caching_hash_fn, enable_kv_cache_events) + self.num_single_type_manager = len(self.single_type_managers) + + def get_num_common_prefix_blocks(self, request_id: str, + num_running_requests: int) -> list[int]: + return [0] * self.num_single_type_manager + + def find_longest_cache_hit( + self, + block_hashes: list[BlockHash], + max_cache_hit_length: int, + ) -> tuple[tuple[list[KVCacheBlock], ...], int]: + blocks: tuple[list[KVCacheBlock], ...] = tuple( + [] for _ in range(self.num_single_type_manager)) + return blocks, 0 + + class UnitaryKVCacheCoordinator(KVCacheCoordinator): """ KV cache coordinator for models with only one KV cache group. This is the @@ -359,6 +388,10 @@ def get_kv_cache_coordinator( kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool, enable_caching: bool, caching_hash_fn: Callable, enable_kv_cache_events: bool) -> KVCacheCoordinator: + if not enable_caching: + return KVCacheCoordinatorNoPrefixCache(kv_cache_config, max_model_len, + use_eagle, caching_hash_fn, + enable_kv_cache_events) if len(kv_cache_config.kv_cache_groups) == 1: return UnitaryKVCacheCoordinator(kv_cache_config, max_model_len, use_eagle, enable_caching,