From e5d630946fb49895454870733862f0ce4afe6185 Mon Sep 17 00:00:00 2001 From: nopperl <54780682+nopperl@users.noreply.github.com> Date: Wed, 9 Jul 2025 02:15:09 +0000 Subject: [PATCH 1/4] skip strict HybridKVCacheCoordinator verification if prefix caching is disabled Signed-off-by: nopperl <54780682+nopperl@users.noreply.github.com> --- vllm/v1/core/kv_cache_coordinator.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 38de00625e3..6a186dfd15e 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -221,7 +221,8 @@ def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, super().__init__(kv_cache_config, max_model_len, use_eagle, enable_caching, caching_hash_fn, enable_kv_cache_events) - self.verify_and_split_kv_cache_groups() + if enable_caching: + self.verify_and_split_kv_cache_groups() def verify_and_split_kv_cache_groups(self) -> None: """ @@ -307,6 +308,9 @@ def find_longest_cache_hit( - A list of the cache hit blocks for each single type manager. - The number of tokens of the longest cache hit. """ + assert self.enable_caching, ( + "find_longest_cache_hit can't be used if prefix caching is disabled" + ) # First, find the longest cache hit for full attention. hit_blocks_full_attn = ( self.full_attention_manager_cls.find_longest_cache_hit( From 6323189883eb306093078ee55661ee83e32405e0 Mon Sep 17 00:00:00 2001 From: nopperl <54780682+nopperl@users.noreply.github.com> Date: Fri, 11 Jul 2025 16:14:10 +0900 Subject: [PATCH 2/4] Revert "skip strict HybridKVCacheCoordinator verification if prefix caching is disabled" This reverts commit e5d630946fb49895454870733862f0ce4afe6185. Signed-off-by: nopperl <54780682+nopperl@users.noreply.github.com> --- vllm/v1/core/kv_cache_coordinator.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 6a186dfd15e..38de00625e3 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -221,8 +221,7 @@ def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, super().__init__(kv_cache_config, max_model_len, use_eagle, enable_caching, caching_hash_fn, enable_kv_cache_events) - if enable_caching: - self.verify_and_split_kv_cache_groups() + self.verify_and_split_kv_cache_groups() def verify_and_split_kv_cache_groups(self) -> None: """ @@ -308,9 +307,6 @@ def find_longest_cache_hit( - A list of the cache hit blocks for each single type manager. - The number of tokens of the longest cache hit. """ - assert self.enable_caching, ( - "find_longest_cache_hit can't be used if prefix caching is disabled" - ) # First, find the longest cache hit for full attention. hit_blocks_full_attn = ( self.full_attention_manager_cls.find_longest_cache_hit( From 76e0c05443bd81543f13a17c964e0c3718652ea1 Mon Sep 17 00:00:00 2001 From: nopperl <54780682+nopperl@users.noreply.github.com> Date: Fri, 11 Jul 2025 16:27:48 +0900 Subject: [PATCH 3/4] introduce kv cache coordinator for settings where prefix caching is not used Signed-off-by: nopperl <54780682+nopperl@users.noreply.github.com> --- vllm/v1/core/kv_cache_coordinator.py | 32 ++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 38de00625e3..c05690031ee 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -171,6 +171,34 @@ def find_longest_cache_hit( pass +class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator): + """ + KV cache coordinator to use if prefix caching is disabled or unsupported. + In contrast to UnitaryKVCacheCoordinator and HybridKVCacheCoordinator, + supports arbitrary numbers of KV cache groups (including 0 groups). + Does not implement any features related to prefix caching. + """ + + def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, + use_eagle: bool, caching_hash_fn: Callable, + enable_kv_cache_events: bool): + super().__init__(kv_cache_config, max_model_len, use_eagle, False, + caching_hash_fn, enable_kv_cache_events) + + def get_num_common_prefix_blocks(self, request_id: str, + num_running_requests: int) -> list[int]: + return [0] * len(self.single_type_managers) + + def find_longest_cache_hit( + self, + block_hashes: list[BlockHash], + max_cache_hit_length: int, + ) -> tuple[tuple[list[KVCacheBlock], ...], int]: + blocks: tuple[list[KVCacheBlock], ...] = tuple( + [] for _ in range(len(self.single_type_managers))) + return blocks, 0 + + class UnitaryKVCacheCoordinator(KVCacheCoordinator): """ KV cache coordinator for models with only one KV cache group. This is the @@ -359,6 +387,10 @@ def get_kv_cache_coordinator( kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool, enable_caching: bool, caching_hash_fn: Callable, enable_kv_cache_events: bool) -> KVCacheCoordinator: + if not enable_caching: + return KVCacheCoordinatorNoPrefixCache(kv_cache_config, max_model_len, + use_eagle, caching_hash_fn, + enable_kv_cache_events) if len(kv_cache_config.kv_cache_groups) == 1: return UnitaryKVCacheCoordinator(kv_cache_config, max_model_len, use_eagle, enable_caching, From 8d1aa20074fbaacdd664f8251253407de2c4f866 Mon Sep 17 00:00:00 2001 From: nopperl <54780682+nopperl@users.noreply.github.com> Date: Fri, 11 Jul 2025 19:11:24 +0900 Subject: [PATCH 4/4] cache number of single type managers Signed-off-by: nopperl <54780682+nopperl@users.noreply.github.com> --- vllm/v1/core/kv_cache_coordinator.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index c05690031ee..de72e60434a 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -184,10 +184,11 @@ def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, enable_kv_cache_events: bool): super().__init__(kv_cache_config, max_model_len, use_eagle, False, caching_hash_fn, enable_kv_cache_events) + self.num_single_type_manager = len(self.single_type_managers) def get_num_common_prefix_blocks(self, request_id: str, num_running_requests: int) -> list[int]: - return [0] * len(self.single_type_managers) + return [0] * self.num_single_type_manager def find_longest_cache_hit( self, @@ -195,7 +196,7 @@ def find_longest_cache_hit( max_cache_hit_length: int, ) -> tuple[tuple[list[KVCacheBlock], ...], int]: blocks: tuple[list[KVCacheBlock], ...] = tuple( - [] for _ in range(len(self.single_type_managers))) + [] for _ in range(self.num_single_type_manager)) return blocks, 0