Skip to content

[V1][CUDA] Full cudagraph support for FlashInfer #21367

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.utils import cdiv
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
get_kv_cache_layout)
from vllm.v1.kv_cache_interface import AttentionSpec
Expand Down Expand Up @@ -144,7 +145,9 @@ def _get_sliding_window_configs(

class FlashAttentionMetadataBuilder(
AttentionMetadataBuilder[FlashAttentionMetadata]):
full_cudagraph_supported: ClassVar[bool] = get_flash_attn_version() == 3
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.NEVER if get_flash_attn_version() == 2 \
else AttentionCGSupport.ALWAYS

def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device):
Expand Down
147 changes: 129 additions & 18 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, ClassVar, Optional

import torch
from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
Expand All @@ -18,10 +18,11 @@
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import cdiv
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata, PerLayerParameters,
get_kv_cache_layout, get_per_layer_parameters,
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
PerLayerParameters, get_kv_cache_layout, get_per_layer_parameters,
infer_global_hyperparameters, reorder_batch_to_split_decodes_and_prefills,
split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec
Expand Down Expand Up @@ -223,21 +224,49 @@ def __post_init__(self):


class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.PURE_DECODE_ONLY

def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device):
self.device = device
self.vllm_config = vllm_config
self.cache_config = vllm_config.cache_config
self.kv_cache_spec = kv_cache_spec
self._workspace_buffer = None
self._prefill_wrapper = None # Wrapper for prefill/append
self._decode_wrapper = None # Wrapper for decode
self._decode_wrapper = None # Wrapper for decode (general shape)

self.compilation_config = vllm_config.compilation_config
max_num_pages_per_req = cdiv(vllm_config.model_config.max_model_len,
self.kv_cache_spec.block_size)
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
max_num_pages = max_num_reqs * max_num_pages_per_req
self.enable_cuda_graph = self.compilation_config.full_cuda_graph
if self.enable_cuda_graph:
# For full cudagraph capture, one `decode_wrapper` for each batch
# size is needed for FlashInfer.
self._decode_wrappers_cudagraph: dict[
int, BatchDecodeWithPagedKVCacheWrapper] = {}
self._decode_cudagraph_max_bs = min(
max_num_reqs, self.compilation_config.max_capture_size)

self._cascade_wrapper = None # Wrapper for cascade attention

# Global hyperparameters shared by all attention layers
self.global_hyperparameters: Optional[PerLayerParameters] = None

self.vllm_config = vllm_config
self.cache_config = vllm_config.cache_config
self.kv_cache_spec = kv_cache_spec
# Preparing persistent buffers
self.paged_kv_indptr = torch.zeros(max_num_reqs + 1,
dtype=torch.int32,
device=self.device)
self.paged_kv_indices = torch.zeros(
max_num_pages, # max num pages possible
dtype=torch.int32,
device=self.device)
self.paged_kv_last_page_len = torch.zeros(max_num_reqs,
dtype=torch.int32,
device=self.device)

def reorder_batch(self, input_batch: InputBatch,
scheduler_output: SchedulerOutput) -> bool:
Expand All @@ -259,20 +288,49 @@ def _get_prefill_wrapper(self):
self._get_workspace_buffer(), get_kv_cache_layout())
return self._prefill_wrapper

def _get_decode_wrapper(self):
if self._decode_wrapper is None:
def _get_decode_wrapper(self,
batch_size: int,
use_cudagraph: bool = False):
if use_cudagraph:
decode_wrapper = self._decode_wrappers_cudagraph.get(
batch_size, None)
else:
decode_wrapper = self._decode_wrapper

