Skip to content

Commit 822ca6b

Browse files
cleanup eagle tests
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
1 parent 6bce371 commit 822ca6b

File tree

2 files changed

+162
-29
lines changed

2 files changed

+162
-29
lines changed

tests/v1/attention/utils.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""Utility functions for attention-related v1 tests."""
4+
5+
from dataclasses import dataclass
6+
7+
import pytest
8+
import torch
9+
10+
from vllm.config import VllmConfig
11+
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
12+
from vllm.v1.kv_cache_interface import FullAttentionSpec
13+
14+
15+
@dataclass
16+
class BatchSpec:
17+
"""Specification for a batch configuration (workload shape only)."""
18+
batch_size: int
19+
seq_lens: list[int]
20+
query_lens: list[int]
21+
22+
name: str = "unnamed"
23+
24+
def __post_init__(self):
25+
assert len(self.seq_lens) == self.batch_size
26+
assert len(self.query_lens) == self.batch_size
27+
28+
def compute_num_tokens(self):
29+
return sum(self.seq_lens)
30+
31+
32+
def create_common_attn_metadata(
33+
batch_spec: BatchSpec,
34+
block_size: int,
35+
device: torch.device,
36+
max_block_idx: int = 1000) -> CommonAttentionMetadata:
37+
"""Create CommonAttentionMetadata from a BatchSpec and ModelParams."""
38+
# Create query start locations
39+
query_start_loc = torch.zeros(batch_spec.batch_size + 1,
40+
dtype=torch.int32,
41+
device=device)
42+
query_start_loc[1:] = torch.tensor(batch_spec.query_lens,
43+
dtype=torch.int32,
44+
device=device).cumsum(0)
45+
query_start_loc_cpu = query_start_loc.cpu()
46+
num_tokens = batch_spec.compute_num_tokens()
47+
48+
# Create sequence lengths
49+
seq_lens = torch.tensor(batch_spec.seq_lens,
50+
dtype=torch.int32,
51+
device=device)
52+
seq_lens_cpu = seq_lens.cpu()
53+
54+
# Create computed tokens (assume all tokens are computed for simplicity)
55+
num_computed_tokens_cpu = seq_lens_cpu.clone()
56+
57+
# Create block table (random for testing)
58+
max_blocks = max(batch_spec.seq_lens) // block_size + 1
59+
block_table_tensor = torch.randint(0,
60+
max_block_idx,
61+
(batch_spec.batch_size, max_blocks),
62+
dtype=torch.int32,
63+
device=device)
64+
65+
# Create slot mapping
66+
slot_mapping = torch.randint(0,
67+
max_block_idx, (num_tokens, ),
68+
dtype=torch.int64,
69+
device=device)
70+
71+
# Calculate max query length
72+
max_query_len = max(batch_spec.query_lens)
73+
74+
return CommonAttentionMetadata(
75+
query_start_loc=query_start_loc,
76+
query_start_loc_cpu=query_start_loc_cpu,
77+
seq_lens=seq_lens,
78+
seq_lens_cpu=seq_lens_cpu,
79+
num_computed_tokens_cpu=num_computed_tokens_cpu,
80+
num_reqs=batch_spec.batch_size,
81+
num_actual_tokens=num_tokens,
82+
max_query_len=max_query_len,
83+
block_table_tensor=block_table_tensor,
84+
slot_mapping=slot_mapping,
85+
)
86+
87+
88+
def get_attention_backend(backend_name: str):
89+
"""Set up attention backend classes for testing.
90+
91+
Args:
92+
backend_name: Name of the backend ("flash_attn", "flashinfer", etc.)
93+
vllm_config: VllmConfig instance
94+
95+
Returns:
96+
Tuple of (backend_builder_class, backend_impl_class)
97+
"""
98+
backend_map = {
99+
"flash_attn":
100+
("vllm.v1.attention.backends.flash_attn", "FlashAttentionBackend"),
101+
"flashinfer":
102+
("vllm.v1.attention.backends.flashinfer", "FlashInferBackend"),
103+
"flex_attention":
104+
("vllm.v1.attention.backends.flex_attention", "FlexAttentionBackend"),
105+
}
106+
107+
if backend_name not in backend_map:
108+
raise ValueError(f"Unknown backend: {backend_name}")
109+
110+
module_name, backend_class_name = backend_map[backend_name]
111+
112+
try:
113+
import importlib
114+
module = importlib.import_module(module_name)
115+
backend_class = getattr(module, backend_class_name)
116+
return backend_class.get_builder_cls(), backend_class.get_impl_cls()
117+
except ImportError as e:
118+
pytest.skip(f"{backend_name} not available: {e}")
119+
120+
121+
def create_standard_kv_cache_spec(
122+
vllm_config: VllmConfig) -> FullAttentionSpec:
123+
"""Create a FullAttentionSpec from ModelParams only."""
124+
return FullAttentionSpec(
125+
block_size=vllm_config.cache_config.block_size,
126+
num_kv_heads=vllm_config.model_config.get_num_attention_heads(
127+
vllm_config.parallel_config),
128+
head_size=vllm_config.model_config.get_head_size(),
129+
dtype=vllm_config.model_config.dtype,
130+
use_mla=vllm_config.model_config.use_mla,
131+
sliding_window=vllm_config.model_config.get_sliding_window(),
132+
)

tests/v1/spec_decode/test_eagle.py

Lines changed: 30 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
import pytest
77
import torch
88

9+
from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata,
10+
create_standard_kv_cache_spec,
11+
get_attention_backend)
912
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
1013
ParallelConfig, SchedulerConfig, SpeculativeConfig,
1114
VllmConfig)
@@ -90,13 +93,20 @@ def test_prepare_inputs():
9093
"""
9194
device = torch.device(current_platform.device_type)
9295

