-
-
Notifications
You must be signed in to change notification settings - Fork 8.8k
[Core] Allow full cudagraph with separate attention routines and orthogonal to compilation, add support for FA2 and FlashInfer #20059
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
92b1733
58ce477
c2c5fea
1606880
806432a
7c5df45
c7a9424
e8b9296
94d0b79
a67c698
da110af
deaf0fe
02ca154
fa0d25c
5108bef
1c1873d
7d4667a
fedff47
833ac56
d57257d
8b7ea7a
328615d
debc682
cad6c39
dc455ee
620a728
b1e6978
beee69a
21b1a8d
ec79af7
210359a
11263e0
9a38a4e
699aff3
ef3d9d9
658565e
15e2b4a
4253dbf
2783e26
1b54962
fb2a3c7
d6269bd
2e1304c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
||
from typing import Any, Callable, Protocol | ||
|
||
from vllm.config import CUDAGraphRuntimeStyle, VllmConfig | ||
|
||
|
||
class AbstractStaticGraphWrapper(Protocol): | ||
fhl2000 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
StaticGraphWrapper interface that allows platforms to wrap a callable | ||
to be captured as a static graph. | ||
""" | ||
|
||
def __init__(self, runnable: Callable, vllm_config: VllmConfig, | ||
fhl2000 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
runtime_style: CUDAGraphRuntimeStyle, graph_pool: Any, | ||
**kwargs): | ||
""" | ||
Initializes the StaticGraphWrapper class with graph capturing and | ||
execution-related configurations. | ||
Args: | ||
runnable (Callable): The callable to be wrapped and captured. | ||
vllm_config (VllmConfig): Global configuration for vLLM. | ||
runtime_style (CUDAGraphRuntimeStyle): The style of the static | ||
graph runtime. See CUDAGraphRuntimeStyle in vllm/config.py. | ||
graph_pool (Any): | ||
Graph memory pool handle, e.g., | ||
`torch.cuda.graph_pool_handle()`. | ||
Keyword Args: | ||
kwargs: Additional keyword arguments for platform-specific | ||
configurations. | ||
""" | ||
raise NotImplementedError | ||
|
||
def __call__(self, *args, **kwargs) -> Any: | ||
""" | ||
Executes the wrapped callable. | ||
If the current CUDAGraphRuntimeStyle in the ForwardContext | ||
matches the runtime style of this instance, it replays the CUDAGraph | ||
or captures it using the callable if it hasn't been captured yet. | ||
Otherwise, it calls the original callable directly. | ||
Args: | ||
*args: Variable length input arguments to be passed into the | ||
callable. | ||
**kwargs: Keyword arguments to be passed into the callable. | ||
Returns: | ||
Any: Output of the executed callable. | ||
""" | ||
raise NotImplementedError |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
||
import dataclasses | ||
from contextlib import ExitStack | ||
from typing import Any, Callable, Optional | ||
from unittest.mock import patch | ||
|
||
import torch | ||
|
||
import vllm.envs as envs | ||
from vllm.compilation.counter import compilation_counter | ||
from vllm.config import CUDAGraphRuntimeStyle, VllmConfig | ||
from vllm.forward_context import get_forward_context | ||
from vllm.logger import init_logger | ||
from vllm.platforms import current_platform | ||
from vllm.utils import weak_ref_tensors | ||
|
||
logger = init_logger(__name__) | ||
|
||
|
||
@dataclasses.dataclass | ||
class CUDAGraphEntry: | ||
runtime_shape: int | ||
cudagraph: Optional[torch.cuda.CUDAGraph] = None | ||
output: Optional[Any] = None | ||
|
||
# for cudagraph debugging, track the input addresses | ||
# during capture, and check if they are the same during replay | ||
input_addresses: Optional[list[int]] = None | ||
|
||
|
||
@dataclasses.dataclass | ||
class CUDAGraphOptions: | ||
debug_log_enable: bool = True | ||
gc_disable: bool = False | ||
weak_ref_output: bool = True | ||
usage_str: Optional[str] = None # For debug logging only | ||
|
||
|
||
class CUDAGraphWrapper: | ||
""" | ||
This class simply wrap a runnable for cudagraph functionality, | ||
taking responsibility of capturing cudagraph and running the replay. | ||
""" | ||
|
||
def __init__(self, | ||
runnable: Callable, | ||
vllm_config: VllmConfig, | ||
runtime_style: CUDAGraphRuntimeStyle, | ||
graph_pool: Any = None, | ||
cudagraph_options: Optional[CUDAGraphOptions] = None): | ||
self.runnable = runnable | ||
self.vllm_config = vllm_config | ||
self.graph_pool = graph_pool | ||
self.runtime_style = runtime_style | ||
self.compilation_config = vllm_config.compilation_config | ||
|
||
self.first_run_finished = False | ||
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" | ||
|
||
# assert runtime_style is not NONE(no cudagraph), otherwise, we don't | ||
# need to initialize a CUDAGraphWrapper. | ||
assert self.runtime_style != CUDAGraphRuntimeStyle.NONE | ||
if self.graph_pool is None: | ||
self.graph_pool = current_platform.get_global_graph_pool() | ||
|
||
if cudagraph_options is None: | ||
cudagraph_options = CUDAGraphOptions() | ||
self.cudagraph_options = cudagraph_options | ||
|
||
self.cudagraph_capture_sizes: set[int] = set( | ||
self.compilation_config.cudagraph_capture_sizes) | ||
# the entries for different shapes that we need to capture cudagraph | ||
self.concrete_cudagraph_entries: dict[int, CUDAGraphEntry] = {} | ||
|
||
for shape in self.cudagraph_capture_sizes: | ||
self.concrete_cudagraph_entries[shape] = CUDAGraphEntry( | ||
runtime_shape=shape) | ||
|
||
def __call__(self, *args, **kwargs): | ||
forward_context = get_forward_context() | ||
runtime_shape = forward_context.num_tokens | ||
fhl2000 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
cudagraph_runtime_style = forward_context.cudagraph_runtime_style | ||
|
||
if cudagraph_runtime_style == CUDAGraphRuntimeStyle.NONE or\ | ||
runtime_shape is None: | ||
# This could mean the profile run, a warmup run, or running | ||
# without cudagraphs. | ||
return self.runnable(*args, **kwargs) | ||
if cudagraph_runtime_style != self.runtime_style: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems like a strange case. Is it valid for us to end up here? Do you think it warrants a log? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the reviews. Actually, the answer can be found in the above history of this PR. I initially designed this mechanism on purpose to avoid double capturing and ensure extra safety, in case we may have nested cudagraph callable in this PR (i.e., a full cudagraph wrapper wrapped on piecewise cudagraphs). And thanks @ProExpertProg for the proposal of the current implementation.
Yep, I'll add more comments later to make it clear. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I think this is right: it basically only triggers capture/replay if the style matches. Style might not match if it's |
||
# Only triggers capture/replay if the runtime style matches, | ||
# otherwise, we fallback to the original runnable. | ||
# This enables properly dispatching to the correct CUDAGraphWrapper | ||
# when nesting multiple instances with different runtime styles. | ||
return self.runnable(*args, **kwargs) | ||
|
||
if runtime_shape not in self.concrete_cudagraph_entries: | ||
# we don't need to do anything for this shape. | ||
return self.runnable(*args, **kwargs) | ||
|
||
entry = self.concrete_cudagraph_entries[runtime_shape] | ||
|
||
if entry.cudagraph is None: | ||
if self.cudagraph_options.debug_log_enable: | ||
# Since we capture cudagraph for many different shapes and | ||
# capturing is fast, we don't need to log it for every | ||
# shape. E.g. we only log it for the first subgraph in | ||
# piecewise mode. | ||
logger.debug("Capturing a cudagraph of %s usage for shape %s", | ||
self.cudagraph_options.usage_str, | ||
entry.runtime_shape) | ||
|
||
input_addresses = [ | ||
x.data_ptr() for x in args if isinstance(x, torch.Tensor) | ||
] | ||
entry.input_addresses = input_addresses | ||
cudagraph = torch.cuda.CUDAGraph() | ||
|
||
with ExitStack() as stack: | ||
if self.cudagraph_options.gc_disable: | ||
# during every model forward for piecewise cudagraph | ||
# mode, we will capture many pieces of cudagraphs | ||
# (roughly one per layer). running gc again and again | ||
# across layers will make the cudagraph capture very slow. | ||
# therefore, we only run gc for the first graph, | ||
# and disable gc for the rest of the graphs. | ||
stack.enter_context(patch("gc.collect", lambda: None)) | ||
stack.enter_context( | ||
patch("torch.cuda.empty_cache", lambda: None)) | ||
|
||
# mind-exploding: carefully manage the reference and memory. | ||
with torch.cuda.graph(cudagraph, pool=self.graph_pool): | ||
# `output` is managed by pytorch's cudagraph pool | ||
output = self.runnable(*args, **kwargs) | ||
if self.cudagraph_options.weak_ref_output: | ||
# by converting it to weak ref, | ||
# the original `output` will immediately be released | ||
# to save memory. It is only safe to do this for | ||
# the last graph in piecewise cuadgraph mode, because | ||
# the output of the last graph will not be used by | ||
# any other cuda graph. | ||
output = weak_ref_tensors(output) | ||
|
||
# here we always use weak ref for the output | ||
# to save memory | ||
entry.output = weak_ref_tensors(output) | ||
entry.cudagraph = cudagraph | ||
|
||
compilation_counter.num_cudagraph_captured += 1 | ||
|
||
# important: we need to return the output, rather than | ||
# the weak ref of the output, so that pytorch can correctly | ||
# manage the memory during cuda graph capture | ||
return output | ||
|
||
if self.is_debugging_mode: | ||
# check if the input addresses are the same | ||
new_input_addresses = [ | ||
x.data_ptr() for x in args if isinstance(x, torch.Tensor) | ||
] | ||
assert new_input_addresses == entry.input_addresses, ( | ||
f"Input addresses for cudagraphs of " | ||
f"{self.cudagraph_options.usage_str} are different " | ||
f"during replay. Expected {entry.input_addresses}, " | ||
f"got {new_input_addresses}") | ||
|
||
entry.cudagraph.replay() | ||
return entry.output |
Uh oh!
There was an error while loading. Please reload this page.