Skip to content

Commit 8f7ffc4

Browse files
authored
Create patch_eagle.py
1 parent a4928a4 commit 8f7ffc4

File tree

1 file changed

+70
-0
lines changed

1 file changed

+70
-0
lines changed
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import torch
3+
from vllm.v1.spec_decode.eagle import EagleProposer
4+
5+
6+
def prepare_inputs(
7+
# [batch_size + 1]
8+
cu_target_query_lens: torch.Tensor,
9+
# [batch_size]
10+
num_rejected_tokens: torch.Tensor,
11+
) -> tuple[torch.Tensor, torch.Tensor]:
12+
# cu_target_query_lens: [0, a, a + b, a + b + c]
13+
# num_rejected_tokens: [n1, n2, n3]
14+
# num_tokens_per_req: [a - n1, b - n2, c - n3]
15+
# cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
16+
# token_indices: [0, 1, ..., a - n1 - 1,
17+
# a, a + 1, ..., a + b - n2 - 1,
18+
# a + b, a + b + 1, ..., a + b + c - n3 - 1]
19+
20+
# [0, a, a + b, a + b + c] -> [a, b, c]
21+
query_len_per_req = (cu_target_query_lens[1:] - cu_target_query_lens[:-1])
22+
# [a, b, c] -> [a - n1, b - n2, c - n3]
23+
num_tokens_per_req = query_len_per_req - num_rejected_tokens
24+
25+
cu_num_tokens = torch.empty_like(cu_target_query_lens)
26+
torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:])
27+
cu_num_tokens[0] = 0
28+
29+
# FIXME(woosuk): Avoid synchronization.
30+
num_tokens = cu_num_tokens[-1].item()
31+
token_indices = torch.empty(
32+
num_tokens,
33+
dtype=torch.int32,
34+
device=cu_num_tokens.device,
35+
)
36+
37+
BLOCK_SIZE = 1024
38+
prepare_input_pytorch(
39+
token_indices,
40+
cu_target_query_lens,
41+
cu_num_tokens,
42+
block_size=BLOCK_SIZE,
43+
)
44+
return cu_num_tokens, token_indices
45+
46+
47+
def prepare_input_pytorch(out_ptr: torch.Tensor, cu_query_lens: torch.Tensor,
48+
cu_num_tokens: torch.Tensor, block_size: int):
49+
num_pids = cu_num_tokens.shape[0] - 1
50+
51+
for pid in range(num_pids):
52+
start_pos = cu_num_tokens[pid].item()
53+
end_pos = cu_num_tokens[pid + 1].item()
54+
num_tokens = end_pos - start_pos
55+
56+
index_start = cu_query_lens[pid].item()
57+
num_blocks = (num_tokens + block_size - 1)
58+
59+
for i in range(num_blocks):
60+
offset = torch.arange(0,
61+
block_size,
62+
dtype=out_ptr.dtype,
63+
device=cu_query_lens.device)
64+
global_indices = start_pos + offset
65+
values = index_start + offset
66+
mask = offset < num_tokens
67+
out_ptr[global_indices[mask]] = values[mask]
68+
69+
70+
EagleProposer.prepare_inputs = prepare_inputs

0 commit comments

Comments
 (0)