diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 673fb586623..e48d38fe3dc 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -15,7 +15,8 @@ from torch._dispatch.python import enable_python_dispatcher import vllm.envs as envs -from vllm.config import CompilationConfig, VllmConfig +from vllm.config import (CompilationConfig, CUDAGraphMode, + CUDAGraphRuntimeStyle, VllmConfig) from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import is_torch_equal_or_newer, resolve_obj_by_qualname @@ -277,9 +278,6 @@ def split_graph(graph: fx.GraphModule, return split_gm, outputs -# we share the global graph pool among all the backends -global_graph_pool = None - compilation_start_time = 0.0 @@ -339,14 +337,38 @@ def call_module(self, target: torch.fx.node.Target, graph_index=index, num_graphs=len(self.compile_submod_names), runtime_shape=None) + # Lazy import here to avoid circular import + from .cuda_graph import CUDAGraphOptions + from .cuda_piecewise_backend import PiecewiseBackend - piecewise_backend = resolve_obj_by_qualname( - current_platform.get_piecewise_backend_cls()) - self.module.__dict__[target] = piecewise_backend( - submod, self.vllm_config, self.graph_pool, index, + piecewise_backend = PiecewiseBackend( + submod, self.vllm_config, index, len(self.compile_submod_names), sym_shape_indices, compiled_graph_for_dynamic_shape, self.vllm_backend) + if self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE: + # resolve the static graph wrapper class (e.g. CUDAGraphWrapper + # class) as platform dependent. + static_graph_wrapper_class = resolve_obj_by_qualname( + current_platform.get_static_graph_wrapper_cls()) + + # Always assign PIECEWISE runtime style to the + # CUDAGraphWrapper for piecewise_backend, to distinguish + # it from the FULL cudagraph runtime style, no matter it + # is wrapped on a full or piecewise fx graph. + self.module.__dict__[target] = static_graph_wrapper_class( + runnable=piecewise_backend, + vllm_config=self.vllm_config, + runtime_style=CUDAGraphRuntimeStyle.PIECEWISE, + graph_pool=self.graph_pool, + cudagraph_options=CUDAGraphOptions( + debug_log_enable=piecewise_backend.is_first_graph, + gc_disable=not piecewise_backend.is_first_graph, + weak_ref_output=piecewise_backend.is_last_graph, + usage_str="piecewise")) + else: + self.module.__dict__[target] = piecewise_backend + compilation_counter.num_piecewise_capturable_graphs_seen += 1 return output @@ -413,9 +435,7 @@ def __init__( # them, e.g. backbone (default), eagle_head, etc. self.prefix = prefix or model_tag - global global_graph_pool - if global_graph_pool is None: - global_graph_pool = current_platform.graph_pool_handle() + global_graph_pool = current_platform.get_global_graph_pool() # TODO: in the future, if we want to use multiple # streams, it might not be safe to share a global pool. @@ -585,7 +605,7 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: self._called = True - if not self.compilation_config.use_cudagraph or \ + if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE or \ not self.compilation_config.cudagraph_copy_inputs: return self.split_gm diff --git a/vllm/compilation/base_piecewise_backend.py b/vllm/compilation/base_piecewise_backend.py deleted file mode 100644 index 4d7aeeb4d03..00000000000 --- a/vllm/compilation/base_piecewise_backend.py +++ /dev/null @@ -1,72 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import Any, Callable, Protocol - -import torch.fx as fx - -from vllm.compilation.backends import VllmBackend -from vllm.config import VllmConfig - - -class AbstractPiecewiseBackend(Protocol): - """ - PiecewiseBackend interface that allows platforms to extend - piecewise static graph. - """ - - def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, - graph_pool: Any, piecewise_compile_index: int, - total_piecewise_compiles: int, sym_shape_indices: list[int], - compiled_graph_for_general_shape: Callable, - vllm_backend: VllmBackend, **kwargs): - """ - Initializes the PiecewiseBackend class with compilation and - execution-related configurations. - - This class handles piecewise compilation, graph capturing, - and dispatching for specific input shapes. - - Args: - graph (fx.GraphModule): The graph represented in fx. - vllm_config (VllmConfig): Global configuration for vLLM. - graph_pool (Any): - Graph memory pool handle, e.g., - `torch.cuda.graph_pool_handle()`. - piecewise_compile_index (int): - Index of the current piecewise subgraph. - total_piecewise_compiles (int): - Total number of piecewise-compiled graphs. - sym_shape_indices (list[int]): - Indices of symbolic shape. - compiled_graph_for_general_shape (Callable): - Callable that executes the graph compiled for general shapes. - vllm_backend (VllmBackend): - Backend compiler that manages compilation and graph runtime - for vLLM. - - Keyword Args: - kwargs: Additional keyword arguments reserved for future - extensions or custom platforms. - """ - raise NotImplementedError - - def __call__(self, *args) -> Any: - """Executes the compiled graph for given input args. - - If this is the first invocation, executes the general compiled graph - and initiates the compilation process tracking. For subsequent calls, - dynamically dispatches execution to either a compiled graph or a static - graph based on the input shape. - - Args: - *args: Variable length input arguments to be passed into the - graph. The symbolic shape is expected to be in position - `sym_shape_indices[0]`. - - Returns: - Any: Output of the executed graph. This can be from the general - compiled graph, a specialized compiled version for the given shape, - or a replayed static graph. - """ - raise NotImplementedError diff --git a/vllm/compilation/base_static_graph.py b/vllm/compilation/base_static_graph.py new file mode 100644 index 00000000000..ae98603fc6c --- /dev/null +++ b/vllm/compilation/base_static_graph.py @@ -0,0 +1,53 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any, Callable, Protocol + +from vllm.config import CUDAGraphRuntimeStyle, VllmConfig + + +class AbstractStaticGraphWrapper(Protocol): + """ + StaticGraphWrapper interface that allows platforms to wrap a callable + to be captured as a static graph. + """ + + def __init__(self, runnable: Callable, vllm_config: VllmConfig, + runtime_style: CUDAGraphRuntimeStyle, graph_pool: Any, + **kwargs): + """ + Initializes the StaticGraphWrapper class with graph capturing and + execution-related configurations. + + Args: + runnable (Callable): The callable to be wrapped and captured. + vllm_config (VllmConfig): Global configuration for vLLM. + runtime_style (CUDAGraphRuntimeStyle): The style of the static + graph runtime. See CUDAGraphRuntimeStyle in vllm/config.py. + graph_pool (Any): + Graph memory pool handle, e.g., + `torch.cuda.graph_pool_handle()`. + Keyword Args: + kwargs: Additional keyword arguments for platform-specific + configurations. + """ + raise NotImplementedError + + def __call__(self, *args, **kwargs) -> Any: + """ + Executes the wrapped callable. + + If the current CUDAGraphRuntimeStyle in the ForwardContext + matches the runtime style of this instance, it replays the CUDAGraph + or captures it using the callable if it hasn't been captured yet. + Otherwise, it calls the original callable directly. + + Args: + *args: Variable length input arguments to be passed into the + callable. + **kwargs: Keyword arguments to be passed into the callable. + + Returns: + Any: Output of the executed callable. + """ + raise NotImplementedError diff --git a/vllm/compilation/cuda_graph.py b/vllm/compilation/cuda_graph.py new file mode 100644 index 00000000000..313eebeb24a --- /dev/null +++ b/vllm/compilation/cuda_graph.py @@ -0,0 +1,169 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import dataclasses +from contextlib import ExitStack +from typing import Any, Callable, Optional +from unittest.mock import patch + +import torch + +import vllm.envs as envs +from vllm.compilation.counter import compilation_counter +from vllm.config import CUDAGraphRuntimeStyle, VllmConfig +from vllm.forward_context import get_forward_context +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.utils import weak_ref_tensors + +logger = init_logger(__name__) + + +@dataclasses.dataclass +class CUDAGraphEntry: + runtime_shape: int + cudagraph: Optional[torch.cuda.CUDAGraph] = None + output: Optional[Any] = None + + # for cudagraph debugging, track the input addresses + # during capture, and check if they are the same during replay + input_addresses: Optional[list[int]] = None + + +@dataclasses.dataclass +class CUDAGraphOptions: + debug_log_enable: bool = True + gc_disable: bool = False + weak_ref_output: bool = True + usage_str: Optional[str] = None # For debug logging only + + +class CUDAGraphWrapper: + """ + This class simply wrap a runnable for cudagraph functionality, + taking responsibility of capturing cudagraph and running the replay. + """ + + def __init__(self, + runnable: Callable, + vllm_config: VllmConfig, + runtime_style: CUDAGraphRuntimeStyle, + graph_pool: Any = None, + cudagraph_options: Optional[CUDAGraphOptions] = None): + self.runnable = runnable + self.vllm_config = vllm_config + self.graph_pool = graph_pool + self.runtime_style = runtime_style + self.compilation_config = vllm_config.compilation_config + + self.first_run_finished = False + self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" + + # assert runtime_style is not NONE(no cudagraph), otherwise, we don't + # need to initialize a CUDAGraphWrapper. + assert self.runtime_style != CUDAGraphRuntimeStyle.NONE + if self.graph_pool is None: + self.graph_pool = current_platform.get_global_graph_pool() + + if cudagraph_options is None: + cudagraph_options = CUDAGraphOptions() + self.cudagraph_options = cudagraph_options + + self.cudagraph_capture_sizes: set[int] = set( + self.compilation_config.cudagraph_capture_sizes) + # the entries for different shapes that we need to capture cudagraph + self.concrete_cudagraph_entries: dict[int, CUDAGraphEntry] = {} + + for shape in self.cudagraph_capture_sizes: + self.concrete_cudagraph_entries[shape] = CUDAGraphEntry( + runtime_shape=shape) + + def __call__(self, *args, **kwargs): + forward_context = get_forward_context() + runtime_shape = forward_context.num_tokens + cudagraph_runtime_style = forward_context.cudagraph_runtime_style + + if cudagraph_runtime_style == CUDAGraphRuntimeStyle.NONE or\ + runtime_shape is None: + # This could mean the profile run, a warmup run, or running + # without cudagraphs. + return self.runnable(*args, **kwargs) + if cudagraph_runtime_style != self.runtime_style: + # Only triggers capture/replay if the runtime style matches, + # otherwise, we fallback to the original runnable. + # This enables properly dispatching to the correct CUDAGraphWrapper + # when nesting multiple instances with different runtime styles. + return self.runnable(*args, **kwargs) + + if runtime_shape not in self.concrete_cudagraph_entries: + # we don't need to do anything for this shape. + return self.runnable(*args, **kwargs) + + entry = self.concrete_cudagraph_entries[runtime_shape] + + if entry.cudagraph is None: + if self.cudagraph_options.debug_log_enable: + # Since we capture cudagraph for many different shapes and + # capturing is fast, we don't need to log it for every + # shape. E.g. we only log it for the first subgraph in + # piecewise mode. + logger.debug("Capturing a cudagraph of %s usage for shape %s", + self.cudagraph_options.usage_str, + entry.runtime_shape) + + input_addresses = [ + x.data_ptr() for x in args if isinstance(x, torch.Tensor) + ] + entry.input_addresses = input_addresses + cudagraph = torch.cuda.CUDAGraph() + + with ExitStack() as stack: + if self.cudagraph_options.gc_disable: + # during every model forward for piecewise cudagraph + # mode, we will capture many pieces of cudagraphs + # (roughly one per layer). running gc again and again + # across layers will make the cudagraph capture very slow. + # therefore, we only run gc for the first graph, + # and disable gc for the rest of the graphs. + stack.enter_context(patch("gc.collect", lambda: None)) + stack.enter_context( + patch("torch.cuda.empty_cache", lambda: None)) + + # mind-exploding: carefully manage the reference and memory. + with torch.cuda.graph(cudagraph, pool=self.graph_pool): + # `output` is managed by pytorch's cudagraph pool + output = self.runnable(*args, **kwargs) + if self.cudagraph_options.weak_ref_output: + # by converting it to weak ref, + # the original `output` will immediately be released + # to save memory. It is only safe to do this for + # the last graph in piecewise cuadgraph mode, because + # the output of the last graph will not be used by + # any other cuda graph. + output = weak_ref_tensors(output) + + # here we always use weak ref for the output + # to save memory + entry.output = weak_ref_tensors(output) + entry.cudagraph = cudagraph + + compilation_counter.num_cudagraph_captured += 1 + + # important: we need to return the output, rather than + # the weak ref of the output, so that pytorch can correctly + # manage the memory during cuda graph capture + return output + + if self.is_debugging_mode: + # check if the input addresses are the same + new_input_addresses = [ + x.data_ptr() for x in args if isinstance(x, torch.Tensor) + ] + assert new_input_addresses == entry.input_addresses, ( + f"Input addresses for cudagraphs of " + f"{self.cudagraph_options.usage_str} are different " + f"during replay. Expected {entry.input_addresses}, " + f"got {new_input_addresses}") + + entry.cudagraph.replay() + return entry.output diff --git a/vllm/compilation/cuda_piecewise_backend.py b/vllm/compilation/cuda_piecewise_backend.py index 8c49ea6cc10..aad1293e317 100644 --- a/vllm/compilation/cuda_piecewise_backend.py +++ b/vllm/compilation/cuda_piecewise_backend.py @@ -2,21 +2,16 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import dataclasses -from contextlib import ExitStack -from typing import Any, Callable, Optional -from unittest.mock import patch +from typing import Any, Callable -import torch import torch.fx as fx import vllm.envs as envs from vllm.compilation.backends import VllmBackend -from vllm.compilation.counter import compilation_counter from vllm.compilation.monitor import end_monitoring_torch_compile from vllm.config import VllmConfig from vllm.forward_context import get_forward_context from vllm.logger import init_logger -from vllm.utils import weak_ref_tensors logger = init_logger(__name__) @@ -24,44 +19,28 @@ @dataclasses.dataclass class ConcreteSizeEntry: runtime_shape: int - need_to_compile: bool # the size is in compile_sizes - use_cudagraph: bool # the size is in cudagraph_capture_sizes - compiled: bool = False runnable: Callable = None # type: ignore - num_finished_warmup: int = 0 - cudagraph: Optional[torch.cuda.CUDAGraph] = None - output: Optional[Any] = None - - # for cudagraph debugging, track the input addresses - # during capture, and check if they are the same during replay - input_addresses: Optional[list[int]] = None -class CUDAPiecewiseBackend: +class PiecewiseBackend: def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, - graph_pool: Any, piecewise_compile_index: int, - total_piecewise_compiles: int, sym_shape_indices: list[int], + piecewise_compile_index: int, total_piecewise_compiles: int, + sym_shape_indices: list[int], compiled_graph_for_general_shape: Callable, vllm_backend: VllmBackend): """ The backend for piecewise compilation. - It mainly handles the compilation and cudagraph capturing. + It mainly handles the compilation. We will compile `self.graph` once for the general shape, and then compile for different shapes specified in `compilation_config.compile_sizes`. - - Independently, we will capture cudagraph for different shapes. - - If a shape needs both compilation and cudagraph, we will - compile it first, and then capture cudagraph. """ self.graph = graph self.vllm_config = vllm_config self.compilation_config = vllm_config.compilation_config - self.graph_pool = graph_pool self.piecewise_compile_index = piecewise_compile_index self.total_piecewise_compiles = total_piecewise_compiles self.vllm_backend = vllm_backend @@ -70,11 +49,10 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, self.is_last_graph = ( piecewise_compile_index == total_piecewise_compiles - 1) + self.is_full_graph = total_piecewise_compiles == 1 + self.compile_sizes: set[int] = set( self.compilation_config.compile_sizes) - self.cudagraph_capture_sizes: set[int] = set( - self.compilation_config.cudagraph_capture_sizes - ) if self.compilation_config.use_cudagraph else set() self.first_run_finished = False @@ -84,18 +62,18 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" - # the entries for different shapes that we need to either - # compile or capture cudagraph + # the entries for different shapes that we need to compile self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {} # to_be_compiled_sizes tracks the remaining sizes to compile, # and updates during the compilation process, so we need to copy it self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy() - for shape in self.compile_sizes.union(self.cudagraph_capture_sizes): + + # We only keep compilation management inside this class directly. + for shape in self.compile_sizes: self.concrete_size_entries[shape] = ConcreteSizeEntry( runtime_shape=shape, - need_to_compile=shape in self.compile_sizes, - use_cudagraph=shape in self.cudagraph_capture_sizes, + runnable=self.compiled_graph_for_general_shape, ) def check_for_ending_compilation(self): @@ -112,16 +90,16 @@ def __call__(self, *args) -> Any: return self.compiled_graph_for_general_shape(*args) runtime_shape = args[self.sym_shape_indices[0]] + if self.is_debugging_mode: + assert runtime_shape == get_forward_context().num_tokens + if runtime_shape not in self.concrete_size_entries: # we don't need to do anything for this shape return self.compiled_graph_for_general_shape(*args) entry = self.concrete_size_entries[runtime_shape] - if entry.runnable is None: - entry.runnable = self.compiled_graph_for_general_shape - - if entry.need_to_compile and not entry.compiled: + if not entry.compiled: entry.compiled = True self.to_be_compiled_sizes.remove(runtime_shape) # args are real arguments @@ -138,81 +116,4 @@ def __call__(self, *args) -> Any: if self.is_last_graph and not self.to_be_compiled_sizes: self.check_for_ending_compilation() - # Skip CUDA graphs if this entry doesn't use them OR - # if we're supposed to skip them globally - skip_cuda_graphs = get_forward_context().skip_cuda_graphs - if not entry.use_cudagraph or skip_cuda_graphs: - return entry.runnable(*args) - - if entry.cudagraph is None: - if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa - entry.num_finished_warmup += 1 - if self.is_first_graph: - logger.debug( - "Warming up %s/%s for shape %s", - entry.num_finished_warmup, - self.compilation_config.cudagraph_num_of_warmups, - runtime_shape) - return entry.runnable(*args) - - if self.is_first_graph: - # Since we capture cudagraph for many different shapes and - # capturing is fast, we don't need to log it for every shape. - # We only log it in the debug mode. - logger.debug("Capturing a cudagraph for shape %s", - runtime_shape) - - input_addresses = [ - x.data_ptr() for x in args if isinstance(x, torch.Tensor) - ] - entry.input_addresses = input_addresses - cudagraph = torch.cuda.CUDAGraph() - - with ExitStack() as stack: - if not self.is_first_graph: - # during every model forward, we will capture - # many pieces of cudagraphs (roughly one per layer). - # running gc again and again across layers will - # make the cudagraph capture very slow. - # therefore, we only run gc for the first graph, - # and disable gc for the rest of the graphs. - stack.enter_context(patch("gc.collect", lambda: None)) - stack.enter_context( - patch("torch.cuda.empty_cache", lambda: None)) - - # mind-exploding: carefully manage the reference and memory. - with torch.cuda.graph(cudagraph, pool=self.graph_pool): - # `output` is managed by pytorch's cudagraph pool - output = entry.runnable(*args) - if self.is_last_graph: - # by converting it to weak ref, - # the original `output` will immediately be released - # to save memory. It is only safe to do this for - # the last graph, because the output of the last graph - # will not be used by any other cuda graph. - output = weak_ref_tensors(output) - - # here we always use weak ref for the output - # to save memory - entry.output = weak_ref_tensors(output) - entry.cudagraph = cudagraph - - compilation_counter.num_cudagraph_captured += 1 - - # important: we need to return the output, rather than - # the weak ref of the output, so that pytorch can correctly - # manage the memory during cuda graph capture - return output - - if self.is_debugging_mode: - # check if the input addresses are the same - new_input_addresses = [ - x.data_ptr() for x in args if isinstance(x, torch.Tensor) - ] - assert new_input_addresses == entry.input_addresses, ( - "Input addresses for cudagraphs are different during replay." - f" Expected {entry.input_addresses}, got {new_input_addresses}" - ) - - entry.cudagraph.replay() - return entry.output + return entry.runnable(*args) diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index 8d5df1061ed..96d4eae2ee9 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -11,7 +11,8 @@ import torch import vllm.envs as envs -from vllm.config import CompilationLevel, get_current_vllm_config +from vllm.config import (CompilationLevel, CUDAGraphMode, + get_current_vllm_config) from vllm.logger import init_logger logger = init_logger(__name__) @@ -115,8 +116,8 @@ def bytecode_hook(self, old_code: CodeType, new_code: CodeType): except Exception: pass - if self.vllm_config.compilation_config.use_cudagraph and \ - "update" in new_code.co_names: + if self.vllm_config.compilation_config.cudagraph_mode != \ + CUDAGraphMode.NONE and "update" in new_code.co_names: import depyf src = depyf.decompile(new_code) msg = "Assigning / modifying buffers of nn.Module during forward pass is not allowed when using cudagraph inside the compiler because it will cause silent errors. Please use eager mode or fix the code. The following code contains clues about which buffer is being modified (please search for the usage of the function `update`):\n" + src # noqa diff --git a/vllm/config.py b/vllm/config.py index 22f74017136..006b6296acb 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3960,6 +3960,21 @@ class CompilationLevel: PIECEWISE = 3 +class CUDAGraphMode(enum.Enum): + # constants for the config of the cudagraph mode. + NONE = 0 + PIECEWISE = 1 + FULL = 2 + + +class CUDAGraphRuntimeStyle(enum.Enum): + # constants for concrete cudagraph runtime style, used for + # runtime dispatching. + NONE = 0 + PIECEWISE = 1 + FULL = 2 + + @config @dataclass class PassConfig: @@ -4020,10 +4035,13 @@ class CompilationConfig: - [`splitting_ops`][vllm.config.CompilationConfig.splitting_ops] - CudaGraph capture: - [`use_cudagraph`][vllm.config.CompilationConfig.use_cudagraph] + - [`cudagraph_mode`][vllm.config.CompilationConfig.cudagraph_mode] - [`cudagraph_capture_sizes`] [vllm.config.CompilationConfig.cudagraph_capture_sizes] - [`cudagraph_num_of_warmups`] [vllm.config.CompilationConfig.cudagraph_num_of_warmups] + - [`cudagraph_separate_routine`] + [vllm.config.CompilationConfig.cudagraph_separate_routine] - [`cudagraph_copy_inputs`] [vllm.config.CompilationConfig.cudagraph_copy_inputs] - [`full_cuda_graph`][vllm.config.CompilationConfig.full_cuda_graph] @@ -4046,7 +4064,7 @@ class CompilationConfig: certain small batchsizes, where inductor is good at optimizing. """ # Top-level Compilation control - level: int = 0 + level: Optional[int] = None """The level of compilation: - 0: no compilation. @@ -4084,9 +4102,11 @@ class CompilationConfig: By default, all custom ops are enabled when running without Inductor and disabled when running with Inductor: level>=PIECEWISE and use_inductor=True. Inductor generates (fused) Triton kernels for disabled custom ops.""" - splitting_ops: list[str] = field(default_factory=list) + splitting_ops: Optional[list[str]] = None """A list of ops to split the full graph into subgraphs, used in piecewise compilation.""" + is_attention_splitting: bool = False + """A flag to indicate if the splitting_ops contains all attention ops.""" # Inductor capture use_inductor: bool = True @@ -4114,7 +4134,34 @@ class CompilationConfig: constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`.""" # CudaGraph compilation - use_cudagraph: bool = field(default_factory=lambda: envs.VLLM_USE_V1) + cudagraph_mode: CUDAGraphMode = field( + default_factory=lambda: CUDAGraphMode.PIECEWISE + if envs.VLLM_USE_V1 else CUDAGraphMode.NONE) + """ + The mode of the cudagraph. + - NONE, no cudagraph capture. + - PIECEWISE. (v1 default) + - FULL. + For cudagraph_mode != CUDAGraphMode.NONE, it requires that all input + buffers have fixed addresses and all splitting ops write their outputs + to input buffers. + + PIECEWISE mode build piecewise cudagraph only, keeping the cudagraph + incompatiable ops (i.e. some attention ops) outside the cudagraph + for general flexibility. + + FULL mode instead try to capture full cudagraph for fully compatible + routines (i.e. most attention backend support pure decode batches), + and may fall back to piecewise cudagraph for partially imcompatible + routines. This may provide performance benefits for smaller models. + + Currently, the cudagraph mode is only used for the v1 engine. + Note that the cudagraph logic is generally orthogonal to the + compilation logic. While piecewise cudagraphs require piecewise + compilation (level=PIECEWISE and non-empty splitting_ops), full + cudagraphs are supported with and without compilation. + """ + use_cudagraph: Optional[bool] = None """Whether to use cudagraph inside compilation. - False: cudagraph inside compilation is not used. - True: cudagraph inside compilation is used. It requires @@ -4124,8 +4171,9 @@ class CompilationConfig: CompilationLevel.PIECEWISE (aka -O3). Note that this is orthogonal to the cudagraph capture logic outside of compilation. - TODO: move outside cudagraph logic into compilation. - torch.compile will handle cudagraph capture logic in the future.""" + Warning: This flag is deprecated and will be removed in future releases. + Please use cudagraph_mode instead. + """ cudagraph_num_of_warmups: int = 0 """Number of warmup runs for cudagraph. It means the first several runs will be treated as warmup runs. @@ -4140,13 +4188,25 @@ class CompilationConfig: cudagraph. If the caller can guarantee that the same input buffers are always used, it can set this to False. Otherwise, it should set this to True, and the compiler will copy the input to an - internally managed buffer. Default is False.""" - full_cuda_graph: bool = False + internally managed buffer. Default is False. + Note that this flag is only effective when cudagraph_mode is PIECEWISE. + """ + full_cuda_graph: Optional[bool] = None """whether to use a full cuda graph for the entire forward pass rather than splitting certain operations such as attention into subgraphs. Thus this flag cannot be used together with splitting_ops. This may provide - performance benefits for smaller models.""" - + performance benefits for smaller models. + Warning: This flag is deprecated and will be removed in future releases. + Please use cudagraph_mode instead. + """ + cudagraph_separate_routine: bool = False + """ + Enable distinct attention routines for mixed and pure-decode batches during + full cuda graph capturing. This is because some attention backends like + FlashMLA, FlashInfer, FA2, etc. implement different branches for mixed + prefill-decode and pure decode cases. This flag enables capturing separate + cudagraphs for each branch. + """ pass_config: PassConfig = field(default_factory=PassConfig) """Custom inductor passes, see PassConfig for more details""" @@ -4233,6 +4293,16 @@ def from_cli(cls, cli_value: str) -> "CompilationConfig": """ return TypeAdapter(CompilationConfig).validate_json(cli_value) + @field_validator("cudagraph_mode", mode="before") + @classmethod + def validate_cudagraph_mode_before(cls, value: Any) -> Any: + """ + enable parse the `cudagraph_mode` enum type from string + """ + if isinstance(value, str): + return CUDAGraphMode[value.upper()] + return value + def __post_init__(self) -> None: count_none = self.custom_ops.count("none") count_all = self.custom_ops.count("all") @@ -4344,16 +4414,29 @@ def init_with_cudagraph_sizes(self, def set_splitting_ops_for_v1(self): # NOTE: this function needs to be called - if self.splitting_ops and self.full_cuda_graph: - raise ValueError("full_cuda_graph cannot be used together with " - "splitting_ops, as Full CUDA graph will override " - f"the splitting_ops: {self.splitting_ops}") - - if not self.splitting_ops: - self.splitting_ops = [] if self.full_cuda_graph else [ + if self.cudagraph_separate_routine: + assert self.cudagraph_mode == CUDAGraphMode.FULL, ( + "cudagraph_separate_routine requires " + "cudagraph_mode be CUDAGraphMode.FULL") + + if self.splitting_ops is None: + # NOTE: When using full cudagraph, instead of setting an empty + # list and capture the full cudagraph inside the flattened fx + # graph, we keep the piecewise fx graph structure but capture the + # full cudagraph outside the fx graph. This reduces some cpu + # overhead when the runtime batch_size is not cudagraph captured. + # see https://github.com/vllm-project/vllm/pull/20059 for details. + self.splitting_ops = [ "vllm.unified_attention", "vllm.unified_attention_with_output", ] + elif len(self.splitting_ops) == 0: + assert self.cudagraph_mode != CUDAGraphMode.PIECEWISE, ( + "Cannot use piecewise CUDAGraph without splitting ops.") + self.splitting_ops = [] + self.is_attention_splitting = all( + op in self.splitting_ops for op in + ["vllm.unified_attention", "vllm.unified_attention_with_output"]) @config @@ -4622,12 +4705,50 @@ def __post_init__(self): self.compilation_config.custom_ops.append("+rms_norm") if envs.VLLM_USE_V1 and self.model_config is not None and \ not self.model_config.enforce_eager: - # By default, V1 uses piecewise CUDA graphs. If full_cuda_graph - # is set to True, full CUDA graphs will be used. + # By default, V1 uses piecewise CUDA graphs. If cudagraph_mode + # is set to `FULL`, full CUDA graphs will be used. self.compilation_config.cudagraph_num_of_warmups = 1 - self.compilation_config.level = CompilationLevel.PIECEWISE + if self.compilation_config.level is None: + self.compilation_config.level = CompilationLevel.PIECEWISE self.compilation_config.set_splitting_ops_for_v1() + if self.compilation_config.use_cudagraph is not None: + logger.warning( + "`use_cudagraph` is deprecated and will be removed in the " + "future release. Switch to use `cudagraph_mode` instead.") + if self.compilation_config.use_cudagraph: + if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE: + self.compilation_config.cudagraph_mode =\ + CUDAGraphMode.PIECEWISE + # otherwise, keep the cudagraph_mode as is + else: + self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE + + if self.compilation_config.full_cuda_graph is not None: + logger.warning( + "`full_cuda_graph` is deprecated and will be removed in the " + "future release. Switch to use `cudagraph_mode` instead.") + if self.compilation_config.use_cudagraph is not None and \ + self.compilation_config.full_cuda_graph: + assert self.compilation_config.use_cudagraph, ( + "`use_cudagraph` must be True when `full_cuda_graph` " + "is True.") + self.compilation_config.cudagraph_mode = CUDAGraphMode.FULL + self.compilation_config.use_cudagraph = True + elif self.compilation_config.full_cuda_graph: + self.compilation_config.cudagraph_mode = CUDAGraphMode.FULL + self.compilation_config.use_cudagraph = True + # other cases, keep the cudagraph_mode as is + + # For V0 or other cases, default to level 0 with no compilation + if self.compilation_config.level is None: + self.compilation_config.level = CompilationLevel.NO_COMPILATION + + # disable cudagraph if enforce eager execution + if self.model_config is not None and self.model_config.enforce_eager: + logger.info("Cudagraph is disabled under eager mode.") + self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE + self._set_cudagraph_sizes() if self.cache_config.cpu_offload_gb > 0 and \ @@ -4646,10 +4767,11 @@ def __post_init__(self): "Disabling `torch.compile`.") self.compilation_config.level = CompilationLevel.NO_COMPILATION - if self.compilation_config.full_cuda_graph and \ + if self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL and \ not self.model_config.disable_cascade_attn: - logger.info("full_cuda_graph is not supported with " - "cascade attention. Disabling cascade attention.") + logger.info("CUDAGraphMode.FULL is not supported with " + "cascade attention currently. Disabling cascade" + "attention.") self.model_config.disable_cascade_attn = True disable_chunked_prefill_reasons: list[str] = [] diff --git a/vllm/forward_context.py b/vllm/forward_context.py index dd55b19feea..0de5f1fc93a 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -11,7 +11,7 @@ import torch.distributed as dist import vllm.envs as envs -from vllm.config import ParallelConfig, VllmConfig +from vllm.config import CUDAGraphRuntimeStyle, ParallelConfig, VllmConfig from vllm.logger import init_logger if TYPE_CHECKING: @@ -92,9 +92,12 @@ class ForwardContext: attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"]] # TODO: remove after making all virtual_engines share the same kv cache virtual_engine: int # set dynamically for each forward pass + num_tokens: Optional[int] = None # set dynamically for each forward pass dp_metadata: Optional[DPMetadata] = None - skip_cuda_graphs: bool = False + # determine the cudagraph style at runtime to be FULL, PIECEWISE, or NONE. + # by default NONE, no cudagraph is used. + cudagraph_runtime_style: CUDAGraphRuntimeStyle = CUDAGraphRuntimeStyle.NONE _forward_context: Optional[ForwardContext] = None @@ -109,14 +112,13 @@ def get_forward_context() -> ForwardContext: @contextmanager -def set_forward_context( - attn_metadata: Any, - vllm_config: VllmConfig, - virtual_engine: int = 0, - num_tokens: Optional[int] = None, - num_tokens_across_dp: Optional[torch.Tensor] = None, - skip_cuda_graphs: bool = False, -): +def set_forward_context(attn_metadata: Any, + vllm_config: VllmConfig, + virtual_engine: int = 0, + num_tokens: Optional[int] = None, + num_tokens_across_dp: Optional[torch.Tensor] = None, + cudagraph_runtime_style: CUDAGraphRuntimeStyle = ( + CUDAGraphRuntimeStyle.NONE)): """A context manager that stores the current forward context, can be attention metadata, etc. Here we can inject common logic for every model forward pass. @@ -137,10 +139,11 @@ def set_forward_context( _forward_context = ForwardContext( no_compile_layers=vllm_config.compilation_config. static_forward_context, + num_tokens=num_tokens, virtual_engine=virtual_engine, attn_metadata=attn_metadata, dp_metadata=dp_metadata, - skip_cuda_graphs=skip_cuda_graphs, + cudagraph_runtime_style=cudagraph_runtime_style, ) try: diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 03f0c15270b..23ac5f73f48 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -173,19 +173,23 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: logger.info("Forcing kv cache block size to 128 for " "CUTLASS_MLA_VLLM_V1 backend.") + # lazy import to avoid circular import + from vllm.config import CUDAGraphMode + compilation_config = vllm_config.compilation_config if (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" and parallel_config.data_parallel_size > 1 - and compilation_config.use_cudagraph): + and compilation_config.cudagraph_mode != CUDAGraphMode.NONE): logger.info( "Data Parallel: Forcing enforce eager to be True since DP " "with DeepEP high-throughput kernels are not CUDA Graph " "compatible. The DeepEP low-latency kernels are CUDA Graph " "compatible. Set the all_to_all backend to deepep_low_latency " "to use those kernels instead.") - compilation_config.use_cudagraph = False + compilation_config.cudagraph_mode = CUDAGraphMode.NONE if model_config is not None: model_config.enforce_eager = True + # TODO (varun): Turning this ON gives incorrect results for the # Deepseek-V2-lite model. vllm_config.compilation_config.use_inductor = False @@ -426,8 +430,8 @@ def use_custom_allreduce(cls) -> bool: return True @classmethod - def get_piecewise_backend_cls(cls) -> str: - return "vllm.compilation.cuda_piecewise_backend.CUDAPiecewiseBackend" # noqa + def get_static_graph_wrapper_cls(cls) -> str: + return "vllm.compilation.cuda_graph.CUDAGraphWrapper" @classmethod def stateless_init_device_torch_dist_pg( diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index ae675bcc8d2..881f069f923 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -7,7 +7,7 @@ import sys from datetime import timedelta from platform import uname -from typing import TYPE_CHECKING, NamedTuple, Optional, Union +from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union import numpy as np import torch @@ -137,6 +137,8 @@ class Platform: additional_env_vars: list[str] = [] + _global_graph_pool: Optional[Any] = None + @property def supported_dtypes(self) -> list[torch.dtype]: """Returns the supported dtypes for the current platform.""" @@ -521,6 +523,15 @@ def __getattr__(self, key: str): " attribute.", self.device_type, key) return None + def get_global_graph_pool(self) -> Any: + """ + Return the global graph pool for the this platform. + """ + cls = self.__class__ + if cls._global_graph_pool is None: + cls._global_graph_pool = self.graph_pool_handle() + return cls._global_graph_pool + @classmethod def get_cu_count(cls, device_id: int = 0) -> int: """ @@ -529,11 +540,11 @@ def get_cu_count(cls, device_id: int = 0) -> int: raise NotImplementedError @classmethod - def get_piecewise_backend_cls(cls) -> str: + def get_static_graph_wrapper_cls(cls) -> str: """ - Get piecewise backend class for piecewise graph. + Get static graph wrapper class for static graph. """ - return "vllm.compilation.base_piecewise_backend.AbstractPiecewiseBackend" # noqa + return "vllm.compilation.base_static_graph.AbstractStaticGraphWrapper" @classmethod def stateless_init_device_torch_dist_pg( diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 04637f5c7aa..044b9f0a3c7 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -423,8 +423,8 @@ def is_navi(cls) -> bool: return 'gfx1' in torch.cuda.get_device_properties(0).gcnArchName @classmethod - def get_piecewise_backend_cls(cls) -> str: - return "vllm.compilation.cuda_piecewise_backend.CUDAPiecewiseBackend" # noqa + def get_static_graph_wrapper_cls(cls) -> str: + return "vllm.compilation.cuda_graph.CUDAGraphWrapper" @classmethod def stateless_init_device_torch_dist_pg( diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 4224d807c2b..e56a0d93fea 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -22,12 +22,12 @@ get_scheduler_metadata, reshape_and_cache_flash) -from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.config import CUDAGraphMode, VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger from vllm.utils import cdiv from vllm.v1.attention.backends.utils import ( - AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout, - make_local_attention_virtual_batches) + AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, + get_kv_cache_layout, make_local_attention_virtual_batches) from vllm.v1.kv_cache_interface import AttentionSpec logger = init_logger(__name__) @@ -156,7 +156,12 @@ def _get_sliding_window_configs( class FlashAttentionMetadataBuilder( AttentionMetadataBuilder[FlashAttentionMetadata]): - full_cudagraph_supported: ClassVar[bool] = get_flash_attn_version() == 3 + # FA2 launches separte routines for prefill-decode and pure decode batches, + # while FA3 launches a unified varlen fwd kernel for both prefill-decode + # and pure decode batches. + attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ + AttentionCGSupport.ALWAYS_SEPARATE if get_flash_attn_version() == 2 \ + else AttentionCGSupport.ALWAYS_UNIFIED def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, device: torch.device): @@ -176,17 +181,13 @@ def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, self.max_num_splits = 0 # No upper bound on the number of splits. self.aot_schedule = (get_flash_attn_version() == 3) - self.use_full_cuda_graph = self.compilation_config.full_cuda_graph - if self.use_full_cuda_graph: - if not self.aot_schedule: - raise ValueError( - "AoT scheduling is required for full cuda graph.") - capture_sizes = self.compilation_config.cudagraph_capture_sizes - if not capture_sizes: - raise ValueError( - "cudagraph_capture_sizes should not be None when " - "full_cuda_graph is True.") - self.max_cudagraph_size = max(capture_sizes) + + self.use_full_cuda_graph = ( + self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL) + + if self.use_full_cuda_graph and self.aot_schedule: + self.max_cudagraph_size = self.compilation_config.max_capture_size + if self.max_cudagraph_size > 992: # This condition derives from FA3's internal heuristic. # TODO(woosuk): Support larger cudagraph sizes. @@ -336,9 +337,9 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, seqlens=seq_lens, max_seq_len=max_seq_len, causal=True) - - if self.use_full_cuda_graph: - assert scheduler_metadata is not None + # For FA3 + full cudagraph + max_num_splits = 0 + if self.use_full_cuda_graph and scheduler_metadata is not None: n = scheduler_metadata.shape[0] self.scheduler_metadata[:n] = scheduler_metadata # NOTE(woosuk): We should zero out the rest of the scheduler @@ -348,14 +349,12 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, self.scheduler_metadata[n:] = 0 scheduler_metadata = self.scheduler_metadata[:n] - max_num_splits = 0 - if (self.use_full_cuda_graph - and num_actual_tokens <= self.max_cudagraph_size): - # NOTE(woosuk): Setting num_splits > 1 may increase the memory - # usage, because the intermediate buffers of size [num_splits, - # num_heads, num_tokens, head_size] are allocated. Therefore, - # we only set num_splits when using cuda graphs. - max_num_splits = self.max_num_splits + if num_actual_tokens <= self.max_cudagraph_size: + # NOTE(woosuk): Setting num_splits > 1 may increase the memory + # usage, because the intermediate buffers of size [num_splits, + # num_heads, num_tokens, head_size] are allocated. Therefore, + # we only set num_splits when using cuda graphs. + max_num_splits = self.max_num_splits attn_metadata = FlashAttentionMetadata( num_actual_tokens=num_actual_tokens, @@ -379,7 +378,7 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, def can_run_in_cudagraph( self, common_attn_metadata: CommonAttentionMetadata) -> bool: - # Full CUDA Graph always supported (FA2 support checked separately) + # Full CUDA Graph always supported return True def use_cascade_attention(self, *args, **kwargs) -> bool: diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 1eb27d57acf..5e3d6540079 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -4,7 +4,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, ClassVar, Optional import torch from flashinfer import (BatchDecodeWithPagedKVCacheWrapper, @@ -15,13 +15,14 @@ import vllm.envs as envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionType) -from vllm.config import VllmConfig +from vllm.config import CUDAGraphMode, VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform +from vllm.utils import cdiv from vllm.v1.attention.backends.flash_attn import use_cascade_attention from vllm.v1.attention.backends.utils import ( - AttentionMetadataBuilder, CommonAttentionMetadata, PerLayerParameters, - get_kv_cache_layout, get_per_layer_parameters, + AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, + PerLayerParameters, get_kv_cache_layout, get_per_layer_parameters, infer_global_hyperparameters, reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills) from vllm.v1.kv_cache_interface import AttentionSpec @@ -223,13 +224,31 @@ def __post_init__(self): class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): + attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ + AttentionCGSupport.PURE_DECODE_ONLY def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, device: torch.device): self.device = device self._workspace_buffer = None self._prefill_wrapper = None # Wrapper for prefill/append - self._decode_wrapper = None # Wrapper for decode + self._decode_wrapper = None # Wrapper for decode (general shape) + + compilation_config = vllm_config.compilation_config + max_num_pages_per_req = cdiv(vllm_config.model_config.max_model_len, + self.kv_cache_spec.block_size) + max_num_reqs = vllm_config.scheduler_config.max_num_seqs + max_num_pages = max_num_reqs * max_num_pages_per_req + self.enable_cuda_graph = ( + compilation_config.cudagraph_mode == CUDAGraphMode.FULL) + if self.enable_cuda_graph: + # For full cudagraph capture, one `decode_wrapper` for each batch + # size is needed for FlashInfer. + self._decode_wrappers_cudagraph: dict[ + int, BatchDecodeWithPagedKVCacheWrapper] = {} + self._decode_cudagraph_max_bs = min( + max_num_reqs, compilation_config.max_capture_size) + self._cascade_wrapper = None # Wrapper for cascade attention # Global hyperparameters shared by all attention layers @@ -239,6 +258,18 @@ def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, self.cache_config = vllm_config.cache_config self.kv_cache_spec = kv_cache_spec + # Preparing persistent buffers + self.paged_kv_indptr = torch.zeros(max_num_reqs + 1, + dtype=torch.int32, + device=self.device) + self.paged_kv_indices = torch.zeros( + max_num_pages, # max num pages possible + dtype=torch.int32, + device=self.device) + self.paged_kv_last_page_len = torch.zeros(max_num_reqs, + dtype=torch.int32, + device=self.device) + def reorder_batch(self, input_batch: InputBatch, scheduler_output: SchedulerOutput) -> bool: return reorder_batch_to_split_decodes_and_prefills(input_batch, @@ -259,8 +290,16 @@ def _get_prefill_wrapper(self): self._get_workspace_buffer(), get_kv_cache_layout()) return self._prefill_wrapper - def _get_decode_wrapper(self): - if self._decode_wrapper is None: + def _get_decode_wrapper(self, + batch_size: int, + use_cudagraph: bool = False): + if use_cudagraph: + decode_wrapper = self._decode_wrappers_cudagraph.get( + batch_size, None) + else: + decode_wrapper = self._decode_wrapper + + if decode_wrapper is None: num_qo_heads = ( self.vllm_config.model_config.get_num_attention_heads( self.vllm_config.parallel_config)) @@ -268,11 +307,32 @@ def _get_decode_wrapper(self): self.vllm_config.parallel_config) use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or ( num_qo_heads // num_kv_heads > 4) - self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( + + if use_cudagraph: + paged_kv_indptr = self.paged_kv_indptr[:batch_size + 1] + paged_kv_indices = self.paged_kv_indices + paged_kv_last_page_len = self.paged_kv_last_page_len[: + batch_size] + else: + paged_kv_indptr = None + paged_kv_indices = None + paged_kv_last_page_len = None + decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( self._get_workspace_buffer(), get_kv_cache_layout(), + use_cuda_graph=use_cudagraph, + paged_kv_indptr_buffer=paged_kv_indptr, + paged_kv_indices_buffer=paged_kv_indices, + paged_kv_last_page_len_buffer=paged_kv_last_page_len, use_tensor_cores=use_tensor_cores) - return self._decode_wrapper + + # save the decode wrapper + if use_cudagraph: + self._decode_wrappers_cudagraph[batch_size] = decode_wrapper + else: + self._decode_wrapper = decode_wrapper + + return decode_wrapper def _get_cascade_wrapper(self): if self._cascade_wrapper is None: @@ -350,15 +410,33 @@ def _plan(self, num_prefills: int, num_decodes: int, ) if num_decodes > 0: - attn_metadata.decode_wrapper = self._get_decode_wrapper() + pure_decode = self._num_prefills == 0 + # possible required padding for cudagraph replay + use_cudagraph = (self.enable_cuda_graph and pure_decode and \ + self._num_decodes <= self._decode_cudagraph_max_bs) + if use_cudagraph: + num_input_tokens_decode = ( + self.vllm_config.pad_for_cudagraph(self._num_decodes)) + else: + num_input_tokens_decode = self._num_decodes + + attn_metadata.decode_wrapper = self._get_decode_wrapper( + num_input_tokens_decode, use_cudagraph) if not FlashInferBackend.use_trtllm_decode_attention( num_decodes, attn_metadata.max_seq_len, attn_metadata.kv_data_type, attn_metadata.num_qo_heads, attn_metadata.num_kv_heads, attn_metadata.head_dim): + # TODO: Override flashinfer's plan function to avoid some + # host-to-device copy overhead. attn_metadata.decode_wrapper.plan( - attn_metadata.paged_kv_indptr[:num_decodes + 1], - attn_metadata.paged_kv_indices, - attn_metadata.paged_kv_last_page_len[:num_decodes], + # NOTE: Use the persistent buffer with padding length, + # instead of the same address but chunked length buffers + # in the atten_metadata. This is to be compatible with + # FlashInfer's decode_wrapper when using cudagraph. + self.paged_kv_indptr[:num_input_tokens_decode + 1], + self.paged_kv_indices if use_cudagraph else \ + attn_metadata.paged_kv_indices, + self.paged_kv_last_page_len[:num_input_tokens_decode], attn_metadata.num_qo_heads, attn_metadata.num_kv_heads, attn_metadata.head_dim, @@ -377,6 +455,7 @@ def build(self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, fast_build: bool = False) -> FlashInferMetadata: + num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\ split_decodes_and_prefills(common_attn_metadata) @@ -420,6 +499,10 @@ def build(self, device=block_table_tensor.device).unsqueeze(0) < block_table_bounds.unsqueeze(1)) paged_kv_indices = block_table_tensor[mask] + num_actual_pages = paged_kv_indices.size(0) + self.paged_kv_indices[:num_actual_pages].copy_(paged_kv_indices, + non_blocking=True) + self.paged_kv_indices[num_actual_pages:].fill_(-1) paged_kv_indptr = torch.cat([ torch.zeros(1, @@ -427,10 +510,20 @@ def build(self, device=block_table_bounds.device), block_table_bounds.cumsum(dim=0, dtype=torch.int32) ]) + self.paged_kv_indptr[:1 + num_reqs].copy_(paged_kv_indptr, + non_blocking=True) + # make sure self.paged_kv_indptr is not decreasing + self.paged_kv_indptr[1 + num_reqs:].fill_(paged_kv_indptr[-1]) paged_kv_last_page_len = seq_lens % page_size paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0, page_size, paged_kv_last_page_len) + self.paged_kv_last_page_len[:num_reqs].copy_(paged_kv_last_page_len, + non_blocking=True) + # Fill the remaining paged_kv_last_page_len with 1. This is because + # flashinfer treats 0 as a full page instead of empty. + self.paged_kv_last_page_len[num_reqs:].fill_(1) + cache_dtype = self.cache_config.cache_dtype if cache_dtype.startswith("fp8"): kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( @@ -440,9 +533,9 @@ def build(self, attn_metadata = FlashInferMetadata( num_actual_tokens=num_actual_tokens, qo_indptr=qo_indptr, - paged_kv_indptr=paged_kv_indptr, - paged_kv_indices=paged_kv_indices, - paged_kv_last_page_len=paged_kv_last_page_len, + paged_kv_indptr=self.paged_kv_indptr[:1 + num_reqs], + paged_kv_indices=self.paged_kv_indices[:num_actual_pages], + paged_kv_last_page_len=self.paged_kv_last_page_len[:num_reqs], num_qo_heads=self.vllm_config.model_config.get_num_attention_heads( self.vllm_config.parallel_config), num_kv_heads=self.kv_cache_spec.num_kv_heads, @@ -470,6 +563,31 @@ def build(self, return attn_metadata + def build_for_cudagraph_capture( + self, common_attn_metadata: CommonAttentionMetadata): + """ + This method builds the metadata for full cudagraph capture. + Currently, only decode is supported for full cudagraphs with FlashInfer. + """ + m = common_attn_metadata + + assert m.num_reqs == m.num_actual_tokens, \ + "FlashInfer only supports decode-only full CUDAGraph capture. " \ + "Make sure all cudagraph capture sizes <= max_num_seq." + + assert m.max_query_len == 1 # decode-only + + # Update state usually set in reorder_batch. + self._num_decodes = m.num_reqs + self._num_decode_tokens = m.num_actual_tokens + self._num_prefills = 0 + self._num_prefill_tokens = 0 + return self.build(0, m) + + def can_run_in_cudagraph( + self, common_attn_metadata: CommonAttentionMetadata) -> bool: + return common_attn_metadata.max_query_len == 1 + def use_cascade_attention(self, *args, **kwargs) -> bool: if self.kv_cache_spec.dtype != self.vllm_config.model_config.dtype: # TODO: The cascade wrapper currently does not support setting diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 93c8156b16a..d6ed98f8a13 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -581,7 +581,7 @@ def build_for_cudagraph_capture( "MLA only supports decode-only full CUDAGraph capture. " \ "Make sure all cudagraph capture sizes <= max_num_seq." - m.max_query_len = 1 # decode-only + assert m.max_query_len == 1 # decode-only return self.build(0, m) diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 935311aacc3..8a1acc32b92 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -18,6 +18,7 @@ MLACommonImpl, MLACommonMetadata, MLACommonMetadataBuilder) +from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.kv_cache_interface import AttentionSpec logger = init_logger(__name__) @@ -54,7 +55,8 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]): class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): - full_cudagraph_supported: ClassVar[bool] = True # Decode-only + attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ + AttentionCGSupport.PURE_DECODE_ONLY def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, device: torch.device): diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 42a04258361..7f84786410a 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -17,6 +17,7 @@ MLACommonImpl, MLACommonMetadata, MLACommonMetadataBuilder) +from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.kv_cache_interface import AttentionSpec # yapf: enable @@ -64,7 +65,8 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): - full_cudagraph_supported: ClassVar[bool] = True # decode only + attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ + AttentionCGSupport.PURE_DECODE_ONLY def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, device: torch.device): diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index ee95b5af6e4..4577a60f2ff 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -19,7 +19,7 @@ from vllm.platforms import current_platform from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.utils import ( - AttentionMetadataBuilder, CommonAttentionMetadata, + AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, make_local_attention_virtual_batches) from vllm.v1.kv_cache_interface import AttentionSpec @@ -70,7 +70,8 @@ class LocalAttentionMetadata: class TritonAttentionMetadataBuilder( AttentionMetadataBuilder[TritonAttentionMetadata]): - full_cudagraph_supported: ClassVar[bool] = True + attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ + AttentionCGSupport.ALWAYS_SEPARATE def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, device: torch.device): diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index db6eaa55864..c3e3a2d0353 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import abc +import enum import functools from abc import abstractmethod from dataclasses import dataclass @@ -68,9 +69,27 @@ def __post_init__(self): M = TypeVar("M") +class AttentionCGSupport(enum.Enum): + """ Constants for the cudagraph support of the attention backend + Here we do not consider the cascade attention, as currently + it is never cudagraph supported.""" + + NEVER = 0 + """NO cudagraph support""" + PURE_DECODE_ONLY = 1 + """Cudagraph supported for pure decode, need to use piecewise + cudagraph or no cudagraph for mixed prefill-decode batches""" + ALWAYS_UNIFIED = 2 + """Cudagraph always supported with unified routine""" + ALWAYS_SEPARATE = 3 + """Cudagraph always supported, with better performance when separate + routines are used for mixed prefill-decode and pure decode batches.""" + + class AttentionMetadataBuilder(abc.ABC, Generic[M]): # Does this backend/builder support CUDA Graphs for attention. - full_cudagraph_supported: ClassVar[bool] = False + attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ + AttentionCGSupport.NEVER @abstractmethod def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py new file mode 100644 index 00000000000..37d4e8da50e --- /dev/null +++ b/vllm/v1/cudagraph_dispatcher.py @@ -0,0 +1,142 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any, Callable, NamedTuple, Optional + +from vllm.compilation.cuda_graph import CUDAGraphOptions, CUDAGraphWrapper +from vllm.config import (CompilationLevel, CUDAGraphMode, + CUDAGraphRuntimeStyle, VllmConfig) +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class CudagraphKey(NamedTuple): + """ + Key for dispatching cudagraphs. + """ + cudagraph_runtime_style: CUDAGraphRuntimeStyle + # Be aware that uniform_batch should be default None + # for both piecewise cudagraphs and no cudagraphs. + uniform_batch: Optional[bool] = None + + +class CudagraphDispatcher: + """ + Runtime cudagraph dispatcher to switch between multiple cudagraphs. + """ + + def __init__(self, vllm_config: VllmConfig): + self.vllm_config = vllm_config + self.compilation_config = vllm_config.compilation_config + self.cudagraph_mode = self.compilation_config.cudagraph_mode + + self.model: Callable = None # type: ignore + # we lazy initialize self.model once the model loading of + # runner have been done. + + # Dict to store cudagraph candidates for runtime dispatching. + self.cudagraph_candidates: dict[CudagraphKey, Any] = {} + + # Verify if correctly piecewise compilation for attention. + piecewise_compilation = not vllm_config.model_config.enforce_eager\ + and self.compilation_config.level == CompilationLevel.PIECEWISE + self.piecewise_attn_compilation = piecewise_compilation and\ + self.compilation_config.is_attention_splitting + + def after_load_model(self, model: Callable): + # add original model to cudagraph_candidates for profile run. + assert model is not None, "model should not be None" + self.model = model + self.cudagraph_candidates.update( + {CudagraphKey(CUDAGraphRuntimeStyle.NONE): self.model}) + logger.debug("Cudagraph candidates for NONE style initialized") + + def maybe_initialize_cudagraph(self, create_mixed_batch_full_cg: bool): + assert self.model is not None, ( + "No model have been assigned to cudagraph dispatcher") + # This should be called only after attention backend is initialized. + + if self.compilation_config.level == CompilationLevel.PIECEWISE\ + and len(self.compilation_config.splitting_ops)>0: + self.cudagraph_candidates.update( + {CudagraphKey(CUDAGraphRuntimeStyle.PIECEWISE): self.model}) + logger.debug("Piecewise cudagraph initialized") + + if self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL: + # create full cudagraph for mix prefill-decode/general batches + if create_mixed_batch_full_cg: + self.cudagraph_candidates.update({ + CudagraphKey(CUDAGraphRuntimeStyle.FULL, False): + CUDAGraphWrapper(self.model, + self.vllm_config, + runtime_style=CUDAGraphRuntimeStyle.FULL, + cudagraph_options=CUDAGraphOptions( + usage_str="full/mixed")) + }) + logger.debug("Full cudagraph for mixed batches initialized") + # always create full cudagraph for uniform batches if cudagraph + # separate routine is enabled. + if self.compilation_config.cudagraph_separate_routine: + self.cudagraph_candidates.update({ + CudagraphKey(CUDAGraphRuntimeStyle.FULL, True): + CUDAGraphWrapper(self.model, + self.vllm_config, + runtime_style=CUDAGraphRuntimeStyle.FULL, + cudagraph_options=CUDAGraphOptions( + usage_str="full/uniform")) + }) + logger.debug("Full cudagraph for uniform batches initialized") + + def get_cudagraph_runtime_style( + self, attn_cuda_graphs: bool) -> CUDAGraphRuntimeStyle: # noqa + + if self.cudagraph_mode == CUDAGraphMode.NONE: + return CUDAGraphRuntimeStyle.NONE + + if self.cudagraph_mode == CUDAGraphMode.PIECEWISE: + # safe to direct return as we have already checked + # CUDAGraphMode.PIECEWISE is compatible only when + # enable vllm compilation. + return CUDAGraphRuntimeStyle.PIECEWISE + + # Otherwise, for modes that enable full cudagraph. + assert self.cudagraph_mode == CUDAGraphMode.FULL + # If attention backend supports full cudagraphs for current batch, + # run with full cudagraphs. + if attn_cuda_graphs: + return CUDAGraphRuntimeStyle.FULL + + # Fall back to piecewise cudagraphs if possible + if self.piecewise_attn_compilation: + return CUDAGraphRuntimeStyle.PIECEWISE + + # Otherwise, fall back to running entirely without cudagraphs + return CUDAGraphRuntimeStyle.NONE + + def dispatch(self, cudagraph_runtime_style: CUDAGraphRuntimeStyle, + uniform_batch: bool) -> Any: + assert self.model is not None, ("No model have been assigned" + "to cudagraph dispatcher") + # if no cudagraph candidates, just skip dispatching. + if not self.cudagraph_candidates: + logger.warning_once("cudagraphs are not initialized." + " No cudagraph will be used.") + return self.model + + # select between no cudagraph and piecewise cudagraph + if cudagraph_runtime_style in [ + CUDAGraphRuntimeStyle.NONE, CUDAGraphRuntimeStyle.PIECEWISE + ]: + key = CudagraphKey(cudagraph_runtime_style) + selected_model = self.cudagraph_candidates.get(key, None) + else: + # for full cudagraph, select between mixed batches + # or uniform batches + uniform_batch = uniform_batch and\ + self.compilation_config.cudagraph_separate_routine + key = CudagraphKey(cudagraph_runtime_style, uniform_batch) + selected_model = self.cudagraph_candidates.get(key, None) + assert selected_model is not None, ( + f"cudagraph_candidates is not correctly initialized for key: " + f"{key}") + return selected_model diff --git a/vllm/v1/worker/cpu_model_runner.py b/vllm/v1/worker/cpu_model_runner.py index 410a54e7466..f96d1f64892 100644 --- a/vllm/v1/worker/cpu_model_runner.py +++ b/vllm/v1/worker/cpu_model_runner.py @@ -61,6 +61,7 @@ def load_model(self) -> None: self.model = self.load_lora_model(self.model, self.model_config, self.scheduler_config, self.lora_config, self.device) + self.cudagraph_dispatcher.after_load_model(self.model) def warming_up_model(self) -> None: logger.info("Warming up model for the compilation...") diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 29f519393e4..c8a322654c8 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -17,7 +17,8 @@ from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.layer import Attention from vllm.compilation.counter import compilation_counter -from vllm.config import (CompilationLevel, VllmConfig, +from vllm.config import (CompilationLevel, CUDAGraphMode, + CUDAGraphRuntimeStyle, VllmConfig, get_layers_from_vllm_config, update_config) from vllm.distributed.eplb.eplb_state import EplbState from vllm.distributed.kv_transfer import (get_kv_transfer_group, @@ -44,9 +45,11 @@ GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size, is_pin_memory_available, round_up) from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend -from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, +from vllm.v1.attention.backends.utils import (AttentionCGSupport, + AttentionMetadataBuilder, CommonAttentionMetadata) from vllm.v1.core.encoder_cache_manager import compute_encoder_budget +from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, KVCacheConfig, KVCacheSpec, MambaSpec, SlidingWindowSpec) @@ -211,11 +214,8 @@ def __init__( is_spec_decode=bool(self.vllm_config.speculative_config), ) - self.use_cuda_graph = ( - self.vllm_config.compilation_config.level - == CompilationLevel.PIECEWISE - and self.vllm_config.compilation_config.use_cudagraph - and not self.model_config.enforce_eager) + self.cudagraph_mode = self.compilation_config.cudagraph_mode + # TODO(woosuk): Provide an option to tune the max cudagraph batch size. # The convention is different. # self.cudagraph_batch_sizes sorts in ascending order. @@ -223,7 +223,7 @@ def __init__( self.cudagraph_batch_sizes = list( reversed(self.compilation_config.cudagraph_capture_sizes)) - self.full_cuda_graph = self.compilation_config.full_cuda_graph + self.full_cuda_graph = self.cudagraph_mode == CUDAGraphMode.FULL # Cache the device properties. self._init_device_properties() @@ -313,6 +313,19 @@ def __init__( # from the KV cache of `shared_kv_cache_layers[layer_name]`. self.shared_kv_cache_layers: dict[str, str] = {} + # We may disable capturing cudagraph for mixed batches when + # no support (e.g., no piecewise compilation) or want only capturing + # full cudagraph for pure decode batches. + self.capture_mixed_batches = self.cudagraph_mode != CUDAGraphMode.NONE + self.no_piecewise_compilation = self.compilation_config.level != \ + CompilationLevel.PIECEWISE or self.model_config.enforce_eager + + self.uniform_decode_query_len = 1 if not self.speculative_config else \ + 1 + self.speculative_config.num_speculative_tokens + + # Cudagraph dispatcher for runtime cudagraph dispatching. + self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config) + def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: """ Update the order of requests in the batch based on the attention @@ -576,7 +589,7 @@ def _prepare_inputs( scheduler_output: "SchedulerOutput", ) -> tuple[dict[str, Any], bool, torch.Tensor, Optional[SpecDecodeMetadata], - np.ndarray, Optional[CommonAttentionMetadata]]: + np.ndarray, Optional[CommonAttentionMetadata], int]: """ :return: tuple[ attn_metadata: layer-to-attention_metadata mapping, @@ -759,7 +772,7 @@ def _prepare_inputs( return (attn_metadata, attention_cuda_graphs, logits_indices, spec_decode_metadata, num_scheduled_tokens, - spec_decode_common_attn_metadata) + spec_decode_common_attn_metadata, max_num_scheduled_tokens) def _compute_cascade_attn_prefix_len( self, @@ -1281,12 +1294,12 @@ def execute_model( # Prepare the decoder inputs. (attn_metadata, attention_cuda_graphs, logits_indices, spec_decode_metadata, num_scheduled_tokens_np, - spec_decode_common_attn_metadata) = ( - self._prepare_inputs(scheduler_output)) + spec_decode_common_attn_metadata, + max_query_len) = (self._prepare_inputs(scheduler_output)) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - if (self.use_cuda_graph + if (self.cudagraph_mode != CUDAGraphMode.NONE and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): - # Use piecewise CUDA graphs. + # Use CUDA graphs. # Add padding to the batch size. num_input_tokens = self.vllm_config.pad_for_cudagraph( num_scheduled_tokens) @@ -1346,10 +1359,14 @@ def execute_model( intermediate_tensors = self.sync_and_slice_intermediate_tensors( num_input_tokens, intermediate_tensors, True) - # Some attention backends only support CUDA Graphs in pure decode. - # If attention doesn't support CUDA Graphs for this batch, but we - # compiled with full CUDA graphs, we have to skip them entirely. - skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs + cudagraph_runtime_style = self.cudagraph_dispatcher.\ + get_cudagraph_runtime_style(attention_cuda_graphs) + # Note: When cudagraph_mode is FULL and + # compilation_config.cudagraph_separate_routine is True, this + # flag helps to determine the correct cudagraph routine (optimized + # for attention ops). + uniform_batch = max_query_len == self.uniform_decode_query_len and \ + num_scheduled_tokens == self.input_batch.num_reqs*max_query_len # Run the model. # Use persistent buffers for CUDA graphs. @@ -1358,11 +1375,13 @@ def execute_model( self.vllm_config, num_tokens=num_input_tokens, num_tokens_across_dp=num_tokens_across_dp, - skip_cuda_graphs=skip_cuda_graphs, - ): + cudagraph_runtime_style=cudagraph_runtime_style): self.maybe_setup_kv_connector(scheduler_output) - model_output = self.model( + model = self.cudagraph_dispatcher.dispatch(cudagraph_runtime_style, + uniform_batch) + + model_output = model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, @@ -1759,6 +1778,8 @@ def load_model(self) -> None: self.device, self.parallel_config, ) + # immediately initialize the dispatcher for profile run + self.cudagraph_dispatcher.after_load_model(self.model) def save_tensorized_model( self, @@ -1918,7 +1939,10 @@ def rand_input_ids() -> torch.Tensor: def _dummy_run( self, num_tokens: int, - capture_attn_cudagraph: bool = False, + cudagraph_runtime_style: CUDAGraphRuntimeStyle = ( + CUDAGraphRuntimeStyle.NONE), + force_attention: bool = False, + uniform_batch: bool = False, skip_eplb: bool = False, is_profile: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: @@ -1927,22 +1951,53 @@ def _dummy_run( num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) num_tokens += num_pad + # If cudagraph_separate_routine is enabled when use full cudagraph, + # we need to manually activate the correct routine of attention backend + # for mixed prefill-decode batches and uniform decode batches + # separately during capturing. Uniform batch means that all + # requests have identical query length, except possibly a single, + # shorter dummy request in the batch (account for padding when + # cudagraph capturing). + # An example of uniform batch is common pure decode, where + # max_query_len == 1. Another case is speculative decode, + # where max_query_len == 1 + num_spec_decode_tokens. + + # When setting max_query_len = 1, we switch to and capture the optimized + # routine of FA2 for pure decode, i.e., Flashdecode + an optimization + # for GQA/MQA. + max_query_len = self.uniform_decode_query_len if uniform_batch else \ + num_tokens + # Set num_scheduled_tokens based on num_tokens and max_num_seqs # for dummy run with LoRA so that the num_reqs collectively # has num_tokens in total. assert num_tokens <= self.scheduler_config.max_num_batched_tokens max_num_reqs = self.scheduler_config.max_num_seqs - num_reqs = min(num_tokens, max_num_reqs) - min_tokens_per_req = num_tokens // num_reqs - num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs - num_scheduled_tokens_list[-1] += num_tokens % num_reqs + if uniform_batch: + num_reqs = (num_tokens+self.uniform_decode_query_len-1) // \ + self.uniform_decode_query_len + assert num_reqs <= max_num_reqs, \ + "Do not capture num_reqs > max_num_reqs for uniform batch" + num_scheduled_tokens_list = [max_query_len] * num_reqs + if num_tokens % max_query_len != 0: + num_scheduled_tokens_list[-1] = num_tokens % max_query_len + else: + num_reqs = min(num_tokens, max_num_reqs) + min_tokens_per_req = num_tokens // num_reqs + num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs + num_scheduled_tokens_list[-1] += num_tokens % num_reqs + assert sum(num_scheduled_tokens_list) == num_tokens assert len(num_scheduled_tokens_list) == num_reqs num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) attn_metadata: Optional[dict[str, Any]] = None - if capture_attn_cudagraph: + + # If force_attention is True, we always capture attention. Otherwise, + # it depends on the cudagraph_runtime_style to be FULL or PIECEWISE. + if force_attention or cudagraph_runtime_style == \ + CUDAGraphRuntimeStyle.FULL: attn_metadata = {} # Make sure max_model_len is used at the graph capture time. @@ -1977,7 +2032,6 @@ def _dummy_run( with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens): - model = self.model if self.is_multimodal_model: input_ids = None inputs_embeds = self.inputs_embeds[:num_tokens] @@ -2006,7 +2060,10 @@ def _dummy_run( attn_metadata, self.vllm_config, num_tokens=num_tokens, - num_tokens_across_dp=num_tokens_across_dp): + num_tokens_across_dp=num_tokens_across_dp, + cudagraph_runtime_style=cudagraph_runtime_style): + model = self.cudagraph_dispatcher.dispatch( + cudagraph_runtime_style, uniform_batch) outputs = model( input_ids=input_ids, positions=positions, @@ -2240,11 +2297,10 @@ def profile_run(self) -> None: gc.collect() def capture_model(self) -> None: - if not self.use_cuda_graph: + if self.cudagraph_mode == CUDAGraphMode.NONE: logger.warning( "Skipping CUDA graph capture. To turn on CUDA graph capture, " - "set -O %s and ensure `use_cudagraph` was not manually set to " - "False", CompilationLevel.PIECEWISE) + "ensure `cudagraph_mode` was not manually set to `NONE`") return compilation_counter.num_gpu_runner_capture_triggers += 1 @@ -2256,24 +2312,48 @@ def capture_model(self) -> None: # Capture the large shapes first so that the smaller shapes # can reuse the memory pool allocated for the large shapes. with graph_capture(device=self.device): - full_cg = self.full_cuda_graph - # Only rank 0 should print progress bar during capture - compilation_cases = reversed(self.cudagraph_batch_sizes) - if is_global_first_rank(): - compilation_cases = tqdm( - list(compilation_cases), - disable=not self.load_config.use_tqdm_on_load, - desc="Capturing CUDA graph shapes") - for num_tokens in compilation_cases: - # We skip EPLB here since we don't want to record dummy metrics - for _ in range( - self.compilation_config.cudagraph_num_of_warmups): - self._dummy_run(num_tokens, - capture_attn_cudagraph=full_cg, - skip_eplb=True) - self._dummy_run(num_tokens, - capture_attn_cudagraph=full_cg, - skip_eplb=True) + if self.capture_mixed_batches: + # select between full cudagraph and piecewise cudagraph + # for mixed prefill-decode batches. + attn_cuda_graphs = self.cudagraph_mode == CUDAGraphMode.FULL \ + and self.attn_metadata_builders[0].attn_cudagraph_support \ + in [AttentionCGSupport.ALWAYS_UNIFIED, + AttentionCGSupport.ALWAYS_SEPARATE] + cudagraph_runtime_style = CUDAGraphRuntimeStyle.FULL if \ + attn_cuda_graphs else CUDAGraphRuntimeStyle.PIECEWISE + + # Skip capturing batch sizes of 1 in mix prefill-decode if + # cudagraph_separate_routine is on. As bs=1 can treat as a + # uniform batch. + start_idx = 0 + if self.compilation_config.cudagraph_separate_routine \ + and len(self.cudagraph_batch_sizes) > 0 \ + and self.cudagraph_batch_sizes[0] == 1: + start_idx = 1 + compilation_cases = list( + reversed(self.cudagraph_batch_sizes[start_idx:])) + self._capture_cudagraphs( + compilation_cases, + cudagraph_runtime_style=cudagraph_runtime_style, + uniform_batch=False) + + if self.compilation_config.cudagraph_separate_routine: + # Capture full cudagraph for uniform batches (pure decode/ + # speculative decode). + cudagraph_runtime_style = CUDAGraphRuntimeStyle.FULL + + max_num_tokens = self.scheduler_config.max_num_seqs * \ + self.uniform_decode_query_len + decode_cudagraph_batch_sizes = [ + x for x in self.cudagraph_batch_sizes + if x <= max_num_tokens + ] + compilation_cases_decode = list( + reversed(decode_cudagraph_batch_sizes)) + self._capture_cudagraphs( + compilation_cases=compilation_cases_decode, + cudagraph_runtime_style=cudagraph_runtime_style, + uniform_batch=True) end_time = time.perf_counter() end_free_gpu_memory = torch.cuda.mem_get_info()[0] @@ -2283,6 +2363,37 @@ def capture_model(self) -> None: logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", elapsed_time, cuda_graph_size / (1 << 30)) + def _capture_cudagraphs(self, compilation_cases: list[int], + cudagraph_runtime_style: CUDAGraphRuntimeStyle, + uniform_batch: bool): + # Only rank 0 should print progress bar during capture + if is_global_first_rank(): + compilation_cases = tqdm( + compilation_cases, + disable=not self.load_config.use_tqdm_on_load, + desc="Capturing CUDA graphs ({})".format( + "decode" if uniform_batch else "mix prefill-decode")) + # We skip EPLB here since we don't want to record dummy metrics + for num_tokens in compilation_cases: + for _ in range(self.compilation_config.cudagraph_num_of_warmups): + # Use CUDAGraphRuntimeStyle.NONE (default) for warmup. + # But be careful, warm up with `NONE`is orthogonal to + # if we want to warm up attention or not. This is + # different from the case where `FULL` implies capture + # attention while `PIECEWISE` implies no attention. + force_attention = ( + cudagraph_runtime_style == CUDAGraphRuntimeStyle.FULL) + self._dummy_run( + num_tokens, + cudagraph_runtime_style=CUDAGraphRuntimeStyle.NONE, + force_attention=force_attention, + uniform_batch=uniform_batch, + skip_eplb=True) + self._dummy_run(num_tokens, + cudagraph_runtime_style=cudagraph_runtime_style, + uniform_batch=uniform_batch, + skip_eplb=True) + def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize the attention backends and attention metadata builders. @@ -2290,6 +2401,9 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: assert len(self.attn_backends) == 0 and len( self.attn_metadata_builders ) == 0, "Attention backends are already initialized" + + # Record the attention cudagraph support of the first spec. + attn_cg: AttentionCGSupport = None # type: ignore for i, kv_cache_group_spec in enumerate( kv_cache_config.kv_cache_groups): kv_cache_spec = kv_cache_group_spec.kv_cache_spec @@ -2325,16 +2439,104 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: self.device, ) - if (self.full_cuda_graph - and not attn_metadata_builder_i.full_cudagraph_supported): - raise ValueError( - f"Full CUDAGraph not supported for " - f"{attn_backend_i.__name__}. Turn off CompilationConfig." - f"full_cuda_graph or use a different attention backend.") + if self.cudagraph_mode == CUDAGraphMode.FULL: + if attn_cg is None: + attn_cg = attn_metadata_builder_i.attn_cudagraph_support + else: + if attn_cg != attn_metadata_builder_i.\ + attn_cudagraph_support: + raise ValueError( + "All attention backends must have the same " + "AttentionCGSupport type when using full " + "CUDAGraph. Set CompilationConfig.cudagraph_mode" + " to `PIECEWISE` instead.") + + if attn_cg == AttentionCGSupport.NEVER: + raise ValueError( + f"Full CUDAGraph for {attn_backend_i.__name__} is " + f"no supported. Set CompilationConfig.cudagraph_mode " + f"to `NONE` or `PIECEWISE`, or use a different " + f"attention backend.") + + if len(self.compilation_config.splitting_ops) == 0: + assert attn_cg in [ + AttentionCGSupport.ALWAYS_UNIFIED, + AttentionCGSupport.ALWAYS_SEPARATE, + ], (f"Full CUDAGraph not supported for " + f"{attn_backend_i.__name__} with " + f"CompilationConfig.splitting_ops = []. " + f"Set it to None (default values) " + f"or use a different attention backend.") + + # check if the attention backends compatible with + # CompilationConfig.cudagraph_separate_routine + if attn_cg == AttentionCGSupport.ALWAYS_UNIFIED and \ + self.compilation_config.cudagraph_separate_routine: + logger.warning_once( + f"Full CUDAGraph support for {attn_backend_i.__name__}" + f" is {AttentionCGSupport.ALWAYS_UNIFIED}, which expect" + f"CompilationConfig.cudagraph_separate_routine as " + f"False. Set it to False now.") + self.compilation_config.cudagraph_separate_routine = \ + False + + if attn_cg == AttentionCGSupport.PURE_DECODE_ONLY and \ + not self.compilation_config.cudagraph_separate_routine: + + logger.warning_once( + f"Full CUDAGraph support for {attn_backend_i.__name__}" + f" is {AttentionCGSupport.PURE_DECODE_ONLY}, which " + f"expect CompilationConfig.cudagraph_separate_routine" + f"as True. Set it to True now.") + self.compilation_config.cudagraph_separate_routine = \ + True + + # when AttentionCGSupport.ALWAYS_SEPARATE, we don't change + # the cudagraph_separate_routine flag, but should inform + # the user that this flag can be turned on to obtain + # better performance. + if attn_cg == AttentionCGSupport.ALWAYS_SEPARATE and \ + not self.compilation_config.cudagraph_separate_routine: + logger.warning_once( + f"{attn_backend_i.__name__} generally performs better " + f"when capturing full cudagraph for mix prefill-" + f"decode batches and pure decode batches in separate. " + f"To enable this behavior turn on " + f"CompilationConfig.cudagraph_separate_routine.") + + # for attn_cg is pure decode only, and no piecewise compilation, + # we skip capturing mix prefill-decode (general) batches. + if attn_cg == AttentionCGSupport.PURE_DECODE_ONLY: + if self.no_piecewise_compilation: + logger.warning_once( + f"Skipping capturing mixed prefill-decode batches, " + f"since backend {attn_backend_i.__name__} " + f"supports full cudagraph for pure decode only and " + f"vllm piecewise compilation is disabled.") + self.capture_mixed_batches = False + else: + assert self.compilation_config.is_attention_splitting,\ + "Invalid splitting_ops for piecewise compilation " + "with cudagraph_mode `FULL` for backend " + f"{attn_backend_i.__name__}, which support " + "cudagraph only for pure decode. Please include " + "attention ops in compilation_config.splitting_ops." self.attn_backends.append(attn_backend_i) self.attn_metadata_builders.append(attn_metadata_builder_i) + # Trigger cudagraph initialization here (after + # initializing attn backends). + # TODO: move this to better place. + + # if we need capture full cudagraph for mixed prefill-decode batches. + create_mixed_batch_full_cg = attn_cg in [ + AttentionCGSupport.ALWAYS_UNIFIED, + AttentionCGSupport.ALWAYS_SEPARATE] and \ + self.capture_mixed_batches + self.cudagraph_dispatcher.maybe_initialize_cudagraph( + create_mixed_batch_full_cg) + def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None: """