Skip to content

Commit 53e54da

Browse files
optimize eagle
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
1 parent d534e4e commit 53e54da

File tree

2 files changed

+26
-30
lines changed

2 files changed

+26
-30
lines changed

vllm/v1/spec_decode/eagle.py

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -239,11 +239,11 @@ def propose(
239239
return draft_token_ids
240240

241241
def prepare_inputs(
242-
self,
243-
common_attn_metadata: CommonAttentionMetadata,
244-
# [batch_size]
245-
num_rejected_tokens: torch.Tensor,
246-
num_tokens: int) -> tuple[CommonAttentionMetadata, torch.Tensor]:
242+
self,
243+
common_attn_metadata: CommonAttentionMetadata,
244+
# [batch_size]
245+
num_rejected_tokens: torch.Tensor
246+
) -> tuple[CommonAttentionMetadata, torch.Tensor]:
247247
# query_start_loc_cpu: [0, a, a + b, a + b + c]
248248
# num_rejected_tokens: [n1, n2, n3]
249249
# num_tokens_per_req: [a - n1, b - n2, c - n3]
@@ -262,54 +262,52 @@ def prepare_inputs(
262262
query_start_loc_cpu[:-1])
263263
# [a, b, c] -> [a - n1, b - n2, c - n3]
264264
num_tokens_per_req = query_len_per_req - num_rejected_tokens
265+
num_tokens_per_req_np = num_tokens_per_req.numpy()
265266

266267
# [a - n1, b - n2, c - n3] ->
267268
# [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
268-
spec_query_start_loc_cpu = torch.zeros_like(query_start_loc_cpu,
269-
pin_memory=True)
270-
torch.cumsum(num_tokens_per_req,
271-
dim=0,
272-
out=spec_query_start_loc_cpu[1:])
269+
spec_query_start_loc_cpu = torch.zeros(query_start_loc_cpu.shape,
270+
dtype=torch.int32,
271+
pin_memory=True)
272+
spec_query_start_loc_np = spec_query_start_loc_cpu.numpy()
273+
np.cumsum(num_tokens_per_req_np, out=spec_query_start_loc_np[1:])
273274
"""Get the cumulative sum and batched arange of the given array.
274275
# E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2])
275276
# Equivalent to but faster than:
276277
# np.concatenate([np.arange(n) for n in num_tokens])
277278
"""
279+
278280
# Step 1. [2, 5, 3] -> [2, 7, 10]
279-
total_num_tokens = spec_query_start_loc_cpu[-1]
281+
total_num_tokens = spec_query_start_loc_np[-1]
280282
# Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7]
281-
cumsums_offsets = np.repeat(
282-
spec_query_start_loc_cpu[1:].numpy() - num_tokens_per_req.numpy(),
283-
num_tokens_per_req.numpy())
283+
cumsums_offsets = np.repeat(spec_query_start_loc_np[:-1],
284+
num_tokens_per_req_np)
284285
# Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
285286
arange = self.arange_np[:total_num_tokens] - cumsums_offsets
286287

287288
# Expand starting positions to match token pattern
288289
query_start_expanded = np.repeat(query_start_loc_cpu[:-1].numpy(),
289-
num_tokens_per_req.numpy())
290-
tokens_indices = arange + query_start_expanded
291-
292-
# Ensure tokens_indices are within valid range for slot_mapping
293-
max_slot_idx = common_attn_metadata.slot_mapping.size(0) - 1
294-
tokens_indices = np.clip(tokens_indices, 0, max_slot_idx)
290+
num_tokens_per_req_np)
291+
token_indices_np = arange + query_start_expanded
292+
token_indices = torch.from_numpy(token_indices_np).to(
293+
device, non_blocking=True)
295294

296295
spec_common_attn_metadata = CommonAttentionMetadata(
297296
query_start_loc=spec_query_start_loc_cpu.to(device,
298297
non_blocking=True),
299298
seq_lens=spec_seq_lens_cpu.to(device, non_blocking=True),
300-
query_start_loc_cpu=spec_query_start_loc_cpu.cpu(),
301-
seq_lens_cpu=spec_seq_lens_cpu.cpu(),
302-
num_computed_tokens_cpu=(
303-
common_attn_metadata.num_computed_tokens_cpu),
299+
query_start_loc_cpu=spec_query_start_loc_cpu,
300+
seq_lens_cpu=spec_seq_lens_cpu,
301+
num_computed_tokens_cpu=common_attn_metadata.
302+
num_computed_tokens_cpu,
304303
num_reqs=common_attn_metadata.num_reqs,
305304
num_actual_tokens=total_num_tokens,
306305
max_query_len=query_len_per_req.max().item(),
307306
block_table_tensor=common_attn_metadata.block_table_tensor,
308-
slot_mapping=common_attn_metadata.slot_mapping[tokens_indices],
307+
slot_mapping=common_attn_metadata.slot_mapping[token_indices],
309308
)
310309

311-
return spec_common_attn_metadata, torch.from_numpy(tokens_indices).to(
312-
device)
310+
return spec_common_attn_metadata, token_indices
313311

314312
def load_model(self, target_model: nn.Module) -> None:
315313
draft_model_config = \

vllm/v1/worker/gpu_model_runner.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1654,11 +1654,9 @@ def propose_draft_token_ids(
16541654
]
16551655
num_rejected_tokens_cpu = torch.tensor(num_rejected_tokens,
16561656
dtype=torch.int32)
1657-
num_tokens = (num_scheduled_tokens -
1658-
num_rejected_tokens_cpu.sum())
16591657
common_attn_metadata, token_indices =\
16601658
self.drafter.prepare_inputs(
1661-
common_attn_metadata, num_rejected_tokens_cpu, num_tokens)
1659+
common_attn_metadata, num_rejected_tokens_cpu)
16621660

16631661
target_token_ids = self.input_ids[token_indices]
16641662
# TODO(woosuk): Support M-RoPE.

0 commit comments

Comments
 (0)