Skip to content

[Compile][Platform] Make PiecewiseBackend pluggable and extendable #18076

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 6 additions & 200 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,21 @@
import pprint
import time
from collections.abc import Sequence
from contextlib import ExitStack
from typing import Any, Callable, Optional
from unittest.mock import patch

import torch
import torch.fx as fx

import vllm.envs as envs
from vllm.config import CompilationConfig, VllmConfig
from vllm.logger import init_logger
from vllm.utils import weak_ref_tensors
from vllm.platforms import current_platform
from vllm.utils import resolve_obj_by_qualname

from .compiler_interface import (CompilerInterface, EagerAdaptor,
InductorAdaptor, InductorStandaloneAdaptor)
from .counter import compilation_counter
from .inductor_pass import InductorPass
from .monitor import end_monitoring_torch_compile
from .pass_manager import PostGradPassManager

logger = init_logger(__name__)
Expand Down Expand Up @@ -297,7 +295,9 @@ def call_module(self, target: torch.fx.node.Target,
num_graphs=len(self.compile_submod_names),
runtime_shape=None)

self.module.__dict__[target] = PiecewiseBackend(
piecewise_backend = resolve_obj_by_qualname(
current_platform.get_piecewise_backend_cls())
self.module.__dict__[target] = piecewise_backend(
submod, self.vllm_config, self.graph_pool, index,
len(self.compile_submod_names), sym_shape_indices,
compiled_graph_for_general_shape, self.vllm_backend)
Expand Down Expand Up @@ -341,7 +341,7 @@ def __init__(
):
global global_graph_pool
if global_graph_pool is None:
global_graph_pool = torch.cuda.graph_pool_handle()
global_graph_pool = current_platform.graph_pool_handle()

# TODO: in the future, if we want to use multiple
# streams, it might not be safe to share a global pool.
Expand Down Expand Up @@ -558,197 +558,3 @@ def copy_and_call(*args):
return self.split_gm(*list_args)

return copy_and_call


@dataclasses.dataclass
class ConcreteSizeEntry:
runtime_shape: int
need_to_compile: bool # the size is in compile_sizes
use_cudagraph: bool # the size is in cudagraph_capture_sizes

compiled: bool = False
runnable: Callable = None # type: ignore
num_finished_warmup: int = 0
cudagraph: Optional[torch.cuda.CUDAGraph] = None
output: Optional[Any] = None

# for cudagraph debugging, track the input addresses
# during capture, and check if they are the same during replay
input_addresses: Optional[list[int]] = None


class PiecewiseBackend:

def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
graph_pool: Any, piecewise_compile_index: int,
total_piecewise_compiles: int, sym_shape_indices: list[int],
compiled_graph_for_general_shape: Callable,
vllm_backend: VllmBackend):
"""
The backend for piecewise compilation.
It mainly handles the compilation and cudagraph capturing.
We will compile `self.graph` once for the general shape,
and then compile for different shapes specified in
`compilation_config.compile_sizes`.
Independently, we will capture cudagraph for different shapes.
If a shape needs both compilation and cudagraph, we will
compile it first, and then capture cudagraph.
"""
self.graph = graph
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
self.graph_pool = graph_pool
self.piecewise_compile_index = piecewise_compile_index
self.total_piecewise_compiles = total_piecewise_compiles
self.vllm_backend = vllm_backend

self.is_first_graph = piecewise_compile_index == 0
self.is_last_graph = (
piecewise_compile_index == total_piecewise_compiles - 1)

self.compile_sizes: set[int] = set(
self.compilation_config.compile_sizes)
self.cudagraph_capture_sizes: set[int] = set(
self.compilation_config.cudagraph_capture_sizes
) if self.compilation_config.use_cudagraph else set()

self.first_run_finished = False

self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa

self.sym_shape_indices = sym_shape_indices

self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"

# the entries for different shapes that we need to either
# compile or capture cudagraph
self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {}

# to_be_compiled_sizes tracks the remaining sizes to compile,
# and updates during the compilation process, so we need to copy it
self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy()
for shape in self.compile_sizes.union(self.cudagraph_capture_sizes):
self.concrete_size_entries[shape] = ConcreteSizeEntry(
runtime_shape=shape,
need_to_compile=shape in self.compile_sizes,
use_cudagraph=shape in self.cudagraph_capture_sizes,
)

def check_for_ending_compilation(self):
if self.is_last_graph and not self.to_be_compiled_sizes:
# no specific sizes to compile
# save the hash of the inductor graph for the next run
self.vllm_backend.compiler_manager.save_to_file()
end_monitoring_torch_compile(self.vllm_config)

def __call__(self, *args) -> Any:
if not self.first_run_finished:
self.first_run_finished = True
self.check_for_ending_compilation()
return self.compiled_graph_for_general_shape(*args)

runtime_shape = args[self.sym_shape_indices[0]]
if runtime_shape not in self.concrete_size_entries:
# we don't need to do anything for this shape
return self.compiled_graph_for_general_shape(*args)

