Skip to content

[Core] Allow full cudagraph with separate attention routines and orthogonal to compilation, add support for FA2 and FlashInfer #20059

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 43 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
92b1733
FA2 and FlashInfer Full cuda graph support
fhl2000 Jun 25, 2025
58ce477
fix the arch support in CMakeLists.txt to include 8.9
fhl2000 Jun 25, 2025
c2c5fea
Refactors
fhl2000 Jun 25, 2025
1606880
refactors
fhl2000 Jun 25, 2025
806432a
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jun 25, 2025
7c5df45
refactor
fhl2000 Jun 25, 2025
c7a9424
Add check for separate_attention_routine flag
fhl2000 Jun 25, 2025
e8b9296
fix typo error
fhl2000 Jun 26, 2025
94d0b79
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jun 27, 2025
a67c698
refactors and rearchitect cuda graph logic
fhl2000 Jun 28, 2025
da110af
Refactors
fhl2000 Jun 28, 2025
deaf0fe
Delect one commit
fhl2000 Jun 28, 2025
02ca154
Add support for force_no_split_graph
fhl2000 Jun 28, 2025
fa0d25c
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 1, 2025
5108bef
Huge refactors to separete cudagraph logic from vllm compilation
fhl2000 Jul 5, 2025
1c1873d
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 5, 2025
7d4667a
refactors
fhl2000 Jul 5, 2025
fedff47
fix errors
fhl2000 Jul 5, 2025
833ac56
fix small error by lazy import
fhl2000 Jul 5, 2025
d57257d
handle lint-and-deploy errors for cpu execution
fhl2000 Jul 5, 2025
8b7ea7a
remove redundents
fhl2000 Jul 5, 2025
328615d
Clear
fhl2000 Jul 6, 2025
debc682
Big refactors
fhl2000 Jul 9, 2025
cad6c39
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 9, 2025
dc455ee
cleanup
fhl2000 Jul 10, 2025
620a728
fix warmup
fhl2000 Jul 10, 2025
b1e6978
Commit suggestion: Update vllm/config.py
fhl2000 Jul 10, 2025
beee69a
commit suggestion2: Update vllm/config.py
fhl2000 Jul 10, 2025
21b1a8d
fix enforce_eager
fhl2000 Jul 10, 2025
ec79af7
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 10, 2025
210359a
small cleanup for pre-commit
fhl2000 Jul 10, 2025
11263e0
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 11, 2025
9a38a4e
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 12, 2025
699aff3
refactors
fhl2000 Jul 13, 2025
ef3d9d9
resolve yapf conflicts with isort
fhl2000 Jul 13, 2025
658565e
fixes
fhl2000 Jul 13, 2025
15e2b4a
fix global graph pool issue
fhl2000 Jul 13, 2025
4253dbf
fix refactors
fhl2000 Jul 13, 2025
2783e26
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 14, 2025
1b54962
more refactors
fhl2000 Jul 14, 2025
fb2a3c7
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 17, 2025
d6269bd
refactors for and more
fhl2000 Jul 17, 2025
2e1304c
fix pre-commit
fhl2000 Jul 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 32 additions & 12 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from torch._dispatch.python import enable_python_dispatcher

