Skip to content

[Core] Allow full cudagraph with separate attention routines and orthogonal to compilation, add support for FA2 and FlashInfer #20059

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 47 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
92b1733
FA2 and FlashInfer Full cuda graph support
fhl2000 Jun 25, 2025
58ce477
fix the arch support in CMakeLists.txt to include 8.9
fhl2000 Jun 25, 2025
c2c5fea
Refactors
fhl2000 Jun 25, 2025
1606880
refactors
fhl2000 Jun 25, 2025
806432a
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jun 25, 2025
7c5df45
refactor
fhl2000 Jun 25, 2025
c7a9424
Add check for separate_attention_routine flag
fhl2000 Jun 25, 2025
e8b9296
fix typo error
fhl2000 Jun 26, 2025
94d0b79
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jun 27, 2025
a67c698
refactors and rearchitect cuda graph logic
fhl2000 Jun 28, 2025
da110af
Refactors
fhl2000 Jun 28, 2025
deaf0fe
Delect one commit
fhl2000 Jun 28, 2025
02ca154
Add support for force_no_split_graph
fhl2000 Jun 28, 2025
fa0d25c
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 1, 2025
5108bef
Huge refactors to separete cudagraph logic from vllm compilation
fhl2000 Jul 5, 2025
1c1873d
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 5, 2025
7d4667a
refactors
fhl2000 Jul 5, 2025
fedff47
fix errors
fhl2000 Jul 5, 2025
833ac56
fix small error by lazy import
fhl2000 Jul 5, 2025
d57257d
handle lint-and-deploy errors for cpu execution
fhl2000 Jul 5, 2025
8b7ea7a
remove redundents
fhl2000 Jul 5, 2025
328615d
Clear
fhl2000 Jul 6, 2025
debc682
Big refactors
fhl2000 Jul 9, 2025
cad6c39
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 9, 2025
dc455ee
cleanup
fhl2000 Jul 10, 2025
620a728
fix warmup
fhl2000 Jul 10, 2025
b1e6978
Commit suggestion: Update vllm/config.py
fhl2000 Jul 10, 2025
beee69a
commit suggestion2: Update vllm/config.py
fhl2000 Jul 10, 2025
21b1a8d
fix enforce_eager
fhl2000 Jul 10, 2025
ec79af7
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 10, 2025
210359a
small cleanup for pre-commit
fhl2000 Jul 10, 2025
11263e0
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 11, 2025
9a38a4e
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 12, 2025
699aff3
refactors
fhl2000 Jul 13, 2025
ef3d9d9
resolve yapf conflicts with isort
fhl2000 Jul 13, 2025
658565e
fixes
fhl2000 Jul 13, 2025
15e2b4a
fix global graph pool issue
fhl2000 Jul 13, 2025
4253dbf
fix refactors
fhl2000 Jul 13, 2025
2783e26
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 14, 2025
1b54962
more refactors
fhl2000 Jul 14, 2025
fb2a3c7
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 17, 2025
d6269bd
refactors for and more
fhl2000 Jul 17, 2025
2e1304c
fix pre-commit
fhl2000 Jul 17, 2025
db22ca5
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 18, 2025
72d40e6
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 20, 2025
0c79e53
change cudagraph dispatching logics; runtime style->runtime mode
fhl2000 Jul 21, 2025
75db3a6
pass pre-commit
fhl2000 Jul 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# Keep building Marlin for 9.0 as there are some group sizes and shapes that
# are not supported by Machete yet.
# 9.0 for latest bf16 atomicAdd PTX
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.7;9.0+PTX" "${CUDA_ARCHS}")
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.7;8.9;9.0+PTX" "${CUDA_ARCHS}")
if (MARLIN_ARCHS)

#
Expand Down Expand Up @@ -684,7 +684,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")

