Skip to content

[Attention] Refactor attention metadata builder interface #20466

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 46 commits into from
Jul 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
85a3cf0
refactor backends wip
LucasWilkinson Jul 4, 2025
bdcde66
move slot table computation into block_table
LucasWilkinson Jul 4, 2025
92e4539
eagle passing
LucasWilkinson Jul 4, 2025
73319c8
optimize eagle
LucasWilkinson Jul 5, 2025
12730dd
correctness
LucasWilkinson Jul 5, 2025
45b368f
fast build
LucasWilkinson Jul 6, 2025
4925972
cleanup
LucasWilkinson Jul 6, 2025
bbbf537
cleanup
LucasWilkinson Jul 7, 2025
e5d8d51
cleanup
LucasWilkinson Jul 7, 2025
90f076a
refactor triton
LucasWilkinson Jul 7, 2025
5d79a0e
more refactors
LucasWilkinson Jul 7, 2025
e8ca38c
fix flex attention warning
LucasWilkinson Jul 7, 2025
bdf8971
cleanup
LucasWilkinson Jul 7, 2025
62f6ead
cleanup eagle tests
LucasWilkinson Jul 7, 2025
7cdda4d
refactors
LucasWilkinson Jul 7, 2025
02e1c56
fa passing
LucasWilkinson Jul 8, 2025
c32445d
first pass backend tests working
LucasWilkinson Jul 8, 2025
4f3c10f
get tests to pass
LucasWilkinson Jul 8, 2025
373caaf
punt benchmark to followup pr
LucasWilkinson Jul 8, 2025
9f246d1
minor cleanups
LucasWilkinson Jul 9, 2025
63dbfe1
revert cpu metadata refactor
LucasWilkinson Jul 9, 2025
e879dec
refactor cpu_attn
LucasWilkinson Jul 9, 2025
2df935a
review comments
LucasWilkinson Jul 9, 2025
5a62e1c
review comments
LucasWilkinson Jul 9, 2025
d269709
review comments
LucasWilkinson Jul 9, 2025
fa4e470
get triton tests to pass
LucasWilkinson Jul 9, 2025
1ac41f1
review comments
LucasWilkinson Jul 9, 2025
a6399f5
update
LucasWilkinson Jul 11, 2025
4c76fa9
fix rebase error
LucasWilkinson Jul 11, 2025
7365687
more lint fixes
LucasWilkinson Jul 11, 2025
6cdd4c5
fix rebase error
LucasWilkinson Jul 11, 2025
9a020c5
fix rebase error
LucasWilkinson Jul 11, 2025
fcdc674
review comments
LucasWilkinson Jul 11, 2025
807de5e
undo format
LucasWilkinson Jul 11, 2025
c76e3da
fix pre-commit
LucasWilkinson Jul 11, 2025
6699fc8
remove unrelated format
LucasWilkinson Jul 14, 2025
b72b323
review comments
LucasWilkinson Jul 14, 2025
f89ac61
gate pinned memory
LucasWilkinson Jul 14, 2025
8778441
fix estimate
LucasWilkinson Jul 14, 2025
a1070d2
remove duplicates
LucasWilkinson Jul 15, 2025
32132f0
format
LucasWilkinson Jul 15, 2025
1291d2f
format
LucasWilkinson Jul 15, 2025
34d3386
fix FI prefill
LucasWilkinson Jul 16, 2025
b2dac6a
format
LucasWilkinson Jul 16, 2025
026cef6
format2
LucasWilkinson Jul 16, 2025
0748f22
minor cleanup
LucasWilkinson Jul 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
466 changes: 466 additions & 0 deletions tests/v1/attention/test_attention_backends.py

Large diffs are not rendered by default.

