From 90397d5b044d2fed1c50fe4772f266420abbc517 Mon Sep 17 00:00:00 2001 From: Elfie Guo Date: Tue, 15 Jul 2025 05:54:09 +0000 Subject: [PATCH] integrate cudnn for FI. Signed-off-by: Elfie Guo --- vllm/v1/attention/backends/flashinfer.py | 87 ++++++++++++++++++++++-- 1 file changed, 82 insertions(+), 5 deletions(-) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 860309faa90..aef3eddea0c 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -7,16 +7,18 @@ from typing import TYPE_CHECKING, Any, Optional import torch + +import vllm.envs as envs from flashinfer import (BatchDecodeWithPagedKVCacheWrapper, BatchPrefillWithPagedKVCacheWrapper, MultiLevelCascadeAttentionWrapper) - -import vllm.envs as envs +from flashinfer.prefill import cudnn_batch_prefill_with_kv_cache from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionType) from vllm.attention.layer import Attention from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger +from vllm.platforms import current_platform from vllm.v1.attention.backends.flash_attn import use_cascade_attention from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, CommonAttentionMetadata, @@ -33,6 +35,13 @@ logger = init_logger(__name__) +CUDNN_SUPPORTED_HEAD_SIZES = [128] + + +def is_cudnn_supported(head_dim: int): + return head_dim in CUDNN_SUPPORTED_HEAD_SIZES \ + and current_platform.has_device_capability(100) + class FlashInferBackend(AttentionBackend): @@ -202,6 +211,12 @@ class FlashInferMetadata: num_prefills: int num_prefill_tokens: int + # For cudnn prefill + max_query_len: int + max_seq_len: int + actual_seq_lens_q: torch.Tensor + actual_seq_lens_kv: torch.Tensor + # For cascade attention. use_cascade: bool shared_qo_indptr: Optional[torch.Tensor] = None @@ -213,6 +228,9 @@ class FlashInferMetadata: decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None cascade_wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None + cudnn_workspace: Optional[torch.Tensor] = None + block_table: Optional[torch.Tensor] = None + @property def query_start_loc(self): # The GPUModelRunner expects to be able to access this property. @@ -301,9 +319,13 @@ def reorder_batch(self, input_batch: InputBatch, def _get_workspace_buffer(self): if self._workspace_buffer is None: + if is_cudnn_supported(self.kv_cache_spec.head_size): + dtype = torch.int8 + else: + dtype = torch.uint8 self._workspace_buffer = torch.empty( FLASHINFER_WORKSPACE_BUFFER_SIZE, - dtype=torch.uint8, + dtype=dtype, device=self.runner.device) return self._workspace_buffer @@ -367,7 +389,8 @@ def _plan(self, attn_metadata: FlashInferMetadata): # Regular attention (common case). # Decodes are at the front and prefills are at the back, # according to reorder_batch() - if self._num_prefills > 0: + if self._num_prefills > 0 and not is_cudnn_supported( + attn_metadata.head_dim): # Decodes are first so prefills start after the last decode prefill_start = self._num_decodes attn_metadata.prefill_wrapper = self._get_prefill_wrapper() @@ -433,6 +456,7 @@ def build(self, common_prefix_len: int, qo_indptr = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens block_table_tensor = self.block_table.get_device_tensor()[:num_reqs] + max_query_len = common_attn_metadata.max_query_len slot_mapping = self.block_table.slot_mapping_cpu[:num_actual_tokens].to( self.runner.device, non_blocking=True).long() @@ -463,6 +487,7 @@ def build(self, common_prefix_len: int, shared_kv_page_indices = None shared_kv_last_page_len = None + max_seq_len = int(seq_lens.max().item()) mask = (torch.arange(block_table_tensor.size(1), dtype=block_table_tensor.dtype, device=block_table_tensor.device).unsqueeze(0) @@ -480,6 +505,10 @@ def build(self, common_prefix_len: int, paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0, page_size, paged_kv_last_page_len) + if is_cudnn_supported(self.kv_cache_spec.head_size): + self._get_workspace_buffer() + assert self._workspace_buffer is not None, "workspace_buffer is not set" + attn_metadata = FlashInferMetadata( num_actual_tokens=num_actual_tokens, qo_indptr=qo_indptr, @@ -502,7 +531,13 @@ def build(self, common_prefix_len: int, shared_kv_page_indptr=shared_kv_page_indptr, shared_kv_page_indices=shared_kv_page_indices, shared_kv_last_page_len=shared_kv_last_page_len, - ) + max_query_len=max_query_len, + max_seq_len=max_seq_len, + actual_seq_lens_q=qo_indptr[1:] - qo_indptr[:-1], + actual_seq_lens_kv=seq_lens.to(self.runner.device), + block_table=block_table_tensor, + cudnn_workspace=self._workspace_buffer + if is_cudnn_supported(self.kv_cache_spec.head_size) else None) self._plan(attn_metadata) @@ -653,6 +688,7 @@ def forward( assert prefill_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0) assert prefill_wrapper._sm_scale == self.scale + prefill_wrapper.run( prefill_query, kv_cache.permute(*stride_order), @@ -660,6 +696,47 @@ def forward( v_scale=layer._v_scale_float, out=output[num_decode_tokens:], ) + elif num_prefill_tokens > 0 and FlashInferBackend.is_cudnn_supported( + attn_metadata.head_dim): + (total_num_pages, _, page_size, num_kv_heads, + head_dim) = kv_cache.shape + + # Validate dimensions match expected head_dim + assert head_dim == self.head_size, ( + f"KV cache head_dim {head_dim} != expected {self.head_size}") + assert attn_metadata.block_table is not None, \ + "block_table is not set" + k_cache = kv_cache[:, 0].as_strided( + (total_num_pages, num_kv_heads, page_size, head_dim), ( + page_size * num_kv_heads * head_dim, + head_dim, + num_kv_heads * head_dim, + 1, + )) + v_cache = kv_cache[:, 1].as_strided( + (total_num_pages, num_kv_heads, page_size, head_dim), ( + page_size * num_kv_heads * head_dim, + head_dim, + num_kv_heads * head_dim, + 1, + )) + output[num_decode_tokens:], _ = cudnn_batch_prefill_with_kv_cache( + q=query[num_decode_tokens:], + k_cache=k_cache, + v_cache=v_cache, + scale=self.scale, + workspace_buffer=attn_metadata.cudnn_workspace, + max_token_per_sequence=attn_metadata.max_query_len, + max_sequence_kv=attn_metadata.max_seq_len, + block_tables=attn_metadata.block_table[num_decode_tokens:], + actual_seq_lens_q=attn_metadata. + actual_seq_lens_q[num_decode_tokens:].view(-1, 1, 1, 1), + actual_seq_lens_kv=attn_metadata. + actual_seq_lens_kv[num_decode_tokens:].view(-1, 1, 1, 1), + causal=True, + return_lse=True, + is_cuda_graph_compatible=True, + ) if decode_wrapper := attn_metadata.decode_wrapper: decode_query = query[:num_decode_tokens]