7
7
from typing import TYPE_CHECKING , Any , Optional
8
8
9
9
import torch
10
+
11
+ import vllm .envs as envs
10
12
from flashinfer import (BatchDecodeWithPagedKVCacheWrapper ,
11
13
BatchPrefillWithPagedKVCacheWrapper ,
12
14
MultiLevelCascadeAttentionWrapper )
13
-
14
- import vllm .envs as envs
15
+ from flashinfer .prefill import cudnn_batch_prefill_with_kv_cache
15
16
from vllm .attention .backends .abstract import (AttentionBackend , AttentionImpl ,
16
17
AttentionType )
17
18
from vllm .attention .layer import Attention
18
19
from vllm .config import VllmConfig , get_layers_from_vllm_config
19
20
from vllm .logger import init_logger
21
+ from vllm .platforms import current_platform
20
22
from vllm .v1 .attention .backends .flash_attn import use_cascade_attention
21
23
from vllm .v1 .attention .backends .utils import (AttentionMetadataBuilder ,
22
24
CommonAttentionMetadata ,
33
35
34
36
logger = init_logger (__name__ )
35
37
38
+ CUDNN_SUPPORTED_HEAD_SIZES = [128 ]
39
+
40
+
41
+ def is_cudnn_supported (head_dim : int ):
42
+ return head_dim in CUDNN_SUPPORTED_HEAD_SIZES \
43
+ and current_platform .has_device_capability (100 )
44
+
36
45
37
46
class FlashInferBackend (AttentionBackend ):
38
47
@@ -202,6 +211,12 @@ class FlashInferMetadata:
202
211
num_prefills : int
203
212
num_prefill_tokens : int
204
213
214
+ # For cudnn prefill
215
+ max_query_len : int
216
+ max_seq_len : int
217
+ actual_seq_lens_q : torch .Tensor
218
+ actual_seq_lens_kv : torch .Tensor
219
+
205
220
# For cascade attention.
206
221
use_cascade : bool
207
222
shared_qo_indptr : Optional [torch .Tensor ] = None
@@ -213,6 +228,9 @@ class FlashInferMetadata:
213
228
decode_wrapper : Optional [BatchDecodeWithPagedKVCacheWrapper ] = None
214
229
cascade_wrapper : Optional [MultiLevelCascadeAttentionWrapper ] = None
215
230
231
+ cudnn_workspace : Optional [torch .Tensor ] = None
232
+ block_table : Optional [torch .Tensor ] = None
233
+
216
234
@property
217
235
def query_start_loc (self ):
218
236
# The GPUModelRunner expects to be able to access this property.
@@ -301,9 +319,13 @@ def reorder_batch(self, input_batch: InputBatch,
301
319
302
320
def _get_workspace_buffer (self ):
303
321
if self ._workspace_buffer is None :
322
+ if is_cudnn_supported (self .kv_cache_spec .head_size ):
323
+ dtype = torch .int8
324
+ else :
325
+ dtype = torch .uint8
304
326
self ._workspace_buffer = torch .empty (
305
327
FLASHINFER_WORKSPACE_BUFFER_SIZE ,
306
- dtype = torch . uint8 ,
328
+ dtype = dtype ,
307
329
device = self .runner .device )
308
330
return self ._workspace_buffer
309
331
@@ -367,7 +389,8 @@ def _plan(self, attn_metadata: FlashInferMetadata):
367
389
# Regular attention (common case).
368
390
# Decodes are at the front and prefills are at the back,
369
391
# according to reorder_batch()
370
- if self ._num_prefills > 0 :
392
+ if self ._num_prefills > 0 and not is_cudnn_supported (
393
+ attn_metadata .head_dim ):
371
394
# Decodes are first so prefills start after the last decode
372
395
prefill_start = self ._num_decodes
373
396
attn_metadata .prefill_wrapper = self ._get_prefill_wrapper ()
@@ -433,6 +456,7 @@ def build(self, common_prefix_len: int,
433
456
qo_indptr = common_attn_metadata .query_start_loc
434
457
seq_lens = common_attn_metadata .seq_lens
435
458
block_table_tensor = self .block_table .get_device_tensor ()[:num_reqs ]
459
+ max_query_len = common_attn_metadata .max_query_len
436
460
slot_mapping = self .block_table .slot_mapping_cpu [:num_actual_tokens ].to (
437
461
self .runner .device , non_blocking = True ).long ()
438
462
@@ -463,6 +487,7 @@ def build(self, common_prefix_len: int,
463
487
shared_kv_page_indices = None
464
488
shared_kv_last_page_len = None
465
489
490
+ max_seq_len = int (seq_lens .max ().item ())
466
491
mask = (torch .arange (block_table_tensor .size (1 ),
467
492
dtype = block_table_tensor .dtype ,
468
493
device = block_table_tensor .device ).unsqueeze (0 )
@@ -480,6 +505,10 @@ def build(self, common_prefix_len: int,
480
505
paged_kv_last_page_len = torch .where (paged_kv_last_page_len == 0 ,
481
506
page_size , paged_kv_last_page_len )
482
507
508
+ if is_cudnn_supported (self .kv_cache_spec .head_size ):
509
+ self ._get_workspace_buffer ()
510
+ assert self ._workspace_buffer is not None , "workspace_buffer is not set"
511
+
483
512
attn_metadata = FlashInferMetadata (
484
513
num_actual_tokens = num_actual_tokens ,
485
514
qo_indptr = qo_indptr ,
@@ -502,7 +531,13 @@ def build(self, common_prefix_len: int,
502
531
shared_kv_page_indptr = shared_kv_page_indptr ,
503
532
shared_kv_page_indices = shared_kv_page_indices ,
504
533
shared_kv_last_page_len = shared_kv_last_page_len ,
505
- )
534
+ max_query_len = max_query_len ,
535
+ max_seq_len = max_seq_len ,
536
+ actual_seq_lens_q = qo_indptr [1 :] - qo_indptr [:- 1 ],
537
+ actual_seq_lens_kv = seq_lens .to (self .runner .device ),
538
+ block_table = block_table_tensor ,
539
+ cudnn_workspace = self ._workspace_buffer
540
+ if is_cudnn_supported (self .kv_cache_spec .head_size ) else None )
506
541
507
542
self ._plan (attn_metadata )
508
543
@@ -653,13 +688,55 @@ def forward(
653
688
assert prefill_wrapper ._logits_soft_cap == (self .logits_soft_cap
654
689
or 0.0 )
655
690
assert prefill_wrapper ._sm_scale == self .scale
691
+
656
692
prefill_wrapper .run (
657
693
prefill_query ,
658
694
kv_cache .permute (* stride_order ),
659
695
k_scale = layer ._k_scale_float ,
660
696
v_scale = layer ._v_scale_float ,
661
697
out = output [num_decode_tokens :],
662
698
)
699
+ elif num_prefill_tokens > 0 and FlashInferBackend .is_cudnn_supported (
700
+ attn_metadata .head_dim ):
701
+ (total_num_pages , _ , page_size , num_kv_heads ,
702
+ head_dim ) = kv_cache .shape
703
+
704
+ # Validate dimensions match expected head_dim
705
+ assert head_dim == self .head_size , (
706
+ f"KV cache head_dim { head_dim } != expected { self .head_size } " )
707
+ assert attn_metadata .block_table is not None , \
708
+ "block_table is not set"
709
+ k_cache = kv_cache [:, 0 ].as_strided (
710
+ (total_num_pages , num_kv_heads , page_size , head_dim ), (
711
+ page_size * num_kv_heads * head_dim ,
712
+ head_dim ,
713
+ num_kv_heads * head_dim ,
714
+ 1 ,
715
+ ))
716
+ v_cache = kv_cache [:, 1 ].as_strided (
717
+ (total_num_pages , num_kv_heads , page_size , head_dim ), (
718
+ page_size * num_kv_heads * head_dim ,
719
+ head_dim ,
720
+ num_kv_heads * head_dim ,
721
+ 1 ,
722
+ ))
723
+ output [num_decode_tokens :], _ = cudnn_batch_prefill_with_kv_cache (
724
+ q = query [num_decode_tokens :],
725
+ k_cache = k_cache ,
726
+ v_cache = v_cache ,
727
+ scale = self .scale ,
728
+ workspace_buffer = attn_metadata .cudnn_workspace ,
729
+ max_token_per_sequence = attn_metadata .max_query_len ,
730
+ max_sequence_kv = attn_metadata .max_seq_len ,
731
+ block_tables = attn_metadata .block_table [num_decode_tokens :],
732
+ actual_seq_lens_q = attn_metadata .
733
+ actual_seq_lens_q [num_decode_tokens :].view (- 1 , 1 , 1 , 1 ),
734
+ actual_seq_lens_kv = attn_metadata .
735
+ actual_seq_lens_kv [num_decode_tokens :].view (- 1 , 1 , 1 , 1 ),
736
+ causal = True ,
737
+ return_lse = True ,
738
+ is_cuda_graph_compatible = True ,
739
+ )
663
740
664
741
if decode_wrapper := attn_metadata .decode_wrapper :
665
742
decode_query = query [:num_decode_tokens ]
0 commit comments