if decode_wrapper is None:
num_qo_heads = (
self.vllm_config.model_config.get_num_attention_heads(
self.vllm_config.parallel_config))
num_kv_heads = self.vllm_config.model_config.get_num_kv_heads(
self.vllm_config.parallel_config)
use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or (
num_qo_heads // num_kv_heads > 4)
self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(

if use_cudagraph:
paged_kv_indptr = self.paged_kv_indptr[:batch_size + 1]
paged_kv_indices = self.paged_kv_indices
paged_kv_last_page_len = self.paged_kv_last_page_len[:
batch_size]
else:
paged_kv_indptr = None
paged_kv_indices = None
paged_kv_last_page_len = None
decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self._get_workspace_buffer(),
get_kv_cache_layout(),
use_cuda_graph=use_cudagraph,
paged_kv_indptr_buffer=paged_kv_indptr,
paged_kv_indices_buffer=paged_kv_indices,
paged_kv_last_page_len_buffer=paged_kv_last_page_len,
use_tensor_cores=use_tensor_cores)
return self._decode_wrapper

# save the decode wrapper
if use_cudagraph:
self._decode_wrappers_cudagraph[batch_size] = decode_wrapper
else:
self._decode_wrapper = decode_wrapper

return decode_wrapper

def _get_cascade_wrapper(self):
if self._cascade_wrapper is None:
Expand Down Expand Up @@ -350,16 +408,34 @@ def _plan(self, num_prefills: int, num_decodes: int,
)

if num_decodes > 0:
attn_metadata.decode_wrapper = self._get_decode_wrapper()
pure_decode = num_prefills == 0
# possible required padding for cudagraph replay
use_cudagraph = (self.enable_cuda_graph and pure_decode and \
num_decodes <= self._decode_cudagraph_max_bs)
if use_cudagraph:
num_input_tokens_decode = (
self.vllm_config.pad_for_cudagraph(num_decodes))
else:
num_input_tokens_decode = num_decodes

attn_metadata.decode_wrapper = self._get_decode_wrapper(
num_input_tokens_decode, use_cudagraph)
if not FlashInferBackend.use_trtllm_decode_attention(
num_decodes, attn_metadata.max_seq_len,
self.cache_config.cache_dtype,
attn_metadata.num_qo_heads, attn_metadata.num_kv_heads,
attn_metadata.head_dim):
# TODO: Override flashinfer's plan function to avoid some
# host-to-device copy overhead.
attn_metadata.decode_wrapper.plan(
attn_metadata.paged_kv_indptr[:num_decodes + 1],
attn_metadata.paged_kv_indices,
attn_metadata.paged_kv_last_page_len[:num_decodes],
# NOTE: Use the persistent buffer with padding length,
# instead of the same address but chunked length buffers
# in the atten_metadata. This is to be compatible with
# FlashInfer's decode_wrapper when using cudagraph.
self.paged_kv_indptr[:num_input_tokens_decode + 1],
self.paged_kv_indices if use_cudagraph else \
attn_metadata.paged_kv_indices,
self.paged_kv_last_page_len[:num_input_tokens_decode],
attn_metadata.num_qo_heads,
attn_metadata.num_kv_heads,
attn_metadata.head_dim,
Expand All @@ -378,6 +454,7 @@ def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> FlashInferMetadata:
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
split_decodes_and_prefills(common_attn_metadata)
Expand Down Expand Up @@ -421,17 +498,31 @@ def build(self,
device=block_table_tensor.device).unsqueeze(0)
< block_table_bounds.unsqueeze(1))
paged_kv_indices = block_table_tensor[mask]
num_actual_pages = paged_kv_indices.size(0)
self.paged_kv_indices[:num_actual_pages].copy_(paged_kv_indices,
non_blocking=True)
self.paged_kv_indices[num_actual_pages:].fill_(-1)

paged_kv_indptr = torch.cat([
torch.zeros(1,
dtype=block_table_bounds.dtype,
device=block_table_bounds.device),
block_table_bounds.cumsum(dim=0, dtype=torch.int32)
])
self.paged_kv_indptr[:1 + num_reqs].copy_(paged_kv_indptr,
non_blocking=True)
# make sure self.paged_kv_indptr is not decreasing
self.paged_kv_indptr[1 + num_reqs:].fill_(paged_kv_indptr[-1])

paged_kv_last_page_len = seq_lens % page_size
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
page_size, paged_kv_last_page_len)
self.paged_kv_last_page_len[:num_reqs].copy_(paged_kv_last_page_len,
non_blocking=True)
# Fill the remaining paged_kv_last_page_len with 1. This is because
# flashinfer treats 0 as a full page instead of empty.
self.paged_kv_last_page_len[num_reqs:].fill_(1)

