Skip to content

Commit 90397d5

Browse files
committed
integrate cudnn for FI.
Signed-off-by: Elfie Guo <elfieg@nvidia.com>
1 parent 6bbf179 commit 90397d5

File tree

1 file changed

+82
-5
lines changed

1 file changed

+82
-5
lines changed

vllm/v1/attention/backends/flashinfer.py

Lines changed: 82 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,18 @@
77
from typing import TYPE_CHECKING, Any, Optional
88

99
import torch
10+
11+
import vllm.envs as envs
1012
from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
1113
BatchPrefillWithPagedKVCacheWrapper,
1214
MultiLevelCascadeAttentionWrapper)
13-
14-
import vllm.envs as envs
15+
from flashinfer.prefill import cudnn_batch_prefill_with_kv_cache
1516
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
1617
AttentionType)
1718
from vllm.attention.layer import Attention
1819
from vllm.config import VllmConfig, get_layers_from_vllm_config
1920
from vllm.logger import init_logger
21+
from vllm.platforms import current_platform
2022
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
2123
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
2224
CommonAttentionMetadata,
@@ -33,6 +35,13 @@
3335

3436
logger = init_logger(__name__)
3537

38+
CUDNN_SUPPORTED_HEAD_SIZES = [128]
39+
40+
41+
def is_cudnn_supported(head_dim: int):
42+
return head_dim in CUDNN_SUPPORTED_HEAD_SIZES \
43+
and current_platform.has_device_capability(100)
44+
3645

3746
class FlashInferBackend(AttentionBackend):
3847

@@ -202,6 +211,12 @@ class FlashInferMetadata:
202211
num_prefills: int
203212
num_prefill_tokens: int
204213

214+
# For cudnn prefill
215+
max_query_len: int
216+
max_seq_len: int
217+
actual_seq_lens_q: torch.Tensor
218+
actual_seq_lens_kv: torch.Tensor
219+
205220
# For cascade attention.
206221
use_cascade: bool
207222
shared_qo_indptr: Optional[torch.Tensor] = None
@@ -213,6 +228,9 @@ class FlashInferMetadata:
213228
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
214229
cascade_wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None
215230

