7
7
from typing import TYPE_CHECKING , Any , Optional
8
8
9
9
import torch
10
+
11
+ import vllm .envs as envs
10
12
from flashinfer import (BatchDecodeWithPagedKVCacheWrapper ,
11
13
BatchPrefillWithPagedKVCacheWrapper ,
12
14
MultiLevelCascadeAttentionWrapper )
13
-
14
- import vllm .envs as envs
15
+ from flashinfer .prefill import cudnn_batch_prefill_with_kv_cache
15
16
from vllm .attention .backends .abstract import (AttentionBackend , AttentionImpl ,
16
17
AttentionType )
17
18
from vllm .attention .layer import Attention
33
34
34
35
logger = init_logger (__name__ )
35
36
37
+ CUDNN_SUPPORTED_HEAD_SIZES = [128 ]
38
+
36
39
37
40
class FlashInferBackend (AttentionBackend ):
38
41
@@ -202,6 +205,12 @@ class FlashInferMetadata:
202
205
num_prefills : int
203
206
num_prefill_tokens : int
204
207
208
+ # For cudnn prefill
209
+ max_query_len : int
210
+ max_seq_len : int
211
+ actual_seq_lens_q : torch .Tensor
212
+ actual_seq_lens_kv : torch .Tensor
213
+
205
214
# For cascade attention.
206
215
use_cascade : bool
207
216
shared_qo_indptr : Optional [torch .Tensor ] = None
@@ -213,6 +222,12 @@ class FlashInferMetadata:
213
222
decode_wrapper : Optional [BatchDecodeWithPagedKVCacheWrapper ] = None
214
223
cascade_wrapper : Optional [MultiLevelCascadeAttentionWrapper ] = None
215
224
225
+ cudnn_workspace : Optional [torch .Tensor ] = None
226
+ block_table : Optional [torch .Tensor ] = None
227
+
228
+ def _is_cudnn_supported (self ):
229
+ return self .head_dim in CUDNN_SUPPORTED_HEAD_SIZES and envs .VLLM_USE_CUDNN_PREFILL
230
+
216
231
@property
217
232
def query_start_loc (self ):
218
233
# The GPUModelRunner expects to be able to access this property.
@@ -367,7 +382,8 @@ def _plan(self, attn_metadata: FlashInferMetadata):
367
382
# Regular attention (common case).
368
383
# Decodes are at the front and prefills are at the back,
369
384
# according to reorder_batch()
370
- if self ._num_prefills > 0 :
385
+ if self ._num_prefills > 0 and not attn_metadata ._is_cudnn_supported (
386
+ ):
371
387
# Decodes are first so prefills start after the last decode
372
388
prefill_start = self ._num_decodes
373
389
attn_metadata .prefill_wrapper = self ._get_prefill_wrapper ()
@@ -433,6 +449,7 @@ def build(self, common_prefix_len: int,
433
449
qo_indptr = common_attn_metadata .query_start_loc
434
450
seq_lens = common_attn_metadata .seq_lens
435
451
block_table_tensor = self .block_table .get_device_tensor ()[:num_reqs ]
452
+ max_query_len = common_attn_metadata .max_query_len
436
453
slot_mapping = self .block_table .slot_mapping_cpu [:num_actual_tokens ].to (
437
454
self .runner .device , non_blocking = True ).long ()
438
455
@@ -463,6 +480,7 @@ def build(self, common_prefix_len: int,
463
480
shared_kv_page_indices = None
464
481
shared_kv_last_page_len = None
465
482
483
+ max_seq_len = int (seq_lens .max ().item ())
466
484
mask = (torch .arange (block_table_tensor .size (1 ),
467
485
dtype = block_table_tensor .dtype ,
468
486
device = block_table_tensor .device ).unsqueeze (0 )
@@ -479,7 +497,7 @@ def build(self, common_prefix_len: int,
479
497
paged_kv_last_page_len = seq_lens % page_size
480
498
paged_kv_last_page_len = torch .where (paged_kv_last_page_len == 0 ,
481
499
page_size , paged_kv_last_page_len )
482
-
500
+ self . _get_workspace_buffer ()
483
501
attn_metadata = FlashInferMetadata (
484
502
num_actual_tokens = num_actual_tokens ,
485
503
qo_indptr = qo_indptr ,
@@ -502,6 +520,12 @@ def build(self, common_prefix_len: int,
502
520
shared_kv_page_indptr = shared_kv_page_indptr ,
503
521
shared_kv_page_indices = shared_kv_page_indices ,
504
522
shared_kv_last_page_len = shared_kv_last_page_len ,
523
+ max_query_len = max_query_len ,
524
+ max_seq_len = max_seq_len ,
525
+ actual_seq_lens_q = qo_indptr [1 :] - qo_indptr [:- 1 ],
526
+ actual_seq_lens_kv = seq_lens .to (self .runner .device ),
527
+ block_table = block_table_tensor ,
528
+ cudnn_workspace = self ._workspace_buffer .to (torch .int8 ),
505
529
)
506
530
507
531
self ._plan (attn_metadata )
@@ -653,13 +677,48 @@ def forward(
653
677
assert prefill_wrapper ._logits_soft_cap == (self .logits_soft_cap
654
678
or 0.0 )
655
679
assert prefill_wrapper ._sm_scale == self .scale
680
+
656
681
prefill_wrapper .run (
657
682
prefill_query ,
658
683
kv_cache .permute (* stride_order ),
659
684
k_scale = layer ._k_scale_float ,
660
685
v_scale = layer ._v_scale_float ,
661
686
out = output [num_decode_tokens :],
662
687
)
688
+ elif num_prefill_tokens > 0 and attn_metadata ._is_cudnn_supported ():
689
+ (total_num_pages , _ , page_size , num_kv_heads ,
690
+ head_dim ) = kv_cache .shape
691
+ k_cache = kv_cache [:, 0 ].as_strided (
692
+ (total_num_pages , num_kv_heads , page_size , head_dim ), (
693
+ page_size * num_kv_heads * head_dim ,
694
+ head_dim ,
695
+ num_kv_heads * head_dim ,
696
+ 1 ,
697
+ ))
698
+ v_cache = kv_cache [:, 1 ].as_strided (
699
+ (total_num_pages , num_kv_heads , page_size , head_dim ), (
700
+ page_size * num_kv_heads * head_dim ,
701
+ head_dim ,
702
+ num_kv_heads * head_dim ,
703
+ 1 ,
704
+ ))
705
+ output [num_decode_tokens :], _ = cudnn_batch_prefill_with_kv_cache (
706
+ q = query [num_decode_tokens :],
707
+ k_cache = k_cache ,
708
+ v_cache = v_cache ,
709
+ scale = self .scale ,
710
+ workspace_buffer = attn_metadata .cudnn_workspace ,
711
+ max_token_per_sequence = attn_metadata .max_query_len ,
712
+ max_sequence_kv = attn_metadata .max_seq_len ,
713
+ block_tables = attn_metadata .block_table [num_decode_tokens :],
714
+ actual_seq_lens_q = attn_metadata .
715
+ actual_seq_lens_q [num_decode_tokens :].view (- 1 , 1 , 1 , 1 ),
716
+ actual_seq_lens_kv = attn_metadata .
717
+ actual_seq_lens_kv [num_decode_tokens :].view (- 1 , 1 , 1 , 1 ),
718
+ causal = True ,
719
+ return_lse = True ,
720
+ is_cuda_graph_compatible = True ,
721
+ )
663
722
664
723
if decode_wrapper := attn_metadata .decode_wrapper :
665
724
decode_query = query [:num_decode_tokens ]
0 commit comments