Skip to content

[Attention] Refactor attention metadata builder interface #20466

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 46 commits into
base: main
Choose a base branch
from

Conversation

LucasWilkinson
Copy link
Collaborator

@LucasWilkinson LucasWilkinson commented Jul 4, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

Refactor the attention metadata to try to break the dependency on runner. This has a few distinct advantages:

  1. Easier to unit-test (don't need to mock an entire GPU runner)
  2. Easier to benchmark (don't need to mock an entire GPU runner)
  3. Some attention schemes are fairly generic and can be implemented by just manipulating a sufficiently descriptive CommonAttentionMetadata saving the metadata builders from having to re-implement them, i.e. make them backend agnostic; e.g.
  4. Allow for more optimized metadata building; e.g. FlashInfer prefers hosts tensors in some cases (
    # FlashInfer 0.2 encourages passing host tensors
    ) and we have noticed D2H transfers in the blackwell v1 flash infer integration that could be avoided with this new CommonAttentionMetadata since it contains copies of the host tensors

NOTE: Mocking the vllm_config can still be cumbersome so there is more work that can be done here but this PR is a good first step
NOTE: This PR is primarily to create a backend agnostic interface for "Microbatch slicing: #18415"

NOTE: This changes how the attention meta-data is constructor for speculative decode (namely Eagle and MTP). More of the input prep has been moved host side then copied to the device.

Test Plan

  • Test correctness with all attention backends
  • Test that we do not regress Eagle Perf
  • Test MTP (using python -m pytest tests/spec_decode/e2e/test_mtp_correctness.py)

Test Result


lm_eval checks

H100 - GSM8K (meta-llama/Meta-Llama-3-8B-Instruct)

Backend Main Branch Attn-Refactor
FlashAttention + EAGLE3 1 0.7551 / 0.7566 0.7551 / 0.7566
FlashAttention 2 0.7544 / 0.7559 0.7544 / 0.7559
FlashInfer 3 0.7544 / 0.7559 0.7544 / 0.7559
Triton 4 0.7566 / 0.7597 0.7582 / 0.7604
FlexAttention 5 0.7036 / 0.7043 0.7028 / 0.7013

H100 - MLA Backends - GSM8K (deepseek-ai/DeepSeek-V2-Lite-Chat)

Backend Main Branch Attn-Refactor
FlashMLA 6 0.6657 / 0.6581 0.6679 / 0.6581
TritonMLA 7 0.6649 / 0.6581 0.6702 / 0.6619

H100 - Mamba Backends - HumanEval (mistralai/Mamba-Codestral-7B-v0.1)

Task Main Branch Attn-Refactor
Mamba HumanEval 8 0.4024 0.4024

ROCm MI300X - GSM8K (meta-llama/Meta-Llama-3-8B-Instruct)

Backend Attn-Refactor
Triton 9 0.7521 / 0.7544
AITER Flash Attention 10 0.7528 / 0.7544
FlashAttention 11 0.7521 / 0.7544

ROCm MI300X - MLA Backends - GSM8K (deepseek-ai/DeepSeek-V2-Lite-Chat)

Backend Attn-Refactor
TritonMLA 12 0.6641 / 0.6535
AITER MLA 13 0.6550 / 0.6482

Eagle Perf Regression Check

Test Environment

  • CPU: AMD EPYC 9654 96-Core Processor
  • GPU: NVIDIA H100 80GB HBM3
  • Model: meta-llama/Meta-Llama-3-8B-Instruct with EAGLE3 Speculative Decoding
  • Input/Output: 1024/128 tokens, 3 iterations per configuration

Commands Used

Server Command:

canhazgpu run -g 1 vllm serve \
    meta-llama/Meta-Llama-3-8B-Instruct \
    --tensor-parallel-size 1 \
    --max-model-len 4096 \
    --speculative-config '{
        "model":"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
        "num_speculative_tokens":3,
        "method":"eagle3",
        "draft_tensor_parallel_size":1
    }' \
    --port 8000 \
    --trust-remote-code

Benchmark Command:

vllm bench serve \
    --model meta-llama/Meta-Llama-3-8B-Instruct \
    --dataset-name random \
    --num-prompts 100 \
    --request-rate {1,5,10,20} \
    --port 8000 \
    --endpoint-type vllm \
    --trust-remote-code \
    --disable-tqdm

Results Summary

Request Rate Branch Throughput (req/s) TPOT (ms/token)
1 req/s attn-refactor 0.92 ± 0.000 8.28 ± 0.025
main 0.92 ± 0.000 8.44 ± 0.102
5 req/s attn-refactor 4.47 ± 0.003 8.43 ± 0.026
main 4.46 ± 0.006 8.78 ± 0.017
10 req/s attn-refactor 8.42 ± 0.012 9.06 ± 0.147
main 8.41 ± 0.009 9.28 ± 0.116
20 req/s attn-refactor 14.89 ± 0.012 10.37 ± 0.114
main 14.84 ± 0.041 10.59 ± 0.101

MTP

E2E Test (4xB200, R1)

vllm serve deepseek-ai/DeepSeek-R1 --host 127.0.0.1 --trust-remote-code --tensor-parallel-size 4 --enable-reasoning --reasoning-parser deepseek_r1 --gpu-memory-utilization 0.97 --disable-log-requests --max-model-len 65536 --speculative-config='{"method": "deepseek_mtp", "num_speculative_tokens": 1}'
lm_eval --model local-completions --model_args "base_url=http://localhost:8000/v1/completions,model=deepseek-ai/DeepSeek-R1,num_concurrent=1" --tasks gsm8k --num_fewshot 5 --batch_size 1

confirmed V1 Engine

INFO 07-08 21:48:32 [core.py:69] Initializing a V1 LLM engine (v0.9.2rc2.dev61+ge95532bd1) with config: model='deepseek-ai/DeepSeek-R1', speculative_config=SpeculativeConfig(method='deepseek_mtp', model='deepseek-ai/DeepSeek-R1', num_spec_tokens=1), tokenizer='deepseek-ai/DeepSeek-R1', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config={}, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=65536, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=4, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=fp8, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_backend='deepseek_r1'), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0, served_model_name=deepseek-ai/DeepSeek-R1, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill_enabled=True, use_async_output_proc=False, pooler_config=None, compilation_config={"level":3,"debug_dump_path":"","cache_dir":"","backend":"","custom_ops":[],"splitting_ops":["vllm.unified_attention","vllm.unified_attention_with_output"],"use_inductor":true,"compile_sizes":[],"inductor_compile_config":{"enable_auto_functionalized_v2":false},"inductor_passes":{},"use_cudagraph":true,"cudagraph_num_of_warmups":1,"cudagraph_capture_sizes":[512,504,496,488,480,472,464,456,448,440,432,424,416,408,400,392,384,376,368,360,352,344,336,328,320,312,304,296,288,280,272,264,256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],"cudagraph_copy_inputs":false,"full_cuda_graph":false,"max_capture_size":512,"local_cache_dir":null}
Metric Main Branch attn-refactor Branch
GSM8K Flexible-extract 95.30% ± 0.58% 95.30% ± 0.58%
GSM8K Strict-match 95.15% ± 0.59% 95.15% ± 0.59%
Draft Acceptance Rate 88.4% - 97.6% 90.9% - 95.0%

