Skip to content

Commit 7242ff8

Browse files
author
Elfie Guo
committed
port cudnn API.
1 parent 6bbf179 commit 7242ff8

File tree

2 files changed

+129
-29
lines changed

2 files changed

+129
-29
lines changed

vllm/envs.py

100644100755
Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@
139139
VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True
140140
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None
141141
VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 120
142+
VLLM_USE_CUDNN_PREFILL: bool = False
142143

143144

144145
def get_default_cache_root():
@@ -961,7 +962,11 @@ def get_vllm_port() -> Optional[int]:
961962
# consumer. This is only applicable when using NixlConnector in a
962963
# disaggregated decode-prefill setup.
963964
"VLLM_NIXL_ABORT_REQUEST_TIMEOUT":
964-
lambda: int(os.getenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", "120"))
965+
lambda: int(os.getenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", "120")),
966+
967+
# Controls whether or not to use cudnn prefill
968+
"VLLM_USE_CUDNN_PREFILL":
969+
lambda: bool(int(os.getenv("VLLM_USE_CUDNN_PREFILL", "1")))
965970
}
966971

967972
# --8<-- [end:env-vars-definition]

vllm/v1/attention/backends/mla/common.py

100644100755
Lines changed: 123 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@
194194

195195
import torch
196196

197+
import vllm.envs as envs
197198
from vllm import _custom_ops as ops
198199
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
199200
AttentionMetadata,
@@ -228,6 +229,9 @@
228229

229230
logger = init_logger(__name__)
230231

232+
CUDNN_SUPPORTED_HEAD_DIMS = [192, 128]
233+
CUDNN_WORKSPACE_SIZE = 12800
234+
231235

232236
class MLACommonBackend(AttentionBackend):
233237

@@ -282,11 +286,14 @@ class ChunkedContextMetadata:
282286
starts: torch.Tensor
283287
seq_tot: list[int]
284288
max_seq_lens: list[int]
289+
seq_lens: torch.Tensor
285290
workspace: torch.Tensor
286291

287292
block_table: torch.Tensor
288293
query_start_loc: torch.Tensor
294+
query_seq_lens: torch.Tensor
289295
max_query_len: int
296+
workspace: torch.Tensor
290297
chunked_context: Optional[ChunkedContextMetadata] = None
291298

292299

@@ -390,6 +397,12 @@ def __init__(self,
390397
dtype=model_config.dtype,
391398
device=runner.device,
392399
)
400+
self.workspace = torch.empty(
401+
CUDNN_WORKSPACE_SIZE * scheduler_config.max_num_seqs,
402+
dtype=torch.int8,
403+
device=runner.device,
404+
)
405+
393406
self.block_table = block_table
394407

395408
def reorder_batch(self, input_batch: "InputBatch",
@@ -566,6 +579,7 @@ def build(self, common_prefix_len: int,
566579
starts=chunk_starts.to(device, non_blocking=True),
567580
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
568581
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
582+
seq_lens=chunk_seq_lens,
569583
workspace=self.chunked_prefill_workspace,
570584
)
571585

@@ -576,6 +590,9 @@ def build(self, common_prefix_len: int,
576590
block_table=block_table_tensor[reqs_start:, ...],
577591
query_start_loc=prefill_query_start_loc,
578592
max_query_len=max_query_len,
593+
workspace=self.workspace,
594+
query_seq_lens=prefill_query_start_loc[1:] -
595+
prefill_query_start_loc[:-1],
579596
chunked_context=chunked_context_metadata,
580597
)
581598

@@ -663,9 +680,10 @@ def __init__(
663680
# v with 0s to match the qk head dim for attention backends that do
664681
# not support different headdims
665682
# We don't need to pad V if we are on a hopper system with FA3
666-
self._pad_v = self.vllm_flash_attn_version is None or not (
667-
self.vllm_flash_attn_version == 3
668-
and current_platform.get_device_capability()[0] == 9)
683+
self._pad_v = not envs.VLLM_USE_CUDNN_PREFILL and (
684+
self.vllm_flash_attn_version is None
685+
or not (self.vllm_flash_attn_version == 3
686+
and current_platform.get_device_capability()[0] == 9))
669687

670688
def _flash_attn_varlen_diff_headdims(self,
671689
q,
@@ -705,6 +723,40 @@ def _flash_attn_varlen_diff_headdims(self,
705723
return attn_out, lse
706724
return attn_out
707725

726+
def _cudnn_varlen_func_diff_headdims(
727+
self,
728+
q,
729+
k,
730+
v,
731+
scale,
732+
workspace,
733+
max_q_seq_lens,
734+
max_kv_seq_lens,
735+
seq_lens_q,
736+
seq_lens_kv,
737+
causal,
738+
is_cuda_graph_compatible=True,
739+
):
740+
from flashinfer.prefill import cudnn_batch_prefill_with_kv_cache
741+
742+
if not is_cuda_graph_compatible:
743+
seq_lens_q = seq_lens_q.to("cpu")
744+
seq_lens_kv = seq_lens_kv.to("cpu")
745+
return cudnn_batch_prefill_with_kv_cache(
746+
q=q,
747+
k_cache=k,
748+
v_cache=v,
749+
scale=scale,
750+
workspace_buffer=workspace,
751+
max_token_per_sequence=max_q_seq_lens,
752+
max_sequence_kv=max_kv_seq_lens,
753+
actual_seq_lens_q=seq_lens_q.view(-1, 1, 1, 1),
754+
actual_seq_lens_kv=seq_lens_kv.view(-1, 1, 1, 1),
755+
causal=causal,
756+
return_lse=True,
757+
is_cuda_graph_compatible=is_cuda_graph_compatible,
758+
)
759+
708760
def _v_up_proj(self, x):
709761
# Convert from (B, N, L) to (N, B, L)
710762
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
@@ -803,19 +855,41 @@ def _compute_prefill_context(
803855
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
804856
dim=-1)
805857

806-
attn_output, attn_softmax_lse = \
807-
self._flash_attn_varlen_diff_headdims(
808-
q=q,
809-
k=k,
810-
v=v,
811-
cu_seqlens_q=prefill_metadata.query_start_loc,
812-
cu_seqlens_k=prefill_metadata.chunked_context.cu_seq_lens[i],
813-
max_seqlen_q=prefill_metadata.max_query_len,
814-
max_seqlen_k=prefill_metadata.chunked_context.max_seq_lens[i],
815-
softmax_scale=self.scale,
816-
causal=False, # Context is unmasked
817-
return_softmax_lse=True,
818-
)
858+
if envs.VLLM_USE_CUDNN_PREFILL and all(
859+
t.shape[-1] in CUDNN_SUPPORTED_HEAD_DIMS
860+
for t in (q, k, v)):
861+
attn_output, attn_softmax_lse = (
862+
self._cudnn_varlen_func_diff_headdims(
863+
q,
864+
k,
865+
v,
866+
scale=self.scale,
867+
workspace=prefill_metadata.workspace,
868+
max_q_seq_lens=prefill_metadata.max_query_len,
869+
max_kv_seq_lens=prefill_metadata.chunked_context.
870+
max_seq_lens[i],
871+
seq_lens_q=prefill_metadata.query_seq_lens.view(
872+
-1, 1, 1, 1),
873+
seq_lens_kv=prefill_metadata.chunked_context.
874+
seq_lens[i].view(-1, 1, 1, 1),
875+
causal=False,
876+
is_cuda_graph_compatible=
877+
True, #Indicates actual_seq_lens are on GPU or CPU.
878+
))
879+
else:
880+
attn_output, attn_softmax_lse = \
881+
self._flash_attn_varlen_diff_headdims(
882+
q=q,
883+
k=k,
884+
v=v,
885+
cu_seqlens_q=prefill_metadata.query_start_loc,
886+
cu_seqlens_k=prefill_metadata.chunked_context.cu_seq_lens[i],
887+
max_seqlen_q=prefill_metadata.max_query_len,
888+
max_seqlen_k=prefill_metadata.chunked_context.max_seq_lens[i],
889+
softmax_scale=self.scale,
890+
causal=False, # Context is unmasked
891+
return_softmax_lse=True,
892+
)
819893

820894
if output is None:
821895
output = attn_output
@@ -854,18 +928,39 @@ def _forward_prefill(
854928

855929
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
856930

857-
output = self._flash_attn_varlen_diff_headdims(
858-
q=q,
859-
k=k,
860-
v=v,
861-
cu_seqlens_q=attn_metadata.prefill.query_start_loc,
862-
cu_seqlens_k=attn_metadata.prefill.query_start_loc,
863-
max_seqlen_q=attn_metadata.prefill.max_query_len,
864-
max_seqlen_k=attn_metadata.prefill.max_query_len,
865-
softmax_scale=self.scale,
866-
causal=True,
867-
return_softmax_lse=has_context,
868-
)
931+
if envs.VLLM_USE_CUDNN_PREFILL and all(
932+
t.shape[-1] in CUDNN_SUPPORTED_HEAD_DIMS for t in (q, k, v)):
933+
output = self._cudnn_varlen_func_diff_headdims(
934+
q,
935+
k,
936+
v,
937+
scale=self.scale,
938+
workspace=attn_metadata.prefill.workspace,
939+
max_q_seq_lens=attn_metadata.prefill.max_query_len,
940+
max_kv_seq_lens=attn_metadata.prefill.max_query_len,
941+
seq_lens_q=attn_metadata.prefill.query_seq_lens.view(
942+
-1, 1, 1, 1),
943+
seq_lens_kv=attn_metadata.prefill.query_seq_lens.view(
944+
-1, 1, 1, 1),
945+
causal=True,
946+
is_cuda_graph_compatible=
947+
True, #Indicates actual_seq_lens are on GPU or CPU.
948+
)
949+
if not has_context:
950+
output = output[0]
951+
else:
952+
output = self._flash_attn_varlen_diff_headdims(
953+
q=q,
954+
k=k,
955+
v=v,
956+
cu_seqlens_q=attn_metadata.prefill.query_start_loc,
957+
cu_seqlens_k=attn_metadata.prefill.query_start_loc,
958+
max_seqlen_q=attn_metadata.prefill.max_query_len,
959+
max_seqlen_k=attn_metadata.prefill.max_query_len,
960+
softmax_scale=self.scale,
961+
causal=True,
962+
return_softmax_lse=has_context,
963+
)
869964

870965
if has_context:
871966
suffix_output, suffix_lse = output

0 commit comments

Comments
 (0)