Skip to content

Commit 30800b0

Browse files
elfieggElfie Guo
andauthored
[Nvidia] Integrate SM100 cudnn prefill API to MLA prefill (#20411)
Signed-off-by: Elfie Guo <elfieg@nvidia.com> Co-authored-by: Elfie Guo <eflieg@nvidia.com>
1 parent 10be209 commit 30800b0

File tree

2 files changed

+113
-5
lines changed

2 files changed

+113
-5
lines changed

vllm/envs.py

100644100755
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@
139139
VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True
140140
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None
141141
VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 120
142+
VLLM_USE_CUDNN_PREFILL: bool = False
142143
VLLM_LOOPBACK_IP: str = ""
143144

144145

@@ -962,6 +963,10 @@ def get_vllm_port() -> Optional[int]:
962963
"VLLM_NIXL_ABORT_REQUEST_TIMEOUT":
963964
lambda: int(os.getenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", "120")),
964965

966+
# Controls whether or not to use cudnn prefill
967+
"VLLM_USE_CUDNN_PREFILL":
968+
lambda: bool(int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0"))),
969+
965970
# If set to 1, use the TRTLLM Decode Attention backend in flashinfer.
966971
"VLLM_USE_TRTLLM_DECODE_ATTENTION":
967972
lambda: os.getenv("VLLM_USE_TRTLLM_DECODE_ATTENTION", None),

vllm/v1/attention/backends/mla/common.py

100644100755
Lines changed: 108 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@
194194

195195
import torch
196196

197+
import vllm.envs as envs
197198
from vllm import _custom_ops as ops
198199
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
199200
AttentionMetadata,
@@ -225,6 +226,8 @@
225226

226227
try:
227228
from flashinfer import BatchPrefillWithRaggedKVCacheWrapper
229+
from flashinfer.prefill import ( # noqa: F401
230+
cudnn_batch_prefill_with_kv_cache)
228231
flashinfer_available = True
229232
except ImportError:
230233
flashinfer_available = False
@@ -236,6 +239,8 @@
236239

237240
logger = init_logger(__name__)
238241

242+
CUDNN_WORKSPACE_SIZE = 12800
243+
239244

240245
class MLACommonBackend(AttentionBackend):
241246

@@ -294,6 +299,7 @@ class ChunkedContextMetadata:
294299
starts: torch.Tensor
295300
seq_tot: list[int]
296301
max_seq_lens: list[int]
302+
seq_lens: torch.Tensor
297303
workspace: torch.Tensor
298304

299305
block_table: torch.Tensor
@@ -309,6 +315,17 @@ class FlashInferPrefillMetadata(MLACommonPrefillMetadata):
309315
default_factory=list)
310316

311317

318+
@dataclass
319+
class CudnnPrefillMetadata(MLACommonPrefillMetadata):
320+
321+
class ChunkedContextMetadata(
322+
MLACommonPrefillMetadata.ChunkedContextMetadata):
323+
seq_lens: torch.Tensor
324+
325+
query_seq_lens: Optional[torch.Tensor] = None
326+
cudnn_workspace: Optional[torch.Tensor] = None
327+
328+
312329
@dataclass
313330
class MLACommonDecodeMetadata:
314331
block_table: torch.Tensor
@@ -351,7 +368,8 @@ class MLACommonMetadata(Generic[D]):
351368

352369
decode: Optional[D] = None
353370
prefill: Optional[Union[MLACommonPrefillMetadata,
354-
FlashInferPrefillMetadata]] = None
371+
FlashInferPrefillMetadata,
372+
CudnnPrefillMetadata]] = None
355373

356374
def __post_init__(self):
357375
if self.head_dim is not None:
@@ -362,13 +380,19 @@ def __post_init__(self):
362380

363381

