-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
[Nvidia] Integrate cudnn prefill paged attention kernel for head_dim == 128 models, like Llama family #20850
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -7,11 +7,12 @@ | |||||||||||||||||||||||||||||
from typing import TYPE_CHECKING, Any, Optional | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
import torch | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
import vllm.envs as envs | ||||||||||||||||||||||||||||||
from flashinfer import (BatchDecodeWithPagedKVCacheWrapper, | ||||||||||||||||||||||||||||||
BatchPrefillWithPagedKVCacheWrapper, | ||||||||||||||||||||||||||||||
MultiLevelCascadeAttentionWrapper) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
import vllm.envs as envs | ||||||||||||||||||||||||||||||
from flashinfer.prefill import cudnn_batch_prefill_with_kv_cache | ||||||||||||||||||||||||||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, | ||||||||||||||||||||||||||||||
AttentionType) | ||||||||||||||||||||||||||||||
from vllm.attention.layer import Attention | ||||||||||||||||||||||||||||||
|
@@ -33,6 +34,8 @@ | |||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
logger = init_logger(__name__) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
CUDNN_SUPPORTED_HEAD_SIZES = [128] | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
class FlashInferBackend(AttentionBackend): | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
|
@@ -202,6 +205,12 @@ class FlashInferMetadata: | |||||||||||||||||||||||||||||
num_prefills: int | ||||||||||||||||||||||||||||||
num_prefill_tokens: int | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
# For cudnn prefill | ||||||||||||||||||||||||||||||
max_query_len: int | ||||||||||||||||||||||||||||||
max_seq_len: int | ||||||||||||||||||||||||||||||
actual_seq_lens_q: torch.Tensor | ||||||||||||||||||||||||||||||
actual_seq_lens_kv: torch.Tensor | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
# For cascade attention. | ||||||||||||||||||||||||||||||
use_cascade: bool | ||||||||||||||||||||||||||||||
shared_qo_indptr: Optional[torch.Tensor] = None | ||||||||||||||||||||||||||||||
|
@@ -213,6 +222,12 @@ class FlashInferMetadata: | |||||||||||||||||||||||||||||
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None | ||||||||||||||||||||||||||||||
cascade_wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
cudnn_workspace: Optional[torch.Tensor] = None | ||||||||||||||||||||||||||||||
block_table: Optional[torch.Tensor] = None | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
def _is_cudnn_supported(self): | ||||||||||||||||||||||||||||||
return self.head_dim in CUDNN_SUPPORTED_HEAD_SIZES and envs.VLLM_USE_CUDNN_PREFILL | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
@property | ||||||||||||||||||||||||||||||
def query_start_loc(self): | ||||||||||||||||||||||||||||||
# The GPUModelRunner expects to be able to access this property. | ||||||||||||||||||||||||||||||
|
@@ -367,7 +382,8 @@ def _plan(self, attn_metadata: FlashInferMetadata): | |||||||||||||||||||||||||||||
# Regular attention (common case). | ||||||||||||||||||||||||||||||
# Decodes are at the front and prefills are at the back, | ||||||||||||||||||||||||||||||
# according to reorder_batch() | ||||||||||||||||||||||||||||||
if self._num_prefills > 0: | ||||||||||||||||||||||||||||||
if self._num_prefills > 0 and not attn_metadata._is_cudnn_supported( | ||||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||||
# Decodes are first so prefills start after the last decode | ||||||||||||||||||||||||||||||
prefill_start = self._num_decodes | ||||||||||||||||||||||||||||||
attn_metadata.prefill_wrapper = self._get_prefill_wrapper() | ||||||||||||||||||||||||||||||
|
@@ -433,6 +449,7 @@ def build(self, common_prefix_len: int, | |||||||||||||||||||||||||||||
qo_indptr = common_attn_metadata.query_start_loc | ||||||||||||||||||||||||||||||
seq_lens = common_attn_metadata.seq_lens | ||||||||||||||||||||||||||||||
block_table_tensor = self.block_table.get_device_tensor()[:num_reqs] | ||||||||||||||||||||||||||||||
max_query_len = common_attn_metadata.max_query_len | ||||||||||||||||||||||||||||||
slot_mapping = self.block_table.slot_mapping_cpu[:num_actual_tokens].to( | ||||||||||||||||||||||||||||||
self.runner.device, non_blocking=True).long() | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
|
@@ -463,6 +480,7 @@ def build(self, common_prefix_len: int, | |||||||||||||||||||||||||||||
shared_kv_page_indices = None | ||||||||||||||||||||||||||||||
shared_kv_last_page_len = None | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
max_seq_len = int(seq_lens.max().item()) | ||||||||||||||||||||||||||||||
mask = (torch.arange(block_table_tensor.size(1), | ||||||||||||||||||||||||||||||
dtype=block_table_tensor.dtype, | ||||||||||||||||||||||||||||||
device=block_table_tensor.device).unsqueeze(0) | ||||||||||||||||||||||||||||||
|
@@ -479,7 +497,7 @@ def build(self, common_prefix_len: int, | |||||||||||||||||||||||||||||
paged_kv_last_page_len = seq_lens % page_size | ||||||||||||||||||||||||||||||
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0, | ||||||||||||||||||||||||||||||
page_size, paged_kv_last_page_len) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
self._get_workspace_buffer() | ||||||||||||||||||||||||||||||
attn_metadata = FlashInferMetadata( | ||||||||||||||||||||||||||||||
num_actual_tokens=num_actual_tokens, | ||||||||||||||||||||||||||||||
qo_indptr=qo_indptr, | ||||||||||||||||||||||||||||||
|
@@ -502,6 +520,12 @@ def build(self, common_prefix_len: int, | |||||||||||||||||||||||||||||
shared_kv_page_indptr=shared_kv_page_indptr, | ||||||||||||||||||||||||||||||
shared_kv_page_indices=shared_kv_page_indices, | ||||||||||||||||||||||||||||||
shared_kv_last_page_len=shared_kv_last_page_len, | ||||||||||||||||||||||||||||||
max_query_len=max_query_len, | ||||||||||||||||||||||||||||||
max_seq_len=max_seq_len, | ||||||||||||||||||||||||||||||
actual_seq_lens_q=qo_indptr[1:] - qo_indptr[:-1], | ||||||||||||||||||||||||||||||
actual_seq_lens_kv=seq_lens.to(self.runner.device), | ||||||||||||||||||||||||||||||
block_table=block_table_tensor, | ||||||||||||||||||||||||||||||
cudnn_workspace=self._workspace_buffer.to(torch.int8), | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
self._plan(attn_metadata) | ||||||||||||||||||||||||||||||
|
@@ -653,13 +677,48 @@ def forward( | |||||||||||||||||||||||||||||
assert prefill_wrapper._logits_soft_cap == (self.logits_soft_cap | ||||||||||||||||||||||||||||||
or 0.0) | ||||||||||||||||||||||||||||||
assert prefill_wrapper._sm_scale == self.scale | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
prefill_wrapper.run( | ||||||||||||||||||||||||||||||
prefill_query, | ||||||||||||||||||||||||||||||
kv_cache.permute(*stride_order), | ||||||||||||||||||||||||||||||
k_scale=layer._k_scale_float, | ||||||||||||||||||||||||||||||
v_scale=layer._v_scale_float, | ||||||||||||||||||||||||||||||
out=output[num_decode_tokens:], | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
elif num_prefill_tokens > 0 and attn_metadata._is_cudnn_supported(): | ||||||||||||||||||||||||||||||
(total_num_pages, _, page_size, num_kv_heads, | ||||||||||||||||||||||||||||||
head_dim) = kv_cache.shape | ||||||||||||||||||||||||||||||
k_cache = kv_cache[:, 0].as_strided( | ||||||||||||||||||||||||||||||
(total_num_pages, num_kv_heads, page_size, head_dim), ( | ||||||||||||||||||||||||||||||
page_size * num_kv_heads * head_dim, | ||||||||||||||||||||||||||||||
head_dim, | ||||||||||||||||||||||||||||||
num_kv_heads * head_dim, | ||||||||||||||||||||||||||||||
1, | ||||||||||||||||||||||||||||||
)) | ||||||||||||||||||||||||||||||
Comment on lines
+691
to
+697
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||||||||||||||||||||||||||
v_cache = kv_cache[:, 1].as_strided( | ||||||||||||||||||||||||||||||
(total_num_pages, num_kv_heads, page_size, head_dim), ( | ||||||||||||||||||||||||||||||
page_size * num_kv_heads * head_dim, | ||||||||||||||||||||||||||||||
head_dim, | ||||||||||||||||||||||||||||||
num_kv_heads * head_dim, | ||||||||||||||||||||||||||||||
1, | ||||||||||||||||||||||||||||||
)) | ||||||||||||||||||||||||||||||
Comment on lines
+698
to
+704
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to
Suggested change
|
||||||||||||||||||||||||||||||
output[num_decode_tokens:], _ = cudnn_batch_prefill_with_kv_cache( | ||||||||||||||||||||||||||||||
q=query[num_decode_tokens:], | ||||||||||||||||||||||||||||||
k_cache=k_cache, | ||||||||||||||||||||||||||||||
v_cache=v_cache, | ||||||||||||||||||||||||||||||
scale=self.scale, | ||||||||||||||||||||||||||||||
workspace_buffer=attn_metadata.cudnn_workspace, | ||||||||||||||||||||||||||||||
max_token_per_sequence=attn_metadata.max_query_len, | ||||||||||||||||||||||||||||||
max_sequence_kv=attn_metadata.max_seq_len, | ||||||||||||||||||||||||||||||
block_tables=attn_metadata.block_table[num_decode_tokens:], | ||||||||||||||||||||||||||||||
actual_seq_lens_q=attn_metadata. | ||||||||||||||||||||||||||||||
actual_seq_lens_q[num_decode_tokens:].view(-1, 1, 1, 1), | ||||||||||||||||||||||||||||||
actual_seq_lens_kv=attn_metadata. | ||||||||||||||||||||||||||||||
actual_seq_lens_kv[num_decode_tokens:].view(-1, 1, 1, 1), | ||||||||||||||||||||||||||||||
causal=True, | ||||||||||||||||||||||||||||||
return_lse=True, | ||||||||||||||||||||||||||||||
is_cuda_graph_compatible=True, | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
if decode_wrapper := attn_metadata.decode_wrapper: | ||||||||||||||||||||||||||||||
decode_query = query[:num_decode_tokens] | ||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default value for
VLLM_USE_CUDNN_PREFILL
isFalse
in theVllmEnvs
TypedDict (line 142), but the default value inos.getenv
is"1"
, which evaluates toTrue
. This inconsistency can lead to unexpected behavior where the feature is enabled by default when the environment variable is not explicitly set. To maintain consistency, the default value inos.getenv
should be"0"
.