list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}")
# 9.0 for latest bf16 atomicAdd PTX
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.7;9.0+PTX" "${CUDA_ARCHS}")
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.7;8.9;9.0+PTX" "${CUDA_ARCHS}")
if (MARLIN_MOE_ARCHS)

#
Expand Down
16 changes: 12 additions & 4 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,10 +563,6 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:

self._called = True

if not self.compilation_config.use_cudagraph or \
not self.compilation_config.cudagraph_copy_inputs:
return self.split_gm

# if we need to copy input buffers for cudagraph
from torch._guards import detect_fake_mode
fake_mode = detect_fake_mode()
Expand All @@ -585,6 +581,18 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
any(is_symbolic(d) for d in x.size())
]

if self.compilation_config.full_cuda_graph:
assert self.compilation_config.use_cudagraph, \
"full_cuda_graph mode requires use_cudagraph to be True"
fullgraph_wrapper = resolve_obj_by_qualname(
current_platform.get_fullgraph_wrapper_cls())
self.split_gm = fullgraph_wrapper(self.split_gm, self.vllm_config,
self.graph_pool, self.sym_tensor_indices)

if not self.compilation_config.use_cudagraph or \
not self.compilation_config.cudagraph_copy_inputs:
return self.split_gm

# compiler managed cudagraph input buffers
# we assume the first run with symbolic shapes
# has the maximum size among all the tensors
Expand Down
43 changes: 43 additions & 0 deletions vllm/compilation/base_piecewise_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,46 @@ def __call__(self, *args) -> Any:
or a replayed static graph.
"""
raise NotImplementedError


class AbstractFullgraphWrapper(Protocol):
"""
FullgraphWrapper interface that allows platforms to wrap the piecewise graph
to be viewed or captured as a full graph.
"""

def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
graph_pool: Any, sym_shape_indices: list[int], **kwargs):
"""
Initializes the FullgraphWrapper class with compilation and
execution-related configurations.
Args:
graph (fx.GraphModule): The graph represented in fx.
vllm_config (VllmConfig): Global configuration for vLLM.
graph_pool (Any):
Graph memory pool handle, e.g.,
`torch.cuda.graph_pool_handle()`.
sym_shape_indices (list[int]):
Indices of symbolic shape.
Keyword Args:
kwargs: Additional keyword arguments reserved for future
extensions or custom platforms.
"""
raise NotImplementedError

def __call__(self, *args) -> Any:
"""
Executes the wrapped graph for given input args.
Args:
*args: Variable length input arguments to be passed into the
graph. The symbolic shape is expected to be in position
`sym_shape_indices[0]`.
Returns:
Any: Output of the executed wrapped graph.
"""
raise NotImplementedError
152 changes: 147 additions & 5 deletions vllm/compilation/cuda_piecewise_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
# during capture, and check if they are the same during replay
input_addresses: Optional[list[int]] = None

usage_type: Optional[str] = None


class CUDAPiecewiseBackend:

Expand Down Expand Up @@ -96,6 +98,7 @@
runtime_shape=shape,
need_to_compile=shape in self.compile_sizes,
use_cudagraph=shape in self.cudagraph_capture_sizes,
usage_type="piecewise(general)", # for logging only
)

def check_for_ending_compilation(self):
Expand Down Expand Up @@ -139,27 +142,32 @@
self.check_for_ending_compilation()

# Skip CUDA graphs if this entry doesn't use them OR
# if we're supposed to skip them globally
skip_cuda_graphs = get_forward_context().skip_cuda_graphs
if not entry.use_cudagraph or skip_cuda_graphs:
# if we're supposed to treat the piecewise graphs as a whole,
# which implies forward_context.skip_attention_cuda_graphs is False.
# In the latter case, we rely on a wrapper class to capture
# the full cudagraph outside the fx graph.
skip_attention_cuda_graphs = get_forward_context().skip_attention_cuda_graphs
if not entry.use_cudagraph or not skip_attention_cuda_graphs:
return entry.runnable(*args)

if entry.cudagraph is None:
if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa
entry.num_finished_warmup += 1
if self.is_first_graph:
logger.debug(
"Warming up %s/%s for shape %s",
"Warming up %s/%s of %s usage for shape %s",
entry.num_finished_warmup,
self.compilation_config.cudagraph_num_of_warmups,
entry.usage_type,
runtime_shape)
return entry.runnable(*args)

if self.is_first_graph:
# Since we capture cudagraph for many different shapes and
# capturing is fast, we don't need to log it for every shape.
# We only log it in the debug mode.
logger.debug("Capturing a cudagraph for shape %s",
logger.debug("Capturing a cudagraph of %s usage for shape %s",
entry.usage_type,
runtime_shape)

input_addresses = [
Expand Down Expand Up @@ -216,3 +224,137 @@

entry.cudagraph.replay()
return entry.output


class FullCudagraphWrapper:
def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
graph_pool: Any, sym_shape_indices: list[int],
):
self.graph = graph
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
self.graph_pool = graph_pool
self.sym_shape_indices = sym_shape_indices

self.separate_attention_routine = vllm_config.compilation_config.separate_attention_routine

self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"

self.first_run_finished = False

Check failure on line 243 in vllm/compilation/cuda_piecewise_backend.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/compilation/cuda_piecewise_backend.py:243:81: E501 Line too long (99 > 80)

self.cudagraph_capture_sizes: set[int] = set(
self.compilation_config.cudagraph_capture_sizes
) if self.compilation_config.use_cudagraph else set()

self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {}
self.concrete_size_entries_decode: dict[int, ConcreteSizeEntry] = {}


for shape in self.cudagraph_capture_sizes:
self.concrete_size_entries[shape] = ConcreteSizeEntry(
runtime_shape=shape,
need_to_compile=False,
use_cudagraph=True,
usage_type="general",
)
if self.separate_attention_routine:
self.concrete_size_entries_decode[shape] = ConcreteSizeEntry(
runtime_shape=shape,
need_to_compile=False,
use_cudagraph=True,
usage_type="decode",
)

def __call__(self, *args) -> Any:
if not self.first_run_finished:
self.first_run_finished = True
return self.graph(*args)
list_args = list(args)
runtime_shape = list_args[self.sym_shape_indices[0]].shape[0]
forward_context = get_forward_context()

if forward_context.skip_attention_cuda_graphs:
# turn back to piecewise cudagraphs backend, which is responsible
# for capturing and running the piecewise cudagraphs.
return self.graph(*args)

# if not skip, the fx graph and its sub-graphs will only be supposed to
# eagerly run the compiled graphs, which should be cudagraph capturable
# as a whole.

concrete_size_entries = self.concrete_size_entries # default as general usage
if self.separate_attention_routine and forward_context.is_pure_decoding:
concrete_size_entries = self.concrete_size_entries_decode

Check failure on line 288 in vllm/compilation/cuda_piecewise_backend.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/compilation/cuda_piecewise_backend.py:288:81: E501 Line too long (86 > 80)
if not runtime_shape in concrete_size_entries:
# we don't need to do anything for this shape.
return self.graph(*args)

entry = concrete_size_entries[runtime_shape]

if entry.runnable is None:
entry.runnable = self.graph

if not entry.use_cudagraph:
return entry.runnable(*args)

if entry.cudagraph is None:
if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa
entry.num_finished_warmup += 1
logger.debug(
"Warming up %s/%s of %s usage for shape %s",
entry.num_finished_warmup,
self.compilation_config.cudagraph_num_of_warmups,
entry.usage_type,
runtime_shape)
return entry.runnable(*args)


# Since we capture cudagraph for many different shapes and
# capturing is fast, we don't need to log it for every shape.
# We only log it in the debug mode.

logger.debug("Capturing a cudagraph of %s usage for shape %s",
entry.usage_type,
runtime_shape)

input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
entry.input_addresses = input_addresses
cudagraph = torch.cuda.CUDAGraph()

Check failure on line 326 in vllm/compilation/cuda_piecewise_backend.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F841)

vllm/compilation/cuda_piecewise_backend.py:326:33: F841 Local variable `stack` is assigned to but never used
with ExitStack() as stack:
# mind-exploding: carefully manage the reference and memory.

Check failure on line 328 in vllm/compilation/cuda_piecewise_backend.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (SIM117)

vllm/compilation/cuda_piecewise_backend.py:326:13: SIM117 Use a single `with` statement with multiple contexts instead of nested `with` statements
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
# `output` is managed by pytorch's cudagraph pool
output = entry.runnable(*args)
# by converting it to weak ref,
# the original `output` will immediately be released
# to save memory.
output = weak_ref_tensors(output)

# here we always use weak ref for the output
# to save memory
entry.output = weak_ref_tensors(output)
entry.cudagraph = cudagraph

compilation_counter.num_cudagraph_captured += 1

# important: we need to return the output, rather than
# the weak ref of the output, so that pytorch can correctly
# manage the memory during cuda graph capture
return output

if self.is_debugging_mode:
# check if the input addresses are the same
new_input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
assert new_input_addresses == entry.input_addresses, (
"Input addresses for cudagraphs are different during replay."
f" Expected {entry.input_addresses}, got {new_input_addresses}"
)

entry.cudagraph.replay()
return entry.output
23 changes: 17 additions & 6 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3974,13 +3974,21 @@
splitting certain operations such as attention into subgraphs. Thus this
flag cannot be used together with splitting_ops. This may provide
performance benefits for smaller models."""
separate_attention_routine: bool = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be named better. Perhaps split_attn_cudagraph? I also don't understand why this has to be a flag and we can't just ask the attention backend what it wants?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we must leave such a flag in the global config, which tells the compiler backend to do the right thing. Otherwise, how is the attention backend supposed to communicate its requirements to the compiler? At least for now, the force_separate_routine flag of an attention backend has the ability to enforce its preference during the initialize_attn_backend phase of the gpu model runner.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be named better. Perhaps split_attn_cudagraph?

