Skip to content

Commit e52c4f1

Browse files
committed
MLA FlashInfer Ragged Prefill Support
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
1 parent 5e5baa9 commit e52c4f1

File tree

3 files changed

+250
-28
lines changed

3 files changed

+250
-28
lines changed

vllm/v1/attention/backends/flashinfer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ def get_per_layer_parameters(
9191

9292
for key, layer in layers.items():
9393
impl = layer.impl
94-
assert isinstance(impl, FlashInferImpl)
9594

9695
# Infer hyperparameters from the attention layer
9796
window_size = impl.sliding_window

vllm/v1/attention/backends/mla/common.py

Lines changed: 249 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,11 @@
207207
UnquantizedLinearMethod)
208208
from vllm.platforms import current_platform
209209
from vllm.utils import cdiv, round_down
210+
# yapf conflicts with isort for this block
211+
# yapf: disable
212+
from vllm.v1.attention.backends.flashinfer import (
213+
get_per_layer_parameters, infer_global_hyperparameters)
214+
# yapf: enable
210215
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
211216
CommonAttentionMetadata)
212217
from vllm.v1.kv_cache_interface import AttentionSpec
@@ -225,6 +230,9 @@
225230
from vllm.v1.worker.gpu_input_batch import InputBatch
226231
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
227232

233+
from flashinfer import BatchPrefillWithRaggedKVCacheWrapper
234+
from flashinfer.utils import is_sm100a_supported
235+
228236
logger = init_logger(__name__)
229237

230238

@@ -278,6 +286,12 @@ class ChunkedContextMetadata:
278286
chunked_context: Optional[ChunkedContextMetadata] = None
279287

280288

289+
@dataclass
290+
class FlashInferPrefillMetadata:
291+
prefill_main: Optional[BatchPrefillWithRaggedKVCacheWrapper]
292+
prefill_chunks: list[BatchPrefillWithRaggedKVCacheWrapper]
293+
294+
281295
@dataclass
282296
class MLACommonDecodeMetadata:
283297
block_table: torch.Tensor
@@ -317,6 +331,7 @@ class MLACommonMetadata(Generic[D]):
317331

318332
decode: Optional[D] = None
319333
prefill: Optional[MLACommonPrefillMetadata] = None
334+
fi_prefill: Optional[FlashInferPrefillMetadata] = None
320335

321336
def __post_init__(self):
322337
supported_head_sizes = MLACommonBackend.get_supported_head_sizes()
@@ -330,6 +345,43 @@ def __post_init__(self):
330345
M = TypeVar("M", bound=MLACommonMetadata)
331346

332347