cache_dtype = self.cache_config.cache_dtype
if cache_dtype.startswith("fp8"):
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
Expand All @@ -441,9 +532,9 @@ def build(self,
attn_metadata = FlashInferMetadata(
num_actual_tokens=num_actual_tokens,
qo_indptr=qo_indptr,
paged_kv_indptr=paged_kv_indptr,
paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len=paged_kv_last_page_len,
paged_kv_indptr=self.paged_kv_indptr[:1 + num_reqs],
paged_kv_indices=self.paged_kv_indices[:num_actual_pages],
paged_kv_last_page_len=self.paged_kv_last_page_len[:num_reqs],
num_qo_heads=self.vllm_config.model_config.get_num_attention_heads(
self.vllm_config.parallel_config),
num_kv_heads=self.kv_cache_spec.num_kv_heads,
Expand Down Expand Up @@ -471,6 +562,26 @@ def build(self,

return attn_metadata

def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata):
"""
This method builds the metadata for full cudagraph capture.
Currently, only decode is supported for full cudagraphs with FlashInfer.
"""
m = common_attn_metadata

assert m.num_reqs == m.num_actual_tokens, \
"FlashInfer only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq."

m.max_query_len = 1 # decode-only
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: You shouldn't need to set this. You can add it to your decode_only assert on the previous line.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a common practice now (also see this part for FlashMLA). As the attn_metadate passed from dummy run have max_query_len=num_tokens currently.


return self.build(0, m)

def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
return common_attn_metadata.max_query_len == 1

def use_cascade_attention(self, *args, **kwargs) -> bool:
if self.kv_cache_spec.dtype != self.vllm_config.model_config.dtype:
# TODO: The cascade wrapper currently does not support setting
Expand Down
4 changes: 3 additions & 1 deletion vllm/v1/attention/backends/mla/flashmla.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder)
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import AttentionSpec

logger = init_logger(__name__)
Expand Down Expand Up @@ -54,7 +55,8 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):


class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
full_cudagraph_supported: ClassVar[bool] = True # Decode-only
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.PURE_DECODE_ONLY

def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device):
Expand Down
4 changes: 3 additions & 1 deletion vllm/v1/attention/backends/mla/rocm_aiter_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder)
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import AttentionSpec

# yapf: enable
Expand Down Expand Up @@ -64,7 +65,8 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):


class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
full_cudagraph_supported: ClassVar[bool] = True # decode only
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.PURE_DECODE_ONLY

def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device):
Expand Down
6 changes: 4 additions & 2 deletions vllm/v1/attention/backends/triton_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec

Expand Down Expand Up @@ -57,7 +58,8 @@ class TritonAttentionMetadata:

class TritonAttentionMetadataBuilder(
AttentionMetadataBuilder[TritonAttentionMetadata]):
full_cudagraph_supported: ClassVar[bool] = True
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.ALWAYS

def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device):
Expand Down
16 changes: 15 additions & 1 deletion vllm/v1/attention/backends/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import abc
import enum
import functools
from abc import abstractmethod
from dataclasses import dataclass
Expand Down Expand Up @@ -62,10 +63,23 @@ class CommonAttentionMetadata:

M = TypeVar("M")

class AttentionCGSupport(enum.Enum):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: How do you feel about changing the name to CudaGraphSupportLevel or CudaGraphCompatibility?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is generally a better name if we have other operators that also may not support cudagraph, if that is the case, we can move its place out of attention. I think we can retain the name now, and this part also play as a stepping stone for the upcoming #20059.

""" Constants for the cudagraph support of the attention backend
Here we do not consider the cascade attention, as currently
it is never cudagraph supported."""

