Skip to content

Commit 56ea170

Browse files
cleanup
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
1 parent cf3796e commit 56ea170

File tree

3 files changed

+85
-61
lines changed

3 files changed

+85
-61
lines changed

vllm/attention/backends/flashinfer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def get_per_layer_parameters(
140140
"""
141141

142142
layers = get_layers_from_vllm_config(vllm_config, Attention)
143-
per_layer_params: Dict[str, PerLayerParameters] = {}
143+
per_layer_params: dict[str, PerLayerParameters] = {}
144144

145145
for key, layer in layers.items():
146146
impl = layer.impl

vllm/v1/attention/backends/utils.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -366,20 +366,22 @@ def split_decodes_and_prefills(
366366
num_tokens = common_attn_metadata.num_actual_tokens
367367
query_start_loc = common_attn_metadata.query_start_loc_cpu
368368

369-
if max_query_len == 1:
369+
if max_query_len <= decode_threshold:
370370
return num_reqs, 0, num_tokens, 0
371-
else:
372-
query_lens = query_start_loc[1:] - query_start_loc[:-1]
373-
first_prefill = (query_lens
374-
> decode_threshold).int().argmax(dim=-1).item()
375-
assert torch.all(query_lens[first_prefill:] > decode_threshold)
376-
assert torch.all(query_lens[:first_prefill] <= decode_threshold)
377-
num_decodes = first_prefill
378-
num_prefills = num_reqs - num_decodes
379-
num_decode_tokens = first_prefill
380-
num_prefill_tokens = num_tokens - query_start_loc[first_prefill]
381-
return (num_decodes, num_prefills, num_decode_tokens,
382-
num_prefill_tokens)
371+
372+
query_lens = query_start_loc[1:] - query_start_loc[:-1]
373+
is_prefill = query_lens > decode_threshold
374+
if not torch.any(is_prefill):
375+
return num_reqs, 0, num_tokens, 0
376+
377+
first_prefill = is_prefill.int().argmax(dim=-1).item()
378+
assert torch.all(query_lens[first_prefill:] > decode_threshold)
379+
assert torch.all(query_lens[:first_prefill] <= decode_threshold)
380+
num_decodes = first_prefill
381+
num_prefills = num_reqs - num_decodes
382+
num_decode_tokens = query_start_loc[first_prefill].item()
383+
num_prefill_tokens = num_tokens - num_decode_tokens
384+
return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens)
383385

384386

385387
def reoder_batch_to_split_decodes_and_prefills(

vllm/v1/spec_decode/eagle.py

Lines changed: 69 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def __init__(
4444
self.speculative_config.num_speculative_tokens)
4545
self.max_num_tokens = (
4646
vllm_config.scheduler_config.max_num_batched_tokens)
47-
self.arange_np = np.arange(self.max_num_tokens)
47+
self.token_arange_np = np.arange(self.max_num_tokens)
4848
# We need to get the hidden size from the draft model config because
4949
# the draft model's hidden size can be different from the target model's
5050
# hidden size (e.g., Llama 3.3 70B).
@@ -245,65 +245,87 @@ def prepare_inputs(
245245
# [batch_size]
246246
num_rejected_tokens: torch.Tensor
247247
) -> tuple[CommonAttentionMetadata, torch.Tensor]:
248-
# query_start_loc_cpu: [0, a, a + b, a + b + c]
249-
# num_rejected_tokens: [n1, n2, n3]
250-
# num_tokens_per_req: [a - n1, b - n2, c - n3]
251-
# cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
252-
# token_indices: [0, 1, ..., a - n1 - 1,
253-
# a, a + 1, ..., a + b - n2 - 1,
254-
# a + b, a + b + 1, ..., a + b + c - n3 - 1]
248+
"""
249+
This function is used to prepare the inputs for the spec decode.
250+
It updates to the common_attn_metadata to account for the rejected
251+
tokens (and newly sampled tokens). It also returns the token indices
252+
of the tokens that should be fed to the speculator.
253+
"""
254+
# E.g.
255+
# common_attn_metadata.query_start_loc{_cpu}:
256+
# [0, q1, q1 + q2, q1 + q2 + q3]
257+
# common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3]
258+
# num_rejected_tokens: [n1, n2, n3]
259+
# This function computes the intermediate values:
260+
# num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3]
261+
# And returns:
262+
# common_attn_metadata.query_start_loc{_cpu}:
263+
# [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
264+
# common_attn_metadata.seq_lens{_cpu}:
265+
# [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1]
266+
# token_indices: [0, 1, ..., q1 - n1 - 1,
267+
# q1, q1 + 1, ..., q1 + q2 - n2 - 1,
268+
# q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
255269

256270
device = common_attn_metadata.query_start_loc.device
257271
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
258-
spec_seq_lens_cpu =\
259-
common_attn_metadata.seq_lens_cpu - num_rejected_tokens + 1
260-
261-
# [0, a, a + b, a + b + c] -> [a, b, c]
262-
query_len_per_req = (query_start_loc_cpu[1:] -
263-
query_start_loc_cpu[:-1])
264-
# [a, b, c] -> [a - n1, b - n2, c - n3]
265-
num_tokens_per_req = query_len_per_req - num_rejected_tokens
266-
num_tokens_per_req_np = num_tokens_per_req.numpy()
267-
268-
# [a - n1, b - n2, c - n3] ->
269-
# [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
270-
spec_query_start_loc_cpu = torch.zeros(query_start_loc_cpu.shape,
271-
dtype=torch.int32,
272-
pin_memory=True)
273-
spec_query_start_loc_np = spec_query_start_loc_cpu.numpy()
274-
np.cumsum(num_tokens_per_req_np, out=spec_query_start_loc_np[1:])
275-
"""Get the cumulative sum and batched arange of the given array.
276-
# E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2])
277-
# Equivalent to but faster than:
278-
# np.concatenate([np.arange(n) for n in num_tokens])
279-
"""
280-
281-
# Step 1. [2, 5, 3] -> [2, 7, 10]
282-
total_num_tokens = spec_query_start_loc_np[-1]
283-
# Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7]
284-
cumsums_offsets = np.repeat(spec_query_start_loc_np[:-1],
285-
num_tokens_per_req_np)
286-
# Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
287-
arange = self.arange_np[:total_num_tokens] - cumsums_offsets
272+
new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu \
273+
- num_rejected_tokens
274+
275+
# [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
276+
new_query_len_per_req = (query_start_loc_cpu[1:] -
277+
query_start_loc_cpu[:-1])
278+
# [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3]
279+
new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens
280+
new_num_tokens_per_req_np = new_num_tokens_per_req.numpy()
281+
282+
# [q1 - n1, q2 - n2, q3 - n3] ->
283+
# [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
284+
new_query_start_loc_cpu = torch.zeros(query_start_loc_cpu.shape,
285+
dtype=torch.int32,
286+
pin_memory=True)
287+
new_query_start_loc_np = new_query_start_loc_cpu.numpy()
288+
np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:])
289+
290+
total_num_tokens = new_query_start_loc_np[-1]
291+
# Example assuming num_tokens_per_req_np = [2, 4, 3]
292+
# this implies that `new_query_start_locs` is:
293+
# [0, 2, 6, 9] ->
294+
# [0, 0, 2, 2, 2, 2, 6, 6, 6]
295+
# _r1_ ____r2____ ___r3__
296+
new_query_start_locs_expanded = np.repeat(new_query_start_loc_np[:-1],
297+
new_num_tokens_per_req_np)
298+
# [0, 1, 2, 3, 4, 5, 6, 7, 8] ->
299+
# [0, 1, 0, 1, 2, 3, 0, 1, 2]
300+
# _r1_ ____r2____ ___r3__
301+
token_offests = self.token_arange_np[:total_num_tokens] \
302+
- new_query_start_locs_expanded
288303

289304
# Expand starting positions to match token pattern
290-
query_start_expanded = np.repeat(query_start_loc_cpu[:-1].numpy(),
291-
num_tokens_per_req_np)
292-
token_indices_np = arange + query_start_expanded
305+
# [0, q1, q1 + q2] ->
306+
# [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2]
307+
# _r1_ _____r2_______ ___________r3____________
308+
old_query_start_locs_expanded = np.repeat(
309+
query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np)
310+
# Final token indices are:
311+
# [0, 1, // req 1
312+
# q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2
313+
# q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
314+
token_indices_np = token_offests + old_query_start_locs_expanded
293315
token_indices = torch.from_numpy(token_indices_np).to(
294316
device, non_blocking=True)
295317

296318
spec_common_attn_metadata = CommonAttentionMetadata(
297-
query_start_loc=spec_query_start_loc_cpu.to(device,
298-
non_blocking=True),
299-
seq_lens=spec_seq_lens_cpu.to(device, non_blocking=True),
300-
query_start_loc_cpu=spec_query_start_loc_cpu,
301-
seq_lens_cpu=spec_seq_lens_cpu,
319+
query_start_loc=new_query_start_loc_cpu.to(device,
320+
non_blocking=True),
321+
seq_lens=new_seq_lens_cpu.to(device, non_blocking=True),
322+
query_start_loc_cpu=new_query_start_loc_cpu,
323+
seq_lens_cpu=new_seq_lens_cpu,
302324
num_computed_tokens_cpu=common_attn_metadata.
303325
num_computed_tokens_cpu,
304326
num_reqs=common_attn_metadata.num_reqs,
305327
num_actual_tokens=total_num_tokens,
306-
max_query_len=query_len_per_req.max().item(),
328+
max_query_len=new_query_len_per_req.max().item(),
307329
block_table_tensor=common_attn_metadata.block_table_tensor,
308330
slot_mapping=common_attn_metadata.slot_mapping[token_indices],
309331
)

0 commit comments

Comments
 (0)