diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index f922e6e4c9e..c5bed7d2cde 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -7,12 +7,13 @@ from typing import TYPE_CHECKING, Any, Optional import torch + +import vllm.envs as envs from flashinfer import (BatchDecodeWithPagedKVCacheWrapper, BatchPrefillWithPagedKVCacheWrapper, MultiLevelCascadeAttentionWrapper) from flashinfer.decode import trtllm_batch_decode_with_kv_cache - -import vllm.envs as envs +from flashinfer.prefill import cudnn_batch_prefill_with_kv_cache from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionType) from vllm.logger import init_logger @@ -36,6 +37,13 @@ logger = init_logger(__name__) +CUDNN_SUPPORTED_HEAD_SIZES = [128] + + +def is_cudnn_supported(head_dim: int): + return head_dim in CUDNN_SUPPORTED_HEAD_SIZES \ + and current_platform.has_device_capability(100) + class FlashInferBackend(AttentionBackend): @@ -203,6 +211,10 @@ class FlashInferMetadata: num_prefills: int num_prefill_tokens: int + # For cudnn prefill + max_query_len: int + actual_seq_lens_q: torch.Tensor + # For cascade attention. use_cascade: bool shared_qo_indptr: Optional[torch.Tensor] = None @@ -302,9 +314,13 @@ def reorder_batch(self, input_batch: InputBatch, def _get_workspace_buffer(self): if self._workspace_buffer is None: + if is_cudnn_supported(self.kv_cache_spec.head_size): + dtype = torch.int8 + else: + dtype = torch.uint8 self._workspace_buffer = torch.empty( FLASHINFER_WORKSPACE_BUFFER_SIZE, - dtype=torch.uint8, + dtype=dtype, device=self.runner.device) return self._workspace_buffer @@ -369,7 +385,8 @@ def _plan(self, attn_metadata: FlashInferMetadata): # Regular attention (common case). # Decodes are at the front and prefills are at the back, # according to reorder_batch() - if self._num_prefills > 0: + if self._num_prefills > 0 and not is_cudnn_supported( + attn_metadata.head_dim): # Decodes are first so prefills start after the last decode prefill_start = self._num_decodes attn_metadata.prefill_wrapper = self._get_prefill_wrapper() @@ -441,6 +458,7 @@ def build(self, common_prefix_len: int, max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max()) seq_lens = common_attn_metadata.seq_lens block_table_tensor = self.block_table.get_device_tensor()[:num_reqs] + max_query_len = common_attn_metadata.max_query_len slot_mapping = self.block_table.slot_mapping_cpu[:num_actual_tokens].to( self.runner.device, non_blocking=True).long() @@ -471,6 +489,7 @@ def build(self, common_prefix_len: int, shared_kv_page_indices = None shared_kv_last_page_len = None + max_seq_len = int(seq_lens.max().item()) mask = (torch.arange(block_table_tensor.size(1), dtype=block_table_tensor.dtype, device=block_table_tensor.device).unsqueeze(0) @@ -487,6 +506,7 @@ def build(self, common_prefix_len: int, paged_kv_last_page_len = seq_lens % page_size paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0, page_size, paged_kv_last_page_len) + cache_dtype = self.runner.cache_config.cache_dtype if cache_dtype.startswith("fp8"): kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( @@ -515,7 +535,9 @@ def build(self, common_prefix_len: int, shared_kv_page_indptr=shared_kv_page_indptr, shared_kv_page_indices=shared_kv_page_indices, shared_kv_last_page_len=shared_kv_last_page_len, + max_query_len=max_query_len, max_seq_len=max_seq_len, + actual_seq_lens_q=qo_indptr[1:] - qo_indptr[:-1], seq_lens=seq_lens, block_table_tensor=block_table_tensor, workspace_buffer=self._workspace_buffer, @@ -681,6 +703,7 @@ def forward( assert prefill_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0) assert prefill_wrapper._sm_scale == self.scale + prefill_wrapper.run( prefill_query, kv_cache.permute(*stride_order), @@ -688,6 +711,48 @@ def forward( v_scale=layer._v_scale_float, out=output[num_decode_tokens:], ) + elif num_prefill_tokens > 0 and FlashInferBackend.is_cudnn_supported( + attn_metadata.head_dim): + (total_num_pages, _, page_size, num_kv_heads, + head_dim) = kv_cache.shape + + # Validate dimensions match expected head_dim + assert head_dim == self.head_size, ( + f"KV cache head_dim {head_dim} != expected {self.head_size}") + + k_cache = kv_cache[:, 0].as_strided( + (total_num_pages, num_kv_heads, page_size, head_dim), ( + page_size * num_kv_heads * head_dim, + head_dim, + num_kv_heads * head_dim, + 1, + )) + v_cache = kv_cache[:, 1].as_strided( + (total_num_pages, num_kv_heads, page_size, head_dim), ( + page_size * num_kv_heads * head_dim, + head_dim, + num_kv_heads * head_dim, + 1, + )) + output[num_decode_tokens:], _ = cudnn_batch_prefill_with_kv_cache( + q=query[num_decode_tokens:], + k_cache=k_cache, + v_cache=v_cache, + scale=self.scale, + workspace_buffer=attn_metadata.workspace_buffer, + max_token_per_sequence=attn_metadata.max_query_len, + max_sequence_kv=attn_metadata.max_seq_len, + block_tables=attn_metadata. + block_table_tensor[num_decode_tokens:], + actual_seq_lens_q=attn_metadata. + actual_seq_lens_q[num_decode_tokens:].view(-1, 1, 1, 1), + actual_seq_lens_kv=attn_metadata.seq_lens[num_decode_tokens:]. + view(-1, 1, 1, 1), + causal=True, + return_lse=True, + is_cuda_graph_compatible=True, + ) + if decode_wrapper := attn_metadata.decode_wrapper: decode_query = query[:num_decode_tokens] assert decode_query.shape[0] == num_decode_tokens