7
7
from typing import TYPE_CHECKING , Any , Optional
8
8
9
9
import torch
10
+
11
+ import vllm .envs as envs
10
12
from flashinfer import (BatchDecodeWithPagedKVCacheWrapper ,
11
13
BatchPrefillWithPagedKVCacheWrapper ,
12
14
MultiLevelCascadeAttentionWrapper )
13
15
from flashinfer .decode import trtllm_batch_decode_with_kv_cache
14
-
15
- import vllm .envs as envs
16
16
from vllm .attention .backends .abstract import (AttentionBackend , AttentionImpl ,
17
17
AttentionType )
18
18
from vllm .config import VllmConfig
19
19
from vllm .logger import init_logger
20
20
from vllm .platforms import current_platform
21
+ from vllm .utils import cdiv
21
22
from vllm .v1 .attention .backends .flash_attn import use_cascade_attention
22
23
from vllm .v1 .attention .backends .utils import (
23
24
AttentionMetadataBuilder , CommonAttentionMetadata , PerLayerParameters ,
@@ -167,13 +168,13 @@ class FlashInferMetadata:
167
168
# [0, 5, 8, 1, 6, 7, 3, 4]
168
169
# paged_kv_indptr is used to index into paged_kv_indices:
169
170
# [0, 3, 6, 8]
170
- # The indptr of the paged kv cache, shape: [batch_size + 1]
171
- paged_kv_indptr : torch .Tensor
172
- # The page indices of the paged kv cache
171
+ # The indptr of the paged kv cache, shape: [batch_size + 1] (CPU for plan)
172
+ paged_kv_indptr_cpu : torch .Tensor
173
+ # The page indices of the paged kv cache (on device for plan)
173
174
paged_kv_indices : torch .Tensor
174
175
# The number of entries in the last page of each request in
175
- # the paged kv cache, shape: [batch_size]
176
- paged_kv_last_page_len : torch .Tensor
176
+ # the paged kv cache, shape: [batch_size] (CPU for plan)
177
+ paged_kv_last_page_len_cpu : torch .Tensor
177
178
# The number of query/output heads
178
179
num_qo_heads : int
179
180
# The number of key/value heads
@@ -201,17 +202,20 @@ class FlashInferMetadata:
201
202
num_prefills : int
202
203
num_prefill_tokens : int
203
204
204
- # For cascade attention.
205
+ # For cascade attention (CPU for planning) .
205
206
use_cascade : bool
206
- shared_qo_indptr : Optional [torch .Tensor ] = None
207
- shared_kv_page_indptr : Optional [torch .Tensor ] = None
208
- shared_kv_page_indices : Optional [torch .Tensor ] = None
209
- shared_kv_last_page_len : Optional [torch .Tensor ] = None
207
+ shared_qo_indptr_cpu : Optional [torch .Tensor ] = None
208
+ shared_kv_page_indptr_cpu : Optional [torch .Tensor ] = None
209
+ shared_kv_page_indices_cpu : Optional [torch .Tensor ] = None
210
+ shared_kv_last_page_len_cpu : Optional [torch .Tensor ] = None
210
211
211
212
prefill_wrapper : Optional [BatchPrefillWithPagedKVCacheWrapper ] = None
212
213
decode_wrapper : Optional [BatchDecodeWithPagedKVCacheWrapper ] = None
213
214
cascade_wrapper : Optional [MultiLevelCascadeAttentionWrapper ] = None
214
215
216
+ # CPU version for FlashInfer planning
217
+ qo_indptr_cpu : Optional [torch .Tensor ] = None
218
+
215
219
@property
216
220
def query_start_loc (self ):
217
221
# The GPUModelRunner expects to be able to access this property.
@@ -238,6 +242,12 @@ def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
238
242
self .vllm_config = vllm_config
239
243
self .cache_config = vllm_config .cache_config
240
244
self .kv_cache_spec = kv_cache_spec
245
+ max_num_blocks_per_request = cdiv (
246
+ vllm_config .model_config .max_model_len ,
247
+ self .kv_cache_spec .block_size )
248
+ self .block_table_arange = torch .arange (max_num_blocks_per_request ,
249
+ dtype = torch .int32 ,
250
+ device = self .device )
241
251
242
252
def reorder_batch (self , input_batch : InputBatch ,
243
253
scheduler_output : SchedulerOutput ) -> bool :
@@ -285,21 +295,31 @@ def _plan(self, num_prefills: int, num_decodes: int,
285
295
if self .global_hyperparameters is None :
286
296
self .global_hyperparameters = infer_global_hyperparameters (
287
297
get_per_layer_parameters (self .vllm_config , FlashInferImpl ))
298
+
299
+ # Ensure CPU tensors are not None
300
+ assert attn_metadata .qo_indptr_cpu is not None
301
+ assert attn_metadata .paged_kv_indptr_cpu is not None
302
+ assert attn_metadata .paged_kv_indices is not None
303
+ assert attn_metadata .paged_kv_last_page_len_cpu is not None
304
+
288
305
if attn_metadata .use_cascade :
289
306
attn_metadata .cascade_wrapper = self ._get_cascade_wrapper ()
290
307
attn_metadata .cascade_wrapper .plan (
291
- [attn_metadata .shared_qo_indptr , attn_metadata .qo_indptr ],
292
308
[
293
- attn_metadata .shared_kv_page_indptr ,
294
- attn_metadata .paged_kv_indptr
309
+ attn_metadata .shared_qo_indptr_cpu ,
310
+ attn_metadata .qo_indptr_cpu
311
+ ],
312
+ [
313
+ attn_metadata .shared_kv_page_indptr_cpu ,
314
+ attn_metadata .paged_kv_indptr_cpu
295
315
],
296
316
[
297
- attn_metadata .shared_kv_page_indices ,
317
+ attn_metadata .shared_kv_page_indices_cpu ,
298
318
attn_metadata .paged_kv_indices
299
319
],
300
320
[
301
- attn_metadata .shared_kv_last_page_len ,
302
- attn_metadata .paged_kv_last_page_len
321
+ attn_metadata .shared_kv_last_page_len_cpu ,
322
+ attn_metadata .paged_kv_last_page_len_cpu
303
323
],
304
324
attn_metadata .num_qo_heads ,
305
325
attn_metadata .num_kv_heads ,
@@ -320,22 +340,22 @@ def _plan(self, num_prefills: int, num_decodes: int,
320
340
# Decodes are first so prefills start after the last decode
321
341
prefill_start = num_decodes
322
342
attn_metadata .prefill_wrapper = self ._get_prefill_wrapper ()
323
- assert attn_metadata .qo_indptr [prefill_start :].shape [
343
+ assert attn_metadata .qo_indptr_cpu [prefill_start :].shape [
324
344
0 ] == num_prefills + 1
325
- assert attn_metadata .paged_kv_indptr [prefill_start :].shape [
345
+ assert attn_metadata .paged_kv_indptr_cpu [prefill_start :].shape [
326
346
0 ] == num_prefills + 1
327
- assert attn_metadata .paged_kv_last_page_len [
347
+ assert attn_metadata .paged_kv_last_page_len_cpu [
328
348
prefill_start :].shape [0 ] == num_prefills
329
349
# Since prefill_wrapper.run() will be called with
330
350
# query[num_decode_tokens:] we need to adjust the qo_indptr
331
351
# to be relative to the start of the prefill queries.
332
- qo_indptr = attn_metadata .qo_indptr [
333
- prefill_start :] - attn_metadata .qo_indptr [prefill_start ]
352
+ qo_indptr_cpu = attn_metadata .qo_indptr_cpu [
353
+ prefill_start :] - attn_metadata .qo_indptr_cpu [prefill_start ]
334
354
attn_metadata .prefill_wrapper .plan (
335
- qo_indptr ,
336
- attn_metadata .paged_kv_indptr [prefill_start :],
355
+ qo_indptr_cpu ,
356
+ attn_metadata .paged_kv_indptr_cpu [prefill_start :],
337
357
attn_metadata .paged_kv_indices ,
338
- attn_metadata .paged_kv_last_page_len [prefill_start :],
358
+ attn_metadata .paged_kv_last_page_len_cpu [prefill_start :],
339
359
attn_metadata .num_qo_heads ,
340
360
attn_metadata .num_kv_heads ,
341
361
attn_metadata .head_dim ,
@@ -356,9 +376,9 @@ def _plan(self, num_prefills: int, num_decodes: int,
356
376
attn_metadata .kv_data_type , attn_metadata .num_qo_heads ,
357
377
attn_metadata .num_kv_heads , attn_metadata .head_dim ):
358
378
attn_metadata .decode_wrapper .plan (
359
- attn_metadata .paged_kv_indptr [:num_decodes + 1 ],
379
+ attn_metadata .paged_kv_indptr_cpu [:num_decodes + 1 ],
360
380
attn_metadata .paged_kv_indices ,
361
- attn_metadata .paged_kv_last_page_len [:num_decodes ],
381
+ attn_metadata .paged_kv_last_page_len_cpu [:num_decodes ],
362
382
attn_metadata .num_qo_heads ,
363
383
attn_metadata .num_kv_heads ,
364
384
attn_metadata .head_dim ,
@@ -382,55 +402,62 @@ def build(self,
382
402
split_decodes_and_prefills (common_attn_metadata )
383
403
384
404
page_size = self .kv_cache_spec .block_size
385
- device = self .device
386
405
qo_indptr = common_attn_metadata .query_start_loc
387
406
max_seq_len = common_attn_metadata .seq_lens_cpu .max ()
388
407
seq_lens = common_attn_metadata .seq_lens
389
408
block_table_tensor = common_attn_metadata .block_table_tensor
390
409
391
- block_table_bounds = (seq_lens + page_size - 1 ) // page_size
410
+ # Build CPU versions directly from seq_lens_cpu
411
+ seq_lens_cpu = common_attn_metadata .seq_lens_cpu
412
+ block_table_bounds_cpu = (seq_lens_cpu + page_size - 1 ) // page_size
392
413
393
414
use_cascade = common_prefix_len > 0
394
415
if use_cascade :
395
416
# Grab the blocks of the shared prefix from the first request.
396
417
assert common_prefix_len % page_size == 0
397
418
num_common_kv_blocks = common_prefix_len // page_size
398
- shared_qo_indptr = torch .tensor ([0 , num_actual_tokens ],
399
- dtype = torch .int32 ,
400
- device = device )
401
- shared_kv_page_indptr = torch .tensor ([0 , num_common_kv_blocks ],
402
- dtype = torch .int32 ,
403
- device = device )
404
- shared_kv_page_indices = block_table_tensor [
419
+
420
+ # Create CPU versions directly for cascade (no GPU versions needed)
421
+ shared_qo_indptr_cpu = torch .tensor ([0 , num_actual_tokens ],
422
+ dtype = torch .int32 ,
423
+ device = 'cpu' )
424
+ shared_kv_page_indptr_cpu = torch .tensor ([0 , num_common_kv_blocks ],
425
+ dtype = torch .int32 ,
426
+ device = 'cpu' )
427
+ shared_kv_page_indices_cpu = block_table_tensor [
405
428
0 , :num_common_kv_blocks ]
406
- shared_kv_last_page_len = torch .tensor ([page_size ],
407
- dtype = torch .int32 ,
408
- device = device )
429
+ shared_kv_last_page_len_cpu = torch .tensor ([page_size ],
430
+ dtype = torch .int32 ,
431
+ device = 'cpu' )
432
+
409
433
# Remove the blocks of the shared prefix from all requests.
410
434
block_table_tensor = block_table_tensor [:, num_common_kv_blocks :]
411
- block_table_bounds -= num_common_kv_blocks
435
+ block_table_bounds_cpu -= num_common_kv_blocks
412
436
else :
413
- shared_qo_indptr = None
414
- shared_kv_page_indptr = None
415
- shared_kv_page_indices = None
416
- shared_kv_last_page_len = None
417
-
418
- mask = (torch .arange (block_table_tensor .size (1 ),
419
- dtype = block_table_tensor .dtype ,
420
- device = block_table_tensor .device ).unsqueeze (0 )
437
+ shared_qo_indptr_cpu = None
438
+ shared_kv_page_indptr_cpu = None
439
+ shared_kv_page_indices_cpu = None
440
+ shared_kv_last_page_len_cpu = None
441
+
442
+ max_num_blocks = block_table_bounds_cpu .max ()
443
+ block_table_bounds = block_table_bounds_cpu .to (self .device ,
444
+ non_blocking = True )
445
+ mask = (self .block_table_arange [:max_num_blocks ].unsqueeze (0 )
421
446
< block_table_bounds .unsqueeze (1 ))
422
- paged_kv_indices = block_table_tensor [mask ]
423
-
424
- paged_kv_indptr = torch .cat ([
425
- torch .zeros (1 ,
426
- dtype = block_table_bounds .dtype ,
427
- device = block_table_bounds .device ),
428
- block_table_bounds .cumsum (dim = 0 , dtype = torch .int32 )
429
- ])
430
-
431
- paged_kv_last_page_len = seq_lens % page_size
432
- paged_kv_last_page_len = torch .where (paged_kv_last_page_len == 0 ,
433
- page_size , paged_kv_last_page_len )
447
+ paged_kv_indices = block_table_tensor [:, :max_num_blocks ][mask ]
448
+
449
+ # paged_kv_indptr_cpu: cumulative sum of block_table_bounds_cpu
450
+ paged_kv_indptr_cpu = torch .zeros (len (block_table_bounds_cpu ) + 1 ,
451
+ dtype = torch .int32 ,
452
+ device = 'cpu' )
453
+ paged_kv_indptr_cpu [1 :] = block_table_bounds_cpu .cumsum (
454
+ dim = 0 , dtype = torch .int32 )
455
+
456
+ # paged_kv_last_page_len_cpu: from seq_lens_cpu
457
+ paged_kv_last_page_len_cpu = seq_lens_cpu % page_size
458
+ paged_kv_last_page_len_cpu = torch .where (
459
+ paged_kv_last_page_len_cpu == 0 , page_size ,
460
+ paged_kv_last_page_len_cpu )
434
461
cache_dtype = self .cache_config .cache_dtype
435
462
if cache_dtype .startswith ("fp8" ):
436
463
kv_cache_dtype = FlashInferBackend .get_fp8_dtype_for_flashinfer (
@@ -440,9 +467,10 @@ def build(self,
440
467
attn_metadata = FlashInferMetadata (
441
468
num_actual_tokens = num_actual_tokens ,
442
469
qo_indptr = qo_indptr ,
443
- paged_kv_indptr = paged_kv_indptr ,
470
+ qo_indptr_cpu = common_attn_metadata .query_start_loc_cpu ,
471
+ paged_kv_indptr_cpu = paged_kv_indptr_cpu ,
444
472
paged_kv_indices = paged_kv_indices ,
445
- paged_kv_last_page_len = paged_kv_last_page_len ,
473
+ paged_kv_last_page_len_cpu = paged_kv_last_page_len_cpu ,
446
474
num_qo_heads = self .vllm_config .model_config .get_num_attention_heads (
447
475
self .vllm_config .parallel_config ),
448
476
num_kv_heads = self .kv_cache_spec .num_kv_heads ,
@@ -456,10 +484,10 @@ def build(self,
456
484
num_prefills = num_prefills ,
457
485
num_prefill_tokens = num_prefill_tokens ,
458
486
use_cascade = use_cascade ,
459
- shared_qo_indptr = shared_qo_indptr ,
460
- shared_kv_page_indptr = shared_kv_page_indptr ,
461
- shared_kv_page_indices = shared_kv_page_indices ,
462
- shared_kv_last_page_len = shared_kv_last_page_len ,
487
+ shared_qo_indptr_cpu = shared_qo_indptr_cpu ,
488
+ shared_kv_page_indptr_cpu = shared_kv_page_indptr_cpu ,
489
+ shared_kv_page_indices_cpu = shared_kv_page_indices_cpu ,
490
+ shared_kv_last_page_len_cpu = shared_kv_last_page_len_cpu ,
463
491
max_seq_len = max_seq_len ,
464
492
seq_lens = seq_lens ,
465
493
block_table_tensor = block_table_tensor ,
0 commit comments