NEVER = 0
"""NO cudagraph support"""
PURE_DECODE_ONLY = 1
"""Cudagraph supported for pure decode, need to use piecewise
cudagraph or no cudagraph for mixed prefill-decode batches"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to use piecewise cudagraph or no cudagraph for mixed prefill-decode batches

In the mixed prefill-decode case: when will piecewise cudagraph be used and when will no cuda graph be used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am sorry, this comment is directly from #20059, I didn't realize there are no piecewise cudagraph exists in this PR when enabling full_cuda_graph. That is only allowed after that PR, which introduced a new cudagraph_mode for config that enables decouple cuda graph logic from vllm compilation. In that PR, when cudagraph_mode is FULL, it will run full cudagraph for pure decode if attention support is PURE_DECODE_ONLY, and fall back to piecewise cudagraph for other situations that are incompatible with full cudagraph when vllm compilation is on. However, if vllm compilation is disabled, then just turn to no cudagraph.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's fixed now

ALWAYS = 2
"""Cudagraph always supported"""

class AttentionMetadataBuilder(abc.ABC, Generic[M]):
# Does this backend/builder support CUDA Graphs for attention.
full_cudagraph_supported: ClassVar[bool] = False
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.NEVER

@abstractmethod
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
Expand Down
14 changes: 12 additions & 2 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
is_pin_memory_available, round_up)
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata,
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
make_local_attention_virtual_batches)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (AttentionSpec,
Expand Down Expand Up @@ -2369,6 +2369,15 @@
# can reuse the memory pool allocated for the large shapes.
with graph_capture(device=self.device):
full_cg = self.full_cuda_graph
# for full cg on pure decode only, do not capture size lager than
# max_num_seqs
if full_cg and self.attn_metadata_builders[0].attn_cudagraph_support\

Check failure on line 2374 in vllm/v1/worker/gpu_model_runner.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/v1/worker/gpu_model_runner.py:2374:81: E501 Line too long (81 > 80)
== AttentionCGSupport.PURE_DECODE_ONLY:
max_num_seqs = self.scheduler_config.max_num_seqs
self.cudagraph_batch_sizes = [
size for size in self.cudagraph_batch_sizes
if size <= max_num_seqs]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This change correctly filters self.cudagraph_batch_sizes to prevent capturing graphs for sizes larger than max_num_seqs for PURE_DECODE_ONLY backends. However, the pad_for_cudagraph method, which is used at runtime to determine the padded graph size, relies on a mapping (bs_to_padded_graph_size) that was initialized with the original, unfiltered cudagraph_batch_sizes.

This discrepancy can lead to a KeyError at runtime. For example, if a batch with num_decodes is processed, pad_for_cudagraph might return a padded size that was filtered out and for which no CUDA graph was captured. This will cause a lookup failure in _decode_wrappers_cudagraph.

To fix this, you should re-initialize the padding map after filtering self.cudagraph_batch_sizes.

Suggested change
self.cudagraph_batch_sizes = [
size for size in self.cudagraph_batch_sizes
if size <= max_num_seqs]
self.cudagraph_batch_sizes = [
size for size in self.cudagraph_batch_sizes
if size <= max_num_seqs
]
self.vllm_config.compilation_config.init_with_cudagraph_sizes(
self.cudagraph_batch_sizes)

Copy link
Contributor Author

@fhl2000 fhl2000 Jul 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch for pad_for_cudagraph , though I think it would not affect the final correctness.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find that the following hangs:

 VLLM_ATTENTION_BACKEND=FLASHINFER  vllm serve models/Llama-3.1-8B-Instruct --no-enable-prefix-caching --compilation-config='{"full_cuda_graph": true}' --max-num-seqs 2

at this point:

Capturing CUDA graph shapes: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.42it/s]
INFO 07-22 10:32:54 [gpu_model_runner.py:2404] Graph capturing finished in 1 secs, took 0.42 GiB

It just hangs here - could this be related?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if I remove the --max-num-seqs then it works fine, so I think it is indeed related.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find overriding gpu model runner's cudagraph_batch_sizes would be enough. And vllm_config.compilation_config.init_with_cudagraph_sizes method does not actually override cudagraph_batch_sizes of compilation config after its first call.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm unfortunately it doesn't happen on main + FlashAttention. The hang is 100% reproducible using the code from this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will try my best to figure it out.

Copy link
Contributor Author

@fhl2000 fhl2000 Jul 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have tested --max-num-seqs being one of [2,4,8,16,24,32, 40] leads to hangs, while [1,48,56,...] work normally. The stuck occurs in a final dummy_run after all capturing in gpu_worker.py around lines 285~292, which runs into cudagraph replay (nums_tokens = max_num_seqs) without creating attn_metadata. I guess something weird happened in Flashinfer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tdoublep Take the new fix! It should be fine now. Could you please also test if it works for you?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, seems to work now. Thanks!


# Only rank 0 should print progress bar during capture
compilation_cases = reversed(self.cudagraph_batch_sizes)
if is_global_first_rank():
Expand Down Expand Up @@ -2438,7 +2447,8 @@
)

if (self.full_cuda_graph
and not attn_metadata_builder_i.full_cudagraph_supported):
and attn_metadata_builder_i.attn_cudagraph_support == \
AttentionCGSupport.NEVER):
raise ValueError(
f"Full CUDAGraph not supported for "
f"{attn_backend_i.__name__}. Turn off CompilationConfig."
Expand Down
Loading