|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +import torch |
| 3 | +from vllm.v1.spec_decode.eagle import EagleProposer |
| 4 | + |
| 5 | + |
| 6 | +def prepare_inputs( |
| 7 | + # [batch_size + 1] |
| 8 | + cu_target_query_lens: torch.Tensor, |
| 9 | + # [batch_size] |
| 10 | + num_rejected_tokens: torch.Tensor, |
| 11 | +) -> tuple[torch.Tensor, torch.Tensor]: |
| 12 | + # cu_target_query_lens: [0, a, a + b, a + b + c] |
| 13 | + # num_rejected_tokens: [n1, n2, n3] |
| 14 | + # num_tokens_per_req: [a - n1, b - n2, c - n3] |
| 15 | + # cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3] |
| 16 | + # token_indices: [0, 1, ..., a - n1 - 1, |
| 17 | + # a, a + 1, ..., a + b - n2 - 1, |
| 18 | + # a + b, a + b + 1, ..., a + b + c - n3 - 1] |
| 19 | + |
| 20 | + # [0, a, a + b, a + b + c] -> [a, b, c] |
| 21 | + query_len_per_req = (cu_target_query_lens[1:] - cu_target_query_lens[:-1]) |
| 22 | + # [a, b, c] -> [a - n1, b - n2, c - n3] |
| 23 | + num_tokens_per_req = query_len_per_req - num_rejected_tokens |
| 24 | + |
| 25 | + cu_num_tokens = torch.empty_like(cu_target_query_lens) |
| 26 | + torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:]) |
| 27 | + cu_num_tokens[0] = 0 |
| 28 | + |
| 29 | + # FIXME(woosuk): Avoid synchronization. |
| 30 | + num_tokens = cu_num_tokens[-1].item() |
| 31 | + token_indices = torch.empty( |
| 32 | + num_tokens, |
| 33 | + dtype=torch.int32, |
| 34 | + device=cu_num_tokens.device, |
| 35 | + ) |
| 36 | + |
| 37 | + BLOCK_SIZE = 1024 |
| 38 | + prepare_input_pytorch( |
| 39 | + token_indices, |
| 40 | + cu_target_query_lens, |
| 41 | + cu_num_tokens, |
| 42 | + block_size=BLOCK_SIZE, |
| 43 | + ) |
| 44 | + return cu_num_tokens, token_indices |
| 45 | + |
| 46 | + |
| 47 | +def prepare_input_pytorch(out_ptr: torch.Tensor, cu_query_lens: torch.Tensor, |
| 48 | + cu_num_tokens: torch.Tensor, block_size: int): |
| 49 | + num_pids = cu_num_tokens.shape[0] - 1 |
| 50 | + |
| 51 | + for pid in range(num_pids): |
| 52 | + start_pos = cu_num_tokens[pid].item() |
| 53 | + end_pos = cu_num_tokens[pid + 1].item() |
| 54 | + num_tokens = end_pos - start_pos |
| 55 | + |
| 56 | + index_start = cu_query_lens[pid].item() |
| 57 | + num_blocks = (num_tokens + block_size - 1) |
| 58 | + |
| 59 | + for i in range(num_blocks): |
| 60 | + offset = torch.arange(0, |
| 61 | + block_size, |
| 62 | + dtype=out_ptr.dtype, |
| 63 | + device=cu_query_lens.device) |
| 64 | + global_indices = start_pos + offset |
| 65 | + values = index_start + offset |
| 66 | + mask = offset < num_tokens |
| 67 | + out_ptr[global_indices[mask]] = values[mask] |
| 68 | + |
| 69 | + |
| 70 | +EagleProposer.prepare_inputs = prepare_inputs |
0 commit comments