364382
def use_flashinfer_prefill() -> bool:
365-
if flashinfer_available:
383+
if flashinfer_available and not envs.VLLM_USE_CUDNN_PREFILL:
366384
# For blackwell default to flashinfer prefill if its available since
367385
# its faster than FA2.
368386
return current_platform.has_device_capability(100)
369387
return False
370388

371389

390+
def use_cudnn_prefill() -> bool:
391+
if flashinfer_available and envs.VLLM_USE_CUDNN_PREFILL:
392+
return current_platform.has_device_capability(100)
393+
return False
394+
395+
372396
# Currently 394MB, this can be tuned based on GEMM sizes used.
373397
# Choosen to be the same as sglang:
374398
# https://github.com/sgl-project/sglang/blob/766392c6bda2558b61ce6d1c1bfd8081a549e1f1/python/sglang/global_config.py#L37
@@ -427,11 +451,15 @@ def __init__(self,
427451
dtype=model_config.dtype,
428452
device=runner.device,
429453
)
454+
430455
self.block_table = block_table
431456

457+
self._use_cudnn_prefill = use_cudnn_prefill()
432458
self._use_fi_prefill = use_flashinfer_prefill()
433-
self.prefill_metadata_cls = FlashInferPrefillMetadata \
434-
if self._use_fi_prefill else MLACommonPrefillMetadata
459+
self.prefill_metadata_cls = (
460+
FlashInferPrefillMetadata
461+
if self._use_fi_prefill else CudnnPrefillMetadata
462+
if self._use_cudnn_prefill else MLACommonPrefillMetadata)
435463

