Skip to content

[v1][core] Support for attention free models #20811

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
4 changes: 3 additions & 1 deletion vllm/v1/core/kv_cache_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,9 @@ def get_kv_cache_coordinator(
kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool,
enable_caching: bool, caching_hash_fn: Callable,
enable_kv_cache_events: bool) -> KVCacheCoordinator:
if not enable_caching:
if not enable_caching or len(kv_cache_config.kv_cache_groups) == 0:
# We instantiate this coordinator also for attention free models that
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this given prefix caching is disabled here for models that don’t use last pooling method?
https://github.com/maxdebayser/vllm/blob/221f013922c0c118b682d294755e69990b2c43ed/vllm/config.py#L4505

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without this check though you would not be able to disable attention for models that are not of the pooling type as prefix caching is enabled by default for all models except pooling ones.

See below:

vllm/vllm/engine/arg_utils.py

Lines 1620 to 1630 in 38efa28

def _set_default_args_v1(self, usage_context: UsageContext,
model_config: ModelConfig) -> None:
"""Set Default Arguments for V1 Engine."""
# V1 always uses chunked prefills and prefix caching
# for non-pooling tasks.
# For pooling tasks the default is False
if model_config.runner_type != "pooling":
self.enable_chunked_prefill = True
if self.enable_prefix_caching is None:
self.enable_prefix_caching = True

Copy link
Contributor Author

@christian-pinto christian-pinto Jul 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps the safest thing to do is to disable prefix-caching in VllmConfig.__post_init__ right away for any attention free models and then yes, we could just rely on enable_caching as you suggest.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this comment?

# have 0 kv_cache_groups
return KVCacheCoordinatorNoPrefixCache(kv_cache_config, max_model_len,
use_eagle, caching_hash_fn,
enable_kv_cache_events)
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(
self.prefix_cache_stats = PrefixCacheStats() if log_stats else None

self.block_size: Optional[int] = None
if self.enable_caching:
if self.enable_caching and len(kv_cache_config.kv_cache_groups) > 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above. Don’t need if enable_caching is false for attention free models.

assert len(
set(g.kv_cache_spec.block_size
for g in kv_cache_config.kv_cache_groups)
Expand Down
21 changes: 20 additions & 1 deletion vllm/v1/core/kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,10 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig,
ValueError: If there is not enough memory available for the KV cache.
"""

# No need to check for available memory if the model is attention free
if vllm_config.model_config.is_attention_free:
return

if available_memory <= 0:
raise ValueError("No available memory for the cache blocks. "
"Try increasing `gpu_memory_utilization` when "
Expand Down Expand Up @@ -749,6 +753,13 @@ def is_kv_cache_page_size_uniform(
return len(page_sizes) == 1


def is_kv_cache_type_attention_free(
kv_cache_spec: dict[str, KVCacheSpec]) -> bool:

# kv_cache_spec is an empty dict for attention free models
return not kv_cache_spec


def _get_kv_cache_config_uniform_page_size(
vllm_config: VllmConfig, kv_cache_spec: dict[str, KVCacheSpec],
available_memory: int) -> KVCacheConfig:
Expand Down Expand Up @@ -891,6 +902,10 @@ def _get_kv_cache_config_uniform_page_size(
return kv_cache_config


def _get_kv_cache_config_attention_free() -> KVCacheConfig:
return KVCacheConfig(num_blocks=1, kv_cache_tensors=[], kv_cache_groups=[])


def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
"""
This function tries to convert the KV cache specs to one type if the model
Expand Down Expand Up @@ -957,7 +972,11 @@ def get_kv_cache_config(
if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager:
unify_hybrid_kv_cache_specs(kv_cache_spec)

if is_kv_cache_type_uniform(kv_cache_spec):
if is_kv_cache_type_attention_free(kv_cache_spec):
# This returns a kv_cahce config with 0 kv_cache groups and 1 block
# to allow for the KVCache manager to handle attention free models.
return _get_kv_cache_config_attention_free()
elif is_kv_cache_type_uniform(kv_cache_spec):
# KV cache of all layers are the same, which is true for
# most models. Allocate the same amount of memory for
# each layer.
Expand Down
2 changes: 2 additions & 0 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2590,6 +2590,8 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
KVCacheSpec: A dictionary mapping layer names to their KV cache
format. Layers that do not need KV cache are not included.
"""
if self.vllm_config.model_config.is_attention_free:
return {}

block_size = self.vllm_config.cache_config.block_size
use_mla = self.vllm_config.model_config.use_mla
Expand Down
3 changes: 3 additions & 0 deletions vllm/v1/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,9 @@ def determine_available_memory(self) -> int:
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
if self.vllm_config.model_config.is_attention_free:
return 0

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this?

torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
GiB = lambda b: b / GiB_bytes
Expand Down