Skip to content

Commit d534e4e

Browse files
eagle passing
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
1 parent fa85ae0 commit d534e4e

File tree

3 files changed

+72
-27
lines changed

3 files changed

+72
-27
lines changed

tests/v1/spec_decode/test_eagle.py

Lines changed: 58 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
VllmConfig)
1212
from vllm.model_executor.models.llama import LlamaForCausalLM
1313
from vllm.platforms import current_platform
14+
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
1415
from vllm.v1.spec_decode.eagle import EagleProposer
1516

1617
model_dir = "meta-llama/Llama-3.1-8B-Instruct"
@@ -52,6 +53,31 @@ def _create_proposer(method: str, k: int) -> EagleProposer:
5253
device=current_platform.device_type)
5354

5455

56+
def _create_common_attn_metadata(
57+
cu_target_query_lens: torch.Tensor,
58+
device: torch.device) -> CommonAttentionMetadata:
59+
"""Create minimal CommonAttentionMetadata for testing."""
60+
batch_size = cu_target_query_lens.shape[0] - 1
61+
num_tokens = cu_target_query_lens[-1].item()
62+
seq_lens = cu_target_query_lens[1:] - cu_target_query_lens[:-1]
63+
64+
return CommonAttentionMetadata(
65+
query_start_loc=cu_target_query_lens,
66+
query_start_loc_cpu=cu_target_query_lens.cpu(),
67+
seq_lens=seq_lens,
68+
seq_lens_cpu=seq_lens.cpu(),
69+
num_computed_tokens_cpu=seq_lens.cpu(),
70+
num_reqs=batch_size,
71+
num_actual_tokens=int(num_tokens),
72+
max_query_len=int(seq_lens.max().item()),
73+
block_table_tensor=torch.zeros((batch_size, 1),
74+
dtype=torch.int32,
75+
device=device),
76+
slot_mapping=torch.arange(num_tokens, dtype=torch.int64,
77+
device=device),
78+
)
79+
80+
5581
def test_prepare_inputs():
5682
"""
5783
cu_target_query_lens: [0, a, a + b, a + b + c]
@@ -106,13 +132,19 @@ def test_prepare_inputs():
106132
device=device)
107133

108134
# n1 + n2 + n3 - a - b -c
109-
num_tokens = cu_target_query_lens[-1].item() - num_rejected_tokens.sum(
110-
).item()
135+
num_tokens = int(cu_target_query_lens[-1].item() -
136+
num_rejected_tokens.sum().item())
111137

112-
cu_num_tokens, token_indices = EagleProposer.prepare_inputs(
113-
cu_target_query_lens, num_rejected_tokens, num_tokens)
138+
# Create CommonAttentionMetadata for new API
139+
common_attn_metadata = _create_common_attn_metadata(
140+
cu_target_query_lens, device)
141+
proposer = _create_proposer("eagle", 1)
114142

115-
assert torch.equal(cu_num_tokens, expected_cu_num_tokens)
143+
updated_metadata, token_indices = proposer.prepare_inputs(
144+
common_attn_metadata, num_rejected_tokens.cpu(), num_tokens)
145+
146+
assert torch.equal(updated_metadata.query_start_loc,
147+
expected_cu_num_tokens)
116148
assert token_indices.shape[0] == expected_cu_num_tokens[-1].item()
117149
assert torch.equal(token_indices, expected_token_indices)
118150

@@ -284,26 +316,33 @@ def create_deterministic_logits(token_ids):
284316
target_hidden_states = torch.randn(total_tokens,
285317
hidden_size,
286318
device=device)
287-
target_slot_mapping = torch.randint(0,
288-
100, (total_tokens, ),
289-
device=device)
290319
next_token_ids = torch.randint(0,
291320
vocab_size, (batch_size, ),
292321
dtype=torch.int32,
293322
device=device)
294-
block_table = torch.randint(0, 10, (batch_size, 10), device=device)
295-
296323
sampling_metadata = mock.MagicMock()
297324

298-
# Call the method under test
299-
result = proposer.propose(target_token_ids=target_token_ids,
300-
target_positions=target_positions,
301-
target_hidden_states=target_hidden_states,
302-
target_slot_mapping=target_slot_mapping,
303-
next_token_ids=next_token_ids,
304-
cu_num_tokens=cu_num_tokens,
305-
block_table=block_table,
306-
sampling_metadata=sampling_metadata)
325+
# Create CommonAttentionMetadata for new API
326+
common_attn_metadata = _create_common_attn_metadata(cu_num_tokens, device)
327+
328+
# Mock runner for attention metadata building
329+
proposer.runner = mock.MagicMock()
330+
proposer.runner.attn_metadata_builders = [mock.MagicMock()]
331+
332+
# Create mock with required attributes for multi-token tests
333+
attn_metadata_mock = mock.MagicMock()
334+
attn_metadata_mock.max_seq_len = 10
335+
attn_metadata_mock.seq_lens = torch.tensor([5, 3], device=device)
336+
proposer.runner.attn_metadata_builders[
337+
0].build.return_value = attn_metadata_mock
338+
339+
with mock.patch('vllm.v1.spec_decode.eagle.isinstance', return_value=True):
340+
result = proposer.propose(target_token_ids=target_token_ids,
341+
target_positions=target_positions,
342+
target_hidden_states=target_hidden_states,
343+
next_token_ids=next_token_ids,
344+
common_attn_metadata=common_attn_metadata,
345+
sampling_metadata=sampling_metadata)
307346

308347
assert result.shape == (batch_size, num_speculative_tokens)
309348

vllm/v1/attention/backends/flash_attn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,8 +290,8 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
290290
prefix_kv_lens = torch.tensor([common_prefix_len],
291291
dtype=torch.int32,
292292
device=self.device)
293-
suffix_kv_lens = (seq_lens_cpu[:num_reqs] - common_prefix_len)
294-
suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(self.device)
293+
suffix_kv_lens = (seq_lens_cpu[:num_reqs] - common_prefix_len).to(
294+
self.device, non_blocking=True)
295295
prefix_scheduler_metadata = schedule(
296296
batch_size=1,
297297
cu_query_lens=cu_prefix_query_lens,

vllm/v1/spec_decode/eagle.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,14 @@ def __init__(
3737
self.method = self.speculative_config.method
3838

3939
self.runner = runner
40-
self.arange_np = np.arange(vllm_config.scheduler_config.max_num_seqs +
41-
1)
42-
4340
self.dtype = vllm_config.model_config.dtype
4441
self.max_model_len = vllm_config.model_config.max_model_len
4542
self.block_size = vllm_config.cache_config.block_size
4643
self.num_speculative_tokens = (
4744
self.speculative_config.num_speculative_tokens)
4845
self.max_num_tokens = (
4946
vllm_config.scheduler_config.max_num_batched_tokens)
47+
self.arange_np = np.arange(self.max_num_tokens)
5048
# We need to get the hidden size from the draft model config because
5149
# the draft model's hidden size can be different from the target model's
5250
# hidden size (e.g., Llama 3.3 70B).
@@ -286,7 +284,14 @@ def prepare_inputs(
286284
# Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
287285
arange = self.arange_np[:total_num_tokens] - cumsums_offsets
288286

289-
tokens_indices = arange + query_start_loc_cpu[:-1]
287+
# Expand starting positions to match token pattern
288+
query_start_expanded = np.repeat(query_start_loc_cpu[:-1].numpy(),
289+
num_tokens_per_req.numpy())
290+
tokens_indices = arange + query_start_expanded
291+
292+
# Ensure tokens_indices are within valid range for slot_mapping
293+
max_slot_idx = common_attn_metadata.slot_mapping.size(0) - 1
294+
tokens_indices = np.clip(tokens_indices, 0, max_slot_idx)
290295

291296
spec_common_attn_metadata = CommonAttentionMetadata(
292297
query_start_loc=spec_query_start_loc_cpu.to(device,
@@ -297,13 +302,14 @@ def prepare_inputs(
297302
num_computed_tokens_cpu=(
298303
common_attn_metadata.num_computed_tokens_cpu),
299304
num_reqs=common_attn_metadata.num_reqs,
300-
num_actual_tokens=num_tokens,
305+
num_actual_tokens=total_num_tokens,
301306
max_query_len=query_len_per_req.max().item(),
302307
block_table_tensor=common_attn_metadata.block_table_tensor,
303308
slot_mapping=common_attn_metadata.slot_mapping[tokens_indices],
304309
)
305310

306-
return spec_common_attn_metadata, tokens_indices
311+
return spec_common_attn_metadata, torch.from_numpy(tokens_indices).to(
312+
device)
307313

308314
def load_model(self, target_model: nn.Module) -> None:
309315
draft_model_config = \

0 commit comments

Comments
 (0)