194
194
195
195
import torch
196
196
197
+ import vllm .envs as envs
197
198
from vllm import _custom_ops as ops
198
199
from vllm .attention .backends .abstract import (AttentionBackend , AttentionLayer ,
199
200
AttentionMetadata ,
228
229
229
230
logger = init_logger (__name__ )
230
231
232
+ CUDNN_SUPPORTED_HEAD_DIMS = [192 , 128 ]
233
+ CUDNN_WORKSPACE_SIZE = 12800
234
+
231
235
232
236
class MLACommonBackend (AttentionBackend ):
233
237
@@ -282,11 +286,14 @@ class ChunkedContextMetadata:
282
286
starts : torch .Tensor
283
287
seq_tot : list [int ]
284
288
max_seq_lens : list [int ]
289
+ seq_lens : torch .Tensor
285
290
workspace : torch .Tensor
286
291
287
292
block_table : torch .Tensor
288
293
query_start_loc : torch .Tensor
294
+ query_seq_lens : torch .Tensor
289
295
max_query_len : int
296
+ workspace : torch .Tensor
290
297
chunked_context : Optional [ChunkedContextMetadata ] = None
291
298
292
299
@@ -390,6 +397,12 @@ def __init__(self,
390
397
dtype = model_config .dtype ,
391
398
device = runner .device ,
392
399
)
400
+ self .workspace = torch .empty (
401
+ CUDNN_WORKSPACE_SIZE * scheduler_config .max_num_seqs ,
402
+ dtype = torch .int8 ,
403
+ device = runner .device ,
404
+ )
405
+
393
406
self .block_table = block_table
394
407
395
408
def reorder_batch (self , input_batch : "InputBatch" ,
@@ -566,6 +579,7 @@ def build(self, common_prefix_len: int,
566
579
starts = chunk_starts .to (device , non_blocking = True ),
567
580
seq_tot = chunk_seq_lens .sum (dim = 1 ).tolist (),
568
581
max_seq_lens = chunk_seq_lens .max (dim = 1 ).values .tolist (),
582
+ seq_lens = chunk_seq_lens ,
569
583
workspace = self .chunked_prefill_workspace ,
570
584
)
571
585
@@ -576,6 +590,9 @@ def build(self, common_prefix_len: int,
576
590
block_table = block_table_tensor [reqs_start :, ...],
577
591
query_start_loc = prefill_query_start_loc ,
578
592
max_query_len = max_query_len ,
593
+ workspace = self .workspace ,
594
+ query_seq_lens = prefill_query_start_loc [1 :] -
595
+ prefill_query_start_loc [:- 1 ],
579
596
chunked_context = chunked_context_metadata ,
580
597
)
581
598
@@ -663,9 +680,10 @@ def __init__(
663
680
# v with 0s to match the qk head dim for attention backends that do
664
681
# not support different headdims
665
682
# We don't need to pad V if we are on a hopper system with FA3
666
- self ._pad_v = self .vllm_flash_attn_version is None or not (
667
- self .vllm_flash_attn_version == 3
668
- and current_platform .get_device_capability ()[0 ] == 9 )
683
+ self ._pad_v = not envs .VLLM_USE_CUDNN_PREFILL and (
684
+ self .vllm_flash_attn_version is None
685
+ or not (self .vllm_flash_attn_version == 3
686
+ and current_platform .get_device_capability ()[0 ] == 9 ))
669
687
670
688
def _flash_attn_varlen_diff_headdims (self ,
671
689
q ,
@@ -705,6 +723,40 @@ def _flash_attn_varlen_diff_headdims(self,
705
723
return attn_out , lse
706
724
return attn_out
707
725
726
+ def _cudnn_varlen_func_diff_headdims (
727
+ self ,
728
+ q ,
729
+ k ,
730
+ v ,
731
+ scale ,
732
+ workspace ,
733
+ max_q_seq_lens ,
734
+ max_kv_seq_lens ,
735
+ seq_lens_q ,
736
+ seq_lens_kv ,
737
+ causal ,
738
+ is_cuda_graph_compatible = True ,
739
+ ):
740
+ from flashinfer .prefill import cudnn_batch_prefill_with_kv_cache
741
+
742
+ if not is_cuda_graph_compatible :
743
+ seq_lens_q = seq_lens_q .to ("cpu" )
744
+ seq_lens_kv = seq_lens_kv .to ("cpu" )
745
+ return cudnn_batch_prefill_with_kv_cache (
746
+ q = q ,
747
+ k_cache = k ,
748
+ v_cache = v ,
749
+ scale = scale ,
750
+ workspace_buffer = workspace ,
751
+ max_token_per_sequence = max_q_seq_lens ,
752
+ max_sequence_kv = max_kv_seq_lens ,
753
+ actual_seq_lens_q = seq_lens_q .view (- 1 , 1 , 1 , 1 ),
754
+ actual_seq_lens_kv = seq_lens_kv .view (- 1 , 1 , 1 , 1 ),
755
+ causal = causal ,
756
+ return_lse = True ,
757
+ is_cuda_graph_compatible = is_cuda_graph_compatible ,
758
+ )
759
+
708
760
def _v_up_proj (self , x ):
709
761
# Convert from (B, N, L) to (N, B, L)
710
762
x = x .view (- 1 , self .num_heads , self .kv_lora_rank ).transpose (0 , 1 )
@@ -803,19 +855,41 @@ def _compute_prefill_context(
803
855
k = torch .cat ((k_nope , k_pe .expand ((* k_nope .shape [:- 1 ], - 1 ))),
804
856
dim = - 1 )
805
857
806
- attn_output , attn_softmax_lse = \
807
- self ._flash_attn_varlen_diff_headdims (
808
- q = q ,
809
- k = k ,
810
- v = v ,
811
- cu_seqlens_q = prefill_metadata .query_start_loc ,
812
- cu_seqlens_k = prefill_metadata .chunked_context .cu_seq_lens [i ],
813
- max_seqlen_q = prefill_metadata .max_query_len ,
814
- max_seqlen_k = prefill_metadata .chunked_context .max_seq_lens [i ],
815
- softmax_scale = self .scale ,
816
- causal = False , # Context is unmasked
817
- return_softmax_lse = True ,
818
- )
858
+ if envs .VLLM_USE_CUDNN_PREFILL and all (
859
+ t .shape [- 1 ] in CUDNN_SUPPORTED_HEAD_DIMS
860
+ for t in (q , k , v )):
861
+ attn_output , attn_softmax_lse = (
862
+ self ._cudnn_varlen_func_diff_headdims (
863
+ q ,
864
+ k ,
865
+ v ,
866
+ scale = self .scale ,
867
+ workspace = prefill_metadata .workspace ,
868
+ max_q_seq_lens = prefill_metadata .max_query_len ,
869
+ max_kv_seq_lens = prefill_metadata .chunked_context .
870
+ max_seq_lens [i ],
871
+ seq_lens_q = prefill_metadata .query_seq_lens .view (
872
+ - 1 , 1 , 1 , 1 ),
873
+ seq_lens_kv = prefill_metadata .chunked_context .
874
+ seq_lens [i ].view (- 1 , 1 , 1 , 1 ),
875
+ causal = False ,
876
+ is_cuda_graph_compatible =
877
+ True , #Indicates actual_seq_lens are on GPU or CPU.
878
+ ))
879
+ else :
880
+ attn_output , attn_softmax_lse = \
881
+ self ._flash_attn_varlen_diff_headdims (
882
+ q = q ,
883
+ k = k ,
884
+ v = v ,
885
+ cu_seqlens_q = prefill_metadata .query_start_loc ,
886
+ cu_seqlens_k = prefill_metadata .chunked_context .cu_seq_lens [i ],
887
+ max_seqlen_q = prefill_metadata .max_query_len ,
888
+ max_seqlen_k = prefill_metadata .chunked_context .max_seq_lens [i ],
889
+ softmax_scale = self .scale ,
890
+ causal = False , # Context is unmasked
891
+ return_softmax_lse = True ,
892
+ )
819
893
820
894
if output is None :
821
895
output = attn_output
@@ -854,18 +928,39 @@ def _forward_prefill(
854
928
855
929
k = torch .cat ((k_nope , k_pe .expand ((* k_nope .shape [:- 1 ], - 1 ))), dim = - 1 )
856
930
857
- output = self ._flash_attn_varlen_diff_headdims (
858
- q = q ,
859
- k = k ,
860
- v = v ,
861
- cu_seqlens_q = attn_metadata .prefill .query_start_loc ,
862
- cu_seqlens_k = attn_metadata .prefill .query_start_loc ,
863
- max_seqlen_q = attn_metadata .prefill .max_query_len ,
864
- max_seqlen_k = attn_metadata .prefill .max_query_len ,
865
- softmax_scale = self .scale ,
866
- causal = True ,
867
- return_softmax_lse = has_context ,
868
- )
931
+ if envs .VLLM_USE_CUDNN_PREFILL and all (
932
+ t .shape [- 1 ] in CUDNN_SUPPORTED_HEAD_DIMS for t in (q , k , v )):
933
+ output = self ._cudnn_varlen_func_diff_headdims (
934
+ q ,
935
+ k ,
936
+ v ,
937
+ scale = self .scale ,
938
+ workspace = attn_metadata .prefill .workspace ,
939
+ max_q_seq_lens = attn_metadata .prefill .max_query_len ,
940
+ max_kv_seq_lens = attn_metadata .prefill .max_query_len ,
941
+ seq_lens_q = attn_metadata .prefill .query_seq_lens .view (
942
+ - 1 , 1 , 1 , 1 ),
943
+ seq_lens_kv = attn_metadata .prefill .query_seq_lens .view (
944
+ - 1 , 1 , 1 , 1 ),
945
+ causal = True ,
946
+ is_cuda_graph_compatible =
947
+ True , #Indicates actual_seq_lens are on GPU or CPU.
948
+ )
949
+ if not has_context :
950
+ output = output [0 ]
951
+ else :
952
+ output = self ._flash_attn_varlen_diff_headdims (
953
+ q = q ,
954
+ k = k ,
955
+ v = v ,
956
+ cu_seqlens_q = attn_metadata .prefill .query_start_loc ,
957
+ cu_seqlens_k = attn_metadata .prefill .query_start_loc ,
958
+ max_seqlen_q = attn_metadata .prefill .max_query_len ,
959
+ max_seqlen_k = attn_metadata .prefill .max_query_len ,
960
+ softmax_scale = self .scale ,
961
+ causal = True ,
962
+ return_softmax_lse = has_context ,
963
+ )
869
964
870
965
if has_context :
871
966
suffix_output , suffix_lse = output
0 commit comments