Skip to content

[Core] Allow full cudagraph with separate attention routines and orthogonal to compilation, add support for FA2 and FlashInfer #20059

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 51 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
92b1733
FA2 and FlashInfer Full cuda graph support
fhl2000 Jun 25, 2025
58ce477
fix the arch support in CMakeLists.txt to include 8.9
fhl2000 Jun 25, 2025
c2c5fea
Refactors
fhl2000 Jun 25, 2025
1606880
refactors
fhl2000 Jun 25, 2025
806432a
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jun 25, 2025
7c5df45
refactor
fhl2000 Jun 25, 2025
c7a9424
Add check for separate_attention_routine flag
fhl2000 Jun 25, 2025
e8b9296
fix typo error
fhl2000 Jun 26, 2025
94d0b79
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jun 27, 2025
a67c698
refactors and rearchitect cuda graph logic
fhl2000 Jun 28, 2025
da110af
Refactors
fhl2000 Jun 28, 2025
deaf0fe
Delect one commit
fhl2000 Jun 28, 2025
02ca154
Add support for force_no_split_graph
fhl2000 Jun 28, 2025
fa0d25c
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 1, 2025
5108bef
Huge refactors to separete cudagraph logic from vllm compilation
fhl2000 Jul 5, 2025
1c1873d
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 5, 2025
7d4667a
refactors
fhl2000 Jul 5, 2025
fedff47
fix errors
fhl2000 Jul 5, 2025
833ac56
fix small error by lazy import
fhl2000 Jul 5, 2025
d57257d
handle lint-and-deploy errors for cpu execution
fhl2000 Jul 5, 2025
8b7ea7a
remove redundents
fhl2000 Jul 5, 2025
328615d
Clear
fhl2000 Jul 6, 2025
debc682
Big refactors
fhl2000 Jul 9, 2025
cad6c39
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 9, 2025
dc455ee
cleanup
fhl2000 Jul 10, 2025
620a728
fix warmup
fhl2000 Jul 10, 2025
b1e6978
Commit suggestion: Update vllm/config.py
fhl2000 Jul 10, 2025
beee69a
commit suggestion2: Update vllm/config.py
fhl2000 Jul 10, 2025
21b1a8d
fix enforce_eager
fhl2000 Jul 10, 2025
ec79af7
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 10, 2025
210359a
small cleanup for pre-commit
fhl2000 Jul 10, 2025
11263e0
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 11, 2025
9a38a4e
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 12, 2025
699aff3
refactors
fhl2000 Jul 13, 2025
ef3d9d9
resolve yapf conflicts with isort
fhl2000 Jul 13, 2025
658565e
fixes
fhl2000 Jul 13, 2025
15e2b4a
fix global graph pool issue
fhl2000 Jul 13, 2025
4253dbf
fix refactors
fhl2000 Jul 13, 2025
2783e26
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 14, 2025
1b54962
more refactors
fhl2000 Jul 14, 2025
fb2a3c7
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 17, 2025
d6269bd
refactors for and more
fhl2000 Jul 17, 2025
2e1304c
fix pre-commit
fhl2000 Jul 17, 2025
db22ca5
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 18, 2025
72d40e6
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 20, 2025
0c79e53
change cudagraph dispatching logics; runtime style->runtime mode
fhl2000 Jul 21, 2025
75db3a6
pass pre-commit
fhl2000 Jul 21, 2025
0bca4c4
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 23, 2025
9d2f148
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 24, 2025
60bdc61
fix bug when cudagraph_separate_routine==False
fhl2000 Jul 24, 2025
9036bd2
recover FlashInfer from main branch
fhl2000 Jul 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from vllm.config import CompilationConfig, VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import is_torch_equal_or_newer, resolve_obj_by_qualname
from vllm.utils import is_torch_equal_or_newer

