@@ -239,11 +239,11 @@ def propose(
239
239
return draft_token_ids
240
240
241
241
def prepare_inputs (
242
- self ,
243
- common_attn_metadata : CommonAttentionMetadata ,
244
- # [batch_size]
245
- num_rejected_tokens : torch .Tensor ,
246
- num_tokens : int ) -> tuple [CommonAttentionMetadata , torch .Tensor ]:
242
+ self ,
243
+ common_attn_metadata : CommonAttentionMetadata ,
244
+ # [batch_size]
245
+ num_rejected_tokens : torch .Tensor
246
+ ) -> tuple [CommonAttentionMetadata , torch .Tensor ]:
247
247
# query_start_loc_cpu: [0, a, a + b, a + b + c]
248
248
# num_rejected_tokens: [n1, n2, n3]
249
249
# num_tokens_per_req: [a - n1, b - n2, c - n3]
@@ -262,54 +262,52 @@ def prepare_inputs(
262
262
query_start_loc_cpu [:- 1 ])
263
263
# [a, b, c] -> [a - n1, b - n2, c - n3]
264
264
num_tokens_per_req = query_len_per_req - num_rejected_tokens
265
+ num_tokens_per_req_np = num_tokens_per_req .numpy ()
265
266
266
267
# [a - n1, b - n2, c - n3] ->
267
268
# [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
268
- spec_query_start_loc_cpu = torch .zeros_like (query_start_loc_cpu ,
269
- pin_memory = True )
270
- torch . cumsum ( num_tokens_per_req ,
271
- dim = 0 ,
272
- out = spec_query_start_loc_cpu [1 :])
269
+ spec_query_start_loc_cpu = torch .zeros (query_start_loc_cpu . shape ,
270
+ dtype = torch . int32 ,
271
+ pin_memory = True )
272
+ spec_query_start_loc_np = spec_query_start_loc_cpu . numpy ()
273
+ np . cumsum ( num_tokens_per_req_np , out = spec_query_start_loc_np [1 :])
273
274
"""Get the cumulative sum and batched arange of the given array.
274
275
# E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2])
275
276
# Equivalent to but faster than:
276
277
# np.concatenate([np.arange(n) for n in num_tokens])
277
278
"""
279
+
278
280
# Step 1. [2, 5, 3] -> [2, 7, 10]
279
- total_num_tokens = spec_query_start_loc_cpu [- 1 ]
281
+ total_num_tokens = spec_query_start_loc_np [- 1 ]
280
282
# Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7]
281
- cumsums_offsets = np .repeat (
282
- spec_query_start_loc_cpu [1 :].numpy () - num_tokens_per_req .numpy (),
283
- num_tokens_per_req .numpy ())
283
+ cumsums_offsets = np .repeat (spec_query_start_loc_np [:- 1 ],
284
+ num_tokens_per_req_np )
284
285
# Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
285
286
arange = self .arange_np [:total_num_tokens ] - cumsums_offsets
286
287
287
288
# Expand starting positions to match token pattern
288
289
query_start_expanded = np .repeat (query_start_loc_cpu [:- 1 ].numpy (),
289
- num_tokens_per_req .numpy ())
290
- tokens_indices = arange + query_start_expanded
291
-
292
- # Ensure tokens_indices are within valid range for slot_mapping
293
- max_slot_idx = common_attn_metadata .slot_mapping .size (0 ) - 1
294
- tokens_indices = np .clip (tokens_indices , 0 , max_slot_idx )
290
+ num_tokens_per_req_np )
291
+ token_indices_np = arange + query_start_expanded
292
+ token_indices = torch .from_numpy (token_indices_np ).to (
293
+ device , non_blocking = True )
295
294
296
295
spec_common_attn_metadata = CommonAttentionMetadata (
297
296
query_start_loc = spec_query_start_loc_cpu .to (device ,
298
297
non_blocking = True ),
299
298
seq_lens = spec_seq_lens_cpu .to (device , non_blocking = True ),
300
- query_start_loc_cpu = spec_query_start_loc_cpu . cpu () ,
301
- seq_lens_cpu = spec_seq_lens_cpu . cpu () ,
302
- num_computed_tokens_cpu = (
303
- common_attn_metadata . num_computed_tokens_cpu ) ,
299
+ query_start_loc_cpu = spec_query_start_loc_cpu ,
300
+ seq_lens_cpu = spec_seq_lens_cpu ,
301
+ num_computed_tokens_cpu = common_attn_metadata .
302
+ num_computed_tokens_cpu ,
304
303
num_reqs = common_attn_metadata .num_reqs ,
305
304
num_actual_tokens = total_num_tokens ,
306
305
max_query_len = query_len_per_req .max ().item (),
307
306
block_table_tensor = common_attn_metadata .block_table_tensor ,
308
- slot_mapping = common_attn_metadata .slot_mapping [tokens_indices ],
307
+ slot_mapping = common_attn_metadata .slot_mapping [token_indices ],
309
308
)
310
309
311
- return spec_common_attn_metadata , torch .from_numpy (tokens_indices ).to (
312
- device )
310
+ return spec_common_attn_metadata , token_indices
313
311
314
312
def load_model (self , target_model : nn .Module ) -> None :
315
313
draft_model_config = \
0 commit comments