93-
# a = 4, b = 7, c = 5
96+
# q1 = 4, q2 = 7, q3 = 5
9497
# n1 = 1, n2 = 3, n3 = 2
9598

96-
# Cumulative lengths: [0, 4, 11, 16]
97-
cu_target_query_lens = torch.tensor([0, 4, 11, 16],
98-
dtype=torch.int32,
99-
device=device)
99+
batch_spec = BatchSpec(
100+
batch_size=4,
101+
seq_lens=[4, 7, 5],
102+
query_lens=[4, 7, 5],
103+
)
104+
105+
common_attn_metadata = create_common_attn_metadata(
106+
batch_spec,
107+
block_size=16,
108+
device=device,
109+
)
100110

101111
# Rejected tokens per request: [1, 3, 2]
102112
num_rejected_tokens = torch.tensor([1, 3, 2],
@@ -130,18 +140,10 @@ def test_prepare_inputs():
130140
],
131141
dtype=torch.int32,
132142
device=device)
133-
134-
# n1 + n2 + n3 - a - b -c
135-
num_tokens = int(cu_target_query_lens[-1].item() -
136-
num_rejected_tokens.sum().item())
137-
138-
# Create CommonAttentionMetadata for new API
139-
common_attn_metadata = _create_common_attn_metadata(
140-
cu_target_query_lens, device)
141143
proposer = _create_proposer("eagle", 1)
142144

143145
updated_metadata, token_indices = proposer.prepare_inputs(
144-
common_attn_metadata, num_rejected_tokens.cpu(), num_tokens)
146+
common_attn_metadata, num_rejected_tokens.cpu())
145147

146148
assert torch.equal(updated_metadata.query_start_loc,
147149
expected_cu_num_tokens)
@@ -324,25 +326,24 @@ def create_deterministic_logits(token_ids):
324326

325327
# Create CommonAttentionMetadata for new API
326328
common_attn_metadata = _create_common_attn_metadata(cu_num_tokens, device)
329+
attn_metadata_builder_cls, _ = get_attention_backend("flash_attn")
330+
attn_metadata_builder = attn_metadata_builder_cls(
331+
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
332+
vllm_config=proposer.vllm_config,
333+
device=device,
334+
)
327335

328336
# Mock runner for attention metadata building
329337
proposer.runner = mock.MagicMock()
330338
proposer.runner.attn_metadata_builders = [mock.MagicMock()]
331-
332-
# Create mock with required attributes for multi-token tests
333-
attn_metadata_mock = mock.MagicMock()
334-
attn_metadata_mock.max_seq_len = 10
335-
attn_metadata_mock.seq_lens = torch.tensor([5, 3], device=device)
336-
proposer.runner.attn_metadata_builders[
337-
0].build.return_value = attn_metadata_mock
338-
339-
with mock.patch('vllm.v1.spec_decode.eagle.isinstance', return_value=True):
340-
result = proposer.propose(target_token_ids=target_token_ids,
341-
target_positions=target_positions,
342-
target_hidden_states=target_hidden_states,
343-
next_token_ids=next_token_ids,
344-
common_attn_metadata=common_attn_metadata,
345-
sampling_metadata=sampling_metadata)
339+
proposer.runner.attn_metadata_builders[0] = attn_metadata_builder
340+
341+
result = proposer.propose(target_token_ids=target_token_ids,
342+
target_positions=target_positions,
343+
target_hidden_states=target_hidden_states,
344+
next_token_ids=next_token_ids,
345+
common_attn_metadata=common_attn_metadata,
346+
sampling_metadata=sampling_metadata)
346347

347348
assert result.shape == (batch_size, num_speculative_tokens)
348349

0 commit comments

Comments
 (0)