(Optional) Documentation Update

Footnotes/Commands

Footnotes

  1. VLLM_ATTENTION_BACKEND=FLASH_ATTN lm_eval \
      --model vllm \
      --model_args '{
        "pretrained": "meta-llama/Meta-Llama-3-8B-Instruct",
        "speculative_config": {
          "model": "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
          "num_speculative_tokens": 3,
          "method": "eagle3",
          "draft_tensor_parallel_size": 1
        }
      }' \
      --tasks gsm8k --batch_size auto
    
  2. VLLM_ATTENTION_BACKEND=FLASH_ATTN lm_eval \
      --model vllm \
      --model_args '{"pretrained": "meta-llama/Meta-Llama-3-8B-Instruct"}' \
      --tasks gsm8k --batch_size auto
    
  3. VLLM_ATTENTION_BACKEND=FLASHINFER_VLLM_V1 lm_eval \
      --model vllm \
      --model_args '{"pretrained": "meta-llama/Meta-Llama-3-8B-Instruct"}' \
      --tasks gsm8k --batch_size auto
    
  4. VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1 lm_eval \
      --model vllm \
      --model_args '{"pretrained": "meta-llama/Meta-Llama-3-8B-Instruct"}' \
      --tasks gsm8k --batch_size auto
    
  5. VLLM_ATTENTION_BACKEND=FLEX_ATTENTION lm_eval \
      --model vllm \
      --model_args '{
        "pretrained": "meta-llama/Meta-Llama-3-8B-Instruct",
        "gpu_memory_utilization": 0.8
      }' \
      --tasks gsm8k --batch_size auto
    
  6. VLLM_ATTENTION_BACKEND=FLASHMLA lm_eval \
      --model vllm \
      --model_args '{
        "pretrained": "deepseek-ai/DeepSeek-V2-Lite-Chat",
        "trust_remote_code": true
      }' \
      --tasks gsm8k --batch_size auto
    
  7. VLLM_ATTENTION_BACKEND=TRITON_MLA lm_eval \
      --model vllm \
      --model_args '{
        "pretrained": "deepseek-ai/DeepSeek-V2-Lite-Chat",
        "trust_remote_code": true
      }' \
      --tasks gsm8k --batch_size auto
    
  8. HF_ALLOW_CODE_EVAL=1 VLLM_USE_V1=1 lm_eval \
      --model vllm \
      --model_args '{
        "pretrained": "mistralai/Mamba-Codestral-7B-v0.1",
        "enable_prefix_caching": false,
        "enforce_eager": true
      }' \
      --tasks humaneval --batch_size auto \
      --confirm_run_unsafe_code
    
  9. VLLM_USE_V1=1 VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1 lm_eval \
      --model vllm \
      --model_args '{"pretrained": "meta-llama/Meta-Llama-3-8B-Instruct"}' \
      --tasks gsm8k --batch_size auto
    
  10. VLLM_USE_V1=1 VLLM_ATTENTION_BACKEND=FLASH_ATTN_VLLM_V1 \
    VLLM_ROCM_USE_AITER=1 VLLM_ROCM_USE_AITER_MHA=1 lm_eval \
      --model vllm \
      --model_args '{"pretrained": "meta-llama/Meta-Llama-3-8B-Instruct"}' \
      --tasks gsm8k --batch_size auto
    
  11. VLLM_USE_V1=1 VLLM_ATTENTION_BACKEND=FLASH_ATTN_VLLM_V1 lm_eval \
      --model vllm \
      --model_args '{"pretrained": "meta-llama/Meta-Llama-3-8B-Instruct"}' \
      --tasks gsm8k --batch_size auto
    
  12. VLLM_USE_V1=1 VLLM_ATTENTION_BACKEND=TRITON_MLA lm_eval \
      --model vllm \
      --model_args '{"pretrained": "deepseek-ai/DeepSeek-V2-Lite-Chat", "trust_remote_code": true}' \
      --tasks gsm8k --batch_size auto
    
  13. VLLM_USE_V1=1 VLLM_ATTENTION_BACKEND=ROCM_AITER_MLA \
    VLLM_ROCM_USE_AITER=1 VLLM_ROCM_USE_AITER_MLA=1 lm_eval \
      --model vllm \
      --model_args '{"pretrained": "deepseek-ai/DeepSeek-V2-Lite-Chat", "trust_remote_code": true, "block_size": 1}' \
      --tasks gsm8k --batch_size auto
    

