Skip to content

Commit 76b4944

Browse files
[Attention] Refactor attention metadata builder interface (#20466)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
1 parent 28a6d54 commit 76b4944

File tree

18 files changed

+1447
-778
lines changed

18 files changed

+1447
-778
lines changed

tests/v1/attention/test_attention_backends.py

Lines changed: 466 additions & 0 deletions
Large diffs are not rendered by default.

tests/v1/attention/utils.py

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""Utility functions for attention-related v1 tests."""
4+
5+
from dataclasses import dataclass
6+
from typing import Union
7+
8+
import pytest
9+
import torch
10+
11+
from vllm.config import (CacheConfig, CompilationConfig, DeviceConfig,
12+
LoadConfig, ModelConfig, ModelDType, ParallelConfig,
13+
SchedulerConfig, VllmConfig)
14+
from vllm.platforms import _Backend
15+
from vllm.utils import resolve_obj_by_qualname
16+
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
17+
from vllm.v1.kv_cache_interface import FullAttentionSpec
18+
19+
20+
@dataclass
21+
class BatchSpec:
22+
"""Specification for a batch configuration (workload shape only)."""
23+
seq_lens: list[int]
24+
query_lens: list[int]
25+
26+
name: str = "unnamed"
27+
28+
@property
29+
def batch_size(self):
30+
return len(self.seq_lens)
31+
32+
def __post_init__(self):
33+
assert len(self.seq_lens) == len(self.query_lens)
34+
35+
def compute_num_tokens(self):
36+
return sum(self.query_lens)
37+
38+
39+
def create_common_attn_metadata(
40+
batch_spec: BatchSpec,
41+
block_size: int,
42+
device: torch.device,
43+
max_block_idx: int = 1000) -> CommonAttentionMetadata:
44+
"""Create CommonAttentionMetadata from a BatchSpec and ModelParams."""
45+
# Create query start locations
46+
query_start_loc = torch.zeros(batch_spec.batch_size + 1,
47+
dtype=torch.int32,
48+
device=device)
49+
query_start_loc[1:] = torch.tensor(batch_spec.query_lens,
50+
dtype=torch.int32,
51+
device=device).cumsum(0)
52+
query_start_loc_cpu = query_start_loc.cpu()
53+
num_tokens = batch_spec.compute_num_tokens()
54+
55+
# Create sequence lengths
56+
seq_lens = torch.tensor(batch_spec.seq_lens,
57+
dtype=torch.int32,
58+
device=device)
59+
seq_lens_cpu = seq_lens.cpu()
60+
61+
# Create computed tokens (context length for each sequence)
62+
context_lens = [
63+
batch_spec.seq_lens[i] - batch_spec.query_lens[i]
64+
for i in range(batch_spec.batch_size)
65+
]
66+
num_computed_tokens_cpu = torch.tensor(context_lens, dtype=torch.int32)
67+
68+
# Create block table (random for testing)
69+
max_blocks = max(batch_spec.seq_lens) // block_size + 1
70+
block_table_tensor = torch.randint(0,
71+
max_block_idx,
72+
(batch_spec.batch_size, max_blocks),
73+
dtype=torch.int32,
74+
device=device)
75+
76+
# Create slot mapping
77+
slot_mapping = torch.randint(0,
78+
max_block_idx, (num_tokens, ),
79+
dtype=torch.int64,
80+
device=device)
81+
82+
# Calculate max query length
83+
max_query_len = max(batch_spec.query_lens)
84+
85+
return CommonAttentionMetadata(
86+
query_start_loc=query_start_loc,
87+
query_start_loc_cpu=query_start_loc_cpu,
88+
seq_lens=seq_lens,
89+
seq_lens_cpu=seq_lens_cpu,
90+
num_computed_tokens_cpu=num_computed_tokens_cpu,
91+
num_reqs=batch_spec.batch_size,
92+
num_actual_tokens=num_tokens,
93+
max_query_len=max_query_len,
94+
block_table_tensor=block_table_tensor,
95+
slot_mapping=slot_mapping,
96+
)
97+
98+
99+
def get_attention_backend(backend_name: _Backend):
100+
"""Set up attention backend classes for testing.
101+
102+
Args:
103+
backend_name: Name of the backend ("flash_attn", "flashinfer", etc.)
104+
vllm_config: VllmConfig instance
105+
106+
Returns:
107+
Tuple of (backend_builder_class, backend_impl_class)
108+
"""
109+
backend_map = {
110+
_Backend.FLASH_ATTN_VLLM_V1:
111+
"vllm.v1.attention.backends.flash_attn.FlashAttentionBackend",
112+
_Backend.FLASHINFER_VLLM_V1:
113+
"vllm.v1.attention.backends.flashinfer.FlashInferBackend",
114+
_Backend.FLEX_ATTENTION:
115+
"vllm.v1.attention.backends.flex_attention.FlexAttentionBackend",
116+
_Backend.TRITON_ATTN_VLLM_V1:
117+
"vllm.v1.attention.backends.triton_attn.TritonAttentionBackend",
118+
}
119+
120+
if backend_name not in backend_map:
121+
raise ValueError(f"Unknown backend: {backend_name}")
122+
123+
backend_class_name = backend_map[backend_name]
124+
125+
try:
126+
backend_class = resolve_obj_by_qualname(backend_class_name)
127+
return backend_class.get_builder_cls(), backend_class.get_impl_cls()
128+
except ImportError as e:
129+
pytest.skip(f"{backend_name} not available: {e}")
130+
131+
132+
def create_standard_kv_cache_spec(
133+
vllm_config: VllmConfig) -> FullAttentionSpec:
134+
"""Create a FullAttentionSpec from ModelParams only."""
135+
return FullAttentionSpec(
136+
block_size=vllm_config.cache_config.block_size,
137+
num_kv_heads=vllm_config.model_config.get_num_kv_heads(
138+
vllm_config.parallel_config),
139+
head_size=vllm_config.model_config.get_head_size(),
140+
dtype=vllm_config.model_config.dtype,
141+
use_mla=vllm_config.model_config.use_mla,
142+
sliding_window=vllm_config.model_config.get_sliding_window(),
143+
)
144+
145+
146+
def create_vllm_config(model_name: str = "meta-llama/Meta-Llama-3-8B",
147+
tensor_parallel_size: int = 1,
148+
max_model_len: int = 1024,
149+
dtype: Union[ModelDType, torch.dtype] = "auto",
150+
block_size: int = 16,
151+
max_num_seqs: int = 256,
152+
max_num_batched_tokens: int = 8192,
153+
add_mock_model_methods: bool = True) -> VllmConfig:
154+
"""Create a VllmConfig for testing with reasonable defaults."""
155+
156+
model_config = ModelConfig(
157+
model=model_name,
158+
tokenizer=model_name,
159+
trust_remote_code=False,
160+
dtype=dtype,
161+
seed=0,
162+
max_model_len=max_model_len,
163+
)
164+
165+
cache_config = CacheConfig(
166+
block_size=block_size,
167+
cache_dtype="auto",
168+
swap_space=0,
169+
)
170+
# Set cache blocks for testing
171+
# (these may be set during initialization normally)
172+
cache_config.num_gpu_blocks = 1000
173+
cache_config.num_cpu_blocks = 0
174+
175+
parallel_config = ParallelConfig(
176+
tensor_parallel_size=tensor_parallel_size, )
177+
178+
scheduler_config = SchedulerConfig(
179+
max_num_seqs=max_num_seqs,
180+
max_num_batched_tokens=max_num_batched_tokens,
181+
)
182+
183+
device_config = DeviceConfig()
184+
load_config = LoadConfig()
185+
compilation_config = CompilationConfig()
186+
187+
if add_mock_model_methods:
188+
# Add mock methods to satisfy backends that need them
189+
# This is a workaround because tests don't build full, real models,
190+
# but some backends expect to query the model for layer-specific
191+
# parameters
192+
import types
193+
model_config.get_num_layers = types.MethodType(lambda self: 1,
194+
model_config)
195+
model_config.get_sliding_window_for_layer = types.MethodType(
196+
lambda self, i: None, model_config)
197+
model_config.get_logits_soft_cap_for_layer = types.MethodType(
198+
lambda self, i: 0.0, model_config)
199+
model_config.get_sm_scale_for_layer = types.MethodType(
200+
lambda self, i: 1.0 / model_config.get_head_size()**0.5,
201+
model_config)
202+
203+
return VllmConfig(
204+
model_config=model_config,
205+
cache_config=cache_config,
206+
parallel_config=parallel_config,
207+
scheduler_config=scheduler_config,
208+
device_config=device_config,
209+
load_config=load_config,
210+
compilation_config=compilation_config,
211+
)
212+
213+
214+
def create_dummy_kv_cache(block_size: int,
215+
num_kv_heads: int,
216+
head_size: int,
217+
dtype: torch.dtype,
218+
device: torch.device,
219+
num_blocks: int = 100) -> torch.Tensor:
220+
"""Create a dummy KV cache tensor for testing."""
221+
kv_cache = torch.randn(
222+
num_blocks,
223+
2, # K and V
224+
block_size,
225+
num_kv_heads,
226+
head_size,
227+
dtype=dtype,
228+
device=device)
229+
return kv_cache

tests/v1/spec_decode/test_eagle.py

Lines changed: 44 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66
import pytest
77
import torch
88

9+
from tests.v1.attention.utils import (BatchSpec, _Backend,
10+
create_common_attn_metadata,
11+
create_standard_kv_cache_spec,
12+
get_attention_backend)
913
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
1014
ParallelConfig, SchedulerConfig, SpeculativeConfig,
1115
VllmConfig)
@@ -64,13 +68,19 @@ def test_prepare_inputs():
6468
"""
6569
device = torch.device(current_platform.device_type)
6670

67-
# a = 4, b = 7, c = 5
71+
# q1 = 4, q2 = 7, q3 = 5
6872
# n1 = 1, n2 = 3, n3 = 2
6973

70-
# Cumulative lengths: [0, 4, 11, 16]
71-
cu_target_query_lens = torch.tensor([0, 4, 11, 16],
72-
dtype=torch.int32,
73-
device=device)
74+
batch_spec = BatchSpec(
75+
seq_lens=[4, 7, 5],
76+
query_lens=[4, 7, 5],
77+
)
78+
79+
common_attn_metadata = create_common_attn_metadata(
80+
batch_spec,
81+
block_size=16,
82+
device=device,
83+
)
7484

7585
# Rejected tokens per request: [1, 3, 2]
7686
num_rejected_tokens = torch.tensor([1, 3, 2],
@@ -104,15 +114,13 @@ def test_prepare_inputs():
104114
],
105115
dtype=torch.int32,
106116
device=device)
117+
proposer = _create_proposer("eagle", 1)
107118

