|
11 | 11 | VllmConfig)
|
12 | 12 | from vllm.model_executor.models.llama import LlamaForCausalLM
|
13 | 13 | from vllm.platforms import current_platform
|
| 14 | +from vllm.v1.attention.backends.utils import CommonAttentionMetadata |
14 | 15 | from vllm.v1.spec_decode.eagle import EagleProposer
|
15 | 16 |
|
16 | 17 | model_dir = "meta-llama/Llama-3.1-8B-Instruct"
|
@@ -52,6 +53,31 @@ def _create_proposer(method: str, k: int) -> EagleProposer:
|
52 | 53 | device=current_platform.device_type)
|
53 | 54 |
|
54 | 55 |
|
| 56 | +def _create_common_attn_metadata( |
| 57 | + cu_target_query_lens: torch.Tensor, |
| 58 | + device: torch.device) -> CommonAttentionMetadata: |
| 59 | + """Create minimal CommonAttentionMetadata for testing.""" |
| 60 | + batch_size = cu_target_query_lens.shape[0] - 1 |
| 61 | + num_tokens = cu_target_query_lens[-1].item() |
| 62 | + seq_lens = cu_target_query_lens[1:] - cu_target_query_lens[:-1] |
| 63 | + |
| 64 | + return CommonAttentionMetadata( |
| 65 | + query_start_loc=cu_target_query_lens, |
| 66 | + query_start_loc_cpu=cu_target_query_lens.cpu(), |
| 67 | + seq_lens=seq_lens, |
| 68 | + seq_lens_cpu=seq_lens.cpu(), |
| 69 | + num_computed_tokens_cpu=seq_lens.cpu(), |
| 70 | + num_reqs=batch_size, |
| 71 | + num_actual_tokens=int(num_tokens), |
| 72 | + max_query_len=int(seq_lens.max().item()), |
| 73 | + block_table_tensor=torch.zeros((batch_size, 1), |
| 74 | + dtype=torch.int32, |
| 75 | + device=device), |
| 76 | + slot_mapping=torch.arange(num_tokens, dtype=torch.int64, |
| 77 | + device=device), |
| 78 | + ) |
| 79 | + |
| 80 | + |
55 | 81 | def test_prepare_inputs():
|
56 | 82 | """
|
57 | 83 | cu_target_query_lens: [0, a, a + b, a + b + c]
|
@@ -106,13 +132,19 @@ def test_prepare_inputs():
|
106 | 132 | device=device)
|
107 | 133 |
|
108 | 134 | # n1 + n2 + n3 - a - b -c
|
109 |
| - num_tokens = cu_target_query_lens[-1].item() - num_rejected_tokens.sum( |
110 |
| - ).item() |
| 135 | + num_tokens = int(cu_target_query_lens[-1].item() - |
| 136 | + num_rejected_tokens.sum().item()) |
111 | 137 |
|
112 |
| - cu_num_tokens, token_indices = EagleProposer.prepare_inputs( |
113 |
| - cu_target_query_lens, num_rejected_tokens, num_tokens) |
| 138 | + # Create CommonAttentionMetadata for new API |
| 139 | + common_attn_metadata = _create_common_attn_metadata( |
| 140 | + cu_target_query_lens, device) |
| 141 | + proposer = _create_proposer("eagle", 1) |
114 | 142 |
|
115 |
| - assert torch.equal(cu_num_tokens, expected_cu_num_tokens) |
| 143 | + updated_metadata, token_indices = proposer.prepare_inputs( |
| 144 | + common_attn_metadata, num_rejected_tokens.cpu(), num_tokens) |
| 145 | + |
| 146 | + assert torch.equal(updated_metadata.query_start_loc, |
| 147 | + expected_cu_num_tokens) |
116 | 148 | assert token_indices.shape[0] == expected_cu_num_tokens[-1].item()
|
117 | 149 | assert torch.equal(token_indices, expected_token_indices)
|
118 | 150 |
|
@@ -284,26 +316,33 @@ def create_deterministic_logits(token_ids):
|
284 | 316 | target_hidden_states = torch.randn(total_tokens,
|
285 | 317 | hidden_size,
|
286 | 318 | device=device)
|
287 |
| - target_slot_mapping = torch.randint(0, |
288 |
| - 100, (total_tokens, ), |
289 |
| - device=device) |
290 | 319 | next_token_ids = torch.randint(0,
|
291 | 320 | vocab_size, (batch_size, ),
|
292 | 321 | dtype=torch.int32,
|
293 | 322 | device=device)
|
294 |
| - block_table = torch.randint(0, 10, (batch_size, 10), device=device) |
295 |
| - |
296 | 323 | sampling_metadata = mock.MagicMock()
|
297 | 324 |
|
298 |
| - # Call the method under test |
299 |
| - result = proposer.propose(target_token_ids=target_token_ids, |
300 |
| - target_positions=target_positions, |
301 |
| - target_hidden_states=target_hidden_states, |
302 |
| - target_slot_mapping=target_slot_mapping, |
303 |
| - next_token_ids=next_token_ids, |
304 |
| - cu_num_tokens=cu_num_tokens, |
305 |
| - block_table=block_table, |
306 |
| - sampling_metadata=sampling_metadata) |
| 325 | + # Create CommonAttentionMetadata for new API |
| 326 | + common_attn_metadata = _create_common_attn_metadata(cu_num_tokens, device) |
| 327 | + |
| 328 | + # Mock runner for attention metadata building |
| 329 | + proposer.runner = mock.MagicMock() |
| 330 | + proposer.runner.attn_metadata_builders = [mock.MagicMock()] |
| 331 | + |
| 332 | + # Create mock with required attributes for multi-token tests |
| 333 | + attn_metadata_mock = mock.MagicMock() |
| 334 | + attn_metadata_mock.max_seq_len = 10 |
| 335 | + attn_metadata_mock.seq_lens = torch.tensor([5, 3], device=device) |
| 336 | + proposer.runner.attn_metadata_builders[ |
| 337 | + 0].build.return_value = attn_metadata_mock |
| 338 | + |
| 339 | + with mock.patch('vllm.v1.spec_decode.eagle.isinstance', return_value=True): |
| 340 | + result = proposer.propose(target_token_ids=target_token_ids, |
| 341 | + target_positions=target_positions, |
| 342 | + target_hidden_states=target_hidden_states, |
| 343 | + next_token_ids=next_token_ids, |
| 344 | + common_attn_metadata=common_attn_metadata, |
| 345 | + sampling_metadata=sampling_metadata) |
307 | 346 |
|
308 | 347 | assert result.shape == (batch_size, num_speculative_tokens)
|
309 | 348 |
|
|
0 commit comments