15
15
VllmConfig )
16
16
from vllm .model_executor .models .llama import LlamaForCausalLM
17
17
from vllm .platforms import current_platform
18
- from vllm .v1 .attention .backends .utils import CommonAttentionMetadata
19
18
from vllm .v1 .spec_decode .eagle import EagleProposer
20
19
21
20
model_dir = "meta-llama/Llama-3.1-8B-Instruct"
@@ -218,6 +217,7 @@ def test_propose(num_speculative_tokens):
218
217
seq_len_2 = 3
219
218
total_tokens = seq_len_1 + seq_len_2
220
219
vocab_size = 100
220
+ seq_lens = [seq_len_1 , seq_len_2 ]
221
221
222
222
# Create proposer first so we can use its actual hidden_size
223
223
proposer = _create_proposer ("eagle" , num_speculative_tokens )
@@ -279,9 +279,16 @@ def create_deterministic_logits(token_ids):
279
279
proposer .attn_layer_names = ["layer.0" ]
280
280
281
281
# Create input tensors
282
- cu_num_tokens = torch .tensor ([0 , seq_len_1 , total_tokens ],
283
- dtype = torch .int32 ,
284
- device = device )
282
+ batch_spec = BatchSpec (
283
+ seq_lens = seq_lens ,
284
+ query_lens = seq_lens ,
285
+ )
286
+
287
+ common_attn_metadata = create_common_attn_metadata (
288
+ batch_spec ,
289
+ block_size = 16 ,
290
+ device = device ,
291
+ )
285
292
286
293
target_token_ids = torch .randint (0 ,
287
294
vocab_size , (total_tokens , ),
@@ -299,26 +306,6 @@ def create_deterministic_logits(token_ids):
299
306
device = device )
300
307
sampling_metadata = mock .MagicMock ()
301
308
302
- batch_size = cu_num_tokens .shape [0 ] - 1
303
- num_tokens = cu_num_tokens [- 1 ].item ()
304
- seq_lens = cu_num_tokens [1 :] - cu_num_tokens [:- 1 ]
305
-
306
- common_attn_metadata = CommonAttentionMetadata (
307
- query_start_loc = cu_num_tokens ,
308
- query_start_loc_cpu = cu_num_tokens .cpu (),
309
- seq_lens = seq_lens ,
310
- seq_lens_cpu = seq_lens .cpu (),
311
- num_computed_tokens_cpu = seq_lens .cpu (),
312
- num_reqs = batch_size ,
313
- num_actual_tokens = int (num_tokens ),
314
- max_query_len = int (seq_lens .max ().item ()),
315
- block_table_tensor = torch .zeros ((batch_size , 1 ),
316
- dtype = torch .int32 ,
317
- device = device ),
318
- slot_mapping = torch .arange (num_tokens , dtype = torch .int64 ,
319
- device = device ),
320
- )
321
-
322
309
attn_metadata_builder_cls , _ = get_attention_backend (
323
310
_Backend .FLASH_ATTN_VLLM_V1 )
324
311
attn_metadata_builder = attn_metadata_builder_cls (
0 commit comments