7
7
from typing import TYPE_CHECKING , Any , Optional
8
8
9
9
import torch
10
+
11
+ import vllm .envs as envs
10
12
from flashinfer import (BatchDecodeWithPagedKVCacheWrapper ,
11
13
BatchPrefillWithPagedKVCacheWrapper ,
12
14
MultiLevelCascadeAttentionWrapper )
13
15
from flashinfer .decode import trtllm_batch_decode_with_kv_cache
14
-
15
- import vllm .envs as envs
16
16
from vllm .attention .backends .abstract import (AttentionBackend , AttentionImpl ,
17
17
AttentionType )
18
18
from vllm .config import VllmConfig
@@ -167,13 +167,13 @@ class FlashInferMetadata:
167
167
# [0, 5, 8, 1, 6, 7, 3, 4]
168
168
# paged_kv_indptr is used to index into paged_kv_indices:
169
169
# [0, 3, 6, 8]
170
- # The indptr of the paged kv cache, shape: [batch_size + 1]
171
- paged_kv_indptr : torch .Tensor
172
- # The page indices of the paged kv cache
173
- paged_kv_indices : torch .Tensor
170
+ # The indptr of the paged kv cache, shape: [batch_size + 1] (CPU for plan)
171
+ paged_kv_indptr_cpu : torch .Tensor
172
+ # The page indices of the paged kv cache (CPU for plan)
173
+ paged_kv_indices_cpu : torch .Tensor
174
174
# The number of entries in the last page of each request in
175
- # the paged kv cache, shape: [batch_size]
176
- paged_kv_last_page_len : torch .Tensor
175
+ # the paged kv cache, shape: [batch_size] (CPU for plan)
176
+ paged_kv_last_page_len_cpu : torch .Tensor
177
177
# The number of query/output heads
178
178
num_qo_heads : int
179
179
# The number of key/value heads
@@ -201,17 +201,20 @@ class FlashInferMetadata:
201
201
num_prefills : int
202
202
num_prefill_tokens : int
203
203
204
- # For cascade attention.
204
+ # For cascade attention (CPU for planning) .
205
205
use_cascade : bool
206
- shared_qo_indptr : Optional [torch .Tensor ] = None
207
- shared_kv_page_indptr : Optional [torch .Tensor ] = None
208
- shared_kv_page_indices : Optional [torch .Tensor ] = None
209
- shared_kv_last_page_len : Optional [torch .Tensor ] = None
206
+ shared_qo_indptr_cpu : Optional [torch .Tensor ] = None
207
+ shared_kv_page_indptr_cpu : Optional [torch .Tensor ] = None
208
+ shared_kv_page_indices_cpu : Optional [torch .Tensor ] = None
209
+ shared_kv_last_page_len_cpu : Optional [torch .Tensor ] = None
210
210
211
211
prefill_wrapper : Optional [BatchPrefillWithPagedKVCacheWrapper ] = None
212
212
decode_wrapper : Optional [BatchDecodeWithPagedKVCacheWrapper ] = None
213
213
cascade_wrapper : Optional [MultiLevelCascadeAttentionWrapper ] = None
214
214
215
+ # CPU version for FlashInfer planning
216
+ qo_indptr_cpu : Optional [torch .Tensor ] = None
217
+
215
218
@property
216
219
def query_start_loc (self ):
217
220
# The GPUModelRunner expects to be able to access this property.
@@ -285,21 +288,31 @@ def _plan(self, num_prefills: int, num_decodes: int,
285
288
if self .global_hyperparameters is None :
286
289
self .global_hyperparameters = infer_global_hyperparameters (
287
290
get_per_layer_parameters (self .vllm_config , FlashInferImpl ))
291
+
292
+ # Ensure CPU tensors are not None
293
+ assert attn_metadata .qo_indptr_cpu is not None
294
+ assert attn_metadata .paged_kv_indptr_cpu is not None
295
+ assert attn_metadata .paged_kv_indices_cpu is not None
296
+ assert attn_metadata .paged_kv_last_page_len_cpu is not None
297
+
288
298
if attn_metadata .use_cascade :
289
299
attn_metadata .cascade_wrapper = self ._get_cascade_wrapper ()
290
300
attn_metadata .cascade_wrapper .plan (
291
- [attn_metadata .shared_qo_indptr , attn_metadata .qo_indptr ],
292
301
[
293
- attn_metadata .shared_kv_page_indptr ,
294
- attn_metadata .paged_kv_indptr
302
+ attn_metadata .shared_qo_indptr_cpu ,
303
+ attn_metadata .qo_indptr_cpu
304
+ ],
305
+ [
306
+ attn_metadata .shared_kv_page_indptr_cpu ,
307
+ attn_metadata .paged_kv_indptr_cpu
295
308
],
296
309
[
297
- attn_metadata .shared_kv_page_indices ,
298
- attn_metadata .paged_kv_indices
310
+ attn_metadata .shared_kv_page_indices_cpu ,
311
+ attn_metadata .paged_kv_indices_cpu
299
312
],
300
313
[
301
- attn_metadata .shared_kv_last_page_len ,
302
- attn_metadata .paged_kv_last_page_len
314
+ attn_metadata .shared_kv_last_page_len_cpu ,
315
+ attn_metadata .paged_kv_last_page_len_cpu
303
316
],
304
317
attn_metadata .num_qo_heads ,
305
318
attn_metadata .num_kv_heads ,
@@ -320,22 +333,22 @@ def _plan(self, num_prefills: int, num_decodes: int,
320
333
# Decodes are first so prefills start after the last decode
321
334
prefill_start = num_decodes
322
335
attn_metadata .prefill_wrapper = self ._get_prefill_wrapper ()
323
- assert attn_metadata .qo_indptr [prefill_start :].shape [
336
+ assert attn_metadata .qo_indptr_cpu [prefill_start :].shape [
324
337
0 ] == num_prefills + 1
325
- assert attn_metadata .paged_kv_indptr [prefill_start :].shape [
338
+ assert attn_metadata .paged_kv_indptr_cpu [prefill_start :].shape [
326
339
0 ] == num_prefills + 1
327
- assert attn_metadata .paged_kv_last_page_len [
340
+ assert attn_metadata .paged_kv_last_page_len_cpu [
328
341
prefill_start :].shape [0 ] == num_prefills
329
342
# Since prefill_wrapper.run() will be called with
330
343
# query[num_decode_tokens:] we need to adjust the qo_indptr
331
344
# to be relative to the start of the prefill queries.
332
- qo_indptr = attn_metadata .qo_indptr [
333
- prefill_start :] - attn_metadata .qo_indptr [prefill_start ]
345
+ qo_indptr_cpu = attn_metadata .qo_indptr_cpu [
346
+ prefill_start :] - attn_metadata .qo_indptr_cpu [prefill_start ]
334
347
attn_metadata .prefill_wrapper .plan (
335
- qo_indptr ,
336
- attn_metadata .paged_kv_indptr [prefill_start :],
337
- attn_metadata .paged_kv_indices ,
338
- attn_metadata .paged_kv_last_page_len [prefill_start :],
348
+ qo_indptr_cpu ,
349
+ attn_metadata .paged_kv_indptr_cpu [prefill_start :],
350
+ attn_metadata .paged_kv_indices_cpu ,
351
+ attn_metadata .paged_kv_last_page_len_cpu [prefill_start :],
339
352
attn_metadata .num_qo_heads ,
340
353
attn_metadata .num_kv_heads ,
341
354
attn_metadata .head_dim ,
@@ -356,9 +369,9 @@ def _plan(self, num_prefills: int, num_decodes: int,
356
369
attn_metadata .kv_data_type , attn_metadata .num_qo_heads ,
357
370
attn_metadata .num_kv_heads , attn_metadata .head_dim ):
358
371
attn_metadata .decode_wrapper .plan (
359
- attn_metadata .paged_kv_indptr [:num_decodes + 1 ],
360
- attn_metadata .paged_kv_indices ,
361
- attn_metadata .paged_kv_last_page_len [:num_decodes ],
372
+ attn_metadata .paged_kv_indptr_cpu [:num_decodes + 1 ],
373
+ attn_metadata .paged_kv_indices_cpu ,
374
+ attn_metadata .paged_kv_last_page_len_cpu [:num_decodes ],
362
375
attn_metadata .num_qo_heads ,
363
376
attn_metadata .num_kv_heads ,
364
377
attn_metadata .head_dim ,
@@ -382,55 +395,62 @@ def build(self,
382
395
split_decodes_and_prefills (common_attn_metadata )
383
396
384
397
page_size = self .kv_cache_spec .block_size
385
- device = self .device
386
398
qo_indptr = common_attn_metadata .query_start_loc
387
399
max_seq_len = common_attn_metadata .seq_lens_cpu .max ()
388
400
seq_lens = common_attn_metadata .seq_lens
389
401
block_table_tensor = common_attn_metadata .block_table_tensor
390
402
391
- block_table_bounds = (seq_lens + page_size - 1 ) // page_size
403
+ # Build CPU versions directly from seq_lens_cpu
404
+ seq_lens_cpu = common_attn_metadata .seq_lens_cpu
405
+ block_table_bounds_cpu = (seq_lens_cpu + page_size - 1 ) // page_size
392
406
393
407
use_cascade = common_prefix_len > 0
394
408
if use_cascade :
395
409
# Grab the blocks of the shared prefix from the first request.
396
410
assert common_prefix_len % page_size == 0
397
411
num_common_kv_blocks = common_prefix_len // page_size
398
- shared_qo_indptr = torch .tensor ([0 , num_actual_tokens ],
399
- dtype = torch .int32 ,
400
- device = device )
401
- shared_kv_page_indptr = torch .tensor ([0 , num_common_kv_blocks ],
402
- dtype = torch .int32 ,
403
- device = device )
404
- shared_kv_page_indices = block_table_tensor [
412
+
413
+ # Create CPU versions directly for cascade (no GPU versions needed)
414
+ shared_qo_indptr_cpu = torch .tensor ([0 , num_actual_tokens ],
415
+ dtype = torch .int32 ,
416
+ device = 'cpu' )
417
+ shared_kv_page_indptr_cpu = torch .tensor ([0 , num_common_kv_blocks ],
418
+ dtype = torch .int32 ,
419
+ device = 'cpu' )
420
+ shared_kv_page_indices_cpu = block_table_tensor [
405
421
0 , :num_common_kv_blocks ]
406
- shared_kv_last_page_len = torch .tensor ([page_size ],
407
- dtype = torch .int32 ,
408
- device = device )
422
+ shared_kv_last_page_len_cpu = torch .tensor ([page_size ],
423
+ dtype = torch .int32 ,
424
+ device = 'cpu' )
425
+
409
426
# Remove the blocks of the shared prefix from all requests.
410
427
block_table_tensor = block_table_tensor [:, num_common_kv_blocks :]
411
- block_table_bounds -= num_common_kv_blocks
428
+ block_table_bounds_cpu -= num_common_kv_blocks
412
429
else :
413
- shared_qo_indptr = None
414
- shared_kv_page_indptr = None
415
- shared_kv_page_indices = None
416
- shared_kv_last_page_len = None
417
-
418
- mask = (torch .arange (block_table_tensor .size (1 ),
419
- dtype = block_table_tensor .dtype ,
420
- device = block_table_tensor .device ).unsqueeze (0 )
421
- < block_table_bounds .unsqueeze (1 ))
422
- paged_kv_indices = block_table_tensor [mask ]
423
-
424
- paged_kv_indptr = torch .cat ([
425
- torch .zeros (1 ,
426
- dtype = block_table_bounds .dtype ,
427
- device = block_table_bounds .device ),
428
- block_table_bounds .cumsum (dim = 0 , dtype = torch .int32 )
430
+ shared_qo_indptr_cpu = None
431
+ shared_kv_page_indptr_cpu = None
432
+ shared_kv_page_indices_cpu = None
433
+ shared_kv_last_page_len_cpu = None
434
+
435
+ # Build CPU versions directly from CPU data
436
+ # paged_kv_indices_cpu: extract from block_table on CPU
437
+ mask_cpu = (torch .arange (block_table_tensor .size (1 ),
438
+ dtype = torch .int32 ,
439
+ device = 'cpu' ).unsqueeze (0 )
440
+ < block_table_bounds_cpu .unsqueeze (1 ))
441
+ paged_kv_indices_cpu = block_table_tensor .cpu ()[mask_cpu ]
442
+
443
+ # paged_kv_indptr_cpu: cumulative sum of block_table_bounds_cpu
444
+ paged_kv_indptr_cpu = torch .cat ([
445
+ torch .zeros (1 , dtype = torch .int32 , device = 'cpu' ),
446
+ block_table_bounds_cpu .cumsum (dim = 0 , dtype = torch .int32 )
429
447
])
430
448
431
- paged_kv_last_page_len = seq_lens % page_size
432
- paged_kv_last_page_len = torch .where (paged_kv_last_page_len == 0 ,
433
- page_size , paged_kv_last_page_len )
449
+ # paged_kv_last_page_len_cpu: from seq_lens_cpu
450
+ paged_kv_last_page_len_cpu = seq_lens_cpu % page_size
451
+ paged_kv_last_page_len_cpu = torch .where (
452
+ paged_kv_last_page_len_cpu == 0 , page_size ,
453
+ paged_kv_last_page_len_cpu )
434
454
cache_dtype = self .cache_config .cache_dtype
435
455
if cache_dtype .startswith ("fp8" ):
436
456
kv_cache_dtype = FlashInferBackend .get_fp8_dtype_for_flashinfer (
@@ -440,9 +460,10 @@ def build(self,
440
460
attn_metadata = FlashInferMetadata (
441
461
num_actual_tokens = num_actual_tokens ,
442
462
qo_indptr = qo_indptr ,
443
- paged_kv_indptr = paged_kv_indptr ,
444
- paged_kv_indices = paged_kv_indices ,
445
- paged_kv_last_page_len = paged_kv_last_page_len ,
463
+ qo_indptr_cpu = common_attn_metadata .query_start_loc_cpu ,
464
+ paged_kv_indptr_cpu = paged_kv_indptr_cpu ,
465
+ paged_kv_indices_cpu = paged_kv_indices_cpu ,
466
+ paged_kv_last_page_len_cpu = paged_kv_last_page_len_cpu ,
446
467
num_qo_heads = self .vllm_config .model_config .get_num_attention_heads (
447
468
self .vllm_config .parallel_config ),
448
469
num_kv_heads = self .kv_cache_spec .num_kv_heads ,
@@ -456,10 +477,10 @@ def build(self,
456
477
num_prefills = num_prefills ,
457
478
num_prefill_tokens = num_prefill_tokens ,
458
479
use_cascade = use_cascade ,
459
- shared_qo_indptr = shared_qo_indptr ,
460
- shared_kv_page_indptr = shared_kv_page_indptr ,
461
- shared_kv_page_indices = shared_kv_page_indices ,
462
- shared_kv_last_page_len = shared_kv_last_page_len ,
480
+ shared_qo_indptr_cpu = shared_qo_indptr_cpu ,
481
+ shared_kv_page_indptr_cpu = shared_kv_page_indptr_cpu ,
482
+ shared_kv_page_indices_cpu = shared_kv_page_indices_cpu ,
483
+ shared_kv_last_page_len_cpu = shared_kv_last_page_len_cpu ,
463
484
max_seq_len = max_seq_len ,
464
485
seq_lens = seq_lens ,
465
486
block_table_tensor = block_table_tensor ,
0 commit comments