diff --git a/vllm_ascend/compilation/piecewise_backend.py b/vllm_ascend/compilation/piecewise_backend.py new file mode 100644 index 000000000..5fdf6a908 --- /dev/null +++ b/vllm_ascend/compilation/piecewise_backend.py @@ -0,0 +1,213 @@ +# SPDX-License-Identifier: Apache-2.0 + +import dataclasses +from contextlib import ExitStack +from typing import Any, Callable, Dict, List, Optional, Set +from unittest.mock import patch + +import torch +import torch.fx as fx + +import vllm.envs as envs +from vllm.compilation.backends import VllmBackend +from vllm.compilation.counter import compilation_counter +from vllm.compilation.monitor import end_monitoring_torch_compile +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.utils import weak_ref_tensors + + +logger = init_logger(__name__) + +@dataclasses.dataclass +class ConcreteSizeEntry: + runtime_shape: int + need_to_compile: bool # the size is in compile_sizes + use_npugraph: bool # the size is in cudagraph_capture_sizes + + compiled: bool = False + runnable: Callable = None # type: ignore + num_finished_warmup: int = 0 + npugraph: Optional[torch.npu.NPUGraph] = None + output: Optional[Any] = None + + # for aclgraph debugging, track the input addresses + # during capture, and check if they are the same during replay + input_addresses: Optional[List[int]] = None + + +class NPUPiecewiseBackend: + + def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, + graph_pool: Any, piecewise_compile_index: int, + total_piecewise_compiles: int, sym_shape_indices: List[int], + compiled_graph_for_general_shape: Callable, + vllm_backend: VllmBackend): + """ + The backend for piecewise compilation. + It mainly handles the compilation and aclgraph capturing. + + We will compile `self.graph` once for the general shape, + and then compile for different shapes specified in + `compilation_config.compile_sizes`. + + Independently, we will capture aclgraph for different shapes. + + If a shape needs both compilation and aclgraph, we will + compile it first, and then capture aclgraph. + """ + self.graph = graph + self.vllm_config = vllm_config + self.compilation_config = vllm_config.compilation_config + self.graph_pool = graph_pool + self.piecewise_compile_index = piecewise_compile_index + self.total_piecewise_compiles = total_piecewise_compiles + self.vllm_backend = vllm_backend + + self.is_first_graph = piecewise_compile_index == 0 + self.is_last_graph = ( + piecewise_compile_index == total_piecewise_compiles - 1) + + self.compile_sizes: Set[int] = set( + self.compilation_config.compile_sizes) + self.cudagraph_capture_sizes: Set[int] = set( + self.compilation_config.cudagraph_capture_sizes + ) if self.compilation_config.use_cudagraph else set() + + self.first_run_finished = False + + self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa + + self.sym_shape_indices = sym_shape_indices + + self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" + + # the entries for different shapes that we need to either + # compile or capture aclgraph + self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {} + + # to_be_compiled_sizes tracks the remaining sizes to compile, + # and updates during the compilation process, so we need to copy it + self.to_be_compiled_sizes: Set[int] = self.compile_sizes.copy() + for shape in self.compile_sizes.union(self.cudagraph_capture_sizes): + self.concrete_size_entries[shape] = ConcreteSizeEntry( + runtime_shape=shape, + need_to_compile=shape in self.compile_sizes, + use_npugraph=shape in self.cudagraph_capture_sizes, + ) + + def check_for_ending_compilation(self): + if self.is_last_graph and not self.to_be_compiled_sizes: + # no specific sizes to compile + # save the hash of the inductor graph for the next run + self.vllm_backend.compiler_manager.save_to_file() + end_monitoring_torch_compile(self.vllm_config) + + def __call__(self, *args) -> Any: + if not self.first_run_finished: + self.first_run_finished = True + self.check_for_ending_compilation() + return self.compiled_graph_for_general_shape(*args) + + runtime_shape = args[self.sym_shape_indices[0]] + if runtime_shape not in self.concrete_size_entries: + # we don't need to do anything for this shape + return self.compiled_graph_for_general_shape(*args) + + entry = self.concrete_size_entries[runtime_shape] + + if entry.runnable is None: + entry.runnable = self.compiled_graph_for_general_shape + + if entry.need_to_compile and not entry.compiled: + entry.compiled = True + self.to_be_compiled_sizes.remove(runtime_shape) + # args are real arguments + entry.runnable = self.vllm_backend.compiler_manager.compile( + self.graph, + args, + self.compilation_config.inductor_compile_config, + self.compilation_config, + graph_index=self.piecewise_compile_index, + num_graphs=self.total_piecewise_compiles, + runtime_shape=runtime_shape) + + # finished compilations for all required shapes + if self.is_last_graph and not self.to_be_compiled_sizes: + self.check_for_ending_compilation() + + if not entry.use_npugraph: + return entry.runnable(*args) + + if entry.npugraph is None: + if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa + entry.num_finished_warmup += 1 + if self.is_first_graph: + logger.debug( + "Warming up %s/%s for shape %s", + entry.num_finished_warmup, + self.compilation_config.cudagraph_num_of_warmups, + runtime_shape) + return entry.runnable(*args) + + if self.is_first_graph: + # Since we capture aclgraph for many different shapes and + # capturing is fast, we don't need to log it for every shape. + # We only log it in the debug mode. + logger.debug("Capturing a aclgraph for shape %s", + runtime_shape) + + input_addresses = [ + x.data_ptr() for x in args if isinstance(x, torch.Tensor) + ] + entry.input_addresses = input_addresses + npugraph = torch.npu.NPUGraph() + + with ExitStack() as stack: + if not self.is_first_graph: + # during every model forward, we will capture + # many pieces of aclgraphs (roughly one per layer). + # running gc again and again across layers will + # make the aclgraph capture very slow. + # therefore, we only run gc for the first graph, + # and disable gc for the rest of the graphs. + stack.enter_context(patch("gc.collect", lambda: None)) + stack.enter_context( + patch("torch.npu.empty_cache", lambda: None)) + + # mind-exploding: carefully manage the reference and memory. + with torch.npu.graph(npugraph, pool=self.graph_pool): + # `output` is managed by pytorch's npugraph pool + output = entry.runnable(*args) + if self.is_last_graph: + # by converting it to weak ref, + # the original `output` will immediately be released + # to save memory. It is only safe to do this for + # the last graph, because the output of the last graph + # will not be used by any other cuda graph. + output = weak_ref_tensors(output) + + # here we always use weak ref for the output + # to save memory + entry.output = weak_ref_tensors(output) + entry.npugraph = npugraph + + compilation_counter.num_cudagraph_caputured += 1 + + # important: we need to return the output, rather than + # the weak ref of the output, so that pytorch can correctly + # manage the memory during cuda graph capture + return output + + if self.is_debugging_mode: + # check if the input addresses are the same + new_input_addresses = [ + x.data_ptr() for x in args if isinstance(x, torch.Tensor) + ] + assert new_input_addresses == entry.input_addresses, ( + "Input addresses for aclgraphs are different during replay." + f" Expected {entry.input_addresses}, got {new_input_addresses}" + ) + + entry.npugraph.replay() + return entry.output diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index fbc1dc688..a48047706 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -121,8 +121,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: enforce_eager = getattr(vllm_config.model_config, "enforce_eager", False) - # TODO(Yizhou): Override the value of enforce_eager to True before - # the CANN and torch_npu support NPU compilation. + # TODO: revert me when the fallback of aclgraph is done. enforce_eager = True logger.warning( "NPU compilation support pending. Will be available in future CANN and " @@ -237,3 +236,10 @@ def supports_v1(cls, model_config: ModelConfig) -> bool: model configuration. """ return True + + @classmethod + def get_piecewise_backend_cls(cls) -> str: + """ + Get piecewise backend class for piecewise graph. + """ + return "vllm_ascend.compilation.piecewise_backend.NPUPiecewiseBackend" # noqa