Skip to content

Commit cdd10f1

Browse files
nopperlb8zhong
authored andcommitted
[V1] Hybrid allocator without prefix caching (vllm-project#20661)
Signed-off-by: nopperl <54780682+nopperl@users.noreply.github.com> Signed-off-by: Brayden Zhong <b8zhong@uwaterloo.ca>
1 parent aecccab commit cdd10f1

File tree

1 file changed

+33
-0
lines changed

1 file changed

+33
-0
lines changed

vllm/v1/core/kv_cache_coordinator.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,35 @@ def find_longest_cache_hit(
171171
pass
172172

173173

174+
class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
175+
"""
176+
KV cache coordinator to use if prefix caching is disabled or unsupported.
177+
In contrast to UnitaryKVCacheCoordinator and HybridKVCacheCoordinator,
178+
supports arbitrary numbers of KV cache groups (including 0 groups).
179+
Does not implement any features related to prefix caching.
180+
"""
181+
182+
def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
183+
use_eagle: bool, caching_hash_fn: Callable,
184+
enable_kv_cache_events: bool):
185+
super().__init__(kv_cache_config, max_model_len, use_eagle, False,
186+
caching_hash_fn, enable_kv_cache_events)
187+
self.num_single_type_manager = len(self.single_type_managers)
188+
189+
def get_num_common_prefix_blocks(self, request_id: str,
190+
num_running_requests: int) -> list[int]:
191+
return [0] * self.num_single_type_manager
192+
193+
def find_longest_cache_hit(
194+
self,
195+
block_hashes: list[BlockHash],
196+
max_cache_hit_length: int,
197+
) -> tuple[tuple[list[KVCacheBlock], ...], int]:
198+
blocks: tuple[list[KVCacheBlock], ...] = tuple(
199+
[] for _ in range(self.num_single_type_manager))
200+
return blocks, 0
201+
202+
174203
class UnitaryKVCacheCoordinator(KVCacheCoordinator):
175204
"""
176205
KV cache coordinator for models with only one KV cache group. This is the
@@ -359,6 +388,10 @@ def get_kv_cache_coordinator(
359388
kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool,
360389
enable_caching: bool, caching_hash_fn: Callable,
361390
enable_kv_cache_events: bool) -> KVCacheCoordinator:
391+
if not enable_caching:
392+
return KVCacheCoordinatorNoPrefixCache(kv_cache_config, max_model_len,
393+
use_eagle, caching_hash_fn,
394+
enable_kv_cache_events)
362395
if len(kv_cache_config.kv_cache_groups) == 1:
363396
return UnitaryKVCacheCoordinator(kv_cache_config, max_model_len,
364397
use_eagle, enable_caching,

0 commit comments

Comments
 (0)