108-
# n1 + n2 + n3 - a - b -c
109-
num_tokens = cu_target_query_lens[-1].item() - num_rejected_tokens.sum(
110-
).item()
119+
updated_metadata, token_indices = proposer.prepare_inputs(
120+
common_attn_metadata, num_rejected_tokens.cpu())
111121

112-
cu_num_tokens, token_indices = EagleProposer.prepare_inputs(
113-
cu_target_query_lens, num_rejected_tokens, num_tokens)
114-
115-
assert torch.equal(cu_num_tokens, expected_cu_num_tokens)
122+
assert torch.equal(updated_metadata.query_start_loc,
123+
expected_cu_num_tokens)
116124
assert token_indices.shape[0] == expected_cu_num_tokens[-1].item()
117125
assert torch.equal(token_indices, expected_token_indices)
118126

@@ -209,6 +217,7 @@ def test_propose(num_speculative_tokens):
209217
seq_len_2 = 3
210218
total_tokens = seq_len_1 + seq_len_2
211219
vocab_size = 100
220+
seq_lens = [seq_len_1, seq_len_2]
212221

213222
# Create proposer first so we can use its actual hidden_size
214223
proposer = _create_proposer("eagle", num_speculative_tokens)
@@ -270,9 +279,16 @@ def create_deterministic_logits(token_ids):
270279
proposer.attn_layer_names = ["layer.0"]
271280

