Skip to content

Commit c61a88a

Browse files
committed
Add cudnn prefill for running llama.
1 parent 7242ff8 commit c61a88a

File tree

1 file changed

+63
-4
lines changed

1 file changed

+63
-4
lines changed

vllm/v1/attention/backends/flashinfer.py

Lines changed: 63 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77
from typing import TYPE_CHECKING, Any, Optional
88

99
import torch
10+
11+
import vllm.envs as envs
1012
from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
1113
BatchPrefillWithPagedKVCacheWrapper,
1214
MultiLevelCascadeAttentionWrapper)
13-
14-
import vllm.envs as envs
15+
from flashinfer.prefill import cudnn_batch_prefill_with_kv_cache
1516
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
1617
AttentionType)
1718
from vllm.attention.layer import Attention
@@ -33,6 +34,8 @@
3334

3435
logger = init_logger(__name__)
3536

37+
CUDNN_SUPPORTED_HEAD_SIZES = [128]
38+
3639

3740
class FlashInferBackend(AttentionBackend):
3841

@@ -202,6 +205,12 @@ class FlashInferMetadata:
202205
num_prefills: int
203206
num_prefill_tokens: int
204207

208+
# For cudnn prefill
209+
max_query_len: int
210+
max_seq_len: int
211+
actual_seq_lens_q: torch.Tensor
212+
actual_seq_lens_kv: torch.Tensor
213+
205214
# For cascade attention.
206215
use_cascade: bool
207216
shared_qo_indptr: Optional[torch.Tensor] = None
@@ -213,6 +222,12 @@ class FlashInferMetadata:
213222
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
214223
cascade_wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None
215224

225+
cudnn_workspace: Optional[torch.Tensor] = None
226+
block_table: Optional[torch.Tensor] = None
227+
228+
def _is_cudnn_supported(self):
229+
return self.head_dim in CUDNN_SUPPORTED_HEAD_SIZES and envs.VLLM_USE_CUDNN_PREFILL
230+
216231
@property
217232
def query_start_loc(self):
218233
# The GPUModelRunner expects to be able to access this property.
@@ -367,7 +382,8 @@ def _plan(self, attn_metadata: FlashInferMetadata):
367382
# Regular attention (common case).
368383
# Decodes are at the front and prefills are at the back,
369384
# according to reorder_batch()
370-
if self._num_prefills > 0:
385+
if self._num_prefills > 0 and not attn_metadata._is_cudnn_supported(
386+
):
371387
# Decodes are first so prefills start after the last decode
372388
prefill_start = self._num_decodes
373389
attn_metadata.prefill_wrapper = self._get_prefill_wrapper()
@@ -433,6 +449,7 @@ def build(self, common_prefix_len: int,
433449
qo_indptr = common_attn_metadata.query_start_loc
434450
seq_lens = common_attn_metadata.seq_lens
435451
block_table_tensor = self.block_table.get_device_tensor()[:num_reqs]
452+
max_query_len = common_attn_metadata.max_query_len
436453
slot_mapping = self.block_table.slot_mapping_cpu[:num_actual_tokens].to(
437454
self.runner.device, non_blocking=True).long()
438455

@@ -463,6 +480,7 @@ def build(self, common_prefix_len: int,
463480
shared_kv_page_indices = None
464481
shared_kv_last_page_len = None
465482

483+
max_seq_len = int(seq_lens.max().item())
466484
mask = (torch.arange(block_table_tensor.size(1),
467485
dtype=block_table_tensor.dtype,
468486
device=block_table_tensor.device).unsqueeze(0)
@@ -479,7 +497,7 @@ def build(self, common_prefix_len: int,
479497
paged_kv_last_page_len = seq_lens % page_size
480498
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
481499
page_size, paged_kv_last_page_len)
482-
500+
self._get_workspace_buffer()
483501
attn_metadata = FlashInferMetadata(
484502
num_actual_tokens=num_actual_tokens,
485503
qo_indptr=qo_indptr,
@@ -502,6 +520,12 @@ def build(self, common_prefix_len: int,
502520
shared_kv_page_indptr=shared_kv_page_indptr,
503521
shared_kv_page_indices=shared_kv_page_indices,
504522
shared_kv_last_page_len=shared_kv_last_page_len,
523+
max_query_len=max_query_len,
524+
max_seq_len=max_seq_len,
525+
actual_seq_lens_q=qo_indptr[1:] - qo_indptr[:-1],
526+
actual_seq_lens_kv=seq_lens.to(self.runner.device),
527+
block_table=block_table_tensor,
528+
cudnn_workspace=self._workspace_buffer.to(torch.int8),
505529
)
506530

507531
self._plan(attn_metadata)
@@ -653,13 +677,48 @@ def forward(
653677
assert prefill_wrapper._logits_soft_cap == (self.logits_soft_cap
654678
or 0.0)
655679
assert prefill_wrapper._sm_scale == self.scale
680+
656681
prefill_wrapper.run(
657682
prefill_query,
658683
kv_cache.permute(*stride_order),
659684
k_scale=layer._k_scale_float,
660685
v_scale=layer._v_scale_float,
661686
out=output[num_decode_tokens:],
662687
)
688+
elif num_prefill_tokens > 0 and attn_metadata._is_cudnn_supported():
689+
(total_num_pages, _, page_size, num_kv_heads,
690+
head_dim) = kv_cache.shape
691+
k_cache = kv_cache[:, 0].as_strided(
692+
(total_num_pages, num_kv_heads, page_size, head_dim), (
693+
page_size * num_kv_heads * head_dim,
694+
head_dim,
695+
num_kv_heads * head_dim,
696+
1,
697+
))
698+
v_cache = kv_cache[:, 1].as_strided(
699+
(total_num_pages, num_kv_heads, page_size, head_dim), (
700+
page_size * num_kv_heads * head_dim,
701+
head_dim,
702+
num_kv_heads * head_dim,
703+
1,
704+
))
705+
output[num_decode_tokens:], _ = cudnn_batch_prefill_with_kv_cache(
706+
q=query[num_decode_tokens:],
707+
k_cache=k_cache,
708+
v_cache=v_cache,
709+
scale=self.scale,
710+
workspace_buffer=attn_metadata.cudnn_workspace,
711+
max_token_per_sequence=attn_metadata.max_query_len,
712+
max_sequence_kv=attn_metadata.max_seq_len,
713+
block_tables=attn_metadata.block_table[num_decode_tokens:],
714+
actual_seq_lens_q=attn_metadata.
715+
actual_seq_lens_q[num_decode_tokens:].view(-1, 1, 1, 1),
716+
actual_seq_lens_kv=attn_metadata.
717+
actual_seq_lens_kv[num_decode_tokens:].view(-1, 1, 1, 1),
718+
causal=True,
719+
return_lse=True,
720+
is_cuda_graph_compatible=True,
721+
)
663722

664723
if decode_wrapper := attn_metadata.decode_wrapper:
665724
decode_query = query[:num_decode_tokens]

0 commit comments

Comments
 (0)