194
194
195
195
import torch
196
196
197
+ import vllm .envs as envs
197
198
from vllm import _custom_ops as ops
198
199
from vllm .attention .backends .abstract import (AttentionBackend , AttentionLayer ,
199
200
AttentionMetadata ,
225
226
226
227
try :
227
228
from flashinfer import BatchPrefillWithRaggedKVCacheWrapper
229
+ from flashinfer .prefill import ( # noqa: F401
230
+ cudnn_batch_prefill_with_kv_cache )
228
231
flashinfer_available = True
229
232
except ImportError :
230
233
flashinfer_available = False
236
239
237
240
logger = init_logger (__name__ )
238
241
242
+ CUDNN_WORKSPACE_SIZE = 12800
243
+
239
244
240
245
class MLACommonBackend (AttentionBackend ):
241
246
@@ -294,6 +299,7 @@ class ChunkedContextMetadata:
294
299
starts : torch .Tensor
295
300
seq_tot : list [int ]
296
301
max_seq_lens : list [int ]
302
+ seq_lens : torch .Tensor
297
303
workspace : torch .Tensor
298
304
299
305
block_table : torch .Tensor
@@ -309,6 +315,17 @@ class FlashInferPrefillMetadata(MLACommonPrefillMetadata):
309
315
default_factory = list )
310
316
311
317
318
+ @dataclass
319
+ class CudnnPrefillMetadata (MLACommonPrefillMetadata ):
320
+
321
+ class ChunkedContextMetadata (
322
+ MLACommonPrefillMetadata .ChunkedContextMetadata ):
323
+ seq_lens : torch .Tensor
324
+
325
+ query_seq_lens : Optional [torch .Tensor ] = None
326
+ cudnn_workspace : Optional [torch .Tensor ] = None
327
+
328
+
312
329
@dataclass
313
330
class MLACommonDecodeMetadata :
314
331
block_table : torch .Tensor
@@ -351,7 +368,8 @@ class MLACommonMetadata(Generic[D]):
351
368
352
369
decode : Optional [D ] = None
353
370
prefill : Optional [Union [MLACommonPrefillMetadata ,
354
- FlashInferPrefillMetadata ]] = None
371
+ FlashInferPrefillMetadata ,
372
+ CudnnPrefillMetadata ]] = None
355
373
356
374
def __post_init__ (self ):
357
375
if self .head_dim is not None :
@@ -362,13 +380,19 @@ def __post_init__(self):
362
380
363
381
364
382
def use_flashinfer_prefill () -> bool :
365
- if flashinfer_available :
383
+ if flashinfer_available and not envs . VLLM_USE_CUDNN_PREFILL :
366
384
# For blackwell default to flashinfer prefill if its available since
367
385
# its faster than FA2.
368
386
return current_platform .has_device_capability (100 )
369
387
return False
370
388
371
389
390
+ def use_cudnn_prefill () -> bool :
391
+ if flashinfer_available and envs .VLLM_USE_CUDNN_PREFILL :
392
+ return current_platform .has_device_capability (100 )
393
+ return False
394
+
395
+
372
396
# Currently 394MB, this can be tuned based on GEMM sizes used.
373
397
# Choosen to be the same as sglang:
374
398
# https://github.com/sgl-project/sglang/blob/766392c6bda2558b61ce6d1c1bfd8081a549e1f1/python/sglang/global_config.py#L37
@@ -427,11 +451,15 @@ def __init__(self,
427
451
dtype = model_config .dtype ,
428
452
device = runner .device ,
429
453
)
454
+
430
455
self .block_table = block_table
431
456
457
+ self ._use_cudnn_prefill = use_cudnn_prefill ()
432
458
self ._use_fi_prefill = use_flashinfer_prefill ()
433
- self .prefill_metadata_cls = FlashInferPrefillMetadata \
434
- if self ._use_fi_prefill else MLACommonPrefillMetadata
459
+ self .prefill_metadata_cls = (
460
+ FlashInferPrefillMetadata
461
+ if self ._use_fi_prefill else CudnnPrefillMetadata
462
+ if self ._use_cudnn_prefill else MLACommonPrefillMetadata )
435
463
436
464
if self ._use_fi_prefill :
437
465
self ._workspace_buffer = torch .empty (
@@ -447,6 +475,13 @@ def __init__(self,
447
475
self ._global_hyperparameters = infer_global_hyperparameters (
448
476
get_per_layer_parameters (runner .vllm_config , MLACommonImpl ))
449
477
478
+ if self ._use_cudnn_prefill :
479
+ self .cudnn_workspace = torch .empty (
480
+ CUDNN_WORKSPACE_SIZE * scheduler_config .max_num_seqs ,
481
+ dtype = torch .int8 ,
482
+ device = runner .device ,
483
+ )
484
+
450
485
def _build_fi_prefill_wrappers (self , prefill : FlashInferPrefillMetadata ):
451
486
qo_indptr = prefill .query_start_loc
452
487
@@ -692,15 +727,24 @@ def build(self, common_prefix_len: int,
692
727
out = cu_seq_lens_cpu [:, 1 :],
693
728
dtype = torch .int32 )
694
729
730
+ chunked_context_metadata_cls = \
731
+ CudnnPrefillMetadata .ChunkedContextMetadata \
732
+ if self ._use_cudnn_prefill else \
733
+ MLACommonPrefillMetadata .ChunkedContextMetadata
734
+
695
735
chunked_context_metadata = \
696
- MLACommonPrefillMetadata . ChunkedContextMetadata (
736
+ chunked_context_metadata_cls (
697
737
cu_seq_lens = cu_seq_lens_cpu .to (device , non_blocking = True ),
698
738
starts = chunk_starts .to (device , non_blocking = True ),
699
739
seq_tot = chunk_seq_lens .sum (dim = 1 ).tolist (),
700
740
max_seq_lens = chunk_seq_lens .max (dim = 1 ).values .tolist (),
741
+ seq_lens = chunk_seq_lens ,
701
742
workspace = self .chunked_prefill_workspace ,
702
743
)
703
744
745
+ if self ._use_cudnn_prefill :
746
+ chunked_context_metadata .seq_lens = chunk_seq_lens
747
+
704
748
assert max (chunked_context_metadata .max_seq_lens ) <= \
705
749
self .chunked_prefill_workspace_size
706
750
@@ -711,6 +755,12 @@ def build(self, common_prefix_len: int,
711
755
chunked_context = chunked_context_metadata ,
712
756
)
713
757
758
+ if self ._use_cudnn_prefill :
759
+ assert isinstance (prefill_metadata , CudnnPrefillMetadata )
760
+ prefill_metadata .query_seq_lens = prefill_query_start_loc [1 :] \
761
+ - prefill_query_start_loc [:- 1 ]
762
+ prefill_metadata .cudnn_workspace = self .cudnn_workspace
763
+
714
764
decode_metadata = None
715
765
if self ._num_decodes > 0 :
716
766
decode_metadata = self ._build_decode (
@@ -794,6 +844,12 @@ def __init__(
794
844
self ._run_prefill_context_chunk = self ._run_prefill_context_chunk_fi
795
845
self ._run_prefill_new_tokens = self ._run_prefill_new_tokens_fi
796
846
self ._pad_v = False
847
+ elif use_cudnn_prefill ():
848
+ logger .debug_once ("Using CUDNN prefill for MLA" )
849
+ self ._run_prefill_context_chunk = \
850
+ self ._run_prefill_context_chunk_cudnn
851
+ self ._run_prefill_new_tokens = self ._run_prefill_new_tokens_cudnn
852
+ self ._pad_v = False
797
853
else : # Use FlashAttention
798
854
logger .debug_once ("Using FlashAttention prefill for MLA" )
799
855
self ._run_prefill_context_chunk = self ._run_prefill_context_chunk_fa
@@ -882,6 +938,29 @@ def _run_prefill_new_tokens_fi(self, prefill: MLACommonPrefillMetadata, q,
882
938
return_lse = return_softmax_lse ,
883
939
)
884
940
941
+ def _run_prefill_new_tokens_cudnn (self , prefill : MLACommonPrefillMetadata ,
942
+ q , k , v , return_softmax_lse ):
943
+ assert isinstance (prefill , CudnnPrefillMetadata )
944
+ assert prefill .query_seq_lens is not None
945
+ output , lse = cudnn_batch_prefill_with_kv_cache (
946
+ q = q ,
947
+ k_cache = k ,
948
+ v_cache = v ,
949
+ scale = self .scale ,
950
+ workspace_buffer = prefill .cudnn_workspace ,
951
+ max_token_per_sequence = prefill .max_query_len ,
952
+ max_sequence_kv = prefill .max_query_len ,
953
+ actual_seq_lens_q = prefill .query_seq_lens .view (- 1 , 1 , 1 , 1 ),
954
+ actual_seq_lens_kv = prefill .query_seq_lens .view (- 1 , 1 , 1 , 1 ),
955
+ causal = True ,
956
+ return_lse = True , # do not support False for now
957
+ is_cuda_graph_compatible =
958
+ True , #Indicates actual_seq_lens are on GPU or CPU.
959
+ )
960
+ if return_softmax_lse :
961
+ return output , lse
962
+ return output
963
+
885
964
def _run_prefill_context_chunk_fa (self , prefill : MLACommonPrefillMetadata ,
886
965
chunk_idx : int , q , k , v ):
887
966
assert prefill .chunked_context is not None
@@ -908,6 +987,30 @@ def _run_prefill_context_chunk_fi(self, prefill: MLACommonPrefillMetadata,
908
987
return_lse = True ,
909
988
)
910
989
990
+ def _run_prefill_context_chunk_cudnn (self ,
991
+ prefill : MLACommonPrefillMetadata ,
992
+ chunk_idx : int , q , k , v ):
993
+ assert isinstance (prefill , CudnnPrefillMetadata )
994
+ assert prefill .chunked_context is not None
995
+ assert prefill .chunked_context .seq_lens [chunk_idx ] is not None
996
+ assert prefill .query_seq_lens is not None
997
+ return cudnn_batch_prefill_with_kv_cache (
998
+ q = q ,
999
+ k_cache = k ,
1000
+ v_cache = v ,
1001
+ scale = self .scale ,
1002
+ workspace_buffer = prefill .cudnn_workspace ,
1003
+ max_token_per_sequence = prefill .max_query_len ,
1004
+ max_sequence_kv = prefill .chunked_context .max_seq_lens [chunk_idx ],
1005
+ actual_seq_lens_q = prefill .query_seq_lens .view (- 1 , 1 , 1 , 1 ),
1006
+ actual_seq_lens_kv = prefill .chunked_context .seq_lens [chunk_idx ].
1007
+ view (- 1 , 1 , 1 , 1 ),
1008
+ causal = False ,
1009
+ return_lse = True ,
1010
+ is_cuda_graph_compatible =
1011
+ True , #Indicates actual_seq_lens are on GPU or CPU.
1012
+ )
1013
+
911
1014
def _v_up_proj (self , x ):
912
1015
# Convert from (B, N, L) to (N, B, L)
913
1016
x = x .view (- 1 , self .num_heads , self .kv_lora_rank ).transpose (0 , 1 )
0 commit comments