I am not sure what name can be better. Btw, I'm afraid split_attn_cudagraph is not a good name. It sounds like splitting the full graph into be piecewise graph, where attn ops are the splitting ops, like what we have already done.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call on the name. Also makes sense we use this to communicate from attention backend to compiler. Let's make sure that happens inside set_splitting_ops_for_v1/somewhere inside config initialization, if we can.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should figure out a different name for this; the current name doesnt indicate any relation to cudagraphs

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not as zoned into this PR as you folks are, but I have no clue what this flag is from the name.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should figure out a different name for this; the current name doesnt indicate any relation to cudagraphs

How about cudagraph_separate_routine? Cutting the "attention" out seems to have no effect on its meaning. While it is basically prepared for distinct attention routines that are actually executed, in the future, that may be more than just attention ops.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just changed to cudagraph_separate_routine. It should be better.

"""
Enable a distinct attention calls routine under an attention backend for full
cuda graph capturing. This is because some attention backends like FlashMLA,
FlashInfer, FA2, etc. implement different branches for mix prefill-decode and
pure decode cases. This flag enables us to potentially capture the cudagraph
separately for each branch.
"""

pass_config: PassConfig = field(default_factory=PassConfig)
"""Custom inductor passes, see PassConfig for more details"""

max_capture_size: int = field(default=None, init=False) # type: ignore

