Skip to content

[Attention] Make local attention backend agnostic #21093

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

LucasWilkinson
Copy link
Collaborator

@LucasWilkinson LucasWilkinson commented Jul 17, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

Make local attention backend agnostic now that #20466 has landed so we can turn on llama4 iRoPE for FlashInfer on Blackwell

Test Plan

Ruler eval

Test Result

This PR

VLLM_ATTENTION_BACKEND=FLASHINFER python -m lm_eval --model vllm --model_args pretrained=/home/lwilkinson/local_models/meta-llama--Llama-4-Scout-17B-16E-Instruct,tensor_parallel_size=2,gpu_memory_utilization=0.8,trust_remote_code=True,max_model_len=16384 --tasks ruler --limit 1 --batch_size auto --output_path ./test_fixed_flashinfer.json
...
|Groups|Version|Filter|n-shot|Metric|   |Value |   |Stderr|
|------|------:|------|------|-----:|---|-----:|---|------|
|ruler |      1|none  |      |  4096|↑  |0.9231|±  |   N/A|
VLLM_ATTENTION_BACKEND=FLASH_ATTN python -m lm_eval --model vllm --model_args pretrained=/home/lwilkinson/local_models/meta-llama--Llama-4-Scout-17B-16E-Instruct,tensor_parallel_size=2,gpu_memory_utilization=0.8,trust_remote_code=True,max_model_len=16384 --tasks ruler --limit 100 --batch_size auto --outp
ut_path ./test_fixed_flashinfer.json
...
|Groups|Version|Filter|n-shot|Metric|   |Value |   |Stderr|
|------|------:|------|------|-----:|---|-----:|---|------|
|ruler |      1|none  |      |  4096|↑  |0.9527|±  |   N/A|

Main

VLLM_ATTENTION_BACKEND=FLASHINFER python -m lm_eval --model vllm --model_args pretrained=/home/lwilkinson/local_models/meta-llama--Llama-4-Scout-17B-16E-Instruct,tensor_parallel_size=2,gpu_memory_utilization=0.8,trust_remote_code=True,max_model_len=16384 --tasks ruler --limit 100 --batch_size auto --output_path ./test_fixed_flashinfer.json
...
|Groups|Version|Filter|n-shot|Metric|   |Value |   |Stderr|
|------|------:|------|------|-----:|---|-----:|---|------|
|ruler |      1|none  |      |  4096|↑  |0.6024|±  |   N/A|
VLLM_ATTENTION_BACKEND=FLASH_ATTN python -m lm_eval --model vllm --model_args pretrained=/home/lwilkinson/local_models/meta-llama--Llama-4-Scout-17B-16E-Instruct,tensor_parallel_size=2,gpu_memory_utilization=0.8,trust_remote_code=True,max_model_len=16384 --tasks ruler --limit 100 --batch_size auto --output_path ./test_fixed_flashinfer.json
...
|Groups|Version|Filter|n-shot|Metric|   |Value |   |Stderr|
|------|------:|------|------|-----:|---|-----:|---|------|
|ruler |      1|none  |      |  4096|↑  |0.9527|±  |   N/A|

(Optional) Documentation Update

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>

fix

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Jul 17, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the local attention mechanism to be backend-agnostic, which is a great improvement for code maintainability and extensibility. The approach of centralizing the virtual batch creation logic in vllm/v1/attention/backends/utils.py and introducing a ChunkedLocalAttentionSpec is well-designed.

I've found one critical issue in the implementation of the new make_local_attention_virtual_batches function where the returned CommonAttentionMetadata has inconsistent dimensions, which would likely cause runtime errors or incorrect attention calculations. I've provided a suggestion to fix this. Once that's addressed, the changes should be in good shape.

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
@LucasWilkinson LucasWilkinson marked this pull request as ready for review July 17, 2025 12:42
@LucasWilkinson LucasWilkinson added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 17, 2025
@LucasWilkinson LucasWilkinson requested a review from mgoin July 17, 2025 17:42
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What a satisfying clean up! LGTM

@mgoin mgoin merged commit 89cab4d into vllm-project:main Jul 18, 2025
80 checks passed
WorldExplored pushed a commit to nadathurv/vllm that referenced this pull request Jul 19, 2025
Signed-off-by: WorldExplored <srreyansh.sethi@gmail.com>
hj-mistral pushed a commit to hj-mistral/vllm that referenced this pull request Jul 19, 2025
ChunkedLocalAttentionSpec):
common_attn_metadata = make_local_attention_virtual_batches(
kv_cache_group_spec.kv_cache_spec.attention_chunk_size,
common_attn_metadata, self.cache_config.block_size)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LucasWilkinson Hybrid kv cache manager is not compatible with kv connectors and kv events. When these things are enabled, we'll fall back to one kv cache group with FullAttentionSpec, and mark FullAttentionSpec.attention_chunk_size. I think there will be some bug here. You can test it by launching model with --disable-hybrid-kv-cache-manager

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the heads up! Back in office tomorrow; will take a look 👍

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants