Skip to content

Commit 35c2d71

Browse files
review comments
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
1 parent 834b496 commit 35c2d71

File tree

3 files changed

+40
-50
lines changed

3 files changed

+40
-50
lines changed

tests/v1/attention/test_attention_backends.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -44,32 +44,27 @@ def _convert_dtype_to_torch(dtype):
4444
# Define common batch configurations
4545
BATCH_SPECS = {
4646
"small_decode":
47-
BatchSpec(batch_size=2, seq_lens=[32, 40], query_lens=[1, 1]),
47+
BatchSpec(seq_lens=[32, 40], query_lens=[1, 1]),
4848
"small_prefill":
49-
BatchSpec(batch_size=2, seq_lens=[32, 40], query_lens=[8, 8]),
49+
BatchSpec(seq_lens=[32, 40], query_lens=[8, 8]),
5050
"mixed_small":
51-
BatchSpec(batch_size=4, seq_lens=[32, 40, 48, 56], query_lens=[1, 1, 5,
52-
5]),
51+
BatchSpec(seq_lens=[32, 40, 48, 56], query_lens=[1, 1, 5, 5]),
5352
"medium_decode":
54-
BatchSpec(batch_size=8,
55-
seq_lens=[128, 256, 512, 1024, 128, 256, 512, 1024],
53+
BatchSpec(seq_lens=[128, 256, 512, 1024, 128, 256, 512, 1024],
5654
query_lens=[1, 1, 1, 1, 1, 1, 1, 1]),
5755
"medium_prefill":
58-
BatchSpec(batch_size=4,
59-
seq_lens=[256, 512, 1024, 2048],
60-
query_lens=[16, 16, 16, 16]),
56+
BatchSpec(seq_lens=[256, 512, 1024, 2048], query_lens=[16, 16, 16, 16]),
6157
"mixed_medium":
62-
BatchSpec(batch_size=6,
63-
seq_lens=[512, 1024, 2048, 512, 1024, 2048],
58+
BatchSpec(seq_lens=[512, 1024, 2048, 512, 1024, 2048],
6459
query_lens=[1, 1, 1, 7, 7, 7]),
6560
"large_decode":
66-
BatchSpec(batch_size=32, seq_lens=[2048] * 32, query_lens=[1] * 32),
61+
BatchSpec(seq_lens=[2048] * 32, query_lens=[1] * 32),
6762
"large_prefill":
68-
BatchSpec(batch_size=8, seq_lens=[4096] * 8, query_lens=[32] * 8),
63+
BatchSpec(seq_lens=[4096] * 8, query_lens=[32] * 8),
6964
"single_decode":
70-
BatchSpec(batch_size=1, seq_lens=[1024], query_lens=[1]),
65+
BatchSpec(seq_lens=[1024], query_lens=[1]),
7166
"single_prefill":
72-
BatchSpec(batch_size=1, seq_lens=[1024], query_lens=[64]),
67+
BatchSpec(seq_lens=[1024], query_lens=[64]),
7368
}
7469

7570

tests/v1/attention/utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,17 @@
2020
@dataclass
2121
class BatchSpec:
2222
"""Specification for a batch configuration (workload shape only)."""
23-
batch_size: int
2423
seq_lens: list[int]
2524
query_lens: list[int]
2625

2726
name: str = "unnamed"
2827

28+
@property
29+
def batch_size(self):
30+
return len(self.seq_lens)
31+
2932
def __post_init__(self):
30-
assert len(self.seq_lens) == self.batch_size
31-
assert len(self.query_lens) == self.batch_size
33+
assert len(self.seq_lens) == len(self.query_lens)
3234

3335
def compute_num_tokens(self):
3436
return sum(self.query_lens)

tests/v1/spec_decode/test_eagle.py

Lines changed: 25 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
import pytest
77
import torch
88

9-
from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata,
9+
from tests.v1.attention.utils import (BatchSpec, _Backend,
10+
create_common_attn_metadata,
1011
create_standard_kv_cache_spec,
1112
get_attention_backend)
1213
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
@@ -56,31 +57,6 @@ def _create_proposer(method: str, k: int) -> EagleProposer:
5657
device=current_platform.device_type)
5758

5859

59-
def _create_common_attn_metadata(
60-
cu_target_query_lens: torch.Tensor,
61-
device: torch.device) -> CommonAttentionMetadata:
62-
"""Create minimal CommonAttentionMetadata for testing."""
63-
batch_size = cu_target_query_lens.shape[0] - 1
64-
num_tokens = cu_target_query_lens[-1].item()
65-
seq_lens = cu_target_query_lens[1:] - cu_target_query_lens[:-1]
66-
67-
return CommonAttentionMetadata(
68-
query_start_loc=cu_target_query_lens,
69-
query_start_loc_cpu=cu_target_query_lens.cpu(),
70-
seq_lens=seq_lens,
71-
seq_lens_cpu=seq_lens.cpu(),
72-
num_computed_tokens_cpu=seq_lens.cpu(),
73-
num_reqs=batch_size,
74-
num_actual_tokens=int(num_tokens),
75-
max_query_len=int(seq_lens.max().item()),
76-
block_table_tensor=torch.zeros((batch_size, 1),
77-
dtype=torch.int32,
78-
device=device),
79-
slot_mapping=torch.arange(num_tokens, dtype=torch.int64,
80-
device=device),
81-
)
82-
83-
8460
def test_prepare_inputs():
8561
"""
8662
cu_target_query_lens: [0, a, a + b, a + b + c]
@@ -97,7 +73,6 @@ def test_prepare_inputs():
9773
# n1 = 1, n2 = 3, n3 = 2
9874

9975
batch_spec = BatchSpec(
100-
batch_size=4,
10176
seq_lens=[4, 7, 5],
10277
query_lens=[4, 7, 5],
10378
)
@@ -324,9 +299,28 @@ def create_deterministic_logits(token_ids):
324299
device=device)
325300
sampling_metadata = mock.MagicMock()
326301

327-
# Create CommonAttentionMetadata for new API
328-
common_attn_metadata = _create_common_attn_metadata(cu_num_tokens, device)
329-
attn_metadata_builder_cls, _ = get_attention_backend("flash_attn")
302+
batch_size = cu_num_tokens.shape[0] - 1
303+
num_tokens = cu_num_tokens[-1].item()
304+
seq_lens = cu_num_tokens[1:] - cu_num_tokens[:-1]
305+
306+
common_attn_metadata = CommonAttentionMetadata(
307+
query_start_loc=cu_num_tokens,
308+
query_start_loc_cpu=cu_num_tokens.cpu(),
309+
seq_lens=seq_lens,
310+
seq_lens_cpu=seq_lens.cpu(),
311+
num_computed_tokens_cpu=seq_lens.cpu(),
312+
num_reqs=batch_size,
313+
num_actual_tokens=int(num_tokens),
314+
max_query_len=int(seq_lens.max().item()),
315+
block_table_tensor=torch.zeros((batch_size, 1),
316+
dtype=torch.int32,
317+
device=device),
318+
slot_mapping=torch.arange(num_tokens, dtype=torch.int64,
319+
device=device),
320+
)
321+
322+
attn_metadata_builder_cls, _ = get_attention_backend(
323+
_Backend.FLASH_ATTN_VLLM_V1)
330324
attn_metadata_builder = attn_metadata_builder_cls(
331325
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
332326
vllm_config=proposer.vllm_config,
@@ -335,8 +329,7 @@ def create_deterministic_logits(token_ids):
335329

336330
# Mock runner for attention metadata building
337331
proposer.runner = mock.MagicMock()
338-
proposer.runner.attn_metadata_builders = [mock.MagicMock()]
339-
proposer.runner.attn_metadata_builders[0] = attn_metadata_builder
332+
proposer.runner.attn_metadata_builders = [attn_metadata_builder]
340333

341334
result = proposer.propose(target_token_ids=target_token_ids,
342335
target_positions=target_positions,

0 commit comments

Comments
 (0)