Skip to content

Commit b764c9d

Browse files
Support for attention free models
Forces 0 KV Cache groups to disable KV Cache in attention free models Signed-off-by: Christian Pinto <christian.pinto@ibm.com>
1 parent a99b9f7 commit b764c9d

File tree

4 files changed

+30
-3
lines changed

4 files changed

+30
-3
lines changed

vllm/v1/core/kv_cache_coordinator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,9 @@ def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
250250
super().__init__(kv_cache_config, max_model_len, use_eagle,
251251
enable_caching, caching_hash_fn,
252252
enable_kv_cache_events)
253-
self.verify_and_split_kv_cache_groups()
253+
# attention free models are initialized with 0 kv_cache_groups
254+
if len(self.kv_cache_config.kv_cache_groups) > 0:
255+
self.verify_and_split_kv_cache_groups()
254256

255257
def verify_and_split_kv_cache_groups(self) -> None:
256258
"""

vllm/v1/core/kv_cache_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def __init__(
8989
self.prefix_cache_stats = PrefixCacheStats() if log_stats else None
9090

9191
self.block_size: Optional[int] = None
92-
if self.enable_caching:
92+
if self.enable_caching and len(self.kv_cache_config.kv_cache_groups) > 0:
9393
assert len(
9494
set(g.kv_cache_spec.block_size
9595
for g in kv_cache_config.kv_cache_groups)

vllm/v1/core/kv_cache_utils.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -563,6 +563,10 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig,
563563
ValueError: If there is not enough memory available for the KV cache.
564564
"""
565565

566+
# No need to check for available memory if the model is attention free
567+
if vllm_config.model_config.is_attention_free:
568+
return
569+
566570
if available_memory <= 0:
567571
raise ValueError("No available memory for the cache blocks. "
568572
"Try increasing `gpu_memory_utilization` when "
@@ -748,6 +752,12 @@ def is_kv_cache_page_size_uniform(
748752
page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()}
749753
return len(page_sizes) == 1
750754

755+
def is_kv_cache_type_attention_free(
756+
kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
757+
758+
# kv_cache_spec is an empty dict for attention free models
759+
if not kv_cache_spec:
760+
return True
751761

752762
def _get_kv_cache_config_uniform_page_size(
753763
vllm_config: VllmConfig, kv_cache_spec: dict[str, KVCacheSpec],
@@ -891,6 +901,11 @@ def _get_kv_cache_config_uniform_page_size(
891901
return kv_cache_config
892902

893903

904+
def _get_kv_cache_config_attention_free() -> KVCacheConfig:
905+
return KVCacheConfig(num_blocks=1,
906+
kv_cache_tensors=[],
907+
kv_cache_groups=[])
908+
894909
def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
895910
"""
896911
This function tries to convert the KV cache specs to one type if the model
@@ -957,7 +972,11 @@ def get_kv_cache_config(
957972
if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager:
958973
unify_hybrid_kv_cache_specs(kv_cache_spec)
959974

960-
if is_kv_cache_type_uniform(kv_cache_spec):
975+
if is_kv_cache_type_attention_free(kv_cache_spec):
976+
# This returns a kv_cahce config with 0 kv_cache groups and 1 block
977+
# to allow for the KVCache manager to handle attention free models.
978+
return _get_kv_cache_config_attention_free()
979+
elif is_kv_cache_type_uniform(kv_cache_spec):
961980
# KV cache of all layers are the same, which is true for
962981
# most models. Allocate the same amount of memory for
963982
# each layer.

vllm/v1/executor/abstract.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,16 @@ def register_failure_callback(self, callback: FailureCallback):
7373
pass
7474

7575
def determine_available_memory(self) -> list[int]: # in bytes
76+
if self.vllm_config.model_config.is_attention_free:
77+
return [0]
78+
7679
output = self.collective_rpc("determine_available_memory")
7780
return output
7881

7982
def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]:
83+
if self.vllm_config.model_config.is_attention_free:
84+
return [{}]
85+
8086
output = self.collective_rpc("get_kv_cache_spec")
8187
return output
8288

0 commit comments

Comments
 (0)