Skip to content

[aclgraph] implentment NPUPiecewiseBackend to enable aclgraph #836

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions tests/compile/test_aclgraph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Compare the outputs of vLLM with and without aclgraph.

Run `pytest tests/compile/test_aclgraph.py`.
"""

import os

import pytest
import torch
from vllm import LLM, SamplingParams

from tests.conftest import VllmRunner
from tests.model_utils import check_outputs_equal
from vllm_ascend.utils import vllm_version_is

MODELS = ["Qwen/Qwen2.5-0.5B-Instruct"]


@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
reason="aclgraph only support on v1")
@pytest.mark.skipif(
(vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1")),
reason="aclgraph not supported in v0.8.5 and v0.8.5.post1")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_tokens", [32])
def test_models(
model: str,
max_tokens: int,
monkeypatch: pytest.MonkeyPatch,
) -> None:
with monkeypatch.context() as m:
prompts = [
"Hello, my name is", "The president of the United States is",
"The capital of France is", "The future of AI is"
]

# aclgraph only support on v1
m.setenv("VLLM_USE_V1", "1")

sampling_params = SamplingParams(max_tokens=max_tokens,
temperature=0.0)
# TODO: change to use vllmrunner when the registry of custom op is solved
# while running pytest
vllm_model = LLM(model)
vllm_aclgraph_outputs = vllm_model.generate(prompts, sampling_params)
del vllm_model
torch.npu.empty_cache()

vllm_model = LLM(model, enforce_eager=True)
vllm_eager_outputs = vllm_model.generate(prompts, sampling_params)
del vllm_model
torch.npu.empty_cache()

vllm_aclgraph_outputs_list = []
for output in vllm_aclgraph_outputs:
vllm_aclgraph_outputs_list.append(
(output.outputs[0].index, output.outputs[0].text))

vllm_eager_outputs_list = []
for output in vllm_eager_outputs:
vllm_eager_outputs_list.append(
(output.outputs[0].index, output.outputs[0].text))

check_outputs_equal(
outputs_0_lst=vllm_eager_outputs_list,
outputs_1_lst=vllm_aclgraph_outputs_list,
name_0="vllm_eager_outputs",
name_1="vllm_aclgraph_outputs",
)


@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
reason="aclgraph only support on v1")
@pytest.mark.skipif(
(vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1")),
reason="aclgraph not supported in v0.8.5 and v0.8.5.post1")
def test_deepseek_raises_error(monkeypatch: pytest.MonkeyPatch) -> None:
with monkeypatch.context() as m:
m.setenv("VLLM_USE_MODELSCOPE", "True")
m.setenv("VLLM_USE_V1", "1")
with pytest.raises(NotImplementedError) as excinfo:
VllmRunner("deepseek-ai/DeepSeek-V2-Lite-Chat",
max_model_len=1024,
enforce_eager=False)
assert "ACL Graph does not support deepseek" in str(excinfo.value)
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(
block_size: int = 16,
enable_chunked_prefill: bool = False,
swap_space: int = 4,
enforce_eager: Optional[bool] = False,
enforce_eager: Optional[bool] = True,
**kwargs,
) -> None:
self.model = LLM(
Expand Down
4 changes: 3 additions & 1 deletion tests/long_term/spec_decode/e2e/test_v1_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_ngram_correctness(
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")

ref_llm = LLM(model=model_name, max_model_len=1024)
ref_llm = LLM(model=model_name, max_model_len=1024, enforce_eager=True)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm

Expand All @@ -85,6 +85,7 @@ def test_ngram_correctness(
"num_speculative_tokens": 3,
},
max_model_len=1024,
enforce_eager=True,
)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0
Expand Down Expand Up @@ -135,6 +136,7 @@ def test_eagle_correctness(
"max_model_len": 2048,
},
max_model_len=2048,
enforce_eager=True,
)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0
Expand Down
6 changes: 4 additions & 2 deletions tests/multicard/test_dynamic_npugraph_batchsize.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
import torch
from vllm import LLM, SamplingParams

# TODO: revert me when cuda hard code is fixed in 'VllmBackend'
torch.cuda.CUDAGraph = torch.npu.NPUGraph
from vllm_ascend.utils import vllm_version_is

MODELS = [
"Qwen/Qwen2.5-0.5B-Instruct",
Expand All @@ -33,6 +32,9 @@
]


@pytest.mark.skipif(
(vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1")),
reason="aclgraph not supported in v0.8.5 and v0.8.5.post1")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("tp_size", TENSOR_PARALLELS)
@pytest.mark.parametrize("max_tokens", [64])
Expand Down
2 changes: 1 addition & 1 deletion tests/singlecard/test_offline_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_models(model: str, dtype: str, max_tokens: int) -> None:
with VllmRunner(model,
max_model_len=8192,
dtype=dtype,
enforce_eager=False,
enforce_eager=True,
gpu_memory_utilization=0.7) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)

Expand Down
226 changes: 226 additions & 0 deletions vllm_ascend/compilation/piecewise_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/vllm/compilation/cuda_piecewise_backend.py
#

import dataclasses
from contextlib import ExitStack
from typing import Any, Callable, Dict, List, Optional, Set
from unittest.mock import patch

import torch
import torch.fx as fx
import vllm.envs as envs
from vllm.compilation.backends import VllmBackend
from vllm.compilation.counter import compilation_counter
from vllm.compilation.monitor import end_monitoring_torch_compile
from vllm.config import VllmConfig
from vllm.logger import logger
from vllm.utils import weak_ref_tensors


@dataclasses.dataclass
class ConcreteSizeEntry:
runtime_shape: int
need_to_compile: bool # the size is in compile_sizes
use_aclgraph: bool # the size is in cudagraph_capture_sizes

compiled: bool = False
runnable: Callable = None # type: ignore
num_finished_warmup: int = 0
aclgraph: Optional[torch.npu.NPUGraph] = None
output: Optional[Any] = None

# for aclgraph debugging, track the input addresses
# during capture, and check if they are the same during replay
input_addresses: Optional[List[int]] = None


class NPUPiecewiseBackend:

def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
graph_pool: Any, piecewise_compile_index: int,
total_piecewise_compiles: int, sym_shape_indices: List[int],
compiled_graph_for_general_shape: Callable,
vllm_backend: VllmBackend):
"""
The backend for piecewise compilation.
It mainly handles the compilation and aclgraph capturing.

