Skip to content

Commit 4b0a9f9

Browse files
Support for attention free models
Forces 0 KV Cache groups to disable KV Cache in attention free models Signed-off-by: Christian Pinto <christian.pinto@ibm.com>
1 parent 8020e98 commit 4b0a9f9

File tree

4 files changed

+30
-3
lines changed

4 files changed

+30
-3
lines changed

vllm/v1/core/kv_cache_coordinator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,9 @@ def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
221221
super().__init__(kv_cache_config, max_model_len, use_eagle,
222222
enable_caching, caching_hash_fn,
223223
enable_kv_cache_events)
224-
self.verify_and_split_kv_cache_groups()
224+
# attention free models are initialized with 0 kv_cache_groups
225+
if len(self.kv_cache_config.kv_cache_groups) > 0:
226+
self.verify_and_split_kv_cache_groups()
225227

226228
def verify_and_split_kv_cache_groups(self) -> None:
227229
"""

vllm/v1/core/kv_cache_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def __init__(
8686
self.prefix_cache_stats = PrefixCacheStats() if log_stats else None
8787

8888
self.block_size: Optional[int] = None
89-
if self.enable_caching:
89+
if self.enable_caching and len(self.kv_cache_config.kv_cache_groups) > 0:
9090
assert len(
9191
set(g.kv_cache_spec.block_size
9292
for g in kv_cache_config.kv_cache_groups)

vllm/v1/core/kv_cache_utils.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,10 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig,
551551
ValueError: If there is not enough memory available for the KV cache.
552552
"""
553553

554+
# No need to check for available memory if the model is attention free
555+
if vllm_config.model_config.is_attention_free:
556+
return
557+
554558
if available_memory <= 0:
555559
raise ValueError("No available memory for the cache blocks. "
556560
"Try increasing `gpu_memory_utilization` when "
@@ -736,6 +740,12 @@ def is_kv_cache_page_size_uniform(
736740
page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()}
737741
return len(page_sizes) == 1
738742

743+
def is_kv_cache_type_attention_free(
744+
kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
745+
746+
# kv_cache_spec is an empty dict for attention free models
747+
if not kv_cache_spec:
748+
return True
739749

740750
def _get_kv_cache_config_uniform_page_size(
741751
vllm_config: VllmConfig, kv_cache_spec: dict[str, KVCacheSpec],
@@ -879,6 +889,11 @@ def _get_kv_cache_config_uniform_page_size(
879889
return kv_cache_config
880890

881891

892+
def _get_kv_cache_config_attention_free() -> KVCacheConfig:
893+
return KVCacheConfig(num_blocks=1,
894+
kv_cache_tensors=[],
895+
kv_cache_groups=[])
896+
882897
def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
883898
"""
884899
This function tries to convert the KV cache specs to one type if the model
@@ -945,7 +960,11 @@ def get_kv_cache_config(
945960
if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager:
946961
unify_hybrid_kv_cache_specs(kv_cache_spec)
947962

948-
if is_kv_cache_type_uniform(kv_cache_spec):
963+
if is_kv_cache_type_attention_free(kv_cache_spec):
964+
# This returns a kv_cahce config with 0 kv cache groups and 1 block
965+
# to allow for the KVCache manager to handle attention fre models.
966+
return _get_kv_cache_config_attention_free()
967+
elif is_kv_cache_type_uniform(kv_cache_spec):
949968
# KV cache of all layers are the same, which is true for
950969
# most models. Allocate the same amount of memory for
951970
# each layer.

vllm/v1/executor/abstract.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,16 @@ def register_failure_callback(self, callback: FailureCallback):
7373
pass
7474

7575
def determine_available_memory(self) -> list[int]: # in bytes
76+
if self.vllm_config.model_config.is_attention_free:
77+
return [0]
78+
7679
output = self.collective_rpc("determine_available_memory")
7780
return output
7881

7982
def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]:
83+
if self.vllm_config.model_config.is_attention_free:
84+
return [{}]
85+
8086
output = self.collective_rpc("get_kv_cache_spec")
8187
return output
8288

0 commit comments

Comments
 (0)