Skip to content

Commit 77a7166

Browse files
committed
[aclgraph] implentment NPUPiecewiseBackend to enable aclgraph
Signed-off-by: MengqingCao <cmq0113@163.com>
1 parent e564470 commit 77a7166

File tree

2 files changed

+221
-2
lines changed

2 files changed

+221
-2
lines changed
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import dataclasses
4+
from contextlib import ExitStack
5+
from typing import Any, Callable, Dict, List, Optional, Set
6+
from unittest.mock import patch
7+
8+
import torch
9+
import torch.fx as fx
10+
11+
import vllm.envs as envs
12+
from vllm.compilation.backends import VllmBackend
13+
from vllm.compilation.counter import compilation_counter
14+
from vllm.compilation.monitor import end_monitoring_torch_compile
15+
from vllm.config import VllmConfig
16+
from vllm.logger import init_logger
17+
from vllm.utils import weak_ref_tensors
18+
19+
20+
logger = init_logger(__name__)
21+
22+
@dataclasses.dataclass
23+
class ConcreteSizeEntry:
24+
runtime_shape: int
25+
need_to_compile: bool # the size is in compile_sizes
26+
use_npugraph: bool # the size is in cudagraph_capture_sizes
27+
28+
compiled: bool = False
29+
runnable: Callable = None # type: ignore
30+
num_finished_warmup: int = 0
31+
npugraph: Optional[torch.npu.NPUGraph] = None
32+
output: Optional[Any] = None
33+
34+
# for aclgraph debugging, track the input addresses
35+
# during capture, and check if they are the same during replay
36+
input_addresses: Optional[List[int]] = None
37+
38+
39+
class NPUPiecewiseBackend:
40+
41+
def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
42+
graph_pool: Any, piecewise_compile_index: int,
43+
total_piecewise_compiles: int, sym_shape_indices: List[int],
44+
compiled_graph_for_general_shape: Callable,
45+
vllm_backend: VllmBackend):
46+
"""
47+
The backend for piecewise compilation.
48+
It mainly handles the compilation and aclgraph capturing.
49+
50+
We will compile `self.graph` once for the general shape,
51+
and then compile for different shapes specified in
52+
`compilation_config.compile_sizes`.
53+
54+
Independently, we will capture aclgraph for different shapes.
55+
56+
If a shape needs both compilation and aclgraph, we will
57+
compile it first, and then capture aclgraph.
58+
"""
59+
self.graph = graph
60+
self.vllm_config = vllm_config
61+
self.compilation_config = vllm_config.compilation_config
62+
self.graph_pool = graph_pool
63+
self.piecewise_compile_index = piecewise_compile_index
64+
self.total_piecewise_compiles = total_piecewise_compiles
65+
self.vllm_backend = vllm_backend
66+
67+
self.is_first_graph = piecewise_compile_index == 0
68+
self.is_last_graph = (
69+
piecewise_compile_index == total_piecewise_compiles - 1)
70+
71+
self.compile_sizes: Set[int] = set(
72+
self.compilation_config.compile_sizes)
73+
self.cudagraph_capture_sizes: Set[int] = set(
74+
self.compilation_config.cudagraph_capture_sizes
75+
) if self.compilation_config.use_cudagraph else set()
76+
77+
self.first_run_finished = False
78+
79+
self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa
80+
81+
self.sym_shape_indices = sym_shape_indices
82+
83+
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
84+
85+
# the entries for different shapes that we need to either
86+
# compile or capture aclgraph
87+
self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {}
88+
89+
# to_be_compiled_sizes tracks the remaining sizes to compile,
90+
# and updates during the compilation process, so we need to copy it
91+
self.to_be_compiled_sizes: Set[int] = self.compile_sizes.copy()
92+
for shape in self.compile_sizes.union(self.cudagraph_capture_sizes):
93+
self.concrete_size_entries[shape] = ConcreteSizeEntry(
94+
runtime_shape=shape,
95+
need_to_compile=shape in self.compile_sizes,
96+
use_npugraph=shape in self.cudagraph_capture_sizes,
97+
)
98+
99+
def check_for_ending_compilation(self):
100+
if self.is_last_graph and not self.to_be_compiled_sizes:
101+
# no specific sizes to compile
102+
# save the hash of the inductor graph for the next run
103+
self.vllm_backend.compiler_manager.save_to_file()
104+
end_monitoring_torch_compile(self.vllm_config)
105+
106+
def __call__(self, *args) -> Any:
107+
if not self.first_run_finished:
108+
self.first_run_finished = True
109+
self.check_for_ending_compilation()
110+
return self.compiled_graph_for_general_shape(*args)
111+
112+
runtime_shape = args[self.sym_shape_indices[0]]
113+
if runtime_shape not in self.concrete_size_entries:
114+
# we don't need to do anything for this shape
115+
return self.compiled_graph_for_general_shape(*args)
116+
117+
entry = self.concrete_size_entries[runtime_shape]
118+
119+
if entry.runnable is None:
120+
entry.runnable = self.compiled_graph_for_general_shape
121+
122+
if entry.need_to_compile and not entry.compiled:
123+
entry.compiled = True
124+
self.to_be_compiled_sizes.remove(runtime_shape)
125+
# args are real arguments
126+
entry.runnable = self.vllm_backend.compiler_manager.compile(
127+
self.graph,
128+
args,
129+
self.compilation_config.inductor_compile_config,
130+
self.compilation_config,
131+
graph_index=self.piecewise_compile_index,
132+
num_graphs=self.total_piecewise_compiles,
133+
runtime_shape=runtime_shape)
134+
135+
# finished compilations for all required shapes
136+
if self.is_last_graph and not self.to_be_compiled_sizes:
137+
self.check_for_ending_compilation()
138+
139+
if not entry.use_npugraph:
140+
return entry.runnable(*args)
141+
142+
if entry.npugraph is None:
143+
if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa
144+
entry.num_finished_warmup += 1
145+
if self.is_first_graph:
146+
logger.debug(
147+
"Warming up %s/%s for shape %s",
148+
entry.num_finished_warmup,
149+
self.compilation_config.cudagraph_num_of_warmups,
150+
runtime_shape)
151+
return entry.runnable(*args)
152+
153+
if self.is_first_graph:
154+
# Since we capture aclgraph for many different shapes and
155+
# capturing is fast, we don't need to log it for every shape.
156+
# We only log it in the debug mode.
157+
logger.debug("Capturing a aclgraph for shape %s",
158+
runtime_shape)
159+
160+
input_addresses = [
161+
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
162+
]
163+
entry.input_addresses = input_addresses
164+
npugraph = torch.npu.NPUGraph()
165+
166+
with ExitStack() as stack:
167+
if not self.is_first_graph:
168+
# during every model forward, we will capture
169+
# many pieces of aclgraphs (roughly one per layer).
170+
# running gc again and again across layers will
171+
# make the aclgraph capture very slow.
172+
# therefore, we only run gc for the first graph,
173+
# and disable gc for the rest of the graphs.
174+
stack.enter_context(patch("gc.collect", lambda: None))
175+
stack.enter_context(
176+
patch("torch.npu.empty_cache", lambda: None))
177+
178+
# mind-exploding: carefully manage the reference and memory.
179+
with torch.npu.graph(npugraph, pool=self.graph_pool):
180+
# `output` is managed by pytorch's npugraph pool
181+
output = entry.runnable(*args)
182+
if self.is_last_graph:
183+
# by converting it to weak ref,
184+
# the original `output` will immediately be released
185+
# to save memory. It is only safe to do this for
186+
# the last graph, because the output of the last graph
187+
# will not be used by any other cuda graph.
188+
output = weak_ref_tensors(output)
189+
190+
# here we always use weak ref for the output
191+
# to save memory
192+
entry.output = weak_ref_tensors(output)
193+
entry.npugraph = npugraph
194+
195+
compilation_counter.num_cudagraph_caputured += 1
196+
197+
# important: we need to return the output, rather than
198+
# the weak ref of the output, so that pytorch can correctly
199+
# manage the memory during cuda graph capture
200+
return output
201+
202+
if self.is_debugging_mode:
203+
# check if the input addresses are the same
204+
new_input_addresses = [
205+
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
206+
]
207+
assert new_input_addresses == entry.input_addresses, (
208+
"Input addresses for aclgraphs are different during replay."
209+
f" Expected {entry.input_addresses}, got {new_input_addresses}"
210+
)
211+
212+
entry.npugraph.replay()
213+
return entry.output

vllm_ascend/platform.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
121121
enforce_eager = getattr(vllm_config.model_config, "enforce_eager",
122122
False)
123123

124-
# TODO(Yizhou): Override the value of enforce_eager to True before
125-
# the CANN and torch_npu support NPU compilation.
124+
# TODO: revert me when the fallback of aclgraph is done.
126125
enforce_eager = True
127126
logger.warning(
128127
"NPU compilation support pending. Will be available in future CANN and "
@@ -237,3 +236,10 @@ def supports_v1(cls, model_config: ModelConfig) -> bool:
237236
model configuration.
238237
"""
239238
return True
239+
240+
@classmethod
241+
def get_piecewise_backend_cls(cls) -> str:
242+
"""
243+
Get piecewise backend class for piecewise graph.
244+
"""
245+
return "vllm_ascend.compilation.piecewise_backend.NPUPiecewiseBackend" # noqa

0 commit comments

Comments
 (0)