Copy link

github-actions bot commented Jul 4, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added performance Performance-related issues rocm Related to AMD ROCm speculative-decoding v1 labels Jul 4, 2025
Copy link

mergify bot commented Jul 4, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @LucasWilkinson.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jul 4, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @LucasWilkinson, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request focuses on refactoring the attention metadata builder interface to enhance testability, benchmarkability, and code reusability. It introduces a common metadata structure and includes benchmarking and testing scripts to ensure performance and correctness of the changes. The goal is to provide a more flexible and optimized foundation for attention mechanisms within the vLLM framework.

Highlights

  • Refactor Attention Metadata Builder Interface: This PR refactors the attention metadata builder interface to remove the dependency on the runner object. This change aims to improve unit testing, benchmarking, and code reusability.
  • Introduce CommonAttentionMetadata: Introduces a CommonAttentionMetadata class to provide a backend-agnostic interface for attention metadata, facilitating optimizations and generic implementations of attention schemes.
  • Add Benchmarking Script: Adds a benchmarking script (benchmark_v1_backends.py) to evaluate the performance of different attention backends with various batch configurations.
  • Add Test Cases: Adds test cases (test_attention_backends.py) to verify the correctness of attention backends against a reference implementation using scaled_dot_product_attention.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The code changes introduce a refactoring of the attention metadata builder interface to remove the dependency on GPUModelRunner. The changes improve testability and modularity. I've identified a few critical issues and areas for improvement, particularly in a new utility function and the benchmark script.


slot_mapping = block_table.slot_mapping[:num_actual_tokens]
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
block_table_tensor = common_attn_metadata.block_table_tensor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm doubt about whether block_table_tensor and slot_mapping should be put into common_attn_metadata.

For models with sliding window + full attention, the block_table_tensor for the two types of layers are different as they need different number of slots. These two types of layers are put into different kv_cache_groups and thus have different attention backend and different BlockTable now.

