Skip to content

Commit 24dd986

Browse files
MengqingCaoyoukaichao
authored andcommitted
[Compile][Platform] Make PiecewiseBackend pluggable and extendable (vllm-project#18076)
Signed-off-by: Mengqing Cao <cmq0113@163.com> Co-authored-by: youkaichao <youkaichao@gmail.com> Signed-off-by: minpeter <kali2005611@gmail.com>
1 parent 54f60ba commit 24dd986

File tree

6 files changed

+305
-200
lines changed

6 files changed

+305
-200
lines changed

vllm/compilation/backends.py

Lines changed: 6 additions & 200 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,21 @@
66
import pprint
77
import time
88
from collections.abc import Sequence
9-
from contextlib import ExitStack
109
from typing import Any, Callable, Optional
11-
from unittest.mock import patch
1210

1311
import torch
1412
import torch.fx as fx
1513

1614
import vllm.envs as envs
1715
from vllm.config import CompilationConfig, VllmConfig
1816
from vllm.logger import init_logger
19-
from vllm.utils import weak_ref_tensors
17+
from vllm.platforms import current_platform
18+
from vllm.utils import resolve_obj_by_qualname
2019

2120
from .compiler_interface import (CompilerInterface, EagerAdaptor,
2221
InductorAdaptor, InductorStandaloneAdaptor)
2322
from .counter import compilation_counter
2423
from .inductor_pass import InductorPass
25-
from .monitor import end_monitoring_torch_compile
2624
from .pass_manager import PostGradPassManager
2725

2826
logger = init_logger(__name__)
@@ -297,7 +295,9 @@ def call_module(self, target: torch.fx.node.Target,
297295
num_graphs=len(self.compile_submod_names),
298296
runtime_shape=None)
299297

300-
self.module.__dict__[target] = PiecewiseBackend(
298+
piecewise_backend = resolve_obj_by_qualname(
299+
current_platform.get_piecewise_backend_cls())
300+
self.module.__dict__[target] = piecewise_backend(
301301
submod, self.vllm_config, self.graph_pool, index,
302302
len(self.compile_submod_names), sym_shape_indices,
303303
compiled_graph_for_general_shape, self.vllm_backend)
@@ -341,7 +341,7 @@ def __init__(
341341
):
342342
global global_graph_pool
343343
if global_graph_pool is None:
344-
global_graph_pool = torch.cuda.graph_pool_handle()
344+
global_graph_pool = current_platform.graph_pool_handle()
345345

346346
# TODO: in the future, if we want to use multiple
347347
# streams, it might not be safe to share a global pool.
@@ -558,197 +558,3 @@ def copy_and_call(*args):
558558
return self.split_gm(*list_args)
559559

560560
return copy_and_call
561-
562-
563-
@dataclasses.dataclass
564-
class ConcreteSizeEntry:
565-
runtime_shape: int
566-
need_to_compile: bool # the size is in compile_sizes
567-
use_cudagraph: bool # the size is in cudagraph_capture_sizes
568-
569-
compiled: bool = False
570-
runnable: Callable = None # type: ignore
571-
num_finished_warmup: int = 0
572-
cudagraph: Optional[torch.cuda.CUDAGraph] = None
573-
output: Optional[Any] = None
574-
575-
# for cudagraph debugging, track the input addresses
576-
# during capture, and check if they are the same during replay
577-
input_addresses: Optional[list[int]] = None
578-
579-
580-
class PiecewiseBackend:
581-
582-
def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
583-
graph_pool: Any, piecewise_compile_index: int,
584-
total_piecewise_compiles: int, sym_shape_indices: list[int],
585-
compiled_graph_for_general_shape: Callable,
586-
vllm_backend: VllmBackend):
587-
"""
588-
The backend for piecewise compilation.
589-
It mainly handles the compilation and cudagraph capturing.
590-
591-
We will compile `self.graph` once for the general shape,
592-
and then compile for different shapes specified in
593-
`compilation_config.compile_sizes`.
594-
595-
Independently, we will capture cudagraph for different shapes.
596-
597-
If a shape needs both compilation and cudagraph, we will
598-
compile it first, and then capture cudagraph.
599-
"""
600-
self.graph = graph
601-
self.vllm_config = vllm_config
602-
self.compilation_config = vllm_config.compilation_config
603-
self.graph_pool = graph_pool
604-
self.piecewise_compile_index = piecewise_compile_index
605-
self.total_piecewise_compiles = total_piecewise_compiles
606-
self.vllm_backend = vllm_backend
607-
608-
self.is_first_graph = piecewise_compile_index == 0
609-
self.is_last_graph = (
610-
piecewise_compile_index == total_piecewise_compiles - 1)
611-
612-
self.compile_sizes: set[int] = set(
613-
self.compilation_config.compile_sizes)
614-
self.cudagraph_capture_sizes: set[int] = set(
615-
self.compilation_config.cudagraph_capture_sizes
616-
) if self.compilation_config.use_cudagraph else set()
617-
618-
self.first_run_finished = False
619-
620-
self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa
621-
622-
self.sym_shape_indices = sym_shape_indices
623-
624-
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
625-
626-
# the entries for different shapes that we need to either
627-
# compile or capture cudagraph
628-
self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {}
629-
630-
# to_be_compiled_sizes tracks the remaining sizes to compile,
631-
# and updates during the compilation process, so we need to copy it
632-
self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy()
633-
for shape in self.compile_sizes.union(self.cudagraph_capture_sizes):
634-
self.concrete_size_entries[shape] = ConcreteSizeEntry(
635-
runtime_shape=shape,
636-
need_to_compile=shape in self.compile_sizes,
637-
use_cudagraph=shape in self.cudagraph_capture_sizes,
638-
)
639-
640-
def check_for_ending_compilation(self):
641-
if self.is_last_graph and not self.to_be_compiled_sizes:
642-
# no specific sizes to compile
643-
# save the hash of the inductor graph for the next run
644-
self.vllm_backend.compiler_manager.save_to_file()
645-
end_monitoring_torch_compile(self.vllm_config)
646-
647-
def __call__(self, *args) -> Any:
648-
if not self.first_run_finished:
649-
self.first_run_finished = True
650-
self.check_for_ending_compilation()
651-
return self.compiled_graph_for_general_shape(*args)
652-
653-
runtime_shape = args[self.sym_shape_indices[0]]
654-
if runtime_shape not in self.concrete_size_entries:
655-
# we don't need to do anything for this shape
656-
return self.compiled_graph_for_general_shape(*args)
657-
658-
entry = self.concrete_size_entries[runtime_shape]
659-
660-
if entry.runnable is None:
661-
entry.runnable = self.compiled_graph_for_general_shape
662-
663-
if entry.need_to_compile and not entry.compiled:
664-
entry.compiled = True
665-
self.to_be_compiled_sizes.remove(runtime_shape)
666-
# args are real arguments
667-
entry.runnable = self.vllm_backend.compiler_manager.compile(
668-
self.graph,
669-
args,
670-
self.compilation_config.inductor_compile_config,
671-
self.compilation_config,
672-
graph_index=self.piecewise_compile_index,
673-
num_graphs=self.total_piecewise_compiles,
674-
runtime_shape=runtime_shape)
675-
676-
# finished compilations for all required shapes
677-
if self.is_last_graph and not self.to_be_compiled_sizes:
678-
self.check_for_ending_compilation()
679-
680-
if not entry.use_cudagraph:
681-
return entry.runnable(*args)
682-
683-
if entry.cudagraph is None:
684-
if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa
685-
entry.num_finished_warmup += 1
686-
if self.is_first_graph:
687-
logger.debug(
688-
"Warming up %s/%s for shape %s",
689-
entry.num_finished_warmup,
690-
self.compilation_config.cudagraph_num_of_warmups,
691-
runtime_shape)
692-
return entry.runnable(*args)
693-
694-
if self.is_first_graph:
695-
# Since we capture cudagraph for many different shapes and
696-
# capturing is fast, we don't need to log it for every shape.
697-
# We only log it in the debug mode.
698-
logger.debug("Capturing a cudagraph for shape %s",
699-
runtime_shape)
700-
701-
input_addresses = [
702-
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
703-
]
704-
entry.input_addresses = input_addresses
705-
cudagraph = torch.cuda.CUDAGraph()
706-
707-
with ExitStack() as stack:
708-
if not self.is_first_graph:
709-
# during every model forward, we will capture
710-
# many pieces of cudagraphs (roughly one per layer).
711-
# running gc again and again across layers will
712-
# make the cudagraph capture very slow.
713-
# therefore, we only run gc for the first graph,
714-
# and disable gc for the rest of the graphs.
715-
stack.enter_context(patch("gc.collect", lambda: None))
716-
stack.enter_context(
717-
patch("torch.cuda.empty_cache", lambda: None))
718-
719-
# mind-exploding: carefully manage the reference and memory.
720-
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
721-
# `output` is managed by pytorch's cudagraph pool
722-
output = entry.runnable(*args)
723-
if self.is_last_graph:
724-
# by converting it to weak ref,
725-
# the original `output` will immediately be released
726-
# to save memory. It is only safe to do this for
727-
# the last graph, because the output of the last graph
728-
# will not be used by any other cuda graph.
729-
output = weak_ref_tensors(output)
730-
731-
# here we always use weak ref for the output
732-
# to save memory
733-
entry.output = weak_ref_tensors(output)
734-
entry.cudagraph = cudagraph
735-
736-
compilation_counter.num_cudagraph_caputured += 1
737-
738-
# important: we need to return the output, rather than
739-
# the weak ref of the output, so that pytorch can correctly
740-
# manage the memory during cuda graph capture
741-
return output
742-
743-
if self.is_debugging_mode:
744-
# check if the input addresses are the same
745-
new_input_addresses = [
746-
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
747-
]
748-
assert new_input_addresses == entry.input_addresses, (
749-
"Input addresses for cudagraphs are different during replay."
750-
f" Expected {entry.input_addresses}, got {new_input_addresses}"
751-
)
752-
753-
entry.cudagraph.replay()
754-
return entry.output
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
from typing import Any, Callable, Protocol
4+
5+
import torch.fx as fx
6+
7+
from vllm.compilation.backends import VllmBackend
8+
from vllm.config import VllmConfig
9+
10+
11+
class AbstractPiecewiseBackend(Protocol):
12+
"""
13+
PiecewiseBackend interface that allows platforms to extend
14+
piecewise static graph.
15+
"""
16+
17+
def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
18+
graph_pool: Any, piecewise_compile_index: int,
19+
total_piecewise_compiles: int, sym_shape_indices: list[int],
20+
compiled_graph_for_general_shape: Callable,
21+
vllm_backend: VllmBackend, **kwargs):
22+
"""
23+
Initializes the PiecewiseBackend class with compilation and
24+
execution-related configurations.
25+
26+
This class handles piecewise compilation, graph capturing,
27+
and dispatching for specific input shapes.
28+
29+
Args:
30+
graph (fx.GraphModule): The graph represented in fx.
31+
vllm_config (VllmConfig): Global configuration for vLLM.
32+
graph_pool (Any):
33+
Graph memory pool handle, e.g.,
34+
`torch.cuda.graph_pool_handle()`.
35+
piecewise_compile_index (int):
36+
Index of the current piecewise subgraph.
37+
total_piecewise_compiles (int):
38+
Total number of piecewise-compiled graphs.
39+
sym_shape_indices (list[int]):
40+
Indices of symbolic shape.
41+
compiled_graph_for_general_shape (Callable):
42+
Callable that executes the graph compiled for general shapes.
43+
vllm_backend (VllmBackend):
44+
Backend compiler that manages compilation and graph runtime
45+
for vLLM.
46+
47+
Keyword Args:
48+
kwargs: Additional keyword arguments reserved for future
49+
extensions or custom platforms.
50+
"""
51+
raise NotImplementedError
52+
53+
def __call__(self, *args) -> Any:
54+
"""Executes the compiled graph for given input args.
55+
56+
If this is the first invocation, executes the general compiled graph
57+
and initiates the compilation process tracking. For subsequent calls,
58+
dynamically dispatches execution to either a compiled graph or a static
59+
graph based on the input shape.
60+
61+
Args:
62+
*args: Variable length input arguments to be passed into the
63+
graph. The symbolic shape is expected to be in position
64+
`sym_shape_indices[0]`.
65+
66+
Returns:
67+
Any: Output of the executed graph. This can be from the general
68+
compiled graph, a specialized compiled version for the given shape,
69+
or a replayed static graph.
70+
"""
71+
raise NotImplementedError

0 commit comments

Comments
 (0)