6
6
import pytest
7
7
import torch
8
8
9
- from tests .v1 .attention .utils import (BatchSpec , create_common_attn_metadata ,
9
+ from tests .v1 .attention .utils import (BatchSpec , _Backend ,
10
+ create_common_attn_metadata ,
10
11
create_standard_kv_cache_spec ,
11
12
get_attention_backend )
12
13
from vllm .config import (CacheConfig , DeviceConfig , LoadConfig , ModelConfig ,
@@ -56,31 +57,6 @@ def _create_proposer(method: str, k: int) -> EagleProposer:
56
57
device = current_platform .device_type )
57
58
58
59
59
- def _create_common_attn_metadata (
60
- cu_target_query_lens : torch .Tensor ,
61
- device : torch .device ) -> CommonAttentionMetadata :
62
- """Create minimal CommonAttentionMetadata for testing."""
63
- batch_size = cu_target_query_lens .shape [0 ] - 1
64
- num_tokens = cu_target_query_lens [- 1 ].item ()
65
- seq_lens = cu_target_query_lens [1 :] - cu_target_query_lens [:- 1 ]
66
-
67
- return CommonAttentionMetadata (
68
- query_start_loc = cu_target_query_lens ,
69
- query_start_loc_cpu = cu_target_query_lens .cpu (),
70
- seq_lens = seq_lens ,
71
- seq_lens_cpu = seq_lens .cpu (),
72
- num_computed_tokens_cpu = seq_lens .cpu (),
73
- num_reqs = batch_size ,
74
- num_actual_tokens = int (num_tokens ),
75
- max_query_len = int (seq_lens .max ().item ()),
76
- block_table_tensor = torch .zeros ((batch_size , 1 ),
77
- dtype = torch .int32 ,
78
- device = device ),
79
- slot_mapping = torch .arange (num_tokens , dtype = torch .int64 ,
80
- device = device ),
81
- )
82
-
83
-
84
60
def test_prepare_inputs ():
85
61
"""
86
62
cu_target_query_lens: [0, a, a + b, a + b + c]
@@ -97,7 +73,6 @@ def test_prepare_inputs():
97
73
# n1 = 1, n2 = 3, n3 = 2
98
74
99
75
batch_spec = BatchSpec (
100
- batch_size = 4 ,
101
76
seq_lens = [4 , 7 , 5 ],
102
77
query_lens = [4 , 7 , 5 ],
103
78
)
@@ -324,9 +299,28 @@ def create_deterministic_logits(token_ids):
324
299
device = device )
325
300
sampling_metadata = mock .MagicMock ()
326
301
327
- # Create CommonAttentionMetadata for new API
328
- common_attn_metadata = _create_common_attn_metadata (cu_num_tokens , device )
329
- attn_metadata_builder_cls , _ = get_attention_backend ("flash_attn" )
302
+ batch_size = cu_num_tokens .shape [0 ] - 1
303
+ num_tokens = cu_num_tokens [- 1 ].item ()
304
+ seq_lens = cu_num_tokens [1 :] - cu_num_tokens [:- 1 ]
305
+
306
+ common_attn_metadata = CommonAttentionMetadata (
307
+ query_start_loc = cu_num_tokens ,
308
+ query_start_loc_cpu = cu_num_tokens .cpu (),
309
+ seq_lens = seq_lens ,
310
+ seq_lens_cpu = seq_lens .cpu (),
311
+ num_computed_tokens_cpu = seq_lens .cpu (),
312
+ num_reqs = batch_size ,
313
+ num_actual_tokens = int (num_tokens ),
314
+ max_query_len = int (seq_lens .max ().item ()),
315
+ block_table_tensor = torch .zeros ((batch_size , 1 ),
316
+ dtype = torch .int32 ,
317
+ device = device ),
318
+ slot_mapping = torch .arange (num_tokens , dtype = torch .int64 ,
319
+ device = device ),
320
+ )
321
+
322
+ attn_metadata_builder_cls , _ = get_attention_backend (
323
+ _Backend .FLASH_ATTN_VLLM_V1 )
330
324
attn_metadata_builder = attn_metadata_builder_cls (
331
325
kv_cache_spec = create_standard_kv_cache_spec (proposer .vllm_config ),
332
326
vllm_config = proposer .vllm_config ,
@@ -335,8 +329,7 @@ def create_deterministic_logits(token_ids):
335
329
336
330
# Mock runner for attention metadata building
337
331
proposer .runner = mock .MagicMock ()
338
- proposer .runner .attn_metadata_builders = [mock .MagicMock ()]
339
- proposer .runner .attn_metadata_builders [0 ] = attn_metadata_builder
332
+ proposer .runner .attn_metadata_builders = [attn_metadata_builder ]
340
333
341
334
result = proposer .propose (target_token_ids = target_token_ids ,
342
335
target_positions = target_positions ,
0 commit comments