from .compiler_interface import (CompilerInterface, EagerAdaptor,
InductorAdaptor, InductorStandaloneAdaptor)
Expand Down Expand Up @@ -258,6 +258,14 @@ def split_graph(graph: fx.GraphModule,
# we share the global graph pool among all the backends
global_graph_pool = None


def get_global_graph_pool():
global global_graph_pool
if global_graph_pool is None:
global_graph_pool = current_platform.graph_pool_handle()
return global_graph_pool


compilation_start_time = 0.0


Expand Down Expand Up @@ -317,10 +325,9 @@ def call_module(self, target: torch.fx.node.Target,
graph_index=index,
num_graphs=len(self.compile_submod_names),
runtime_shape=None)

piecewise_backend = resolve_obj_by_qualname(
current_platform.get_piecewise_backend_cls())
self.module.__dict__[target] = piecewise_backend(
# Lazy import here to avoid circular import
from .piecewise_backend import PiecewiseBackend
self.module.__dict__[target] = PiecewiseBackend(
submod, self.vllm_config, self.graph_pool, index,
len(self.compile_submod_names), sym_shape_indices,
compiled_graph_for_general_shape, self.vllm_backend)
Expand Down Expand Up @@ -391,9 +398,7 @@ def __init__(
# them, e.g. backbone (default), eagle_head, etc.
self.prefix = prefix or model_tag

global global_graph_pool
if global_graph_pool is None:
global_graph_pool = current_platform.graph_pool_handle()
global_graph_pool = get_global_graph_pool()

# TODO: in the future, if we want to use multiple
# streams, it might not be safe to share a global pool.
Expand Down
72 changes: 0 additions & 72 deletions vllm/compilation/base_piecewise_backend.py

This file was deleted.

57 changes: 57 additions & 0 deletions vllm/compilation/base_static_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from typing import Any, Callable, Protocol

from vllm.config import VllmConfig


class AbstractStaticGraphWrapper(Protocol):
"""
StaticGraphWrapper interface that allows platforms to wrap a callable
to be captured as a static graph.
"""

def __init__(self, runnable: Callable, vllm_config: VllmConfig,
graph_pool: Any, runtime_style: int, **kwargs):
"""
Initializes the StaticGraphWrapper class with graph capturing and
execution-related configurations.

Args:
runnable (Callable): The callable to be wrapped and captured.
vllm_config (VllmConfig): Global configuration for vLLM.
graph_pool (Any):
Graph memory pool handle, e.g.,
`torch.cuda.graph_pool_handle()`.
runtime_style (Any): The style of the static
graph runtime.
Keyword Args:
kwargs: Additional keyword arguments for platform-specific
configurations.
"""
raise NotImplementedError

def maybe_replace_runnable(self, shape: int, runnable: Any):
"""
Replaces the runnable with a new one for a specific compiled shape.
"""
raise NotImplementedError

def __call__(self, *args, **kwargs) -> Any:
"""
Executes the wrapped callable.

This may involve replaying a captured static graph if the conditions
are met, or running the original callable eagerly and potentially
capturing it.

Args:
*args: Variable length input arguments to be passed into the
callable.
**kwargs: Keyword arguments to be passed into the callable.

Returns:
Any: Output of the executed callable.
"""
raise NotImplementedError
185 changes: 185 additions & 0 deletions vllm/compilation/cuda_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import dataclasses
from contextlib import ExitStack
from typing import Any, Callable, Optional
from unittest.mock import patch

import torch

import vllm.envs as envs
from vllm.compilation.counter import compilation_counter
from vllm.config import CUDAGraphRuntimeStyle, VllmConfig
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.utils import weak_ref_tensors

logger = init_logger(__name__)


@dataclasses.dataclass
class CUDAGraphEntry:
runtime_shape: int
num_finished_warmup: int = 0
runnable: Callable = None # type: ignore
cudagraph: Optional[torch.cuda.CUDAGraph] = None
output: Optional[Any] = None

# for cudagraph debugging, track the input addresses
# during capture, and check if they are the same during replay
input_addresses: Optional[list[int]] = None

usage_type: Optional[str] = None # For debug logging only


class CUDAGraphWrapper:
"""
This class simply wrap a runnable for cudagraph functionality,
taking responsibility of capturing cudagraph and running the replay.
"""

def __init__(self,
runnable: Any,
vllm_config: VllmConfig,
runtime_style: int,
graph_pool: Any = None,
cudagraph_specific_config: Optional[dict[str, Any]] = None):
self.runnable = runnable
self.vllm_config = vllm_config
self.graph_pool = graph_pool
self.runtime_style = runtime_style
self.compilation_config = vllm_config.compilation_config

self.first_run_finished = False
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"

assert self.runtime_style >= CUDAGraphRuntimeStyle.PIECEWISE
if self.graph_pool is None:
# lazy import to avoid triggering some import issues.
from vllm.compilation.backends import get_global_graph_pool
self.graph_pool = get_global_graph_pool()

if cudagraph_specific_config is None:
cudagraph_specific_config = {}
self.debug_capturing = cudagraph_specific_config.get(
"debug_capturing", True)
self.gc_disable = cudagraph_specific_config.get("gc_disable", False)
self.weak_ref_output = cudagraph_specific_config.get(
"weak_ref_output", True)
usage_type = cudagraph_specific_config.get("usage_type")
self.cudagraph_capture_sizes: set[int] = set(
self.compilation_config.cudagraph_capture_sizes)
# the entries for different shapes that we need to capture cudagraph
self.concrete_cudagraph_entries: dict[int, CUDAGraphEntry] = {}

for shape in self.cudagraph_capture_sizes:

self.concrete_cudagraph_entries[shape] = CUDAGraphEntry(
runtime_shape=shape,
runnable=self.runnable,
usage_type=usage_type, # for debug logging only
)

def maybe_replace_runnable(self, shape: int, runnable: Callable):
# this is a hack to replace a general shape runnable with a compiled
# runnable of a specific shape.
if shape not in self.concrete_cudagraph_entries:
return
entry = self.concrete_cudagraph_entries[shape]
assert entry.cudagraph is None, "Cudagraph is already captured"
entry.runnable = runnable

def __call__(self, *args, **kwargs):
forward_context = get_forward_context()
runtime_shape = forward_context.num_tokens
cudagraph_runtime_style = forward_context.cudagraph_runtime_style

if cudagraph_runtime_style == CUDAGraphRuntimeStyle.NONE or\
runtime_shape is None:
# TODO: make sure here is on profile running or eager running
return self.runnable(*args, **kwargs)
if cudagraph_runtime_style != self.runtime_style:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a strange case. Is it valid for us to end up here? Do you think it warrants a log?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the reviews. Actually, the answer can be found in the above history of this PR.

I initially designed this mechanism on purpose to avoid double capturing and ensure extra safety, in case we may have nested cudagraph callable in this PR (i.e., a full cudagraph wrapper wrapped on piecewise cudagraphs). And thanks @ProExpertProg for the proposal of the current implementation.

Instead of multiple flags for CUDA Graphs on the forward context, add a single field of a new enum type CUDAGraphStyle with options NONE, FULL, and PIECEWISE. That enum value tells the CUDAGraphWrapper instances whether they should "turn on" or not (both running and capture). Each wrapper instance can be initialized with a value (assert not NONE) it then compares to the value in fwd_context), and turns on if equal.

Yep, I'll add more comments later to make it clear.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think this is right: it basically only triggers capture/replay if the style matches. Style might not match if it's NONE (during profile, warmup, or just turned off), or if it is handled by a different wrapper.

# CUDAGraph runtime style don't match the current
# configuration, so directly call runnable eagerly
# as it's always safe.
return self.runnable(*args, **kwargs)

if runtime_shape not in self.concrete_cudagraph_entries:
# we don't need to do anything for this shape.
return self.runnable(*args, **kwargs)

entry = self.concrete_cudagraph_entries[runtime_shape]

if entry.cudagraph is None:
if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa
entry.num_finished_warmup += 1
if self.debug_capturing:
logger.debug(
"Warming up %s/%s of %s usage for shape %s",
entry.num_finished_warmup,
self.compilation_config.cudagraph_num_of_warmups,
entry.usage_type, entry.runtime_shape)
return entry.runnable(*args, **kwargs)

if self.debug_capturing:
# Since we capture cudagraph for many different shapes and
# capturing is fast, we don't need to log it for every
# shape. We only log it in the debug mode.
logger.debug("Capturing a cudagraph of %s usage for shape %s",
entry.usage_type, entry.runtime_shape)

input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
entry.input_addresses = input_addresses
cudagraph = torch.cuda.CUDAGraph()

with ExitStack() as stack:
if self.gc_disable:
# during every model forward for piecewise cudagraph
# mode, we will capture many pieces of cudagraphs
# (roughly one per layer). running gc again and again
# across layers will make the cudagraph capture very slow.
# therefore, we only run gc for the first graph,
# and disable gc for the rest of the graphs.
stack.enter_context(patch("gc.collect", lambda: None))
stack.enter_context(
patch("torch.cuda.empty_cache", lambda: None))

# mind-exploding: carefully manage the reference and memory.
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
# `output` is managed by pytorch's cudagraph pool
output = entry.runnable(*args, **kwargs)
if self.weak_ref_output:
# by converting it to weak ref,
# the original `output` will immediately be released
# to save memory. It is only safe to do this for
# the last graph, because the output of the last
# graph will not be used by any other cuda graph.
output = weak_ref_tensors(output)

# here we always use weak ref for the output
# to save memory
entry.output = weak_ref_tensors(output)
entry.cudagraph = cudagraph

compilation_counter.num_cudagraph_captured += 1

# important: we need to return the output, rather than
# the weak ref of the output, so that pytorch can correctly
# manage the memory during cuda graph capture
return output

if self.is_debugging_mode:
# check if the input addresses are the same
new_input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
assert new_input_addresses == entry.input_addresses, (
"Input addresses for cudagraphs are different during "
f"replay. Expected {entry.input_addresses}, got "
f"{new_input_addresses}")

entry.cudagraph.replay()
return entry.output
Loading