Skip to content

Commit 4c56bb0

Browse files
review comments
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
1 parent b7e7301 commit 4c56bb0

File tree

1 file changed

+11
-24
lines changed

1 file changed

+11
-24
lines changed

tests/v1/spec_decode/test_eagle.py

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
VllmConfig)
1616
from vllm.model_executor.models.llama import LlamaForCausalLM
1717
from vllm.platforms import current_platform
18-
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
1918
from vllm.v1.spec_decode.eagle import EagleProposer
2019

2120
model_dir = "meta-llama/Llama-3.1-8B-Instruct"
@@ -218,6 +217,7 @@ def test_propose(num_speculative_tokens):
218217
seq_len_2 = 3
219218
total_tokens = seq_len_1 + seq_len_2
220219
vocab_size = 100
220+
seq_lens = [seq_len_1, seq_len_2]
221221

222222
# Create proposer first so we can use its actual hidden_size
223223
proposer = _create_proposer("eagle", num_speculative_tokens)
@@ -279,9 +279,16 @@ def create_deterministic_logits(token_ids):
279279
proposer.attn_layer_names = ["layer.0"]
280280

281281
# Create input tensors
282-
cu_num_tokens = torch.tensor([0, seq_len_1, total_tokens],
283-
dtype=torch.int32,
284-
device=device)
282+
batch_spec = BatchSpec(
283+
seq_lens=seq_lens,
284+
query_lens=seq_lens,
285+
)
286+
287+
common_attn_metadata = create_common_attn_metadata(
288+
batch_spec,
289+
block_size=16,
290+
device=device,
291+
)
285292

286293
target_token_ids = torch.randint(0,
287294
vocab_size, (total_tokens, ),
@@ -299,26 +306,6 @@ def create_deterministic_logits(token_ids):
299306
device=device)
300307
sampling_metadata = mock.MagicMock()
301308

302-
batch_size = cu_num_tokens.shape[0] - 1
303-
num_tokens = cu_num_tokens[-1].item()
304-
seq_lens = cu_num_tokens[1:] - cu_num_tokens[:-1]
305-
306-
common_attn_metadata = CommonAttentionMetadata(
307-
query_start_loc=cu_num_tokens,
308-
query_start_loc_cpu=cu_num_tokens.cpu(),
309-
seq_lens=seq_lens,
310-
seq_lens_cpu=seq_lens.cpu(),
311-
num_computed_tokens_cpu=seq_lens.cpu(),
312-
num_reqs=batch_size,
313-
num_actual_tokens=int(num_tokens),
314-
max_query_len=int(seq_lens.max().item()),
315-
block_table_tensor=torch.zeros((batch_size, 1),
316-
dtype=torch.int32,
317-
device=device),
318-
slot_mapping=torch.arange(num_tokens, dtype=torch.int64,
319-
device=device),
320-
)
321-
322309
attn_metadata_builder_cls, _ = get_attention_backend(
323310
_Backend.FLASH_ATTN_VLLM_V1)
324311
attn_metadata_builder = attn_metadata_builder_cls(

0 commit comments

Comments
 (0)