@@ -44,7 +44,7 @@ def __init__(
44
44
self .speculative_config .num_speculative_tokens )
45
45
self .max_num_tokens = (
46
46
vllm_config .scheduler_config .max_num_batched_tokens )
47
- self .arange_np = np .arange (self .max_num_tokens )
47
+ self .token_arange_np = np .arange (self .max_num_tokens )
48
48
# We need to get the hidden size from the draft model config because
49
49
# the draft model's hidden size can be different from the target model's
50
50
# hidden size (e.g., Llama 3.3 70B).
@@ -245,65 +245,87 @@ def prepare_inputs(
245
245
# [batch_size]
246
246
num_rejected_tokens : torch .Tensor
247
247
) -> tuple [CommonAttentionMetadata , torch .Tensor ]:
248
- # query_start_loc_cpu: [0, a, a + b, a + b + c]
249
- # num_rejected_tokens: [n1, n2, n3]
250
- # num_tokens_per_req: [a - n1, b - n2, c - n3]
251
- # cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
252
- # token_indices: [0, 1, ..., a - n1 - 1,
253
- # a, a + 1, ..., a + b - n2 - 1,
254
- # a + b, a + b + 1, ..., a + b + c - n3 - 1]
248
+ """
249
+ This function is used to prepare the inputs for the spec decode.
250
+ It updates to the common_attn_metadata to account for the rejected
251
+ tokens (and newly sampled tokens). It also returns the token indices
252
+ of the tokens that should be fed to the speculator.
253
+ """
254
+ # E.g.
255
+ # common_attn_metadata.query_start_loc{_cpu}:
256
+ # [0, q1, q1 + q2, q1 + q2 + q3]
257
+ # common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3]
258
+ # num_rejected_tokens: [n1, n2, n3]
259
+ # This function computes the intermediate values:
260
+ # num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3]
261
+ # And returns:
262
+ # common_attn_metadata.query_start_loc{_cpu}:
263
+ # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
264
+ # common_attn_metadata.seq_lens{_cpu}:
265
+ # [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1]
266
+ # token_indices: [0, 1, ..., q1 - n1 - 1,
267
+ # q1, q1 + 1, ..., q1 + q2 - n2 - 1,
268
+ # q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
255
269
256
270
device = common_attn_metadata .query_start_loc .device
257
271
query_start_loc_cpu = common_attn_metadata .query_start_loc_cpu
258
- spec_seq_lens_cpu = \
259
- common_attn_metadata .seq_lens_cpu - num_rejected_tokens + 1
260
-
261
- # [0, a, a + b, a + b + c] -> [a, b, c]
262
- query_len_per_req = (query_start_loc_cpu [1 :] -
263
- query_start_loc_cpu [:- 1 ])
264
- # [a, b, c] -> [a - n1, b - n2, c - n3]
265
- num_tokens_per_req = query_len_per_req - num_rejected_tokens
266
- num_tokens_per_req_np = num_tokens_per_req .numpy ()
267
-
268
- # [a - n1, b - n2, c - n3] ->
269
- # [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
270
- spec_query_start_loc_cpu = torch .zeros (query_start_loc_cpu .shape ,
271
- dtype = torch .int32 ,
272
- pin_memory = True )
273
- spec_query_start_loc_np = spec_query_start_loc_cpu .numpy ()
274
- np .cumsum (num_tokens_per_req_np , out = spec_query_start_loc_np [1 :])
275
- """Get the cumulative sum and batched arange of the given array.
276
- # E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2])
277
- # Equivalent to but faster than:
278
- # np.concatenate([np.arange(n) for n in num_tokens])
279
- """
280
-
281
- # Step 1. [2, 5, 3] -> [2, 7, 10]
282
- total_num_tokens = spec_query_start_loc_np [- 1 ]
283
- # Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7]
284
- cumsums_offsets = np .repeat (spec_query_start_loc_np [:- 1 ],
285
- num_tokens_per_req_np )
286
- # Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
287
- arange = self .arange_np [:total_num_tokens ] - cumsums_offsets
272
+ new_seq_lens_cpu = common_attn_metadata .seq_lens_cpu \
273
+ - num_rejected_tokens
274
+
275
+ # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
276
+ new_query_len_per_req = (query_start_loc_cpu [1 :] -
277
+ query_start_loc_cpu [:- 1 ])
278
+ # [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3]
279
+ new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens
280
+ new_num_tokens_per_req_np = new_num_tokens_per_req .numpy ()
281
+
282
+ # [q1 - n1, q2 - n2, q3 - n3] ->
283
+ # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
284
+ new_query_start_loc_cpu = torch .zeros (query_start_loc_cpu .shape ,
285
+ dtype = torch .int32 ,
286
+ pin_memory = True )
287
+ new_query_start_loc_np = new_query_start_loc_cpu .numpy ()
288
+ np .cumsum (new_num_tokens_per_req_np , out = new_query_start_loc_np [1 :])
289
+
290
+ total_num_tokens = new_query_start_loc_np [- 1 ]
291
+ # Example assuming num_tokens_per_req_np = [2, 4, 3]
292
+ # this implies that `new_query_start_locs` is:
293
+ # [0, 2, 6, 9] ->
294
+ # [0, 0, 2, 2, 2, 2, 6, 6, 6]
295
+ # _r1_ ____r2____ ___r3__
296
+ new_query_start_locs_expanded = np .repeat (new_query_start_loc_np [:- 1 ],
297
+ new_num_tokens_per_req_np )
298
+ # [0, 1, 2, 3, 4, 5, 6, 7, 8] ->
299
+ # [0, 1, 0, 1, 2, 3, 0, 1, 2]
300
+ # _r1_ ____r2____ ___r3__
301
+ token_offests = self .token_arange_np [:total_num_tokens ] \
302
+ - new_query_start_locs_expanded
288
303
289
304
# Expand starting positions to match token pattern
290
- query_start_expanded = np .repeat (query_start_loc_cpu [:- 1 ].numpy (),
291
- num_tokens_per_req_np )
292
- token_indices_np = arange + query_start_expanded
305
+ # [0, q1, q1 + q2] ->
306
+ # [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2]
307
+ # _r1_ _____r2_______ ___________r3____________
308
+ old_query_start_locs_expanded = np .repeat (
309
+ query_start_loc_cpu [:- 1 ].numpy (), new_num_tokens_per_req_np )
310
+ # Final token indices are:
311
+ # [0, 1, // req 1
312
+ # q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2
313
+ # q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
314
+ token_indices_np = token_offests + old_query_start_locs_expanded
293
315
token_indices = torch .from_numpy (token_indices_np ).to (
294
316
device , non_blocking = True )
295
317
296
318
spec_common_attn_metadata = CommonAttentionMetadata (
297
- query_start_loc = spec_query_start_loc_cpu .to (device ,
298
- non_blocking = True ),
299
- seq_lens = spec_seq_lens_cpu .to (device , non_blocking = True ),
300
- query_start_loc_cpu = spec_query_start_loc_cpu ,
301
- seq_lens_cpu = spec_seq_lens_cpu ,
319
+ query_start_loc = new_query_start_loc_cpu .to (device ,
320
+ non_blocking = True ),
321
+ seq_lens = new_seq_lens_cpu .to (device , non_blocking = True ),
322
+ query_start_loc_cpu = new_query_start_loc_cpu ,
323
+ seq_lens_cpu = new_seq_lens_cpu ,
302
324
num_computed_tokens_cpu = common_attn_metadata .
303
325
num_computed_tokens_cpu ,
304
326
num_reqs = common_attn_metadata .num_reqs ,
305
327
num_actual_tokens = total_num_tokens ,
306
- max_query_len = query_len_per_req .max ().item (),
328
+ max_query_len = new_query_len_per_req .max ().item (),
307
329
block_table_tensor = common_attn_metadata .block_table_tensor ,
308
330
slot_mapping = common_attn_metadata .slot_mapping [token_indices ],
309
331
)
0 commit comments