436464
if self._use_fi_prefill:
437465
self._workspace_buffer = torch.empty(
@@ -447,6 +475,13 @@ def __init__(self,
447475
self._global_hyperparameters = infer_global_hyperparameters(
448476
get_per_layer_parameters(runner.vllm_config, MLACommonImpl))
449477

478+
if self._use_cudnn_prefill:
479+
self.cudnn_workspace = torch.empty(
480+
CUDNN_WORKSPACE_SIZE * scheduler_config.max_num_seqs,
481+
dtype=torch.int8,
482+
device=runner.device,
483+
)
484+
450485
def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata):
451486
qo_indptr = prefill.query_start_loc
452487

@@ -692,15 +727,24 @@ def build(self, common_prefix_len: int,
692727
out=cu_seq_lens_cpu[:, 1:],
693728
dtype=torch.int32)
694729

730+
chunked_context_metadata_cls = \
731+
CudnnPrefillMetadata.ChunkedContextMetadata \
732+
if self._use_cudnn_prefill else \
733+
MLACommonPrefillMetadata.ChunkedContextMetadata
734+
695735
chunked_context_metadata = \
696-
MLACommonPrefillMetadata.ChunkedContextMetadata(
736+
chunked_context_metadata_cls(
697737
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
698738
starts=chunk_starts.to(device, non_blocking=True),
699739
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
700740
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
741+
seq_lens=chunk_seq_lens,
701742
workspace=self.chunked_prefill_workspace,
702743
)
703744

745+
if self._use_cudnn_prefill:
746+
chunked_context_metadata.seq_lens = chunk_seq_lens
747+
704748
assert max(chunked_context_metadata.max_seq_lens) <= \
705749
self.chunked_prefill_workspace_size
706750

@@ -711,6 +755,12 @@ def build(self, common_prefix_len: int,
711755
chunked_context=chunked_context_metadata,
712756
)
713757

758+
if self._use_cudnn_prefill:
759+
assert isinstance(prefill_metadata, CudnnPrefillMetadata)
760+
prefill_metadata.query_seq_lens = prefill_query_start_loc[1:] \
761+
- prefill_query_start_loc[:-1]
762+
prefill_metadata.cudnn_workspace = self.cudnn_workspace
763+
714764
decode_metadata = None
715765
if self._num_decodes > 0:
716766
decode_metadata = self._build_decode(
@@ -794,6 +844,12 @@ def __init__(
794844
self._run_prefill_context_chunk = self._run_prefill_context_chunk_fi
795845
self._run_prefill_new_tokens = self._run_prefill_new_tokens_fi
796846
self._pad_v = False
847+
elif use_cudnn_prefill():
848+
logger.debug_once("Using CUDNN prefill for MLA")
849+
self._run_prefill_context_chunk = \
850+
self._run_prefill_context_chunk_cudnn
851+
self._run_prefill_new_tokens = self._run_prefill_new_tokens_cudnn
852+
self._pad_v = False
797853
else: # Use FlashAttention
798854
logger.debug_once("Using FlashAttention prefill for MLA")
799855
self._run_prefill_context_chunk = self._run_prefill_context_chunk_fa
@@ -882,6 +938,29 @@ def _run_prefill_new_tokens_fi(self, prefill: MLACommonPrefillMetadata, q,
882938
return_lse=return_softmax_lse,
883939
)
884940

941+
def _run_prefill_new_tokens_cudnn(self, prefill: MLACommonPrefillMetadata,
942+
q, k, v, return_softmax_lse):
943+
assert isinstance(prefill, CudnnPrefillMetadata)
944+
assert prefill.query_seq_lens is not None
945+
output, lse = cudnn_batch_prefill_with_kv_cache(
946+
q=q,
947+
k_cache=k,
948+
v_cache=v,
949+
scale=self.scale,
950+
workspace_buffer=prefill.cudnn_workspace,
951+
max_token_per_sequence=prefill.max_query_len,
952+
max_sequence_kv=prefill.max_query_len,
953+
actual_seq_lens_q=prefill.query_seq_lens.view(-1, 1, 1, 1),
954+
actual_seq_lens_kv=prefill.query_seq_lens.view(-1, 1, 1, 1),
955+
causal=True,
956+
return_lse=True, # do not support False for now
957+
is_cuda_graph_compatible=
958+
True, #Indicates actual_seq_lens are on GPU or CPU.
959+
)
960+
if return_softmax_lse:
961+
return output, lse
962+
return output
963+
885964
def _run_prefill_context_chunk_fa(self, prefill: MLACommonPrefillMetadata,
886965
chunk_idx: int, q, k, v):
887966
assert prefill.chunked_context is not None
@@ -908,6 +987,30 @@ def _run_prefill_context_chunk_fi(self, prefill: MLACommonPrefillMetadata,
908987
return_lse=True,
909988
)
910989

990+
def _run_prefill_context_chunk_cudnn(self,
991+
prefill: MLACommonPrefillMetadata,
992+
chunk_idx: int, q, k, v):
993+
assert isinstance(prefill, CudnnPrefillMetadata)
994+
assert prefill.chunked_context is not None
995+
assert prefill.chunked_context.seq_lens[chunk_idx] is not None
996+
assert prefill.query_seq_lens is not None
997+
return cudnn_batch_prefill_with_kv_cache(
998+
q=q,
999+
k_cache=k,
1000+
v_cache=v,
1001+
scale=self.scale,
1002+
workspace_buffer=prefill.cudnn_workspace,
1003+
max_token_per_sequence=prefill.max_query_len,
1004+
max_sequence_kv=prefill.chunked_context.max_seq_lens[chunk_idx],
1005+
actual_seq_lens_q=prefill.query_seq_lens.view(-1, 1, 1, 1),
1006+
actual_seq_lens_kv=prefill.chunked_context.seq_lens[chunk_idx].
1007+
view(-1, 1, 1, 1),
1008+
causal=False,
1009+
return_lse=True,
1010+
is_cuda_graph_compatible=
1011+
True, #Indicates actual_seq_lens are on GPU or CPU.
1012+
)
1013+
9111014
def _v_up_proj(self, x):
9121015
# Convert from (B, N, L) to (N, B, L)
9131016
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)

0 commit comments

Comments
 (0)