Skip to content

Commit 163124b

Browse files
committed
[aclgraph] implentment NPUPiecewiseBackend to enable aclgraph
Signed-off-by: MengqingCao <cmq0113@163.com>
1 parent e2a0c19 commit 163124b

File tree

7 files changed

+382
-32
lines changed

7 files changed

+382
-32
lines changed

tests/compile/test_aclgraph.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# Copyright 2023 The vLLM team.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
"""
18+
Compare the outputs of vLLM with and without aclgraph.
19+
20+
Run `pytest tests/compile/test_aclgraph.py`.
21+
"""
22+
23+
import os
24+
25+
import pytest
26+
import torch
27+
from vllm import LLM, SamplingParams
28+
29+
from tests.conftest import VllmRunner
30+
from tests.model_utils import check_outputs_equal
31+
from vllm_ascend.utils import vllm_version_is
32+
33+
MODELS = ["Qwen/Qwen2.5-0.5B-Instruct"]
34+
35+
36+
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
37+
reason="aclgraph only support on v1")
38+
@pytest.mark.skipif(
39+
(vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1")),
40+
reason="aclgraph not supported in v0.8.5 and v0.8.5.post1")
41+
@pytest.mark.parametrize("model", MODELS)
42+
@pytest.mark.parametrize("max_tokens", [32])
43+
def test_models(
44+
model: str,
45+
max_tokens: int,
46+
monkeypatch: pytest.MonkeyPatch,
47+
) -> None:
48+
with monkeypatch.context() as m:
49+
prompts = [
50+
"Hello, my name is", "The president of the United States is",
51+
"The capital of France is", "The future of AI is"
52+
]
53+
54+
# aclgraph only support on v1
55+
m.setenv("VLLM_USE_V1", "1")
56+
57+
sampling_params = SamplingParams(max_tokens=max_tokens,
58+
temperature=0.0)
59+
# TODO: change to use vllmrunner when the registry of custom op is solved
60+
# while running pytest
61+
vllm_model = LLM(model)
62+
vllm_aclgraph_outputs = vllm_model.generate(prompts, sampling_params)
63+
del vllm_model
64+
torch.npu.empty_cache()
65+
66+
vllm_model = LLM(model, enforce_eager=True)
67+
vllm_eager_outputs = vllm_model.generate(prompts, sampling_params)
68+
del vllm_model
69+
torch.npu.empty_cache()
70+
71+
vllm_aclgraph_outputs_list = []
72+
for output in vllm_aclgraph_outputs:
73+
vllm_aclgraph_outputs_list.append(
74+
(output.outputs[0].index, output.outputs[0].text))
75+
76+
vllm_eager_outputs_list = []
77+
for output in vllm_eager_outputs:
78+
vllm_eager_outputs_list.append(
79+
(output.outputs[0].index, output.outputs[0].text))
80+
81+
check_outputs_equal(
82+
outputs_0_lst=vllm_eager_outputs_list,
83+
outputs_1_lst=vllm_aclgraph_outputs_list,
84+
name_0="vllm_eager_outputs",
85+
name_1="vllm_aclgraph_outputs",
86+
)
87+
88+
89+
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
90+
reason="aclgraph only support on v1")
91+
@pytest.mark.skipif(
92+
(vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1")),
93+
reason="aclgraph not supported in v0.8.5 and v0.8.5.post1")
94+
def test_deepseek_raises_error(monkeypatch: pytest.MonkeyPatch) -> None:
95+
with monkeypatch.context() as m:
96+
m.setenv("VLLM_USE_MODELSCOPE", "True")
97+
m.setenv("VLLM_USE_V1", "1")
98+
with pytest.raises(NotImplementedError) as excinfo:
99+
VllmRunner("deepseek-ai/DeepSeek-V2-Lite-Chat",
100+
max_model_len=1024,
101+
enforce_eager=False)
102+
assert "ACL Graph does not support deepseek" in str(excinfo.value)

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def __init__(
7777
block_size: int = 16,
7878
enable_chunked_prefill: bool = False,
7979
swap_space: int = 4,
80-
enforce_eager: Optional[bool] = False,
80+
enforce_eager: Optional[bool] = True,
8181
**kwargs,
8282
) -> None:
8383
self.model = LLM(

tests/long_term/spec_decode/e2e/test_v1_spec_decode.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def test_ngram_correctness(
7272
with monkeypatch.context() as m:
7373
m.setenv("VLLM_USE_V1", "1")
7474

75-
ref_llm = LLM(model=model_name, max_model_len=1024)
75+
ref_llm = LLM(model=model_name, max_model_len=1024, enforce_eager=True)
7676
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
7777
del ref_llm
7878

@@ -85,6 +85,7 @@ def test_ngram_correctness(
8585
"num_speculative_tokens": 3,
8686
},
8787
max_model_len=1024,
88+
enforce_eager=True,
8889
)
8990
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
9091
matches = 0
@@ -135,6 +136,7 @@ def test_eagle_correctness(
135136
"max_model_len": 2048,
136137
},
137138
max_model_len=2048,
139+
enforce_eager=True,
138140
)
139141
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
140142
matches = 0

tests/singlecard/test_offline_inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def test_models(model: str, dtype: str, max_tokens: int) -> None:
5252
with VllmRunner(model,
5353
max_model_len=8192,
5454
dtype=dtype,
55-
enforce_eager=False,
55+
enforce_eager=True,
5656
gpu_memory_utilization=0.7) as vllm_model:
5757
vllm_model.generate_greedy(example_prompts, max_tokens)
5858

Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# This file is a part of the vllm-ascend project.
16+
# Adapted from vllm-project/vllm/vllm/compilation/cuda_piecewise_backend.py
17+
#
18+
19+
import dataclasses
20+
from contextlib import ExitStack
21+
from typing import Any, Callable, Dict, List, Optional, Set
22+
from unittest.mock import patch
23+
24+
import torch
25+
import torch.fx as fx
26+
import vllm.envs as envs
27+
from vllm.compilation.backends import VllmBackend
28+
from vllm.compilation.counter import compilation_counter
29+
from vllm.compilation.monitor import end_monitoring_torch_compile
30+
from vllm.config import VllmConfig
31+
from vllm.logger import logger
32+
from vllm.utils import weak_ref_tensors
33+
34+
35+
@dataclasses.dataclass
36+
class ConcreteSizeEntry:
37+
runtime_shape: int
38+
need_to_compile: bool # the size is in compile_sizes
39+
use_aclgraph: bool # the size is in cudagraph_capture_sizes
40+
41+
compiled: bool = False
42+
runnable: Callable = None # type: ignore
43+
num_finished_warmup: int = 0
44+
aclgraph: Optional[torch.npu.NPUGraph] = None
45+
output: Optional[Any] = None
46+
47+
# for aclgraph debugging, track the input addresses
48+
# during capture, and check if they are the same during replay
49+
input_addresses: Optional[List[int]] = None
50+
51+
52+
class NPUPiecewiseBackend:
53+
54+
def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
55+
graph_pool: Any, piecewise_compile_index: int,
56+
total_piecewise_compiles: int, sym_shape_indices: List[int],
57+
compiled_graph_for_general_shape: Callable,
58+
vllm_backend: VllmBackend):
59+
"""
60+
The backend for piecewise compilation.
61+
It mainly handles the compilation and aclgraph capturing.
62+
63+
We will compile `self.graph` once for the general shape,
64+
and then compile for different shapes specified in
65+
`compilation_config.compile_sizes`.
66+
67+
Independently, we will capture aclgraph for different shapes.
68+
69+
If a shape needs both compilation and aclgraph, we will
70+
compile it first, and then capture aclgraph.
71+
"""
72+
self.graph = graph
73+
self.vllm_config = vllm_config
74+
self.compilation_config = vllm_config.compilation_config
75+
self.graph_pool = graph_pool
76+
self.piecewise_compile_index = piecewise_compile_index
77+
self.total_piecewise_compiles = total_piecewise_compiles
78+
self.vllm_backend = vllm_backend
79+
80+
self.is_first_graph = piecewise_compile_index == 0
81+
self.is_last_graph = (
82+
piecewise_compile_index == total_piecewise_compiles - 1)
83+
84+
self.compile_sizes: Set[int] = set(
85+
self.compilation_config.compile_sizes)
86+
self.aclgraph_capture_sizes: Set[int] = set(
87+
self.compilation_config.cudagraph_capture_sizes
88+
) if self.compilation_config.use_cudagraph else set()
89+
90+
self.first_run_finished = False
91+
92+
self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa
93+
94+
self.sym_shape_indices = sym_shape_indices
95+
96+
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
97+
98+
# the entries for different shapes that we need to either
99+
# compile or capture aclgraph
100+
self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {}
101+
102+
# to_be_compiled_sizes tracks the remaining sizes to compile,
103+
# and updates during the compilation process, so we need to copy it
104+
self.to_be_compiled_sizes: Set[int] = self.compile_sizes.copy()
105+
for shape in self.compile_sizes.union(self.aclgraph_capture_sizes):
106+
self.concrete_size_entries[shape] = ConcreteSizeEntry(
107+
runtime_shape=shape,
108+
need_to_compile=shape in self.compile_sizes,
109+
use_aclgraph=shape in self.aclgraph_capture_sizes,
110+
)
111+
112+
def check_for_ending_compilation(self):
113+
if self.is_last_graph and not self.to_be_compiled_sizes:
114+
# no specific sizes to compile
115+
# save the hash of the inductor graph for the next run
116+
self.vllm_backend.compiler_manager.save_to_file()
117+
end_monitoring_torch_compile(self.vllm_config)
118+
119+
def __call__(self, *args) -> Any:
120+
if not self.first_run_finished:
121+
self.first_run_finished = True
122+
self.check_for_ending_compilation()
123+
return self.compiled_graph_for_general_shape(*args)
124+
125+
runtime_shape = args[self.sym_shape_indices[0]]
126+
if runtime_shape not in self.concrete_size_entries:
127+
# we don't need to do anything for this shape
128+
return self.compiled_graph_for_general_shape(*args)
129+
130+
entry = self.concrete_size_entries[runtime_shape]
131+
132+
if entry.runnable is None:
133+
entry.runnable = self.compiled_graph_for_general_shape
134+
135+
if entry.need_to_compile and not entry.compiled:
136+
entry.compiled = True
137+
self.to_be_compiled_sizes.remove(runtime_shape)
138+
# args are real arguments
139+
entry.runnable = self.vllm_backend.compiler_manager.compile(
140+
self.graph,
141+
args,
142+
self.compilation_config.inductor_compile_config,
143+
self.compilation_config,
144+
graph_index=self.piecewise_compile_index,
145+
num_graphs=self.total_piecewise_compiles,
146+
runtime_shape=runtime_shape)
147+
148+
# finished compilations for all required shapes
149+
if self.is_last_graph and not self.to_be_compiled_sizes:
150+
self.check_for_ending_compilation()
151+
152+
if not entry.use_aclgraph:
153+
return entry.runnable(*args)
154+
155+
if entry.aclgraph is None:
156+
if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa
157+
entry.num_finished_warmup += 1
158+
if self.is_first_graph:
159+
logger.debug(
160+
"Warming up %s/%s for shape %s",
161+
entry.num_finished_warmup,
162+
self.compilation_config.cudagraph_num_of_warmups,
163+
runtime_shape)
164+
return entry.runnable(*args)
165+
166+
if self.is_first_graph:
167+
# Since we capture aclgraph for many different shapes and
168+
# capturing is fast, we don't need to log it for every shape.
169+
# We only log it in the debug mode.
170+
logger.debug("Capturing a aclgraph for shape %s",
171+
runtime_shape)
172+
173+
input_addresses = [
174+
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
175+
]
176+
entry.input_addresses = input_addresses
177+
aclgraph = torch.npu.NPUGraph()
178+
179+
with ExitStack() as stack:
180+
if not self.is_first_graph:
181+
# during every model forward, we will capture
182+
# many pieces of aclgraphs (roughly one per layer).
183+
# running gc again and again across layers will
184+
# make the aclgraph capture very slow.
185+
# therefore, we only run gc for the first graph,
186+
# and disable gc for the rest of the graphs.
187+
stack.enter_context(patch("gc.collect", lambda: None))
188+
stack.enter_context(
189+
patch("torch.npu.empty_cache", lambda: None))
190+
191+
# mind-exploding: carefully manage the reference and memory.
192+
with torch.npu.graph(aclgraph, pool=self.graph_pool):
193+
# `output` is managed by pytorch's aclgraph pool
194+
output = entry.runnable(*args)
195+
if self.is_last_graph:
196+
# by converting it to weak ref,
197+
# the original `output` will immediately be released
198+
# to save memory. It is only safe to do this for
199+
# the last graph, because the output of the last graph
200+
# will not be used by any other npu aclgraph.
201+
output = weak_ref_tensors(output)
202+
203+
# here we always use weak ref for the output
204+
# to save memory
205+
entry.output = weak_ref_tensors(output)
206+
entry.aclgraph = aclgraph
207+
208+
compilation_counter.num_cudagraph_caputured += 1
209+
210+
# important: we need to return the output, rather than
211+
# the weak ref of the output, so that pytorch can correctly
212+
# manage the memory during npu aclgraph capture
213+
return output
214+
215+
if self.is_debugging_mode:
216+
# check if the input addresses are the same
217+
new_input_addresses = [
218+
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
219+
]
220+
assert new_input_addresses == entry.input_addresses, (
221+
"Input addresses for aclgraphs are different during replay."
222+
f" Expected {entry.input_addresses}, got {new_input_addresses}"
223+
)
224+
225+
entry.aclgraph.replay()
226+
return entry.output

vllm_ascend/ops/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ def __init__(self, name=""):
3333

3434

3535
def register_dummy_fusion_op() -> None:
36-
torch.cuda.CUDAGraph = torch.npu.NPUGraph
3736
torch.ops._C.rms_norm = dummyFusionOp(name="rms_norm")
3837
torch.ops._C.fused_add_rms_norm = dummyFusionOp(name="fused_add_rms_norm")
3938
torch.ops._C.static_scaled_fp8_quant = dummyFusionOp(

0 commit comments

Comments
 (0)