272281
# Create input tensors
273-
cu_num_tokens = torch.tensor([0, seq_len_1, total_tokens],
274-
dtype=torch.int32,
275-
device=device)
282+
batch_spec = BatchSpec(
283+
seq_lens=seq_lens,
284+
query_lens=seq_lens,
285+
)
286+
287+
common_attn_metadata = create_common_attn_metadata(
288+
batch_spec,
289+
block_size=16,
290+
device=device,
291+
)
276292

277293
target_token_ids = torch.randint(0,
278294
vocab_size, (total_tokens, ),
@@ -284,25 +300,29 @@ def create_deterministic_logits(token_ids):
284300
target_hidden_states = torch.randn(total_tokens,
285301
hidden_size,
286302
device=device)
287-
target_slot_mapping = torch.randint(0,
288-
100, (total_tokens, ),
289-
device=device)
290303
next_token_ids = torch.randint(0,
291304
vocab_size, (batch_size, ),
292305
dtype=torch.int32,
293306
device=device)
294-
block_table = torch.randint(0, 10, (batch_size, 10), device=device)
295-
296307
sampling_metadata = mock.MagicMock()
297308

298-
# Call the method under test
309+
attn_metadata_builder_cls, _ = get_attention_backend(
310+
_Backend.FLASH_ATTN_VLLM_V1)
311+
attn_metadata_builder = attn_metadata_builder_cls(
312+
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
313+
vllm_config=proposer.vllm_config,
314+
device=device,
315+
)
316+
317+
# Mock runner for attention metadata building
318+
proposer.runner = mock.MagicMock()
319+
proposer.runner.attn_metadata_builders = [attn_metadata_builder]
320+
299321
result = proposer.propose(target_token_ids=target_token_ids,
300322
target_positions=target_positions,
301323
target_hidden_states=target_hidden_states,
302-
target_slot_mapping=target_slot_mapping,
303324
next_token_ids=next_token_ids,
304-
cu_num_tokens=cu_num_tokens,
305-
block_table=block_table,
325+
common_attn_metadata=common_attn_metadata,
306326
sampling_metadata=sampling_metadata)
307327

308328
assert result.shape == (batch_size, num_speculative_tokens)

0 commit comments

Comments
 (0)