Skip to content

[V1][CUDA] Full cudagraph support for FlashInfer #21367

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.utils import cdiv
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
get_kv_cache_layout)
from vllm.v1.kv_cache_interface import AttentionSpec
Expand Down Expand Up @@ -144,7 +145,9 @@ def _get_sliding_window_configs(

class FlashAttentionMetadataBuilder(
AttentionMetadataBuilder[FlashAttentionMetadata]):
full_cudagraph_supported: ClassVar[bool] = get_flash_attn_version() == 3
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.NEVER if get_flash_attn_version() == 2 \
else AttentionCGSupport.ALWAYS

def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device):
Expand Down
147 changes: 129 additions & 18 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, ClassVar, Optional

import torch
from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
Expand All @@ -18,10 +18,11 @@
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import cdiv
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata, PerLayerParameters,
get_kv_cache_layout, get_per_layer_parameters,
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
PerLayerParameters, get_kv_cache_layout, get_per_layer_parameters,
infer_global_hyperparameters, reorder_batch_to_split_decodes_and_prefills,
split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec
Expand Down Expand Up @@ -223,21 +224,49 @@ def __post_init__(self):


class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.PURE_DECODE_ONLY

def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device):
self.device = device
self.vllm_config = vllm_config
self.cache_config = vllm_config.cache_config
self.kv_cache_spec = kv_cache_spec
self._workspace_buffer = None
self._prefill_wrapper = None # Wrapper for prefill/append
self._decode_wrapper = None # Wrapper for decode
self._decode_wrapper = None # Wrapper for decode (general shape)

self.compilation_config = vllm_config.compilation_config
max_num_pages_per_req = cdiv(vllm_config.model_config.max_model_len,
self.kv_cache_spec.block_size)
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
max_num_pages = max_num_reqs * max_num_pages_per_req
self.enable_cuda_graph = self.compilation_config.full_cuda_graph
if self.enable_cuda_graph:
# For full cudagraph capture, one `decode_wrapper` for each batch
# size is needed for FlashInfer.
self._decode_wrappers_cudagraph: dict[
int, BatchDecodeWithPagedKVCacheWrapper] = {}
self._decode_cudagraph_max_bs = min(
max_num_reqs, self.compilation_config.max_capture_size)

self._cascade_wrapper = None # Wrapper for cascade attention

# Global hyperparameters shared by all attention layers
self.global_hyperparameters: Optional[PerLayerParameters] = None

self.vllm_config = vllm_config
self.cache_config = vllm_config.cache_config
self.kv_cache_spec = kv_cache_spec
# Preparing persistent buffers
self.paged_kv_indptr = torch.zeros(max_num_reqs + 1,
dtype=torch.int32,
device=self.device)
self.paged_kv_indices = torch.zeros(
max_num_pages, # max num pages possible
dtype=torch.int32,
device=self.device)
self.paged_kv_last_page_len = torch.zeros(max_num_reqs,
dtype=torch.int32,
device=self.device)

def reorder_batch(self, input_batch: InputBatch,
scheduler_output: SchedulerOutput) -> bool:
Expand All @@ -259,20 +288,49 @@ def _get_prefill_wrapper(self):
self._get_workspace_buffer(), get_kv_cache_layout())
return self._prefill_wrapper

def _get_decode_wrapper(self):
if self._decode_wrapper is None:
def _get_decode_wrapper(self,
batch_size: int,
use_cudagraph: bool = False):
if use_cudagraph:
decode_wrapper = self._decode_wrappers_cudagraph.get(
batch_size, None)
else:
decode_wrapper = self._decode_wrapper

if decode_wrapper is None:
num_qo_heads = (
self.vllm_config.model_config.get_num_attention_heads(
self.vllm_config.parallel_config))
num_kv_heads = self.vllm_config.model_config.get_num_kv_heads(
self.vllm_config.parallel_config)
use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or (
num_qo_heads // num_kv_heads > 4)
self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(

if use_cudagraph:
paged_kv_indptr = self.paged_kv_indptr[:batch_size + 1]
paged_kv_indices = self.paged_kv_indices
paged_kv_last_page_len = self.paged_kv_last_page_len[:
batch_size]
else:
paged_kv_indptr = None
paged_kv_indices = None
paged_kv_last_page_len = None
decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self._get_workspace_buffer(),
get_kv_cache_layout(),
use_cuda_graph=use_cudagraph,
paged_kv_indptr_buffer=paged_kv_indptr,
paged_kv_indices_buffer=paged_kv_indices,
paged_kv_last_page_len_buffer=paged_kv_last_page_len,
use_tensor_cores=use_tensor_cores)
return self._decode_wrapper

# save the decode wrapper
if use_cudagraph:
self._decode_wrappers_cudagraph[batch_size] = decode_wrapper
else:
self._decode_wrapper = decode_wrapper

return decode_wrapper

def _get_cascade_wrapper(self):
if self._cascade_wrapper is None:
Expand Down Expand Up @@ -350,16 +408,34 @@ def _plan(self, num_prefills: int, num_decodes: int,
)

if num_decodes > 0:
attn_metadata.decode_wrapper = self._get_decode_wrapper()
pure_decode = num_prefills == 0
# possible required padding for cudagraph replay
use_cudagraph = (self.enable_cuda_graph and pure_decode and \
num_decodes <= self._decode_cudagraph_max_bs)
if use_cudagraph:
num_input_tokens_decode = (
self.vllm_config.pad_for_cudagraph(num_decodes))
else:
num_input_tokens_decode = num_decodes

attn_metadata.decode_wrapper = self._get_decode_wrapper(
num_input_tokens_decode, use_cudagraph)
if not FlashInferBackend.use_trtllm_decode_attention(
num_decodes, attn_metadata.max_seq_len,
self.cache_config.cache_dtype,
attn_metadata.num_qo_heads, attn_metadata.num_kv_heads,
attn_metadata.head_dim):
# TODO: Override flashinfer's plan function to avoid some
# host-to-device copy overhead.
attn_metadata.decode_wrapper.plan(
attn_metadata.paged_kv_indptr[:num_decodes + 1],
attn_metadata.paged_kv_indices,
attn_metadata.paged_kv_last_page_len[:num_decodes],
# NOTE: Use the persistent buffer with padding length,
# instead of the same address but chunked length buffers
# in the atten_metadata. This is to be compatible with
# FlashInfer's decode_wrapper when using cudagraph.
self.paged_kv_indptr[:num_input_tokens_decode + 1],
self.paged_kv_indices if use_cudagraph else \
attn_metadata.paged_kv_indices,
self.paged_kv_last_page_len[:num_input_tokens_decode],
attn_metadata.num_qo_heads,
attn_metadata.num_kv_heads,
attn_metadata.head_dim,
Expand All @@ -378,6 +454,7 @@ def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> FlashInferMetadata:
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
split_decodes_and_prefills(common_attn_metadata)
Expand Down Expand Up @@ -421,17 +498,31 @@ def build(self,
device=block_table_tensor.device).unsqueeze(0)
< block_table_bounds.unsqueeze(1))
paged_kv_indices = block_table_tensor[mask]
num_actual_pages = paged_kv_indices.size(0)
self.paged_kv_indices[:num_actual_pages].copy_(paged_kv_indices,
non_blocking=True)
self.paged_kv_indices[num_actual_pages:].fill_(-1)

paged_kv_indptr = torch.cat([
torch.zeros(1,
dtype=block_table_bounds.dtype,
device=block_table_bounds.device),
block_table_bounds.cumsum(dim=0, dtype=torch.int32)
])
self.paged_kv_indptr[:1 + num_reqs].copy_(paged_kv_indptr,
non_blocking=True)
# make sure self.paged_kv_indptr is not decreasing
self.paged_kv_indptr[1 + num_reqs:].fill_(paged_kv_indptr[-1])

paged_kv_last_page_len = seq_lens % page_size
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
page_size, paged_kv_last_page_len)
self.paged_kv_last_page_len[:num_reqs].copy_(paged_kv_last_page_len,
non_blocking=True)
# Fill the remaining paged_kv_last_page_len with 1. This is because
# flashinfer treats 0 as a full page instead of empty.
self.paged_kv_last_page_len[num_reqs:].fill_(1)

cache_dtype = self.cache_config.cache_dtype
if cache_dtype.startswith("fp8"):
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
Expand All @@ -441,9 +532,9 @@ def build(self,
attn_metadata = FlashInferMetadata(
num_actual_tokens=num_actual_tokens,
qo_indptr=qo_indptr,
paged_kv_indptr=paged_kv_indptr,
paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len=paged_kv_last_page_len,
paged_kv_indptr=self.paged_kv_indptr[:1 + num_reqs],
paged_kv_indices=self.paged_kv_indices[:num_actual_pages],
paged_kv_last_page_len=self.paged_kv_last_page_len[:num_reqs],
num_qo_heads=self.vllm_config.model_config.get_num_attention_heads(
self.vllm_config.parallel_config),
num_kv_heads=self.kv_cache_spec.num_kv_heads,
Expand Down Expand Up @@ -471,6 +562,26 @@ def build(self,

return attn_metadata

def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata):
"""
This method builds the metadata for full cudagraph capture.
Currently, only decode is supported for full cudagraphs with FlashInfer.
"""
m = common_attn_metadata

assert m.num_reqs == m.num_actual_tokens, \
"FlashInfer only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq."

m.max_query_len = 1 # decode-only
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: You shouldn't need to set this. You can add it to your decode_only assert on the previous line.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a common practice now (also see this part for FlashMLA). As the attn_metadate passed from dummy run have max_query_len=num_tokens currently.


return self.build(0, m)

def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
return common_attn_metadata.max_query_len == 1

def use_cascade_attention(self, *args, **kwargs) -> bool:
if self.kv_cache_spec.dtype != self.vllm_config.model_config.dtype:
# TODO: The cascade wrapper currently does not support setting
Expand Down
4 changes: 3 additions & 1 deletion vllm/v1/attention/backends/mla/flashmla.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder)
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import AttentionSpec

logger = init_logger(__name__)
Expand Down Expand Up @@ -54,7 +55,8 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):


class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
full_cudagraph_supported: ClassVar[bool] = True # Decode-only
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.PURE_DECODE_ONLY

def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device):
Expand Down
4 changes: 3 additions & 1 deletion vllm/v1/attention/backends/mla/rocm_aiter_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder)
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import AttentionSpec

# yapf: enable
Expand Down Expand Up @@ -64,7 +65,8 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):


class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
full_cudagraph_supported: ClassVar[bool] = True # decode only
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.PURE_DECODE_ONLY

def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device):
Expand Down
6 changes: 4 additions & 2 deletions vllm/v1/attention/backends/triton_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec

Expand Down Expand Up @@ -57,7 +58,8 @@ class TritonAttentionMetadata:

class TritonAttentionMetadataBuilder(
AttentionMetadataBuilder[TritonAttentionMetadata]):
full_cudagraph_supported: ClassVar[bool] = True
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.ALWAYS

def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device):
Expand Down
18 changes: 17 additions & 1 deletion vllm/v1/attention/backends/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import abc
import enum
import functools
from abc import abstractmethod
from dataclasses import dataclass
Expand Down Expand Up @@ -63,9 +64,24 @@ class CommonAttentionMetadata:
M = TypeVar("M")


class AttentionCGSupport(enum.Enum):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: How do you feel about changing the name to CudaGraphSupportLevel or CudaGraphCompatibility?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is generally a better name if we have other operators that also may not support cudagraph, if that is the case, we can move its place out of attention. I think we can retain the name now, and this part also play as a stepping stone for the upcoming #20059.

""" Constants for the cudagraph support of the attention backend
Here we do not consider the cascade attention, as currently
it is never cudagraph supported."""

NEVER = 0
"""NO cudagraph support"""
PURE_DECODE_ONLY = 1
"""Cudagraph supported for pure decode, need to use piecewise
cudagraph or no cudagraph for mixed prefill-decode batches"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to use piecewise cudagraph or no cudagraph for mixed prefill-decode batches

In the mixed prefill-decode case: when will piecewise cudagraph be used and when will no cuda graph be used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am sorry, this comment is directly from #20059, I didn't realize there are no piecewise cudagraph exists in this PR when enabling full_cuda_graph. That is only allowed after that PR, which introduced a new cudagraph_mode for config that enables decouple cuda graph logic from vllm compilation. In that PR, when cudagraph_mode is FULL, it will run full cudagraph for pure decode if attention support is PURE_DECODE_ONLY, and fall back to piecewise cudagraph for other situations that are incompatible with full cudagraph when vllm compilation is on. However, if vllm compilation is disabled, then just turn to no cudagraph.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's fixed now

ALWAYS = 2
"""Cudagraph always supported"""


class AttentionMetadataBuilder(abc.ABC, Generic[M]):
# Does this backend/builder support CUDA Graphs for attention.
full_cudagraph_supported: ClassVar[bool] = False
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.NEVER

@abstractmethod
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
Expand Down
Loading