Skip to content

[v1][core] Support for attention free models #20811

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
9 changes: 9 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4710,6 +4710,15 @@ def __post_init__(self):
"Only \"last\" pooling supports chunked "
"prefill and prefix caching; disabling both.")

if self.model_config.is_attention_free:
# If the model is not of pooling type and it is attention free,
# we make sure chunked prefill and prefix_caching are
# disabled so that the correct KVCacheCoordinator
# is loaded.
disable_chunked_prefill_reasons.append(
"This is an attention free model, "
"disabling chunked prefill and prefix caching.")

if disable_chunked_prefill_reasons:
for reason in disable_chunked_prefill_reasons:
logger.info(reason)
Expand Down
2 changes: 2 additions & 0 deletions vllm/v1/core/kv_cache_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,8 @@ def get_kv_cache_coordinator(
enable_caching: bool, caching_hash_fn: Callable,
enable_kv_cache_events: bool) -> KVCacheCoordinator:
if not enable_caching:
# We instantiate this coordinator also for attention free models that
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this given prefix caching is disabled here for models that don’t use last pooling method?
https://github.com/maxdebayser/vllm/blob/221f013922c0c118b682d294755e69990b2c43ed/vllm/config.py#L4505

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without this check though you would not be able to disable attention for models that are not of the pooling type as prefix caching is enabled by default for all models except pooling ones.

See below:

vllm/vllm/engine/arg_utils.py

Lines 1620 to 1630 in 38efa28

def _set_default_args_v1(self, usage_context: UsageContext,
model_config: ModelConfig) -> None:
"""Set Default Arguments for V1 Engine."""
# V1 always uses chunked prefills and prefix caching
# for non-pooling tasks.
# For pooling tasks the default is False
if model_config.runner_type != "pooling":
self.enable_chunked_prefill = True
if self.enable_prefix_caching is None:
self.enable_prefix_caching = True

Copy link
Contributor Author

@christian-pinto christian-pinto Jul 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps the safest thing to do is to disable prefix-caching in VllmConfig.__post_init__ right away for any attention free models and then yes, we could just rely on enable_caching as you suggest.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this comment?

# have 0 kv_cache_groups
return KVCacheCoordinatorNoPrefixCache(kv_cache_config, max_model_len,
use_eagle, caching_hash_fn,
enable_kv_cache_events)
Expand Down
6 changes: 4 additions & 2 deletions vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ def __init__(
) -> None:
self.max_model_len = max_model_len

self.enable_caching = enable_caching
self.enable_caching = (enable_caching
if len(kv_cache_config.kv_cache_groups) > 0
else False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.enable_caching = (enable_caching
if len(kv_cache_config.kv_cache_groups) > 0
else False)
```suggestions
if len(kv_cache_config.kv_cache_groups) == 0:
# Attention free models don't have kv cache, thus don't need prefix caching.
enable_caching = False
self.enable_caching = enable_caching

I think this structure is more clear for readers.

self.caching_hash_fn = (
sha256_cbor_64bit if caching_hash_algo == "sha256_cbor_64bit" else
sha256 if caching_hash_algo == "sha256" else hash)
Expand All @@ -101,7 +103,7 @@ def __init__(
kv_cache_config=kv_cache_config,
max_model_len=self.max_model_len,
use_eagle=self.use_eagle,
enable_caching=enable_caching,
enable_caching=self.enable_caching,
caching_hash_fn=self.caching_hash_fn,
enable_kv_cache_events=enable_kv_cache_events,
)
Expand Down
21 changes: 20 additions & 1 deletion vllm/v1/core/kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,10 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig,
ValueError: If there is not enough memory available for the KV cache.
"""

# No need to check for available memory if the kv_cache_spec is empty
if not kv_cache_spec:
return

if available_memory <= 0:
raise ValueError("No available memory for the cache blocks. "
"Try increasing `gpu_memory_utilization` when "
Expand Down Expand Up @@ -749,6 +753,13 @@ def is_kv_cache_page_size_uniform(
return len(page_sizes) == 1


def is_kv_cache_type_attention_free(
kv_cache_spec: dict[str, KVCacheSpec]) -> bool:

# kv_cache_spec is an empty dict for attention free models
return not kv_cache_spec


def _get_kv_cache_config_uniform_page_size(
vllm_config: VllmConfig, kv_cache_spec: dict[str, KVCacheSpec],
available_memory: int) -> KVCacheConfig:
Expand Down Expand Up @@ -891,6 +902,10 @@ def _get_kv_cache_config_uniform_page_size(
return kv_cache_config


def _get_kv_cache_config_attention_free() -> KVCacheConfig:
return KVCacheConfig(num_blocks=1, kv_cache_tensors=[], kv_cache_groups=[])


def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
"""
This function tries to convert the KV cache specs to one type if the model
Expand Down Expand Up @@ -957,7 +972,11 @@ def get_kv_cache_config(
if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager:
unify_hybrid_kv_cache_specs(kv_cache_spec)

if is_kv_cache_type_uniform(kv_cache_spec):
if is_kv_cache_type_attention_free(kv_cache_spec):
# This returns a kv_cache config with 0 kv_cache groups and 1 block
# to allow for the KVCache manager to handle attention free models.
return _get_kv_cache_config_attention_free()
elif is_kv_cache_type_uniform(kv_cache_spec):
# KV cache of all layers are the same, which is true for
# most models. Allocate the same amount of memory for
# each layer.
Expand Down
6 changes: 5 additions & 1 deletion vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,11 @@ def _initialize_kv_caches(

# Profiles the peak memory usage of the model to determine how much
# memory can be allocated for kv cache.
available_gpu_memory = self.model_executor.determine_available_memory()
check_available_memory = not(len(kv_cache_specs) == 1 and not kv_cache_specs[0])
available_gpu_memory = [0]
if check_available_memory:
available_gpu_memory = (
self.model_executor.determine_available_memory())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
check_available_memory = not(len(kv_cache_specs) == 1 and not kv_cache_specs[0])
available_gpu_memory = [0]
if check_available_memory:
available_gpu_memory = (
self.model_executor.determine_available_memory())
has_kv_cache = any(kv_cache_spec for kv_cache_spec in kv_cache_specs)
if has_kv_cache:
available_gpu_memory = self.model_executor.determine_available_memory()
else:
# Attention free models don't need memory for kv cache
available_gpu_memory = [0] * len(kv_cache_specs)

I feel that the condition is not correct. len(kv_cache_specs) can be larger than 1 when TP / PP is enabled.


assert len(kv_cache_specs) == len(available_gpu_memory)
# Get the kv cache tensor size
Expand Down
1 change: 1 addition & 0 deletions vllm/v1/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def determine_available_memory(self) -> int:
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this?

torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
GiB = lambda b: b / GiB_bytes
Expand Down
Loading