207
207
UnquantizedLinearMethod )
208
208
from vllm .platforms import current_platform
209
209
from vllm .utils import cdiv , round_down
210
+ # yapf conflicts with isort for this block
211
+ # yapf: disable
212
+ from vllm .v1 .attention .backends .flashinfer import (
213
+ get_per_layer_parameters , infer_global_hyperparameters )
214
+ # yapf: enable
210
215
from vllm .v1 .attention .backends .utils import (AttentionMetadataBuilder ,
211
216
CommonAttentionMetadata )
212
217
from vllm .v1 .kv_cache_interface import AttentionSpec
225
230
from vllm .v1 .worker .gpu_input_batch import InputBatch
226
231
from vllm .v1 .worker .gpu_model_runner import GPUModelRunner
227
232
233
+ from flashinfer import BatchPrefillWithRaggedKVCacheWrapper
234
+ from flashinfer .utils import is_sm100a_supported
235
+
228
236
logger = init_logger (__name__ )
229
237
230
238
@@ -278,6 +286,12 @@ class ChunkedContextMetadata:
278
286
chunked_context : Optional [ChunkedContextMetadata ] = None
279
287
280
288
289
+ @dataclass
290
+ class FlashInferPrefillMetadata :
291
+ prefill_main : Optional [BatchPrefillWithRaggedKVCacheWrapper ]
292
+ prefill_chunks : list [BatchPrefillWithRaggedKVCacheWrapper ]
293
+
294
+
281
295
@dataclass
282
296
class MLACommonDecodeMetadata :
283
297
block_table : torch .Tensor
@@ -317,6 +331,7 @@ class MLACommonMetadata(Generic[D]):
317
331
318
332
decode : Optional [D ] = None
319
333
prefill : Optional [MLACommonPrefillMetadata ] = None
334
+ fi_prefill : Optional [FlashInferPrefillMetadata ] = None
320
335
321
336
def __post_init__ (self ):
322
337
supported_head_sizes = MLACommonBackend .get_supported_head_sizes ()
@@ -330,6 +345,43 @@ def __post_init__(self):
330
345
M = TypeVar ("M" , bound = MLACommonMetadata )
331
346
332
347
348
+ def use_flashinfer_prefill () -> bool :
349
+ return is_sm100a_supported (torch .device ("cuda" ))
350
+
351
+
352
+ # Currently 394MB, this can be tuned based on GEMM sizes used.
353
+ FLASHINFER_WORKSPACE_BUFFER_SIZE = 394 * 1024 * 1024
354
+
355
+
356
+ class FlashInferPrefill :
357
+
358
+ def __init__ (self , runner ):
359
+ self ._device = runner .device
360
+ self ._workspace_buffer = None
361
+ self ._global_hyperparameters = infer_global_hyperparameters (
362
+ get_per_layer_parameters (runner .vllm_config ))
363
+
364
+ def get_global_hyperparameters (self ):
365
+ return self ._global_hyperparameters
366
+
367
+ def _get_workspace_buffer (self ) -> torch .Tensor :
368
+ # Note that this maintains a single workspace buffer that is reused
369
+ # for all prefill executions.
370
+ if self ._workspace_buffer is None :
371
+ self ._workspace_buffer = torch .empty (
372
+ FLASHINFER_WORKSPACE_BUFFER_SIZE ,
373
+ dtype = torch .uint8 ,
374
+ device = self ._device )
375
+ return self ._workspace_buffer
376
+
377
+ def get_ragged_prefill (self ) -> BatchPrefillWithRaggedKVCacheWrapper :
378
+ # Notes:
379
+ # 1. kv_layout used is NHD
380
+ # 2. Force "cutlass" backend that runs new NVIDIA's B200 kernel
381
+ return BatchPrefillWithRaggedKVCacheWrapper (
382
+ self ._get_workspace_buffer (), "NHD" , backend = "cutlass" )
383
+
384
+
333
385
class MLACommonMetadataBuilder (AttentionMetadataBuilder [M ]):
334
386
"""
335
387
NOTE: Please read the comment at the top of the file before trying to
@@ -384,6 +436,106 @@ def __init__(self,
384
436
)
385
437
self .block_table = block_table
386
438
439
+ self ._use_fi_prefill = use_flashinfer_prefill ()
440
+
441
+ if self ._use_fi_prefill :
442
+ self ._fi_prefill = FlashInferPrefill (self .runner )
443
+ self ._fi_prefill_main : Optional [
444
+ BatchPrefillWithRaggedKVCacheWrapper ] = None
445
+ self ._fi_prefill_chunks : list [
446
+ BatchPrefillWithRaggedKVCacheWrapper ] = []
447
+
448
+ def _get_fi_prefill_main (self ) -> BatchPrefillWithRaggedKVCacheWrapper :
449
+ if self ._fi_prefill_main is None :
450
+ self ._fi_prefill_main = self ._fi_prefill .get_ragged_prefill ()
451
+
452
+ return self ._fi_prefill_main
453
+
454
+ def _get_fi_prefill_chunks (
455
+ self , num_chunks ) -> list [BatchPrefillWithRaggedKVCacheWrapper ]:
456
+ if len (self ._fi_prefill_chunks ) < num_chunks :
457
+ for _ in range (len (self ._fi_prefill_chunks ), num_chunks ):
458
+ self ._fi_prefill_chunks .append (
459
+ self ._fi_prefill .get_ragged_prefill ())
460
+
461
+ return self ._fi_prefill_chunks
462
+
463
+ def _build_fi_prefill (self , common_attn_metadata : CommonAttentionMetadata ,
464
+ attn_metadata : MLACommonMetadata ):
465
+ assert attn_metadata .prefill is not None
466
+ qo_indptr = attn_metadata .prefill .query_start_loc
467
+
468
+ has_context = False
469
+ if attn_metadata .prefill .chunked_context is not None :
470
+ chunked_context = attn_metadata .prefill .chunked_context
471
+ has_context = True
472
+
473
+ prefill_main = self ._get_fi_prefill_main ()
474
+
475
+ prefill_chunks = []
476
+ if has_context :
477
+ num_chunks = chunked_context .cu_seq_lens .shape [0 ]
478
+ prefill_chunks = self ._get_fi_prefill_chunks (num_chunks )
479
+ assert len (prefill_chunks ) == num_chunks
480
+
481
+ # In MLA, the non-latent num_qo_heads == num_kv_heads
482
+ num_qo_heads = self .runner .num_query_heads
483
+ num_kv_heads = num_qo_heads
484
+
485
+ # Sanity: Verify that num_kv_heads == 1 since it is latent space
486
+ assert self .kv_cache_spec .num_kv_heads == 1
487
+
488
+ # Get non-latent head_dim_qk and head_dim_vo
489
+ head_dim_qk = (self .mla_dims .qk_nope_head_dim +
490
+ self .mla_dims .qk_rope_head_dim )
491
+ head_dim_vo = self .mla_dims .v_head_dim
492
+
493
+ global_hyperparameters = self ._fi_prefill .get_global_hyperparameters ()
494
+
495
+ # For main run, qo_indptr == kv_indptr
496
+ kv_indptr = qo_indptr .clone ()
497
+
498
+ # Prepare main prefill
499
+ prefill_main .plan (
500
+ qo_indptr = qo_indptr ,
501
+ kv_indptr = kv_indptr ,
502
+ num_qo_heads = num_qo_heads ,
503
+ num_kv_heads = num_kv_heads ,
504
+ head_dim_qk = head_dim_qk ,
505
+ head_dim_vo = head_dim_vo ,
506
+ causal = True , # This is main run
507
+ sm_scale = global_hyperparameters .sm_scale ,
508
+ window_left = global_hyperparameters .window_left ,
509
+ logits_soft_cap = global_hyperparameters .logits_soft_cap ,
510
+ q_data_type = self .runner .dtype ,
511
+ kv_data_type = self .kv_cache_spec .dtype ,
512
+ )
513
+
514
+ # Prepare context prefills
515
+ if has_context :
516
+ for i in range (num_chunks ):
517
+ kv_indptr_chunk = chunked_context .cu_seq_lens [i ]
518
+
519
+ prefill_chunks [i ].plan (
520
+ qo_indptr = qo_indptr ,
521
+ kv_indptr = kv_indptr_chunk ,
522
+ num_qo_heads = num_qo_heads ,
523
+ num_kv_heads = num_kv_heads ,
524
+ head_dim_qk = head_dim_qk ,
525
+ head_dim_vo = head_dim_vo ,
526
+ causal = False , # This is context run
527
+ sm_scale = global_hyperparameters .sm_scale ,
528
+ window_left = global_hyperparameters .window_left ,
529
+ logits_soft_cap = global_hyperparameters .logits_soft_cap ,
530
+ q_data_type = self .runner .dtype ,
531
+ kv_data_type = self .kv_cache_spec .dtype ,
532
+ )
533
+
534
+ attn_metadata .fi_prefill = FlashInferPrefillMetadata (
535
+ prefill_main = prefill_main ,
536
+ prefill_chunks = prefill_chunks ,
537
+ )
538
+
387
539
def reorder_batch (self , input_batch : "InputBatch" ,
388
540
scheduler_output : "SchedulerOutput" ) -> bool :
389
541
# We now want to reorder the batch so that the "decode" requests are and
@@ -578,7 +730,7 @@ def build(self, common_prefix_len: int,
578
730
seq_lens = seq_lens [:self ._num_decodes ],
579
731
)
580
732
581
- return self .metadata_cls (
733
+ attn_metadata = self .metadata_cls (
582
734
num_actual_tokens = num_actual_tokens ,
583
735
query_start_loc = query_start_loc ,
584
736
slot_mapping = slot_mapping ,
@@ -591,6 +743,11 @@ def build(self, common_prefix_len: int,
591
743
decode = decode_metadata ,
592
744
)
593
745
746
+ if self ._use_fi_prefill and self ._num_prefills > 0 :
747
+ self ._build_fi_prefill (common_attn_metadata , attn_metadata )
748
+
749
+ return attn_metadata
750
+
594
751
def can_run_in_cudagraph (
595
752
self , common_attn_metadata : CommonAttentionMetadata ) -> bool :
596
753
return common_attn_metadata .max_query_len == 1
@@ -660,6 +817,20 @@ def __init__(
660
817
self .vllm_flash_attn_version == 3
661
818
and current_platform .get_device_capability ()[0 ] == 9 )
662
819
820
+ # Determine if FlashInfer prefill is used
821
+ self ._use_fi_prefill = use_flashinfer_prefill ()
822
+ if self ._use_fi_prefill :
823
+ # Do not use v padding when flashinfer prefill is enabled.
824
+ self ._pad_v = False
825
+
826
+ # Hyper params for layers
827
+ if sliding_window is None :
828
+ self .sliding_window = (- 1 , - 1 )
829
+ else :
830
+ self .sliding_window = (sliding_window - 1 , 0 )
831
+
832
+ self .logits_soft_cap = logits_soft_cap
833
+
663
834
def _flash_attn_varlen_diff_headdims (self ,
664
835
q ,
665
836
k ,
@@ -692,6 +863,27 @@ def _flash_attn_varlen_diff_headdims(self,
692
863
return attn_out , lse
693
864
return attn_out
694
865
866
+ def _run_fi_prefill (self , prefill_wrapper , q , k , v , return_softmax_lse ):
867
+ assert not self ._pad_v
868
+
869
+ attn_out = prefill_wrapper .run (
870
+ q ,
871
+ k ,
872
+ v ,
873
+ return_lse = return_softmax_lse ,
874
+ )
875
+
876
+ # Unpack the output if there is multiple results
877
+ lse = None
878
+ if isinstance (attn_out , tuple ):
879
+ attn_out , lse = attn_out [0 ], attn_out [1 ]
880
+
881
+ # Remain consistent with old `flash_attn_varlen_func` where there
882
+ # is only one output tensor if `return_softmax_lse` is False.
883
+ if return_softmax_lse :
884
+ return attn_out , lse
885
+ return attn_out
886
+
695
887
def _v_up_proj (self , x ):
696
888
# Convert from (B, N, L) to (N, B, L)
697
889
x = x .view (- 1 , self .num_heads , self .kv_lora_rank ).transpose (0 , 1 )
@@ -790,19 +982,32 @@ def _compute_prefill_context(
790
982
k = torch .cat ((k_nope , k_pe .expand ((* k_nope .shape [:- 1 ], - 1 ))),
791
983
dim = - 1 )
792
984
793
- attn_output , attn_softmax_lse = \
794
- self ._flash_attn_varlen_diff_headdims (
795
- q = q ,
796
- k = k ,
797
- v = v ,
798
- cu_seqlens_q = prefill_metadata .query_start_loc ,
799
- cu_seqlens_k = prefill_metadata .chunked_context .cu_seq_lens [i ],
800
- max_seqlen_q = prefill_metadata .max_query_len ,
801
- max_seqlen_k = prefill_metadata .chunked_context .max_seq_lens [i ],
802
- softmax_scale = self .scale ,
803
- causal = False , # Context is unmasked
804
- return_softmax_lse = True ,
805
- )
985
+ if self ._use_fi_prefill :
986
+ assert attn_metadata .fi_prefill is not None
987
+
988
+ attn_output , attn_softmax_lse = self ._run_fi_prefill (
989
+ prefill_wrapper = attn_metadata .fi_prefill .prefill_chunks [i ],
990
+ q = q ,
991
+ k = k ,
992
+ v = v ,
993
+ return_softmax_lse = True ,
994
+ )
995
+ else :
996
+ attn_output , attn_softmax_lse = \
997
+ self ._flash_attn_varlen_diff_headdims (
998
+ q = q ,
999
+ k = k ,
1000
+ v = v ,
1001
+ cu_seqlens_q = prefill_metadata .query_start_loc ,
1002
+ cu_seqlens_k = prefill_metadata .chunked_context .
1003
+ cu_seq_lens [i ],
1004
+ max_seqlen_q = prefill_metadata .max_query_len ,
1005
+ max_seqlen_k = prefill_metadata .chunked_context .
1006
+ max_seq_lens [i ],
1007
+ softmax_scale = self .scale ,
1008
+ causal = False , # Context is unmasked
1009
+ return_softmax_lse = True ,
1010
+ )
806
1011
807
1012
if output is None :
808
1013
output = attn_output
@@ -841,18 +1046,36 @@ def _forward_prefill(
841
1046
842
1047
k = torch .cat ((k_nope , k_pe .expand ((* k_nope .shape [:- 1 ], - 1 ))), dim = - 1 )
843
1048
844
- output = self ._flash_attn_varlen_diff_headdims (
845
- q = q ,
846
- k = k ,
847
- v = v ,
848
- cu_seqlens_q = attn_metadata .prefill .query_start_loc ,
849
- cu_seqlens_k = attn_metadata .prefill .query_start_loc ,
850
- max_seqlen_q = attn_metadata .prefill .max_query_len ,
851
- max_seqlen_k = attn_metadata .prefill .max_query_len ,
852
- softmax_scale = self .scale ,
853
- causal = True ,
854
- return_softmax_lse = has_context ,
855
- )
1049
+ # print("_forward_prefill")
1050
+ # print(" q.shape = {}".format(q.shape))
1051
+ # print(" k.shape = {}".format(k.shape))
1052
+ # print(" v.shape = {}".format(v.shape))
1053
+ # print(" has_context = {}".format(has_context))
1054
+ # print(" use_fi_prefill = {}".format(self._use_fi_prefill))
1055
+
1056
+ if self ._use_fi_prefill :
1057
+ assert attn_metadata .fi_prefill is not None
1058
+
1059
+ output = self ._run_fi_prefill (
1060
+ prefill_wrapper = attn_metadata .fi_prefill .prefill_main ,
1061
+ q = q ,
1062
+ k = k ,
1063
+ v = v ,
1064
+ return_softmax_lse = has_context ,
1065
+ )
1066
+ else :
1067
+ output = self ._flash_attn_varlen_diff_headdims (
1068
+ q = q ,
1069
+ k = k ,
1070
+ v = v ,
1071
+ cu_seqlens_q = attn_metadata .prefill .query_start_loc ,
1072
+ cu_seqlens_k = attn_metadata .prefill .query_start_loc ,
1073
+ max_seqlen_q = attn_metadata .prefill .max_query_len ,
1074
+ max_seqlen_k = attn_metadata .prefill .max_query_len ,
1075
+ softmax_scale = self .scale ,
1076
+ causal = True ,
1077
+ return_softmax_lse = has_context ,
1078
+ )
856
1079
857
1080
if has_context :
858
1081
suffix_output , suffix_lse = output
@@ -895,7 +1118,6 @@ def forward(
895
1118
output : Optional [torch .Tensor ] = None ,
896
1119
output_scale : Optional [torch .Tensor ] = None ,
897
1120
) -> torch .Tensor :
898
-
899
1121
assert output is not None , "Output tensor must be provided."
900
1122
901
1123
if output_scale is not None :
0 commit comments