We will compile `self.graph` once for the general shape,
and then compile for different shapes specified in
`compilation_config.compile_sizes`.

Independently, we will capture aclgraph for different shapes.

If a shape needs both compilation and aclgraph, we will
compile it first, and then capture aclgraph.
"""
self.graph = graph
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
self.graph_pool = graph_pool
self.piecewise_compile_index = piecewise_compile_index
self.total_piecewise_compiles = total_piecewise_compiles
self.vllm_backend = vllm_backend

self.is_first_graph = piecewise_compile_index == 0
self.is_last_graph = (
piecewise_compile_index == total_piecewise_compiles - 1)

self.compile_sizes: Set[int] = set(
self.compilation_config.compile_sizes)
self.aclgraph_capture_sizes: Set[int] = set(
self.compilation_config.cudagraph_capture_sizes
) if self.compilation_config.use_cudagraph else set()

self.first_run_finished = False

self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa

self.sym_shape_indices = sym_shape_indices

self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"

# the entries for different shapes that we need to either
# compile or capture aclgraph
self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {}

# to_be_compiled_sizes tracks the remaining sizes to compile,
# and updates during the compilation process, so we need to copy it
self.to_be_compiled_sizes: Set[int] = self.compile_sizes.copy()
for shape in self.compile_sizes.union(self.aclgraph_capture_sizes):
self.concrete_size_entries[shape] = ConcreteSizeEntry(
runtime_shape=shape,
need_to_compile=shape in self.compile_sizes,
use_aclgraph=shape in self.aclgraph_capture_sizes,
)

def check_for_ending_compilation(self):
if self.is_last_graph and not self.to_be_compiled_sizes:
# no specific sizes to compile
# save the hash of the inductor graph for the next run
self.vllm_backend.compiler_manager.save_to_file()
end_monitoring_torch_compile(self.vllm_config)

def __call__(self, *args) -> Any:
if not self.first_run_finished:
self.first_run_finished = True
self.check_for_ending_compilation()
return self.compiled_graph_for_general_shape(*args)

runtime_shape = args[self.sym_shape_indices[0]]
if runtime_shape not in self.concrete_size_entries:
# we don't need to do anything for this shape
return self.compiled_graph_for_general_shape(*args)

entry = self.concrete_size_entries[runtime_shape]

if entry.runnable is None:
entry.runnable = self.compiled_graph_for_general_shape

if entry.need_to_compile and not entry.compiled:
entry.compiled = True
self.to_be_compiled_sizes.remove(runtime_shape)
# args are real arguments
entry.runnable = self.vllm_backend.compiler_manager.compile(
self.graph,
args,
self.compilation_config.inductor_compile_config,
self.compilation_config,
graph_index=self.piecewise_compile_index,
num_graphs=self.total_piecewise_compiles,
runtime_shape=runtime_shape)

# finished compilations for all required shapes
if self.is_last_graph and not self.to_be_compiled_sizes:
self.check_for_ending_compilation()

if not entry.use_aclgraph:
return entry.runnable(*args)

if entry.aclgraph is None:
if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa
entry.num_finished_warmup += 1
if self.is_first_graph:
logger.debug(
"Warming up %s/%s for shape %s",
entry.num_finished_warmup,
self.compilation_config.cudagraph_num_of_warmups,
runtime_shape)
return entry.runnable(*args)

if self.is_first_graph:
# Since we capture aclgraph for many different shapes and
# capturing is fast, we don't need to log it for every shape.
# We only log it in the debug mode.
logger.debug("Capturing a aclgraph for shape %s",
runtime_shape)

input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
entry.input_addresses = input_addresses
aclgraph = torch.npu.NPUGraph()

with ExitStack() as stack:
if not self.is_first_graph:
# during every model forward, we will capture
# many pieces of aclgraphs (roughly one per layer).
# running gc again and again across layers will
# make the aclgraph capture very slow.
# therefore, we only run gc for the first graph,
# and disable gc for the rest of the graphs.
stack.enter_context(patch("gc.collect", lambda: None))
stack.enter_context(
patch("torch.npu.empty_cache", lambda: None))

# mind-exploding: carefully manage the reference and memory.
with torch.npu.graph(aclgraph, pool=self.graph_pool):
# `output` is managed by pytorch's aclgraph pool
output = entry.runnable(*args)
if self.is_last_graph:
# by converting it to weak ref,
# the original `output` will immediately be released
# to save memory. It is only safe to do this for
# the last graph, because the output of the last graph
# will not be used by any other npu aclgraph.
output = weak_ref_tensors(output)

# here we always use weak ref for the output
# to save memory
entry.output = weak_ref_tensors(output)
entry.aclgraph = aclgraph

compilation_counter.num_cudagraph_caputured += 1

# important: we need to return the output, rather than
# the weak ref of the output, so that pytorch can correctly
# manage the memory during npu aclgraph capture
return output

if self.is_debugging_mode:
# check if the input addresses are the same
new_input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
assert new_input_addresses == entry.input_addresses, (
"Input addresses for aclgraphs are different during replay."
f" Expected {entry.input_addresses}, got {new_input_addresses}"
)

entry.aclgraph.replay()
return entry.output
Loading