Copy link
Collaborator Author

@LucasWilkinson LucasWilkinson Jul 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ya, this PR kinda redefines CommonAttentionMetadata from "common inputs across KV-caches groups" to a "common interface for AttentionMetadataBuilder.build that we implement backend-agnostic attention schemes/features on". But as a result there will be different CommonAttentionMetadata for each KV-cache group (this seems fine since they all reference the same underlying tensors anyways).

Basically, the idea is that AttentionMetadataBuilder.build transforms CommonAttentionMetadata to a backend-specific metadata and CommonAttentionMetadata should be the minimal amount of things to do this.

The reason for adding slot_mapping and block_table_tensor to the CommonAttentionMetadata is two-fold:

  1. [This is the main motivation] There's some attention-related things that manipulate the block_table_tensor and/or slot_mapping in a way that is common across backends; e.g. local attention

    def make_local_attention_virtual_batches(
    attn_chunk_size: int,
    query_start_loc_np: np.ndarray,
    seq_lens_np: np.ndarray,
    block_table: torch.Tensor,
    block_size: int = 0,
    ) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]:
    q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
    actual_batch_size = seq_lens_np.shape[0]
    # Handle if we are starting in the middle of a local attention block,
    # we assume q_seqlens > 0 (for all elements), for each batch idx we compute
    # the number of tokens that are not in the first local attention block and
    # then we can simply use a cdiv for the rest.
    # For example if we have:
    # attn_chunk_size = 4
    # q_seqlens = [4, 10, 5]
    # k_seqlens = [6, 17, 9]
    # Then we would get:
    # new_tokens_in_first_block = [2, 1, 4]
    # local_blocks = [2, 4, 2]
    q_tokens_in_first_block = np.minimum(
    attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size),
    q_seqlens).astype(np.int32)
    tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size)
    local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block,
    attn_chunk_size)
    # Once we know the number of local blocks we can compute the request spans
    # for each batch idx, we can figure out the number of "virtual" requests we
    # have to make,
    # For the above example we would get:
    # seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1]
    #
    # First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1])
    # (TODO: max a utility to share this code with _prepare_inputs)
    # arange step 1. [2, 4, 2] -> [2, 6, 8]
    cu_num_blocks = np.cumsum(local_blocks)
    virtual_batches = cu_num_blocks[-1]
    # arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6]
    block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks)
    # arange step 3. [0, 1, 0, 1, 2, 3, 0, 1]
    arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets
    # also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0])
    rarange = np.repeat(local_blocks, local_blocks) - arange - 1
    # Then we can compute the seqlens_q_local, handling the fact that the
    # first and last blocks could be partial
    seqlens_q_local = \
    np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks)
    # set the first block since this may be a partial block
    seqlens_q_local[arange == 0] = q_tokens_in_first_block
    # set the remaining blocks
    seqlens_q_local[arange > 0] = np.minimum(
    seqlens_q_local - attn_chunk_size * (arange - 1),
    attn_chunk_size)[arange > 0]
    # convert from q_seqlens to cu_seqlens_q
    cu_seqlens_q_local = np.pad(np.cumsum(seqlens_q_local), (1, 0))\
    .astype(np.int32)
    # compute the seqlens_k_local,
    # basically a full local attention block for all but the last block in each
    # batch
    # For our example this will be:
    # seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1]
    seqlens_k_local = np.full(cu_num_blocks[-1],
    attn_chunk_size,
    dtype=np.int32)
    seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block
    k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - \
    (rarange * attn_chunk_size + \
    np.repeat(tokens_in_last_block, local_blocks))
    # For the example the local attention blocks start at:
    # _b0_ _____b1_____ _b2_
    # k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8]
    block_starts = k_seqstarts_absolute // block_size
    assert attn_chunk_size % block_size == 0, \
    f"attn_chunk_size {attn_chunk_size} is not " \
    f"divisible by block_size {block_size}"
    pages_per_local_batch = attn_chunk_size // block_size
    # Create a block_table for the local attention blocks
    # For out example if we have a block-table like (assuming block_size=2):
    # block_table = [
    # [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0
    # [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1
    # [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2
    # ]
    # Then for the local batches we would want a block-table like
    # block_table_local = [
    # [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0])
    # [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4])
    # [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4])
    # [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8])
    # [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12])
    # [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16])
    # [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4])
    # [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8])
    # ]
    block_indices= np.broadcast_to(
    np.arange(pages_per_local_batch, dtype=np.int32),
    (virtual_batches, pages_per_local_batch)) \
    + np.expand_dims(block_starts, axis=1)
    block_indices = block_indices.flatten().clip(max=block_table.shape[1] - 1)
    batch_indices = np.repeat(np.arange(actual_batch_size, dtype=np.int32),
    local_blocks * pages_per_local_batch)
    block_table_local = block_table[batch_indices, block_indices]\
    .view(virtual_batches, -1)
    return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, \
    block_table_local
    and attention slicing for micro-batching (dual-batch overlap) [WIP] Two batch overlap #18415. With this refactor, these could be pulled out of AttentionMetadataBuilder and instead operate on CommonAttentionMetadata before being passed into AttentionMetadataBuilder.build, making these features backend-independent.

  2. Break the dependency on model runner in the AttentionMetadataBuilders. This makes unit-testing challenging as it's hard to know what parts need to be mocked.