229 changes: 229 additions & 0 deletions tests/v1/attention/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Utility functions for attention-related v1 tests."""

from dataclasses import dataclass
from typing import Union

import pytest
import torch

from vllm.config import (CacheConfig, CompilationConfig, DeviceConfig,
LoadConfig, ModelConfig, ModelDType, ParallelConfig,
SchedulerConfig, VllmConfig)
from vllm.platforms import _Backend
from vllm.utils import resolve_obj_by_qualname
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import FullAttentionSpec


@dataclass
class BatchSpec:
"""Specification for a batch configuration (workload shape only)."""
seq_lens: list[int]
query_lens: list[int]

name: str = "unnamed"

@property
def batch_size(self):
return len(self.seq_lens)

def __post_init__(self):
assert len(self.seq_lens) == len(self.query_lens)

def compute_num_tokens(self):
return sum(self.query_lens)


def create_common_attn_metadata(
batch_spec: BatchSpec,
block_size: int,
device: torch.device,
max_block_idx: int = 1000) -> CommonAttentionMetadata:
"""Create CommonAttentionMetadata from a BatchSpec and ModelParams."""
# Create query start locations
query_start_loc = torch.zeros(batch_spec.batch_size + 1,
dtype=torch.int32,
device=device)
query_start_loc[1:] = torch.tensor(batch_spec.query_lens,
dtype=torch.int32,
device=device).cumsum(0)
query_start_loc_cpu = query_start_loc.cpu()
num_tokens = batch_spec.compute_num_tokens()

# Create sequence lengths
seq_lens = torch.tensor(batch_spec.seq_lens,
dtype=torch.int32,
device=device)
seq_lens_cpu = seq_lens.cpu()

# Create computed tokens (context length for each sequence)
context_lens = [
batch_spec.seq_lens[i] - batch_spec.query_lens[i]
for i in range(batch_spec.batch_size)
]
num_computed_tokens_cpu = torch.tensor(context_lens, dtype=torch.int32)

# Create block table (random for testing)
max_blocks = max(batch_spec.seq_lens) // block_size + 1
block_table_tensor = torch.randint(0,
max_block_idx,
(batch_spec.batch_size, max_blocks),
dtype=torch.int32,
device=device)

# Create slot mapping
slot_mapping = torch.randint(0,
max_block_idx, (num_tokens, ),
dtype=torch.int64,
device=device)

# Calculate max query length
max_query_len = max(batch_spec.query_lens)

return CommonAttentionMetadata(
query_start_loc=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=seq_lens,
seq_lens_cpu=seq_lens_cpu,
num_computed_tokens_cpu=num_computed_tokens_cpu,
num_reqs=batch_spec.batch_size,
num_actual_tokens=num_tokens,
max_query_len=max_query_len,
block_table_tensor=block_table_tensor,
slot_mapping=slot_mapping,
)


def get_attention_backend(backend_name: _Backend):
"""Set up attention backend classes for testing.
Args:
backend_name: Name of the backend ("flash_attn", "flashinfer", etc.)
vllm_config: VllmConfig instance
Returns:
Tuple of (backend_builder_class, backend_impl_class)
"""
backend_map = {
_Backend.FLASH_ATTN_VLLM_V1:
"vllm.v1.attention.backends.flash_attn.FlashAttentionBackend",
_Backend.FLASHINFER_VLLM_V1:
"vllm.v1.attention.backends.flashinfer.FlashInferBackend",
_Backend.FLEX_ATTENTION:
"vllm.v1.attention.backends.flex_attention.FlexAttentionBackend",
_Backend.TRITON_ATTN_VLLM_V1:
"vllm.v1.attention.backends.triton_attn.TritonAttentionBackend",
}

if backend_name not in backend_map:
raise ValueError(f"Unknown backend: {backend_name}")

backend_class_name = backend_map[backend_name]

try:
backend_class = resolve_obj_by_qualname(backend_class_name)
return backend_class.get_builder_cls(), backend_class.get_impl_cls()
except ImportError as e:
pytest.skip(f"{backend_name} not available: {e}")


def create_standard_kv_cache_spec(
vllm_config: VllmConfig) -> FullAttentionSpec:
"""Create a FullAttentionSpec from ModelParams only."""
return FullAttentionSpec(
block_size=vllm_config.cache_config.block_size,
num_kv_heads=vllm_config.model_config.get_num_kv_heads(
vllm_config.parallel_config),
head_size=vllm_config.model_config.get_head_size(),
dtype=vllm_config.model_config.dtype,
use_mla=vllm_config.model_config.use_mla,
sliding_window=vllm_config.model_config.get_sliding_window(),
)


def create_vllm_config(model_name: str = "meta-llama/Meta-Llama-3-8B",
tensor_parallel_size: int = 1,
max_model_len: int = 1024,
dtype: Union[ModelDType, torch.dtype] = "auto",
block_size: int = 16,
max_num_seqs: int = 256,
max_num_batched_tokens: int = 8192,
add_mock_model_methods: bool = True) -> VllmConfig:
"""Create a VllmConfig for testing with reasonable defaults."""

model_config = ModelConfig(
model=model_name,
tokenizer=model_name,
trust_remote_code=False,
dtype=dtype,
seed=0,
max_model_len=max_model_len,
)

cache_config = CacheConfig(
block_size=block_size,
cache_dtype="auto",
swap_space=0,
)
# Set cache blocks for testing
# (these may be set during initialization normally)
cache_config.num_gpu_blocks = 1000
cache_config.num_cpu_blocks = 0

parallel_config = ParallelConfig(
tensor_parallel_size=tensor_parallel_size, )

scheduler_config = SchedulerConfig(
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
)

device_config = DeviceConfig()
load_config = LoadConfig()
compilation_config = CompilationConfig()

if add_mock_model_methods:
# Add mock methods to satisfy backends that need them
# This is a workaround because tests don't build full, real models,
# but some backends expect to query the model for layer-specific
# parameters
import types
model_config.get_num_layers = types.MethodType(lambda self: 1,
model_config)
model_config.get_sliding_window_for_layer = types.MethodType(
lambda self, i: None, model_config)
model_config.get_logits_soft_cap_for_layer = types.MethodType(
lambda self, i: 0.0, model_config)
model_config.get_sm_scale_for_layer = types.MethodType(
lambda self, i: 1.0 / model_config.get_head_size()**0.5,
model_config)

return VllmConfig(
model_config=model_config,
cache_config=cache_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
device_config=device_config,
load_config=load_config,
compilation_config=compilation_config,
)


def create_dummy_kv_cache(block_size: int,
num_kv_heads: int,
head_size: int,
dtype: torch.dtype,
device: torch.device,
num_blocks: int = 100) -> torch.Tensor:
"""Create a dummy KV cache tensor for testing."""
kv_cache = torch.randn(
num_blocks,
2, # K and V
block_size,
num_kv_heads,
head_size,
dtype=dtype,
device=device)
return kv_cache
68 changes: 44 additions & 24 deletions tests/v1/spec_decode/test_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
import pytest
import torch

from tests.v1.attention.utils import (BatchSpec, _Backend,
create_common_attn_metadata,
create_standard_kv_cache_spec,
get_attention_backend)
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
VllmConfig)
Expand Down Expand Up @@ -64,13 +68,19 @@ def test_prepare_inputs():
"""
device = torch.device(current_platform.device_type)