entry = self.concrete_size_entries[runtime_shape]

if entry.runnable is None:
entry.runnable = self.compiled_graph_for_general_shape

if entry.need_to_compile and not entry.compiled:
entry.compiled = True
self.to_be_compiled_sizes.remove(runtime_shape)
# args are real arguments
entry.runnable = self.vllm_backend.compiler_manager.compile(
self.graph,
args,
self.compilation_config.inductor_compile_config,
self.compilation_config,
graph_index=self.piecewise_compile_index,
num_graphs=self.total_piecewise_compiles,
runtime_shape=runtime_shape)

# finished compilations for all required shapes
if self.is_last_graph and not self.to_be_compiled_sizes:
self.check_for_ending_compilation()

if not entry.use_cudagraph:
return entry.runnable(*args)

if entry.cudagraph is None:
if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa
entry.num_finished_warmup += 1
if self.is_first_graph:
logger.debug(
"Warming up %s/%s for shape %s",
entry.num_finished_warmup,
self.compilation_config.cudagraph_num_of_warmups,
runtime_shape)
return entry.runnable(*args)

if self.is_first_graph:
# Since we capture cudagraph for many different shapes and
# capturing is fast, we don't need to log it for every shape.
# We only log it in the debug mode.
logger.debug("Capturing a cudagraph for shape %s",
runtime_shape)

input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
entry.input_addresses = input_addresses
cudagraph = torch.cuda.CUDAGraph()

with ExitStack() as stack:
if not self.is_first_graph:
# during every model forward, we will capture
# many pieces of cudagraphs (roughly one per layer).
# running gc again and again across layers will
# make the cudagraph capture very slow.
# therefore, we only run gc for the first graph,
# and disable gc for the rest of the graphs.
stack.enter_context(patch("gc.collect", lambda: None))
stack.enter_context(
patch("torch.cuda.empty_cache", lambda: None))

# mind-exploding: carefully manage the reference and memory.
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
# `output` is managed by pytorch's cudagraph pool
output = entry.runnable(*args)
if self.is_last_graph:
# by converting it to weak ref,
# the original `output` will immediately be released
# to save memory. It is only safe to do this for
# the last graph, because the output of the last graph
# will not be used by any other cuda graph.
output = weak_ref_tensors(output)

# here we always use weak ref for the output
# to save memory
entry.output = weak_ref_tensors(output)
entry.cudagraph = cudagraph

compilation_counter.num_cudagraph_caputured += 1

# important: we need to return the output, rather than
# the weak ref of the output, so that pytorch can correctly
# manage the memory during cuda graph capture
return output

if self.is_debugging_mode:
# check if the input addresses are the same
new_input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
assert new_input_addresses == entry.input_addresses, (
"Input addresses for cudagraphs are different during replay."
f" Expected {entry.input_addresses}, got {new_input_addresses}"
)

entry.cudagraph.replay()
return entry.output
71 changes: 71 additions & 0 deletions vllm/compilation/base_piecewise_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Callable, Protocol

import torch.fx as fx

from vllm.compilation.backends import VllmBackend
from vllm.config import VllmConfig


class AbstractPiecewiseBackend(Protocol):
"""
PiecewiseBackend interface that allows platforms to extend
piecewise static graph.
"""

def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
graph_pool: Any, piecewise_compile_index: int,
total_piecewise_compiles: int, sym_shape_indices: list[int],
compiled_graph_for_general_shape: Callable,
vllm_backend: VllmBackend, **kwargs):
"""
Initializes the PiecewiseBackend class with compilation and
execution-related configurations.
This class handles piecewise compilation, graph capturing,
and dispatching for specific input shapes.
Args:
graph (fx.GraphModule): The graph represented in fx.
vllm_config (VllmConfig): Global configuration for vLLM.
graph_pool (Any):
Graph memory pool handle, e.g.,
`torch.cuda.graph_pool_handle()`.
piecewise_compile_index (int):
Index of the current piecewise subgraph.
total_piecewise_compiles (int):
Total number of piecewise-compiled graphs.
sym_shape_indices (list[int]):
Indices of symbolic shape.
compiled_graph_for_general_shape (Callable):
Callable that executes the graph compiled for general shapes.
vllm_backend (VllmBackend):
Backend compiler that manages compilation and graph runtime
for vLLM.
Keyword Args:
kwargs: Additional keyword arguments reserved for future
extensions or custom platforms.
"""
raise NotImplementedError

def __call__(self, *args) -> Any:
"""Executes the compiled graph for given input args.
If this is the first invocation, executes the general compiled graph
and initiates the compilation process tracking. For subsequent calls,
dynamically dispatches execution to either a compiled graph or a static
graph based on the input shape.
Args:
*args: Variable length input arguments to be passed into the
graph. The symbolic shape is expected to be in position
`sym_shape_indices[0]`.
Returns:
Any: Output of the executed graph. This can be from the general
compiled graph, a specialized compiled version for the given shape,
or a replayed static graph.
"""
raise NotImplementedError
Loading