348+
def use_flashinfer_prefill() -> bool:
349+
return is_sm100a_supported(torch.device("cuda"))
350+
351+
352+
# Currently 394MB, this can be tuned based on GEMM sizes used.
353+
FLASHINFER_WORKSPACE_BUFFER_SIZE = 394 * 1024 * 1024
354+
355+
356+
class FlashInferPrefill:
357+
358+
def __init__(self, runner):
359+
self._device = runner.device
360+
self._workspace_buffer = None
361+
self._global_hyperparameters = infer_global_hyperparameters(
362+
get_per_layer_parameters(runner.vllm_config))
363+
364+
def get_global_hyperparameters(self):
365+
return self._global_hyperparameters
366+
367+
def _get_workspace_buffer(self) -> torch.Tensor:
368+
# Note that this maintains a single workspace buffer that is reused
369+
# for all prefill executions.
370+
if self._workspace_buffer is None:
371+
self._workspace_buffer = torch.empty(
372+
FLASHINFER_WORKSPACE_BUFFER_SIZE,
373+
dtype=torch.uint8,
374+
device=self._device)
375+
return self._workspace_buffer
376+
377+
def get_ragged_prefill(self) -> BatchPrefillWithRaggedKVCacheWrapper:
378+
# Notes:
379+
# 1. kv_layout used is NHD
380+
# 2. Force "cutlass" backend that runs new NVIDIA's B200 kernel
381+
return BatchPrefillWithRaggedKVCacheWrapper(
382+
self._get_workspace_buffer(), "NHD", backend="cutlass")
383+
384+
333385
class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
334386
"""
335387
NOTE: Please read the comment at the top of the file before trying to
@@ -384,6 +436,106 @@ def __init__(self,
384436
)
385437
self.block_table = block_table
386438

439+
self._use_fi_prefill = use_flashinfer_prefill()
440+
441+
if self._use_fi_prefill:
442+
self._fi_prefill = FlashInferPrefill(self.runner)
443+
self._fi_prefill_main: Optional[
444+
BatchPrefillWithRaggedKVCacheWrapper] = None
445+
self._fi_prefill_chunks: list[
446+
BatchPrefillWithRaggedKVCacheWrapper] = []
447+
448+
def _get_fi_prefill_main(self) -> BatchPrefillWithRaggedKVCacheWrapper:
449+
if self._fi_prefill_main is None:
450+
self._fi_prefill_main = self._fi_prefill.get_ragged_prefill()
451+
452+
return self._fi_prefill_main
453+
454+
def _get_fi_prefill_chunks(
455+
self, num_chunks) -> list[BatchPrefillWithRaggedKVCacheWrapper]:
456+
if len(self._fi_prefill_chunks) < num_chunks:
457+
for _ in range(len(self._fi_prefill_chunks), num_chunks):
458+
self._fi_prefill_chunks.append(
459+
self._fi_prefill.get_ragged_prefill())
460+
461+
return self._fi_prefill_chunks
462+
463+
def _build_fi_prefill(self, common_attn_metadata: CommonAttentionMetadata,
464+
attn_metadata: MLACommonMetadata):
465+
assert attn_metadata.prefill is not None
466+
qo_indptr = attn_metadata.prefill.query_start_loc
467+
468+
has_context = False
469+
if attn_metadata.prefill.chunked_context is not None:
470+
chunked_context = attn_metadata.prefill.chunked_context
471+
has_context = True
472+
473+
prefill_main = self._get_fi_prefill_main()
474+
475+
prefill_chunks = []
476+
if has_context:
477+
num_chunks = chunked_context.cu_seq_lens.shape[0]
478+
prefill_chunks = self._get_fi_prefill_chunks(num_chunks)
479+
assert len(prefill_chunks) == num_chunks
480+
481+
# In MLA, the non-latent num_qo_heads == num_kv_heads
482+
num_qo_heads = self.runner.num_query_heads
483+
num_kv_heads = num_qo_heads
484+
485+
# Sanity: Verify that num_kv_heads == 1 since it is latent space
486+
assert self.kv_cache_spec.num_kv_heads == 1
487+
488+
# Get non-latent head_dim_qk and head_dim_vo
489+
head_dim_qk = (self.mla_dims.qk_nope_head_dim +
490+
self.mla_dims.qk_rope_head_dim)
491+
head_dim_vo = self.mla_dims.v_head_dim
492+
493+
global_hyperparameters = self._fi_prefill.get_global_hyperparameters()
494+
495+
# For main run, qo_indptr == kv_indptr
496+
kv_indptr = qo_indptr.clone()
497+
498+
# Prepare main prefill
499+
prefill_main.plan(
500+
qo_indptr=qo_indptr,
501+
kv_indptr=kv_indptr,
502+
num_qo_heads=num_qo_heads,
503+
num_kv_heads=num_kv_heads,
504+
head_dim_qk=head_dim_qk,
505+
head_dim_vo=head_dim_vo,
506+
causal=True, # This is main run
507+
sm_scale=global_hyperparameters.sm_scale,
508+
window_left=global_hyperparameters.window_left,
509+
logits_soft_cap=global_hyperparameters.logits_soft_cap,
510+
q_data_type=self.runner.dtype,
511+
kv_data_type=self.kv_cache_spec.dtype,
512+
)
513+
514+
# Prepare context prefills
515+
if has_context:
516+
for i in range(num_chunks):
517+
kv_indptr_chunk = chunked_context.cu_seq_lens[i]
518+
519+
prefill_chunks[i].plan(
520+
qo_indptr=qo_indptr,
521+
kv_indptr=kv_indptr_chunk,
522+
num_qo_heads=num_qo_heads,
523+
num_kv_heads=num_kv_heads,
524+
head_dim_qk=head_dim_qk,
525+
head_dim_vo=head_dim_vo,
526+
causal=False, # This is context run
527+
sm_scale=global_hyperparameters.sm_scale,
528+
window_left=global_hyperparameters.window_left,
529+
logits_soft_cap=global_hyperparameters.logits_soft_cap,
530+
q_data_type=self.runner.dtype,
531+
kv_data_type=self.kv_cache_spec.dtype,
532+
)
533+
534+
attn_metadata.fi_prefill = FlashInferPrefillMetadata(
535+
prefill_main=prefill_main,
536+
prefill_chunks=prefill_chunks,
537+
)
538+
387539
def reorder_batch(self, input_batch: "InputBatch",
388540
scheduler_output: "SchedulerOutput") -> bool:
389541
# We now want to reorder the batch so that the "decode" requests are and
@@ -578,7 +730,7 @@ def build(self, common_prefix_len: int,
578730
seq_lens=seq_lens[:self._num_decodes],
579731
)
580732

581-
return self.metadata_cls(
733+
attn_metadata = self.metadata_cls(
582734
num_actual_tokens=num_actual_tokens,
583735
query_start_loc=query_start_loc,
584736
slot_mapping=slot_mapping,
@@ -591,6 +743,11 @@ def build(self, common_prefix_len: int,
591743
decode=decode_metadata,
592744
)
593745

746+
if self._use_fi_prefill and self._num_prefills > 0:
747+
self._build_fi_prefill(common_attn_metadata, attn_metadata)
748+
749+
return attn_metadata
750+
594751
def can_run_in_cudagraph(
595752
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
596753
return common_attn_metadata.max_query_len == 1
@@ -660,6 +817,20 @@ def __init__(
660817
self.vllm_flash_attn_version == 3
661818
and current_platform.get_device_capability()[0] == 9)
662819

820+
# Determine if FlashInfer prefill is used
821+
self._use_fi_prefill = use_flashinfer_prefill()
822+
if self._use_fi_prefill:
823+
# Do not use v padding when flashinfer prefill is enabled.
824+
self._pad_v = False
825+
826+
# Hyper params for layers
827+
if sliding_window is None:
828+
self.sliding_window = (-1, -1)
829+
else:
830+
self.sliding_window = (sliding_window - 1, 0)
831+
832+
self.logits_soft_cap = logits_soft_cap
833+
663834
def _flash_attn_varlen_diff_headdims(self,
664835
q,
665836
k,
@@ -692,6 +863,27 @@ def _flash_attn_varlen_diff_headdims(self,
692863
return attn_out, lse
693864
return attn_out
694865

866+
def _run_fi_prefill(self, prefill_wrapper, q, k, v, return_softmax_lse):
867+
assert not self._pad_v
868+
869+
attn_out = prefill_wrapper.run(
870+
q,
871+
k,
872+
v,
873+
return_lse=return_softmax_lse,
874+
)
875+
876+
# Unpack the output if there is multiple results
877+
lse = None
878+
if isinstance(attn_out, tuple):
879+
attn_out, lse = attn_out[0], attn_out[1]
880+
881+
# Remain consistent with old `flash_attn_varlen_func` where there
882+
# is only one output tensor if `return_softmax_lse` is False.
883+
if return_softmax_lse:
884+
return attn_out, lse
885+
return attn_out
886+
695887
def _v_up_proj(self, x):
696888
# Convert from (B, N, L) to (N, B, L)
697889
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
@@ -790,19 +982,32 @@ def _compute_prefill_context(
790982
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
791983
dim=-1)
792984

793-
attn_output, attn_softmax_lse = \
794-
self._flash_attn_varlen_diff_headdims(
795-
q=q,
796-
k=k,
797-
v=v,
798-
cu_seqlens_q=prefill_metadata.query_start_loc,
799-
cu_seqlens_k=prefill_metadata.chunked_context.cu_seq_lens[i],
800-
max_seqlen_q=prefill_metadata.max_query_len,
801-
max_seqlen_k=prefill_metadata.chunked_context.max_seq_lens[i],
802-
softmax_scale=self.scale,
803-
causal=False, # Context is unmasked
804-
return_softmax_lse=True,
805-
)
985+
if self._use_fi_prefill:
986+
assert attn_metadata.fi_prefill is not None
987+
988+
attn_output, attn_softmax_lse = self._run_fi_prefill(
989+
prefill_wrapper=attn_metadata.fi_prefill.prefill_chunks[i],
990+
q=q,
991+
k=k,
992+
v=v,
993+
return_softmax_lse=True,
994+
)
995+
else:
996+
attn_output, attn_softmax_lse = \
997+
self._flash_attn_varlen_diff_headdims(
998+
q=q,
999+
k=k,
1000+
v=v,
1001+
cu_seqlens_q=prefill_metadata.query_start_loc,
1002+
cu_seqlens_k=prefill_metadata.chunked_context.
1003+
cu_seq_lens[i],
1004+
max_seqlen_q=prefill_metadata.max_query_len,
1005+
max_seqlen_k=prefill_metadata.chunked_context.
1006+
max_seq_lens[i],
1007+
softmax_scale=self.scale,
1008+
causal=False, # Context is unmasked
1009+
return_softmax_lse=True,
1010+
)
8061011

8071012
if output is None:
8081013
output = attn_output
@@ -841,18 +1046,36 @@ def _forward_prefill(
8411046

8421047
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
8431048

844-
output = self._flash_attn_varlen_diff_headdims(
845-
q=q,
846-
k=k,
847-
v=v,
848-
cu_seqlens_q=attn_metadata.prefill.query_start_loc,
849-
cu_seqlens_k=attn_metadata.prefill.query_start_loc,
850-
max_seqlen_q=attn_metadata.prefill.max_query_len,
851-
max_seqlen_k=attn_metadata.prefill.max_query_len,
852-
softmax_scale=self.scale,
853-
causal=True,
854-
return_softmax_lse=has_context,
855-
)
1049+
# print("_forward_prefill")
1050+
# print(" q.shape = {}".format(q.shape))
1051+
# print(" k.shape = {}".format(k.shape))
1052+
# print(" v.shape = {}".format(v.shape))
1053+
# print(" has_context = {}".format(has_context))
1054+
# print(" use_fi_prefill = {}".format(self._use_fi_prefill))
1055+
1056+
if self._use_fi_prefill:
1057+
assert attn_metadata.fi_prefill is not None
1058+
1059+
output = self._run_fi_prefill(
1060+
prefill_wrapper=attn_metadata.fi_prefill.prefill_main,
1061+
q=q,
1062+
k=k,
1063+
v=v,
1064+
return_softmax_lse=has_context,
1065+
)
1066+
else:
1067+
output = self._flash_attn_varlen_diff_headdims(
1068+
q=q,
1069+
k=k,
1070+
v=v,
1071+
cu_seqlens_q=attn_metadata.prefill.query_start_loc,
1072+
cu_seqlens_k=attn_metadata.prefill.query_start_loc,
1073+
max_seqlen_q=attn_metadata.prefill.max_query_len,
1074+
max_seqlen_k=attn_metadata.prefill.max_query_len,
1075+
softmax_scale=self.scale,
1076+
causal=True,
1077+
return_softmax_lse=has_context,
1078+
)
8561079

8571080
if has_context:
8581081
suffix_output, suffix_lse = output
@@ -895,7 +1118,6 @@ def forward(
8951118
output: Optional[torch.Tensor] = None,
8961119
output_scale: Optional[torch.Tensor] = None,
8971120
) -> torch.Tensor:
898-
8991121
assert output is not None, "Output tensor must be provided."
9001122

9011123
if output_scale is not None:

vllm/v1/attention/backends/mla/cutlass_mla.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def _forward_decode(
9090
# Clone q_nope and q_pe to make sure strides computation is correct.
9191
q_nope = q_nope.clone()
9292
q_pe = q_pe.clone()
93+
9394
ops.cutlass_mla_decode(o, q_nope, q_pe, kv_c_and_k_pe_cache,
9495
attn_metadata.decode.seq_lens,
9596
attn_metadata.decode.block_table, self.scale)

0 commit comments

Comments
 (0)