231+
cudnn_workspace: Optional[torch.Tensor] = None
232+
block_table: Optional[torch.Tensor] = None
233+
216234
@property
217235
def query_start_loc(self):
218236
# The GPUModelRunner expects to be able to access this property.
@@ -301,9 +319,13 @@ def reorder_batch(self, input_batch: InputBatch,
301319

302320
def _get_workspace_buffer(self):
303321
if self._workspace_buffer is None:
322+
if is_cudnn_supported(self.kv_cache_spec.head_size):
323+
dtype = torch.int8
324+
else:
325+
dtype = torch.uint8
304326
self._workspace_buffer = torch.empty(
305327
FLASHINFER_WORKSPACE_BUFFER_SIZE,
306-
dtype=torch.uint8,
328+
dtype=dtype,
307329
device=self.runner.device)
308330
return self._workspace_buffer
309331

@@ -367,7 +389,8 @@ def _plan(self, attn_metadata: FlashInferMetadata):
367389
# Regular attention (common case).
368390
# Decodes are at the front and prefills are at the back,
369391
# according to reorder_batch()
370-
if self._num_prefills > 0:
392+
if self._num_prefills > 0 and not is_cudnn_supported(
393+
attn_metadata.head_dim):
371394
# Decodes are first so prefills start after the last decode
372395
prefill_start = self._num_decodes
373396
attn_metadata.prefill_wrapper = self._get_prefill_wrapper()
@@ -433,6 +456,7 @@ def build(self, common_prefix_len: int,
433456
qo_indptr = common_attn_metadata.query_start_loc
434457
seq_lens = common_attn_metadata.seq_lens
435458
block_table_tensor = self.block_table.get_device_tensor()[:num_reqs]
459+
max_query_len = common_attn_metadata.max_query_len
436460
slot_mapping = self.block_table.slot_mapping_cpu[:num_actual_tokens].to(
437461
self.runner.device, non_blocking=True).long()
438462

@@ -463,6 +487,7 @@ def build(self, common_prefix_len: int,
463487
shared_kv_page_indices = None
464488
shared_kv_last_page_len = None
465489

490+
max_seq_len = int(seq_lens.max().item())
466491
mask = (torch.arange(block_table_tensor.size(1),
467492
dtype=block_table_tensor.dtype,
468493
device=block_table_tensor.device).unsqueeze(0)
@@ -480,6 +505,10 @@ def build(self, common_prefix_len: int,
480505
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
481506
page_size, paged_kv_last_page_len)
482507

508+
if is_cudnn_supported(self.kv_cache_spec.head_size):
509+
self._get_workspace_buffer()
510+
assert self._workspace_buffer is not None, "workspace_buffer is not set"
511+
483512
attn_metadata = FlashInferMetadata(
484513
num_actual_tokens=num_actual_tokens,
485514
qo_indptr=qo_indptr,
@@ -502,7 +531,13 @@ def build(self, common_prefix_len: int,
502531
shared_kv_page_indptr=shared_kv_page_indptr,
503532
shared_kv_page_indices=shared_kv_page_indices,
504533
shared_kv_last_page_len=shared_kv_last_page_len,
505-
)
534+
max_query_len=max_query_len,
535+
max_seq_len=max_seq_len,
536+
actual_seq_lens_q=qo_indptr[1:] - qo_indptr[:-1],
537+
actual_seq_lens_kv=seq_lens.to(self.runner.device),
538+
block_table=block_table_tensor,
539+
cudnn_workspace=self._workspace_buffer
540+
if is_cudnn_supported(self.kv_cache_spec.head_size) else None)
506541

507542
self._plan(attn_metadata)
508543

@@ -653,13 +688,55 @@ def forward(
653688
assert prefill_wrapper._logits_soft_cap == (self.logits_soft_cap
654689
or 0.0)
655690
assert prefill_wrapper._sm_scale == self.scale
691+
656692
prefill_wrapper.run(
657693
prefill_query,
658694
kv_cache.permute(*stride_order),
659695
k_scale=layer._k_scale_float,
660696
v_scale=layer._v_scale_float,
661697
out=output[num_decode_tokens:],
662698
)
699+
elif num_prefill_tokens > 0 and FlashInferBackend.is_cudnn_supported(
700+
attn_metadata.head_dim):
701+
(total_num_pages, _, page_size, num_kv_heads,
702+
head_dim) = kv_cache.shape
703+
704+
# Validate dimensions match expected head_dim
705+
assert head_dim == self.head_size, (
706+
f"KV cache head_dim {head_dim} != expected {self.head_size}")
707+
assert attn_metadata.block_table is not None, \
708+
"block_table is not set"
709+
k_cache = kv_cache[:, 0].as_strided(
710+
(total_num_pages, num_kv_heads, page_size, head_dim), (
711+
page_size * num_kv_heads * head_dim,
712+
head_dim,
713+
num_kv_heads * head_dim,
714+
1,
715+
))
716+
v_cache = kv_cache[:, 1].as_strided(
717+
(total_num_pages, num_kv_heads, page_size, head_dim), (
718+
page_size * num_kv_heads * head_dim,
719+
head_dim,
720+
num_kv_heads * head_dim,
721+
1,
722+
))
723+
output[num_decode_tokens:], _ = cudnn_batch_prefill_with_kv_cache(
724+
q=query[num_decode_tokens:],
725+
k_cache=k_cache,
726+
v_cache=v_cache,
727+
scale=self.scale,
728+
workspace_buffer=attn_metadata.cudnn_workspace,
729+
max_token_per_sequence=attn_metadata.max_query_len,
730+
max_sequence_kv=attn_metadata.max_seq_len,
731+
block_tables=attn_metadata.block_table[num_decode_tokens:],
732+
actual_seq_lens_q=attn_metadata.
733+
actual_seq_lens_q[num_decode_tokens:].view(-1, 1, 1, 1),
734+
actual_seq_lens_kv=attn_metadata.
735+
actual_seq_lens_kv[num_decode_tokens:].view(-1, 1, 1, 1),
736+
causal=True,
737+
return_lse=True,
738+
is_cuda_graph_compatible=True,
739+
)
663740

664741
if decode_wrapper := attn_metadata.decode_wrapper:
665742
decode_query = query[:num_decode_tokens]

0 commit comments

Comments
 (0)