# a = 4, b = 7, c = 5
# q1 = 4, q2 = 7, q3 = 5
# n1 = 1, n2 = 3, n3 = 2

# Cumulative lengths: [0, 4, 11, 16]
cu_target_query_lens = torch.tensor([0, 4, 11, 16],
dtype=torch.int32,
device=device)
batch_spec = BatchSpec(
seq_lens=[4, 7, 5],
query_lens=[4, 7, 5],
)

common_attn_metadata = create_common_attn_metadata(
batch_spec,
block_size=16,
device=device,
)

# Rejected tokens per request: [1, 3, 2]
num_rejected_tokens = torch.tensor([1, 3, 2],
Expand Down Expand Up @@ -104,15 +114,13 @@ def test_prepare_inputs():
],
dtype=torch.int32,
device=device)
proposer = _create_proposer("eagle", 1)

# n1 + n2 + n3 - a - b -c
num_tokens = cu_target_query_lens[-1].item() - num_rejected_tokens.sum(
).item()
updated_metadata, token_indices = proposer.prepare_inputs(
common_attn_metadata, num_rejected_tokens.cpu())

cu_num_tokens, token_indices = EagleProposer.prepare_inputs(
cu_target_query_lens, num_rejected_tokens, num_tokens)

assert torch.equal(cu_num_tokens, expected_cu_num_tokens)
assert torch.equal(updated_metadata.query_start_loc,
expected_cu_num_tokens)
assert token_indices.shape[0] == expected_cu_num_tokens[-1].item()
assert torch.equal(token_indices, expected_token_indices)

Expand Down Expand Up @@ -209,6 +217,7 @@ def test_propose(num_speculative_tokens):
seq_len_2 = 3
total_tokens = seq_len_1 + seq_len_2
vocab_size = 100
seq_lens = [seq_len_1, seq_len_2]

# Create proposer first so we can use its actual hidden_size
proposer = _create_proposer("eagle", num_speculative_tokens)
Expand Down Expand Up @@ -270,9 +279,16 @@ def create_deterministic_logits(token_ids):
proposer.attn_layer_names = ["layer.0"]

# Create input tensors
cu_num_tokens = torch.tensor([0, seq_len_1, total_tokens],
dtype=torch.int32,
device=device)
batch_spec = BatchSpec(
seq_lens=seq_lens,
query_lens=seq_lens,
)

common_attn_metadata = create_common_attn_metadata(
batch_spec,
block_size=16,
device=device,
)

target_token_ids = torch.randint(0,
vocab_size, (total_tokens, ),
Expand All @@ -284,25 +300,29 @@ def create_deterministic_logits(token_ids):
target_hidden_states = torch.randn(total_tokens,
hidden_size,
device=device)
target_slot_mapping = torch.randint(0,
100, (total_tokens, ),
device=device)
next_token_ids = torch.randint(0,
vocab_size, (batch_size, ),
dtype=torch.int32,
device=device)
block_table = torch.randint(0, 10, (batch_size, 10), device=device)

sampling_metadata = mock.MagicMock()

# Call the method under test
attn_metadata_builder_cls, _ = get_attention_backend(
_Backend.FLASH_ATTN_VLLM_V1)
attn_metadata_builder = attn_metadata_builder_cls(
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
vllm_config=proposer.vllm_config,
device=device,
)

# Mock runner for attention metadata building
proposer.runner = mock.MagicMock()
proposer.runner.attn_metadata_builders = [attn_metadata_builder]

result = proposer.propose(target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
target_slot_mapping=target_slot_mapping,
next_token_ids=next_token_ids,
cu_num_tokens=cu_num_tokens,
block_table=block_table,
common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata)

assert result.shape == (batch_size, num_speculative_tokens)
Expand Down
Loading