Skip to content

[Nvidia] Integrate cudnn prefill paged attention kernel for head_dim == 128 models, like Llama family #20850

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion vllm/envs.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@
VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None
VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 120
VLLM_USE_CUDNN_PREFILL: bool = False


def get_default_cache_root():
Expand Down Expand Up @@ -961,7 +962,11 @@ def get_vllm_port() -> Optional[int]:
# consumer. This is only applicable when using NixlConnector in a
# disaggregated decode-prefill setup.
"VLLM_NIXL_ABORT_REQUEST_TIMEOUT":
lambda: int(os.getenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", "120"))
lambda: int(os.getenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", "120")),

# Controls whether or not to use cudnn prefill
"VLLM_USE_CUDNN_PREFILL":
lambda: bool(int(os.getenv("VLLM_USE_CUDNN_PREFILL", "1")))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The default value for VLLM_USE_CUDNN_PREFILL is False in the VllmEnvs TypedDict (line 142), but the default value in os.getenv is "1", which evaluates to True. This inconsistency can lead to unexpected behavior where the feature is enabled by default when the environment variable is not explicitly set. To maintain consistency, the default value in os.getenv should be "0".

Suggested change
lambda: bool(int(os.getenv("VLLM_USE_CUDNN_PREFILL", "1")))
lambda: bool(int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0")))

}

# --8<-- [end:env-vars-definition]
Expand Down
67 changes: 63 additions & 4 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
from typing import TYPE_CHECKING, Any, Optional

import torch

import vllm.envs as envs
from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper,
MultiLevelCascadeAttentionWrapper)

import vllm.envs as envs
from flashinfer.prefill import cudnn_batch_prefill_with_kv_cache
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionType)
from vllm.attention.layer import Attention
Expand All @@ -33,6 +34,8 @@

logger = init_logger(__name__)

CUDNN_SUPPORTED_HEAD_SIZES = [128]


class FlashInferBackend(AttentionBackend):

Expand Down Expand Up @@ -202,6 +205,12 @@ class FlashInferMetadata:
num_prefills: int
num_prefill_tokens: int

# For cudnn prefill
max_query_len: int
max_seq_len: int
actual_seq_lens_q: torch.Tensor
actual_seq_lens_kv: torch.Tensor

# For cascade attention.
use_cascade: bool
shared_qo_indptr: Optional[torch.Tensor] = None
Expand All @@ -213,6 +222,12 @@ class FlashInferMetadata:
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
cascade_wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None

cudnn_workspace: Optional[torch.Tensor] = None
block_table: Optional[torch.Tensor] = None

def _is_cudnn_supported(self):
return self.head_dim in CUDNN_SUPPORTED_HEAD_SIZES and envs.VLLM_USE_CUDNN_PREFILL