import vllm.envs as envs
from vllm.config import CompilationConfig, VllmConfig
from vllm.config import (CompilationConfig, CUDAGraphMode,
CUDAGraphRuntimeStyle, VllmConfig)
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import is_torch_equal_or_newer, resolve_obj_by_qualname
Expand Down Expand Up @@ -277,9 +278,6 @@ def split_graph(graph: fx.GraphModule,
return split_gm, outputs


# we share the global graph pool among all the backends
global_graph_pool = None

compilation_start_time = 0.0


Expand Down Expand Up @@ -339,14 +337,38 @@ def call_module(self, target: torch.fx.node.Target,
graph_index=index,
num_graphs=len(self.compile_submod_names),
runtime_shape=None)
# Lazy import here to avoid circular import
from .cuda_graph import CUDAGraphOptions
from .cuda_piecewise_backend import PiecewiseBackend

piecewise_backend = resolve_obj_by_qualname(
current_platform.get_piecewise_backend_cls())
self.module.__dict__[target] = piecewise_backend(
submod, self.vllm_config, self.graph_pool, index,
piecewise_backend = PiecewiseBackend(
submod, self.vllm_config, index,
len(self.compile_submod_names), sym_shape_indices,
compiled_graph_for_dynamic_shape, self.vllm_backend)

if self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE:
# resolve the static graph wrapper class (e.g. CUDAGraphWrapper
# class) as platform dependent.
static_graph_wrapper_class = resolve_obj_by_qualname(
current_platform.get_static_graph_wrapper_cls())

# Always assign PIECEWISE runtime style to the
# CUDAGraphWrapper for piecewise_backend, to distinguish
# it from the FULL cudagraph runtime style, no matter it
# is wrapped on a full or piecewise fx graph.
self.module.__dict__[target] = static_graph_wrapper_class(
runnable=piecewise_backend,
vllm_config=self.vllm_config,
runtime_style=CUDAGraphRuntimeStyle.PIECEWISE,
graph_pool=self.graph_pool,
cudagraph_options=CUDAGraphOptions(
debug_log_enable=piecewise_backend.is_first_graph,
gc_disable=not piecewise_backend.is_first_graph,
weak_ref_output=piecewise_backend.is_last_graph,
usage_str="piecewise"))
else:
self.module.__dict__[target] = piecewise_backend

compilation_counter.num_piecewise_capturable_graphs_seen += 1

return output
Expand Down Expand Up @@ -413,9 +435,7 @@ def __init__(
# them, e.g. backbone (default), eagle_head, etc.
self.prefix = prefix or model_tag

global global_graph_pool
if global_graph_pool is None:
global_graph_pool = current_platform.graph_pool_handle()
global_graph_pool = current_platform.get_global_graph_pool()

# TODO: in the future, if we want to use multiple
# streams, it might not be safe to share a global pool.
Expand Down Expand Up @@ -585,7 +605,7 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:

self._called = True

if not self.compilation_config.use_cudagraph or \
if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE or \
not self.compilation_config.cudagraph_copy_inputs:
return self.split_gm

Expand Down
72 changes: 0 additions & 72 deletions vllm/compilation/base_piecewise_backend.py

This file was deleted.

53 changes: 53 additions & 0 deletions vllm/compilation/base_static_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from typing import Any, Callable, Protocol

from vllm.config import CUDAGraphRuntimeStyle, VllmConfig


class AbstractStaticGraphWrapper(Protocol):
"""
StaticGraphWrapper interface that allows platforms to wrap a callable
to be captured as a static graph.
"""

def __init__(self, runnable: Callable, vllm_config: VllmConfig,
runtime_style: CUDAGraphRuntimeStyle, graph_pool: Any,
**kwargs):
"""
Initializes the StaticGraphWrapper class with graph capturing and
execution-related configurations.
Args:
runnable (Callable): The callable to be wrapped and captured.
vllm_config (VllmConfig): Global configuration for vLLM.
runtime_style (CUDAGraphRuntimeStyle): The style of the static
graph runtime. See CUDAGraphRuntimeStyle in vllm/config.py.
graph_pool (Any):
Graph memory pool handle, e.g.,
`torch.cuda.graph_pool_handle()`.
Keyword Args:
kwargs: Additional keyword arguments for platform-specific
configurations.
"""
raise NotImplementedError

def __call__(self, *args, **kwargs) -> Any:
"""
Executes the wrapped callable.
If the current CUDAGraphRuntimeStyle in the ForwardContext
matches the runtime style of this instance, it replays the CUDAGraph
or captures it using the callable if it hasn't been captured yet.
Otherwise, it calls the original callable directly.
Args:
*args: Variable length input arguments to be passed into the
callable.
**kwargs: Keyword arguments to be passed into the callable.
Returns:
Any: Output of the executed callable.
"""
raise NotImplementedError
169 changes: 169 additions & 0 deletions vllm/compilation/cuda_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import dataclasses
from contextlib import ExitStack
from typing import Any, Callable, Optional
from unittest.mock import patch

import torch

import vllm.envs as envs
from vllm.compilation.counter import compilation_counter
from vllm.config import CUDAGraphRuntimeStyle, VllmConfig
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import weak_ref_tensors

logger = init_logger(__name__)


@dataclasses.dataclass
class CUDAGraphEntry:
runtime_shape: int
cudagraph: Optional[torch.cuda.CUDAGraph] = None
output: Optional[Any] = None

# for cudagraph debugging, track the input addresses
# during capture, and check if they are the same during replay
input_addresses: Optional[list[int]] = None


@dataclasses.dataclass
class CUDAGraphOptions:
debug_log_enable: bool = True
gc_disable: bool = False
weak_ref_output: bool = True
usage_str: Optional[str] = None # For debug logging only


class CUDAGraphWrapper:
"""
This class simply wrap a runnable for cudagraph functionality,
taking responsibility of capturing cudagraph and running the replay.
"""

def __init__(self,
runnable: Callable,
vllm_config: VllmConfig,
runtime_style: CUDAGraphRuntimeStyle,
graph_pool: Any = None,
cudagraph_options: Optional[CUDAGraphOptions] = None):
self.runnable = runnable
self.vllm_config = vllm_config
self.graph_pool = graph_pool
self.runtime_style = runtime_style
self.compilation_config = vllm_config.compilation_config

self.first_run_finished = False
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"

# assert runtime_style is not NONE(no cudagraph), otherwise, we don't
# need to initialize a CUDAGraphWrapper.
assert self.runtime_style != CUDAGraphRuntimeStyle.NONE
if self.graph_pool is None:
self.graph_pool = current_platform.get_global_graph_pool()

if cudagraph_options is None:
cudagraph_options = CUDAGraphOptions()
self.cudagraph_options = cudagraph_options

self.cudagraph_capture_sizes: set[int] = set(
self.compilation_config.cudagraph_capture_sizes)
# the entries for different shapes that we need to capture cudagraph
self.concrete_cudagraph_entries: dict[int, CUDAGraphEntry] = {}

for shape in self.cudagraph_capture_sizes:
self.concrete_cudagraph_entries[shape] = CUDAGraphEntry(
runtime_shape=shape)

def __call__(self, *args, **kwargs):
forward_context = get_forward_context()
runtime_shape = forward_context.num_tokens
cudagraph_runtime_style = forward_context.cudagraph_runtime_style

if cudagraph_runtime_style == CUDAGraphRuntimeStyle.NONE or\
runtime_shape is None:
# This could mean the profile run, a warmup run, or running
# without cudagraphs.
return self.runnable(*args, **kwargs)
if cudagraph_runtime_style != self.runtime_style:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a strange case. Is it valid for us to end up here? Do you think it warrants a log?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the reviews. Actually, the answer can be found in the above history of this PR.

I initially designed this mechanism on purpose to avoid double capturing and ensure extra safety, in case we may have nested cudagraph callable in this PR (i.e., a full cudagraph wrapper wrapped on piecewise cudagraphs). And thanks @ProExpertProg for the proposal of the current implementation.

Instead of multiple flags for CUDA Graphs on the forward context, add a single field of a new enum type CUDAGraphStyle with options NONE, FULL, and PIECEWISE. That enum value tells the CUDAGraphWrapper instances whether they should "turn on" or not (both running and capture). Each wrapper instance can be initialized with a value (assert not NONE) it then compares to the value in fwd_context), and turns on if equal.

Yep, I'll add more comments later to make it clear.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think this is right: it basically only triggers capture/replay if the style matches. Style might not match if it's NONE (during profile, warmup, or just turned off), or if it is handled by a different wrapper.

# Only triggers capture/replay if the runtime style matches,
# otherwise, we fallback to the original runnable.
# This enables properly dispatching to the correct CUDAGraphWrapper
# when nesting multiple instances with different runtime styles.
return self.runnable(*args, **kwargs)

if runtime_shape not in self.concrete_cudagraph_entries:
# we don't need to do anything for this shape.
return self.runnable(*args, **kwargs)

entry = self.concrete_cudagraph_entries[runtime_shape]

if entry.cudagraph is None:
if self.cudagraph_options.debug_log_enable:
# Since we capture cudagraph for many different shapes and
# capturing is fast, we don't need to log it for every
# shape. E.g. we only log it for the first subgraph in
# piecewise mode.
logger.debug("Capturing a cudagraph of %s usage for shape %s",
self.cudagraph_options.usage_str,
entry.runtime_shape)

input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
entry.input_addresses = input_addresses
cudagraph = torch.cuda.CUDAGraph()

with ExitStack() as stack:
if self.cudagraph_options.gc_disable:
# during every model forward for piecewise cudagraph
# mode, we will capture many pieces of cudagraphs
# (roughly one per layer). running gc again and again
# across layers will make the cudagraph capture very slow.
# therefore, we only run gc for the first graph,
# and disable gc for the rest of the graphs.
stack.enter_context(patch("gc.collect", lambda: None))
stack.enter_context(
patch("torch.cuda.empty_cache", lambda: None))

# mind-exploding: carefully manage the reference and memory.
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
# `output` is managed by pytorch's cudagraph pool
output = self.runnable(*args, **kwargs)
if self.cudagraph_options.weak_ref_output:
# by converting it to weak ref,
# the original `output` will immediately be released
# to save memory. It is only safe to do this for
# the last graph in piecewise cuadgraph mode, because
# the output of the last graph will not be used by
# any other cuda graph.
output = weak_ref_tensors(output)

# here we always use weak ref for the output
# to save memory
entry.output = weak_ref_tensors(output)
entry.cudagraph = cudagraph

compilation_counter.num_cudagraph_captured += 1

# important: we need to return the output, rather than
# the weak ref of the output, so that pytorch can correctly
# manage the memory during cuda graph capture
return output

if self.is_debugging_mode:
# check if the input addresses are the same
new_input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
assert new_input_addresses == entry.input_addresses, (
f"Input addresses for cudagraphs of "
f"{self.cudagraph_options.usage_str} are different "
f"during replay. Expected {entry.input_addresses}, "
f"got {new_input_addresses}")

entry.cudagraph.replay()
return entry.output
Loading