I was thinking about breaking block_table_tensor and slot_mapping into a separate data structure (adding a third input to AttentionMetadataBuilder.build) to preserve the semantics of CommonAttentionMetadata being "common inputs across KV-caches groups," but I'm not convinced this is needed since all the CommonAttentionMetadata classes reference the same underlying tensors anyway.

I could see arguments for passing BlockTable to AttentionMetadataBuilder.build; it's just that it makes things like local attention and microbatch slicing (mentioned in 1) more difficult since instead of just slicing/manipulating the tensors directly, they would have to construct a new sliced/manipulated BlockTable which would be messy.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the detailed reply. It make sense to me.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to have a 2-level structure here, say LayerAttentionMetadata that includes CommonAttentionMetadata as a member, and LayerAttentionMetadata takes the role of CommonAttentionMetadata in this PR while CommonAttentionMetadata maintains its role from before this PR?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could do that for sure; just not convinced there's enough value to justify the extra layers of indirection; would just mean more '.'s in the metadata builders, e.g. 'common_attn_metadata.query_start_locs' would become 'layer_group_attn_metadata.common_attn_metadata.query_start_locs' (side note we would probably want to name LayerGroupAttentionMetadata) since it's common across all layers in the group

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I mean the member could be named common or we could make it a superclass if we're worried about access. I do still think there's value in separating these concepts, but I'll defer to you for the final decision

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opinion noted; I dont totally disagree but I think for the simplicity of the (upcoming) backend agnostic local-attention and ubatch slicing interfaces which return new CommonAttentionMetadata Ill leave it the way it this (since those will modify the "common" part anyways)

I think now that the hard part is done we can always refactor this in the future with a simple find replace if we come across a compelling use case for this separation.

@LucasWilkinson LucasWilkinson force-pushed the lwilkinson/attn-refactor branch from e3e1a8f to 7f0d422 Compare July 7, 2025 05:23
@mergify mergify bot removed the needs-rebase label Jul 7, 2025
@LucasWilkinson LucasWilkinson changed the title [WIP][Attention] Refactor attention metadata builder interface [Attention] Refactor attention metadata builder interface Jul 9, 2025
@LucasWilkinson LucasWilkinson marked this pull request as ready for review July 9, 2025 04:29
@LucasWilkinson LucasWilkinson force-pushed the lwilkinson/attn-refactor branch from 84b2e0e to e796669 Compare July 9, 2025 05:22
Copy link
Contributor

@SageMoore SageMoore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, @LucasWilkinson. I'm not ramped up enough on spec decode to have an opinion there, but the attention backend, block table, and gpu model runner changes all look reasonable to me.

Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really nice refactor, and it certainly looks painful in places, thanks for doing this! Just had a few nits/comments.


slot_mapping = block_table.slot_mapping[:num_actual_tokens]
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
block_table_tensor = common_attn_metadata.block_table_tensor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I mean the member could be named common or we could make it a superclass if we're worried about access. I do still think there's value in separating these concepts, but I'll defer to you for the final decision

atol = 5e-3

if backend_name == "flex_attention":
atol = 5e-1 # TODO: figuure out why flex_attention has such large
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing all of the comments!

@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 10, 2025
Copy link

mergify bot commented Jul 11, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @LucasWilkinson.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jul 11, 2025
@LucasWilkinson LucasWilkinson force-pushed the lwilkinson/attn-refactor branch from 38f7c4e to 4c56bb0 Compare July 11, 2025 03:45
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
@LucasWilkinson LucasWilkinson force-pushed the lwilkinson/attn-refactor branch from 6605f5d to f6b4d45 Compare July 16, 2025 02:29
@mergify mergify bot removed the needs-rebase label Jul 16, 2025
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm speculative-decoding v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants