Skip to content

[V1][CUDA] Full cudagraph support for FlashInfer #21367

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.utils import cdiv
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
get_kv_cache_layout)
from vllm.v1.kv_cache_interface import AttentionSpec
Expand Down Expand Up @@ -144,7 +145,9 @@ def _get_sliding_window_configs(

class FlashAttentionMetadataBuilder(
AttentionMetadataBuilder[FlashAttentionMetadata]):
full_cudagraph_supported: ClassVar[bool] = get_flash_attn_version() == 3
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.NEVER if get_flash_attn_version() == 2 \
else AttentionCGSupport.ALWAYS

def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device):
Expand Down
141 changes: 124 additions & 17 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, ClassVar, Optional

import torch
from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
Expand All @@ -21,8 +21,8 @@
from vllm.utils import cdiv
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata, PerLayerParameters,
get_kv_cache_layout, get_per_layer_parameters,
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
PerLayerParameters, get_kv_cache_layout, get_per_layer_parameters,
infer_global_hyperparameters, reorder_batch_to_split_decodes_and_prefills,
split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec
Expand Down Expand Up @@ -219,25 +219,51 @@


class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.PURE_DECODE_ONLY

def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device):
self.device = device
self.vllm_config = vllm_config
self.cache_config = vllm_config.cache_config
self.kv_cache_spec = kv_cache_spec
self._workspace_buffer = None
self._prefill_wrapper = None # Wrapper for prefill/append
self._decode_wrapper = None # Wrapper for decode
self._decode_wrapper = None # Wrapper for decode (general shape)

self.compilation_config = vllm_config.compilation_config
max_num_pages_per_req = cdiv(vllm_config.model_config.max_model_len,
self.kv_cache_spec.block_size)
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
max_num_pages = max_num_reqs * max_num_pages_per_req
self.enable_cuda_graph = self.compilation_config.full_cuda_graph
if self.enable_cuda_graph:
# For full cudagraph capture, one `decode_wrapper` for each batch
# size is needed for FlashInfer.
self._decode_wrappers_cudagraph: dict[
int, BatchDecodeWithPagedKVCacheWrapper] = {}
self._decode_cudagraph_max_bs = min(
max_num_reqs, self.compilation_config.max_capture_size)

self._cascade_wrapper = None # Wrapper for cascade attention

# Global hyperparameters shared by all attention layers
self.global_hyperparameters: Optional[PerLayerParameters] = None

self.vllm_config = vllm_config
self.cache_config = vllm_config.cache_config
self.kv_cache_spec = kv_cache_spec
max_num_blocks_per_request = cdiv(
vllm_config.model_config.max_model_len,
self.kv_cache_spec.block_size)
self.block_table_arange = torch.arange(max_num_blocks_per_request,
# Preparing persistent buffers
self.paged_kv_indptr = torch.zeros(max_num_reqs + 1,
dtype=torch.int32,
device=self.device)
self.paged_kv_indices = torch.zeros(
max_num_pages, # max num pages possible
dtype=torch.int32,
device=self.device)
self.paged_kv_last_page_len = torch.zeros(max_num_reqs,
dtype=torch.int32,
device=self.device)

self.block_table_arange = torch.arange(max_num_pages_per_req,
dtype=torch.int32,
device=self.device)

Expand All @@ -261,20 +287,49 @@
self._get_workspace_buffer(), get_kv_cache_layout())
return self._prefill_wrapper

def _get_decode_wrapper(self):
if self._decode_wrapper is None:
def _get_decode_wrapper(self,
batch_size: int,
use_cudagraph: bool = False):
if use_cudagraph:
decode_wrapper = self._decode_wrappers_cudagraph.get(
batch_size, None)
else:
decode_wrapper = self._decode_wrapper

if decode_wrapper is None:
num_qo_heads = (
self.vllm_config.model_config.get_num_attention_heads(
self.vllm_config.parallel_config))
num_kv_heads = self.vllm_config.model_config.get_num_kv_heads(
self.vllm_config.parallel_config)
use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or (
num_qo_heads // num_kv_heads > 4)
self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(

if use_cudagraph:
paged_kv_indptr = self.paged_kv_indptr[:batch_size + 1]
paged_kv_indices = self.paged_kv_indices
paged_kv_last_page_len = self.paged_kv_last_page_len[:
batch_size]
else:
paged_kv_indptr = None
paged_kv_indices = None
paged_kv_last_page_len = None
decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self._get_workspace_buffer(),
get_kv_cache_layout(),
use_cuda_graph=use_cudagraph,
paged_kv_indptr_buffer=paged_kv_indptr,
paged_kv_indices_buffer=paged_kv_indices,
paged_kv_last_page_len_buffer=paged_kv_last_page_len,
use_tensor_cores=use_tensor_cores)
return self._decode_wrapper

# save the decode wrapper
if use_cudagraph:
self._decode_wrappers_cudagraph[batch_size] = decode_wrapper
else:
self._decode_wrapper = decode_wrapper

return decode_wrapper

def _get_cascade_wrapper(self):
if self._cascade_wrapper is None:
Expand Down Expand Up @@ -356,13 +411,24 @@
)

if num_decodes > 0:
attn_metadata.decode_wrapper = self._get_decode_wrapper()
pure_decode = num_prefills == 0
# possible required padding for cudagraph replay
use_cudagraph = (self.enable_cuda_graph and pure_decode and
num_decodes <= self._decode_cudagraph_max_bs)
if use_cudagraph:
num_input_tokens_decode = (
self.vllm_config.pad_for_cudagraph(num_decodes))
else:
num_input_tokens_decode = num_decodes

attn_metadata.decode_wrapper = self._get_decode_wrapper(
num_input_tokens_decode, use_cudagraph)
if not FlashInferBackend.use_trtllm_decode_attention(
num_decodes, attn_metadata.max_seq_len,
self.cache_config.cache_dtype,
attn_metadata.num_qo_heads, attn_metadata.num_kv_heads,
attn_metadata.head_dim):
attn_metadata.decode_wrapper.plan(
decode_fast_plan(attn_metadata.decode_wrapper,

Check failure on line 431 in vllm/v1/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Unexpected keyword argument "window_left" for "decode_fast_plan" [call-arg]

Check failure on line 431 in vllm/v1/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Unexpected keyword argument "sm_scale" for "decode_fast_plan" [call-arg]

Check failure on line 431 in vllm/v1/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Unexpected keyword argument "pos_encoding_mode" for "decode_fast_plan" [call-arg]

Check failure on line 431 in vllm/v1/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Too many arguments for "decode_fast_plan" [call-arg]
attn_metadata.paged_kv_indptr_cpu[:num_decodes + 1],
attn_metadata.paged_kv_indices,
attn_metadata.paged_kv_last_page_len_cpu[:num_decodes],
Expand All @@ -384,7 +450,8 @@
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> FlashInferMetadata:
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens

Check failure on line 454 in vllm/v1/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F841)

vllm/v1/attention/backends/flashinfer.py:454:9: F841 Local variable `num_reqs` is assigned to but never used
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
split_decodes_and_prefills(common_attn_metadata)

Expand Down Expand Up @@ -430,6 +497,22 @@
mask = (self.block_table_arange[:max_num_blocks].unsqueeze(0)
< block_table_bounds.unsqueeze(1))
paged_kv_indices = block_table_tensor[:, :max_num_blocks][mask]
# num_actual_pages = paged_kv_indices.size(0)
# self.paged_kv_indices[:num_actual_pages].copy_(paged_kv_indices,
# non_blocking=True)

Check failure on line 502 in vllm/v1/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/v1/attention/backends/flashinfer.py:502:81: E501 Line too long (82 > 80)
# self.paged_kv_indices[num_actual_pages:].fill_(-1)

Check failure on line 503 in vllm/v1/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/v1/attention/backends/flashinfer.py:503:81: E501 Line too long (83 > 80)

# self.paged_kv_indptr[:1 + num_reqs].copy_(paged_kv_indptr,
# non_blocking=True)
# # make sure self.paged_kv_indptr is not decreasing
# self.paged_kv_indptr[1 + num_reqs:].fill_(paged_kv_indptr[-1])

# self.paged_kv_last_page_len[:num_reqs].copy_(paged_kv_last_page_len,
# non_blocking=True)

Check failure on line 511 in vllm/v1/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/v1/attention/backends/flashinfer.py:511:81: E501 Line too long (86 > 80)
# # Fill the remaining paged_kv_last_page_len with 1. This is because

Check failure on line 512 in vllm/v1/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/v1/attention/backends/flashinfer.py:512:81: E501 Line too long (81 > 80)
# # flashinfer treats 0 as a full page instead of empty.

Check failure on line 513 in vllm/v1/attention/backends/flashinfer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/v1/attention/backends/flashinfer.py:513:81: E501 Line too long (85 > 80)
# self.paged_kv_last_page_len[num_reqs:].fill_(1)


paged_kv_indptr_cpu = torch.zeros(len(block_table_bounds_cpu) + 1,
dtype=torch.int32,
Expand Down Expand Up @@ -480,6 +563,26 @@

return attn_metadata

def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata):
"""
This method builds the metadata for full cudagraph capture.
Currently, only decode is supported for full cudagraphs with FlashInfer.
"""
m = common_attn_metadata

assert m.num_reqs == m.num_actual_tokens, \
"FlashInfer only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq."

m.max_query_len = 1 # decode-only
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: You shouldn't need to set this. You can add it to your decode_only assert on the previous line.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a common practice now (also see this part for FlashMLA). As the attn_metadate passed from dummy run have max_query_len=num_tokens currently.


return self.build(0, m)

def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
return common_attn_metadata.max_query_len == 1

def use_cascade_attention(self, *args, **kwargs) -> bool:
if self.kv_cache_spec.dtype != self.vllm_config.model_config.dtype:
# TODO: The cascade wrapper currently does not support setting
Expand Down Expand Up @@ -690,3 +793,7 @@
v_scale=layer._v_scale_float,
))
return output_padded

# TODO:
def decode_fast_plan():
pass
4 changes: 3 additions & 1 deletion vllm/v1/attention/backends/mla/flashmla.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder)
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import AttentionSpec

logger = init_logger(__name__)
Expand Down Expand Up @@ -54,7 +55,8 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):


class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
full_cudagraph_supported: ClassVar[bool] = True # Decode-only
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.PURE_DECODE_ONLY

def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device):
Expand Down
4 changes: 3 additions & 1 deletion vllm/v1/attention/backends/mla/rocm_aiter_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder)
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import AttentionSpec

# yapf: enable
Expand Down Expand Up @@ -64,7 +65,8 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):


class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
full_cudagraph_supported: ClassVar[bool] = True # decode only
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.PURE_DECODE_ONLY

def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device):
Expand Down
6 changes: 4 additions & 2 deletions vllm/v1/attention/backends/triton_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec

Expand Down Expand Up @@ -57,7 +58,8 @@ class TritonAttentionMetadata:

class TritonAttentionMetadataBuilder(
AttentionMetadataBuilder[TritonAttentionMetadata]):
full_cudagraph_supported: ClassVar[bool] = True
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.ALWAYS

def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device):
Expand Down
18 changes: 17 additions & 1 deletion vllm/v1/attention/backends/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import abc
import enum
import functools
from abc import abstractmethod
from dataclasses import dataclass
Expand Down Expand Up @@ -63,9 +64,24 @@ class CommonAttentionMetadata:
M = TypeVar("M")


class AttentionCGSupport(enum.Enum):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: How do you feel about changing the name to CudaGraphSupportLevel or CudaGraphCompatibility?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is generally a better name if we have other operators that also may not support cudagraph, if that is the case, we can move its place out of attention. I think we can retain the name now, and this part also play as a stepping stone for the upcoming #20059.

""" Constants for the cudagraph support of the attention backend
Here we do not consider the cascade attention, as currently
it is never cudagraph supported."""

NEVER = 0
"""NO cudagraph support"""
PURE_DECODE_ONLY = 1
"""Cudagraph supported for pure decode, need to run without
cudagraph for mixed prefill-decode batches"""
ALWAYS = 2
"""Cudagraph always supported"""


class AttentionMetadataBuilder(abc.ABC, Generic[M]):
# Does this backend/builder support CUDA Graphs for attention.
full_cudagraph_supported: ClassVar[bool] = False
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.NEVER

@abstractmethod
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
Expand Down
22 changes: 15 additions & 7 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
is_pin_memory_available, round_up)
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata,
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
make_local_attention_virtual_batches)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (AttentionSpec,
Expand Down Expand Up @@ -2527,12 +2527,20 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
self.device,
)

if (self.full_cuda_graph
and not attn_metadata_builder_i.full_cudagraph_supported):
raise ValueError(
f"Full CUDAGraph not supported for "
f"{attn_backend_i.__name__}. Turn off CompilationConfig."
f"full_cuda_graph or use a different attention backend.")
if self.full_cuda_graph:
if attn_metadata_builder_i.attn_cudagraph_support == \
AttentionCGSupport.NEVER:
raise ValueError(
f"Full CUDAGraph not supported for "
f"{attn_backend_i.__name__}. Turn off "
f"CompilationConfig.full_cuda_graph or use a "
f" different attention backend.")
if attn_metadata_builder_i.attn_cudagraph_support == \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Could you add a comment here explaining what's going on and why it's necessary? Something like # Limit the max cudagraph size to the max number of sequences.

AttentionCGSupport.PURE_DECODE_ONLY:
self.cudagraph_batch_sizes = [
size for size in self.cudagraph_batch_sizes
if size <= self.scheduler_config.max_num_seqs
]

self.attn_backends.append(attn_backend_i)
self.attn_metadata_builders.append(attn_metadata_builder_i)
Expand Down
5 changes: 5 additions & 0 deletions vllm/v1/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,11 +300,16 @@ def compile_or_warm_up_model(self) -> None:
if get_pp_group().is_last_rank:
max_num_reqs = min(self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens)
# activate building attn_metadata for this dummy run to avoid
# potential illegal memory access for full cudagraph relay.
attn_cudagraph = self.compilation_config.full_cuda_graph and\
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I understand why you need this. AIUI, this code is specifically warming up shapes that are not in the cudagraph capture list? Is this required because you modified the list in the GPUModelRunner?

I see there's some discussion about a hang when you don't pass an attention metadata into the dummy_run?

Copy link
Contributor Author

@fhl2000 fhl2000 Jul 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! @SageMoore, Thank you for the questions!

Is this required because you modified the list in the GPUModelRunner?

I think they are not related.

I'd like to try explaining more here. This line of code is actually located after capturing all shapes of cudagraphs for the modified list in gpu_model_runner. This dummy_run with num_tokens= max_num_reqs is actually <= the max captured size of that modified list. And recall that dummy_run for attention_cg_support=PURE_DECODE_ONLY would only try to run pure decode batches. So here it would only run into cudagraph replay of decode only if it hits the size of list, otherwise no cudagraph. However, when it hits the replay, FlashInfer may be trapped in an infinite loop if the content in the persistent buffers is incorrect.

not self.model_config.enforce_eager

# We skip EPLB here since we don't want to record dummy metrics
hidden_states, last_hidden_states = \
self.model_runner._dummy_run(
num_tokens=max_num_reqs,
capture_attn_cudagraph=attn_cudagraph,
skip_eplb=True,
)
if self.model_runner.is_pooling_model:
Expand Down