@property
def query_start_loc(self):
# The GPUModelRunner expects to be able to access this property.
Expand Down Expand Up @@ -367,7 +382,8 @@ def _plan(self, attn_metadata: FlashInferMetadata):
# Regular attention (common case).
# Decodes are at the front and prefills are at the back,
# according to reorder_batch()
if self._num_prefills > 0:
if self._num_prefills > 0 and not attn_metadata._is_cudnn_supported(
):
# Decodes are first so prefills start after the last decode
prefill_start = self._num_decodes
attn_metadata.prefill_wrapper = self._get_prefill_wrapper()
Expand Down Expand Up @@ -433,6 +449,7 @@ def build(self, common_prefix_len: int,
qo_indptr = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
block_table_tensor = self.block_table.get_device_tensor()[:num_reqs]
max_query_len = common_attn_metadata.max_query_len
slot_mapping = self.block_table.slot_mapping_cpu[:num_actual_tokens].to(
self.runner.device, non_blocking=True).long()

Expand Down Expand Up @@ -463,6 +480,7 @@ def build(self, common_prefix_len: int,
shared_kv_page_indices = None
shared_kv_last_page_len = None

max_seq_len = int(seq_lens.max().item())
mask = (torch.arange(block_table_tensor.size(1),
dtype=block_table_tensor.dtype,
device=block_table_tensor.device).unsqueeze(0)
Expand All @@ -479,7 +497,7 @@ def build(self, common_prefix_len: int,
paged_kv_last_page_len = seq_lens % page_size
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
page_size, paged_kv_last_page_len)

self._get_workspace_buffer()
attn_metadata = FlashInferMetadata(
num_actual_tokens=num_actual_tokens,
qo_indptr=qo_indptr,
Expand All @@ -502,6 +520,12 @@ def build(self, common_prefix_len: int,
shared_kv_page_indptr=shared_kv_page_indptr,
shared_kv_page_indices=shared_kv_page_indices,
shared_kv_last_page_len=shared_kv_last_page_len,
max_query_len=max_query_len,
max_seq_len=max_seq_len,
actual_seq_lens_q=qo_indptr[1:] - qo_indptr[:-1],
actual_seq_lens_kv=seq_lens.to(self.runner.device),
block_table=block_table_tensor,
cudnn_workspace=self._workspace_buffer.to(torch.int8),
)

self._plan(attn_metadata)
Expand Down Expand Up @@ -653,13 +677,48 @@ def forward(
assert prefill_wrapper._logits_soft_cap == (self.logits_soft_cap
or 0.0)
assert prefill_wrapper._sm_scale == self.scale

prefill_wrapper.run(
prefill_query,
kv_cache.permute(*stride_order),
k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float,
out=output[num_decode_tokens:],
)
elif num_prefill_tokens > 0 and attn_metadata._is_cudnn_supported():
(total_num_pages, _, page_size, num_kv_heads,
head_dim) = kv_cache.shape
k_cache = kv_cache[:, 0].as_strided(
(total_num_pages, num_kv_heads, page_size, head_dim), (
page_size * num_kv_heads * head_dim,
head_dim,
num_kv_heads * head_dim,
1,
))
Comment on lines +691 to +697
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The as_strided call for k_cache has an incorrect stride for the first dimension. The stride for kv_cache[:, 0]'s first dimension is page_size * num_kv_heads * head_dim, but page_size * num_kv_heads * head_dim is used. This will lead to incorrect memory access. The correct first stride should be kv_cache.stride(0).

Suggested change
k_cache = kv_cache[:, 0].as_strided(
(total_num_pages, num_kv_heads, page_size, head_dim), (
page_size * num_kv_heads * head_dim,
head_dim,
num_kv_heads * head_dim,
1,
))
k_cache = kv_cache[:, 0].as_strided(
(total_num_pages, num_kv_heads, page_size, head_dim), (
kv_cache.stride(0),
head_dim,
num_kv_heads * head_dim,
1,
))

v_cache = kv_cache[:, 1].as_strided(
(total_num_pages, num_kv_heads, page_size, head_dim), (
page_size * num_kv_heads * head_dim,
head_dim,
num_kv_heads * head_dim,
1,
))
Comment on lines +698 to +704
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Similar to k_cache, the as_strided call for v_cache has an incorrect stride for the first dimension. It should also be kv_cache.stride(0).

Suggested change
v_cache = kv_cache[:, 1].as_strided(
(total_num_pages, num_kv_heads, page_size, head_dim), (
page_size * num_kv_heads * head_dim,
head_dim,
num_kv_heads * head_dim,
1,
))
v_cache = kv_cache[:, 1].as_strided(
(total_num_pages, num_kv_heads, page_size, head_dim), (
kv_cache.stride(0),
head_dim,
num_kv_heads * head_dim,
1,
))

output[num_decode_tokens:], _ = cudnn_batch_prefill_with_kv_cache(
q=query[num_decode_tokens:],
k_cache=k_cache,
v_cache=v_cache,
scale=self.scale,
workspace_buffer=attn_metadata.cudnn_workspace,
max_token_per_sequence=attn_metadata.max_query_len,
max_sequence_kv=attn_metadata.max_seq_len,
block_tables=attn_metadata.block_table[num_decode_tokens:],
actual_seq_lens_q=attn_metadata.
actual_seq_lens_q[num_decode_tokens:].view(-1, 1, 1, 1),
actual_seq_lens_kv=attn_metadata.
actual_seq_lens_kv[num_decode_tokens:].view(-1, 1, 1, 1),
causal=True,
return_lse=True,
is_cuda_graph_compatible=True,
)

if decode_wrapper := attn_metadata.decode_wrapper:
decode_query = query[:num_decode_tokens]
Expand Down
151 changes: 123 additions & 28 deletions vllm/v1/attention/backends/mla/common.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@

import torch

import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
AttentionMetadata,
Expand Down Expand Up @@ -228,6 +229,9 @@

logger = init_logger(__name__)

CUDNN_SUPPORTED_HEAD_DIMS = [192, 128]
CUDNN_WORKSPACE_SIZE = 12800


class MLACommonBackend(AttentionBackend):

Expand Down Expand Up @@ -282,11 +286,14 @@ class ChunkedContextMetadata:
starts: torch.Tensor
seq_tot: list[int]
max_seq_lens: list[int]
seq_lens: torch.Tensor
workspace: torch.Tensor

block_table: torch.Tensor
query_start_loc: torch.Tensor
query_seq_lens: torch.Tensor
max_query_len: int
workspace: torch.Tensor
chunked_context: Optional[ChunkedContextMetadata] = None


Expand Down Expand Up @@ -390,6 +397,12 @@ def __init__(self,
dtype=model_config.dtype,
device=runner.device,
)
self.workspace = torch.empty(
CUDNN_WORKSPACE_SIZE * scheduler_config.max_num_seqs,
dtype=torch.int8,
device=runner.device,
)

self.block_table = block_table

def reorder_batch(self, input_batch: "InputBatch",
Expand Down Expand Up @@ -566,6 +579,7 @@ def build(self, common_prefix_len: int,
starts=chunk_starts.to(device, non_blocking=True),
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
seq_lens=chunk_seq_lens,
workspace=self.chunked_prefill_workspace,
)

Expand All @@ -576,6 +590,9 @@ def build(self, common_prefix_len: int,
block_table=block_table_tensor[reqs_start:, ...],
query_start_loc=prefill_query_start_loc,
max_query_len=max_query_len,
workspace=self.workspace,
query_seq_lens=prefill_query_start_loc[1:] -
prefill_query_start_loc[:-1],
chunked_context=chunked_context_metadata,
)

Expand Down Expand Up @@ -663,9 +680,10 @@ def __init__(
# v with 0s to match the qk head dim for attention backends that do
# not support different headdims
# We don't need to pad V if we are on a hopper system with FA3
self._pad_v = self.vllm_flash_attn_version is None or not (
self.vllm_flash_attn_version == 3
and current_platform.get_device_capability()[0] == 9)
self._pad_v = not envs.VLLM_USE_CUDNN_PREFILL and (
self.vllm_flash_attn_version is None
or not (self.vllm_flash_attn_version == 3
and current_platform.get_device_capability()[0] == 9))

def _flash_attn_varlen_diff_headdims(self,
q,
Expand Down Expand Up @@ -705,6 +723,40 @@ def _flash_attn_varlen_diff_headdims(self,
return attn_out, lse
return attn_out

def _cudnn_varlen_func_diff_headdims(
self,
q,
k,
v,
scale,
workspace,
max_q_seq_lens,
max_kv_seq_lens,
seq_lens_q,
seq_lens_kv,
causal,
is_cuda_graph_compatible=True,
):
from flashinfer.prefill import cudnn_batch_prefill_with_kv_cache

if not is_cuda_graph_compatible:
seq_lens_q = seq_lens_q.to("cpu")
seq_lens_kv = seq_lens_kv.to("cpu")
return cudnn_batch_prefill_with_kv_cache(
q=q,
k_cache=k,
v_cache=v,
scale=scale,
workspace_buffer=workspace,
max_token_per_sequence=max_q_seq_lens,
max_sequence_kv=max_kv_seq_lens,
actual_seq_lens_q=seq_lens_q.view(-1, 1, 1, 1),
actual_seq_lens_kv=seq_lens_kv.view(-1, 1, 1, 1),
causal=causal,
return_lse=True,
is_cuda_graph_compatible=is_cuda_graph_compatible,
)

def _v_up_proj(self, x):
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
Expand Down Expand Up @@ -803,19 +855,41 @@ def _compute_prefill_context(
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
dim=-1)

attn_output, attn_softmax_lse = \
self._flash_attn_varlen_diff_headdims(
q=q,
k=k,
v=v,
cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.chunked_context.cu_seq_lens[i],
max_seqlen_q=prefill_metadata.max_query_len,
max_seqlen_k=prefill_metadata.chunked_context.max_seq_lens[i],
softmax_scale=self.scale,
causal=False, # Context is unmasked
return_softmax_lse=True,
)
if envs.VLLM_USE_CUDNN_PREFILL and all(
t.shape[-1] in CUDNN_SUPPORTED_HEAD_DIMS
for t in (q, k, v)):
attn_output, attn_softmax_lse = (
self._cudnn_varlen_func_diff_headdims(
q,
k,
v,
scale=self.scale,
workspace=prefill_metadata.workspace,
max_q_seq_lens=prefill_metadata.max_query_len,
max_kv_seq_lens=prefill_metadata.chunked_context.
max_seq_lens[i],
seq_lens_q=prefill_metadata.query_seq_lens.view(
-1, 1, 1, 1),
seq_lens_kv=prefill_metadata.chunked_context.
seq_lens[i].view(-1, 1, 1, 1),
causal=False,
is_cuda_graph_compatible=
True, #Indicates actual_seq_lens are on GPU or CPU.
))
else:
attn_output, attn_softmax_lse = \
self._flash_attn_varlen_diff_headdims(
q=q,
k=k,
v=v,
cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.chunked_context.cu_seq_lens[i],
max_seqlen_q=prefill_metadata.max_query_len,
max_seqlen_k=prefill_metadata.chunked_context.max_seq_lens[i],
softmax_scale=self.scale,
causal=False, # Context is unmasked
return_softmax_lse=True,
)

if output is None:
output = attn_output
Expand Down Expand Up @@ -854,18 +928,39 @@ def _forward_prefill(

k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)

output = self._flash_attn_varlen_diff_headdims(
q=q,
k=k,
v=v,
cu_seqlens_q=attn_metadata.prefill.query_start_loc,
cu_seqlens_k=attn_metadata.prefill.query_start_loc,
max_seqlen_q=attn_metadata.prefill.max_query_len,
max_seqlen_k=attn_metadata.prefill.max_query_len,
softmax_scale=self.scale,
causal=True,
return_softmax_lse=has_context,
)
if envs.VLLM_USE_CUDNN_PREFILL and all(
t.shape[-1] in CUDNN_SUPPORTED_HEAD_DIMS for t in (q, k, v)):
output = self._cudnn_varlen_func_diff_headdims(
q,
k,
v,
scale=self.scale,
workspace=attn_metadata.prefill.workspace,
max_q_seq_lens=attn_metadata.prefill.max_query_len,
max_kv_seq_lens=attn_metadata.prefill.max_query_len,
seq_lens_q=attn_metadata.prefill.query_seq_lens.view(
-1, 1, 1, 1),
seq_lens_kv=attn_metadata.prefill.query_seq_lens.view(
-1, 1, 1, 1),
causal=True,
is_cuda_graph_compatible=
True, #Indicates actual_seq_lens are on GPU or CPU.
)
if not has_context:
output = output[0]
else:
output = self._flash_attn_varlen_diff_headdims(
q=q,
k=k,
v=v,
cu_seqlens_q=attn_metadata.prefill.query_start_loc,
cu_seqlens_k=attn_metadata.prefill.query_start_loc,
max_seqlen_q=attn_metadata.prefill.max_query_len,
max_seqlen_k=attn_metadata.prefill.max_query_len,
softmax_scale=self.scale,
causal=True,
return_softmax_lse=has_context,
)

if has_context:
suffix_output, suffix_lse = output
Expand Down