Check failure on line 3989 in vllm/config.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/config.py:3989:81: E501 Line too long (81 > 80)
"""not configurable, computed after init"""
local_cache_dir: str = field(default=None, init=False) # type: ignore

Check failure on line 3991 in vllm/config.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/config.py:3991:81: E501 Line too long (81 > 80)
"""local cache dir for each rank"""
bs_to_padded_graph_size: list[int] = field(
default=None, # type: ignore
Expand Down Expand Up @@ -4172,13 +4180,16 @@

def set_splitting_ops_for_v1(self):
# NOTE: this function needs to be called
if self.splitting_ops and self.full_cuda_graph:
raise ValueError("full_cuda_graph cannot be used together with "
"splitting_ops, as Full CUDA graph will override "
f"the splitting_ops: {self.splitting_ops}")

# NOTE: When full_cuda_graph is True, instead of setting an empty
# list and capture the full cudagraph inside the flattened fx graph,
# we keep the piecewise fx graph structure but capture the full
# cudagraph outside the fx graph. This reduces some cpu overhead when
# the runtime batch_size is not cudagraph captured. This is only
# supported for separate_attention_routine.
if self.separate_attention_routine:
assert self.full_cuda_graph, "separate_attention_routine requires full_cuda_graph to be True"
if not self.splitting_ops:
self.splitting_ops = [] if self.full_cuda_graph else [
self.splitting_ops = [
"vllm.unified_attention",
"vllm.unified_attention_with_output",
]
Expand All @@ -4186,7 +4197,7 @@

@config
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class VllmConfig:

Check failure on line 4200 in vllm/config.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/config.py:4200:81: E501 Line too long (105 > 80)
"""Dataclass which contains all vllm-related configuration. This
simplifies passing around the distinct configurations in the codebase.
"""
Expand Down
12 changes: 9 additions & 3 deletions vllm/forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,11 @@ class ForwardContext:
virtual_engine: int # set dynamically for each forward pass
# set dynamically for each forward pass
dp_metadata: Optional[DPMetadata] = None
skip_cuda_graphs: bool = False
# determine whether to use a full cudagraph for attention or piecewise
# cudagraphs that skip the attention part. By default true, we use piecewise
# cudagraphs.
skip_attention_cuda_graphs: bool = True,
is_pure_decoding: bool = False


_forward_context: Optional[ForwardContext] = None
Expand All @@ -115,7 +119,8 @@ def set_forward_context(
virtual_engine: int = 0,
num_tokens: Optional[int] = None,
num_tokens_across_dp: Optional[torch.Tensor] = None,
skip_cuda_graphs: bool = False,
skip_attention_cuda_graphs: bool = True,
is_pure_decoding: bool = False,
):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
Expand All @@ -140,7 +145,8 @@ def set_forward_context(
virtual_engine=virtual_engine,
attn_metadata=attn_metadata,
dp_metadata=dp_metadata,
skip_cuda_graphs=skip_cuda_graphs,
skip_attention_cuda_graphs=skip_attention_cuda_graphs,
is_pure_decoding=is_pure_decoding,
)

try:
Expand Down
4 changes: 4 additions & 0 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,10 @@ def use_custom_allreduce(cls) -> bool:
@classmethod
def get_piecewise_backend_cls(cls) -> str:
return "vllm.compilation.cuda_piecewise_backend.CUDAPiecewiseBackend" # noqa

@classmethod
def get_fullgraph_wrapper_cls(cls) -> str:
return "vllm.compilation.cuda_piecewise_backend.FullCudagraphWrapper" # noqa

@classmethod
def stateless_init_device_torch_dist_pg(
Expand Down
8 changes: 8 additions & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,14 @@ def get_piecewise_backend_cls(cls) -> str:
Get piecewise backend class for piecewise graph.
"""
return "vllm.compilation.base_piecewise_backend.AbstractPiecewiseBackend" # noqa

@classmethod
def get_fullgraph_wrapper_cls(cls) -> str:
"""
Get fullgraph wrapper class for fullgraph static graph.
"""
return "vllm.compilation.base_piecewise_backend.AbstractFullgraphWrapper" # noqa


@classmethod
def stateless_init_device_torch_dist_pg(
Expand Down
Loading
Loading