|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | + |
| 3 | +import dataclasses |
| 4 | +from contextlib import ExitStack |
| 5 | +from typing import Any, Callable, Dict, List, Optional, Set |
| 6 | +from unittest.mock import patch |
| 7 | + |
| 8 | +import torch |
| 9 | +import torch.fx as fx |
| 10 | + |
| 11 | +import vllm.envs as envs |
| 12 | +from vllm.compilation.backends import VllmBackend |
| 13 | +from vllm.compilation.counter import compilation_counter |
| 14 | +from vllm.compilation.monitor import end_monitoring_torch_compile |
| 15 | +from vllm.config import VllmConfig |
| 16 | +from vllm.logger import init_logger |
| 17 | +from vllm.utils import weak_ref_tensors |
| 18 | + |
| 19 | + |
| 20 | +logger = init_logger(__name__) |
| 21 | + |
| 22 | +@dataclasses.dataclass |
| 23 | +class ConcreteSizeEntry: |
| 24 | + runtime_shape: int |
| 25 | + need_to_compile: bool # the size is in compile_sizes |
| 26 | + use_npugraph: bool # the size is in cudagraph_capture_sizes |
| 27 | + |
| 28 | + compiled: bool = False |
| 29 | + runnable: Callable = None # type: ignore |
| 30 | + num_finished_warmup: int = 0 |
| 31 | + npugraph: Optional[torch.npu.NPUGraph] = None |
| 32 | + output: Optional[Any] = None |
| 33 | + |
| 34 | + # for aclgraph debugging, track the input addresses |
| 35 | + # during capture, and check if they are the same during replay |
| 36 | + input_addresses: Optional[List[int]] = None |
| 37 | + |
| 38 | + |
| 39 | +class NPUPiecewiseBackend: |
| 40 | + |
| 41 | + def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, |
| 42 | + graph_pool: Any, piecewise_compile_index: int, |
| 43 | + total_piecewise_compiles: int, sym_shape_indices: List[int], |
| 44 | + compiled_graph_for_general_shape: Callable, |
| 45 | + vllm_backend: VllmBackend): |
| 46 | + """ |
| 47 | + The backend for piecewise compilation. |
| 48 | + It mainly handles the compilation and aclgraph capturing. |
| 49 | +
|
| 50 | + We will compile `self.graph` once for the general shape, |
| 51 | + and then compile for different shapes specified in |
| 52 | + `compilation_config.compile_sizes`. |
| 53 | +
|
| 54 | + Independently, we will capture aclgraph for different shapes. |
| 55 | +
|
| 56 | + If a shape needs both compilation and aclgraph, we will |
| 57 | + compile it first, and then capture aclgraph. |
| 58 | + """ |
| 59 | + self.graph = graph |
| 60 | + self.vllm_config = vllm_config |
| 61 | + self.compilation_config = vllm_config.compilation_config |
| 62 | + self.graph_pool = graph_pool |
| 63 | + self.piecewise_compile_index = piecewise_compile_index |
| 64 | + self.total_piecewise_compiles = total_piecewise_compiles |
| 65 | + self.vllm_backend = vllm_backend |
| 66 | + |
| 67 | + self.is_first_graph = piecewise_compile_index == 0 |
| 68 | + self.is_last_graph = ( |
| 69 | + piecewise_compile_index == total_piecewise_compiles - 1) |
| 70 | + |
| 71 | + self.compile_sizes: Set[int] = set( |
| 72 | + self.compilation_config.compile_sizes) |
| 73 | + self.cudagraph_capture_sizes: Set[int] = set( |
| 74 | + self.compilation_config.cudagraph_capture_sizes |
| 75 | + ) if self.compilation_config.use_cudagraph else set() |
| 76 | + |
| 77 | + self.first_run_finished = False |
| 78 | + |
| 79 | + self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa |
| 80 | + |
| 81 | + self.sym_shape_indices = sym_shape_indices |
| 82 | + |
| 83 | + self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" |
| 84 | + |
| 85 | + # the entries for different shapes that we need to either |
| 86 | + # compile or capture aclgraph |
| 87 | + self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {} |
| 88 | + |
| 89 | + # to_be_compiled_sizes tracks the remaining sizes to compile, |
| 90 | + # and updates during the compilation process, so we need to copy it |
| 91 | + self.to_be_compiled_sizes: Set[int] = self.compile_sizes.copy() |
| 92 | + for shape in self.compile_sizes.union(self.cudagraph_capture_sizes): |
| 93 | + self.concrete_size_entries[shape] = ConcreteSizeEntry( |
| 94 | + runtime_shape=shape, |
| 95 | + need_to_compile=shape in self.compile_sizes, |
| 96 | + use_npugraph=shape in self.cudagraph_capture_sizes, |
| 97 | + ) |
| 98 | + |
| 99 | + def check_for_ending_compilation(self): |
| 100 | + if self.is_last_graph and not self.to_be_compiled_sizes: |
| 101 | + # no specific sizes to compile |
| 102 | + # save the hash of the inductor graph for the next run |
| 103 | + self.vllm_backend.compiler_manager.save_to_file() |
| 104 | + end_monitoring_torch_compile(self.vllm_config) |
| 105 | + |
| 106 | + def __call__(self, *args) -> Any: |
| 107 | + if not self.first_run_finished: |
| 108 | + self.first_run_finished = True |
| 109 | + self.check_for_ending_compilation() |
| 110 | + return self.compiled_graph_for_general_shape(*args) |
| 111 | + |
| 112 | + runtime_shape = args[self.sym_shape_indices[0]] |
| 113 | + if runtime_shape not in self.concrete_size_entries: |
| 114 | + # we don't need to do anything for this shape |
| 115 | + return self.compiled_graph_for_general_shape(*args) |
| 116 | + |
| 117 | + entry = self.concrete_size_entries[runtime_shape] |
| 118 | + |
| 119 | + if entry.runnable is None: |
| 120 | + entry.runnable = self.compiled_graph_for_general_shape |
| 121 | + |
| 122 | + if entry.need_to_compile and not entry.compiled: |
| 123 | + entry.compiled = True |
| 124 | + self.to_be_compiled_sizes.remove(runtime_shape) |
| 125 | + # args are real arguments |
| 126 | + entry.runnable = self.vllm_backend.compiler_manager.compile( |
| 127 | + self.graph, |
| 128 | + args, |
| 129 | + self.compilation_config.inductor_compile_config, |
| 130 | + self.compilation_config, |
| 131 | + graph_index=self.piecewise_compile_index, |
| 132 | + num_graphs=self.total_piecewise_compiles, |
| 133 | + runtime_shape=runtime_shape) |
| 134 | + |
| 135 | + # finished compilations for all required shapes |
| 136 | + if self.is_last_graph and not self.to_be_compiled_sizes: |
| 137 | + self.check_for_ending_compilation() |
| 138 | + |
| 139 | + if not entry.use_npugraph: |
| 140 | + return entry.runnable(*args) |
| 141 | + |
| 142 | + if entry.npugraph is None: |
| 143 | + if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa |
| 144 | + entry.num_finished_warmup += 1 |
| 145 | + if self.is_first_graph: |
| 146 | + logger.debug( |
| 147 | + "Warming up %s/%s for shape %s", |
| 148 | + entry.num_finished_warmup, |
| 149 | + self.compilation_config.cudagraph_num_of_warmups, |
| 150 | + runtime_shape) |
| 151 | + return entry.runnable(*args) |
| 152 | + |
| 153 | + if self.is_first_graph: |
| 154 | + # Since we capture aclgraph for many different shapes and |
| 155 | + # capturing is fast, we don't need to log it for every shape. |
| 156 | + # We only log it in the debug mode. |
| 157 | + logger.debug("Capturing a aclgraph for shape %s", |
| 158 | + runtime_shape) |
| 159 | + |
| 160 | + input_addresses = [ |
| 161 | + x.data_ptr() for x in args if isinstance(x, torch.Tensor) |
| 162 | + ] |
| 163 | + entry.input_addresses = input_addresses |
| 164 | + npugraph = torch.npu.NPUGraph() |
| 165 | + |
| 166 | + with ExitStack() as stack: |
| 167 | + if not self.is_first_graph: |
| 168 | + # during every model forward, we will capture |
| 169 | + # many pieces of aclgraphs (roughly one per layer). |
| 170 | + # running gc again and again across layers will |
| 171 | + # make the aclgraph capture very slow. |
| 172 | + # therefore, we only run gc for the first graph, |
| 173 | + # and disable gc for the rest of the graphs. |
| 174 | + stack.enter_context(patch("gc.collect", lambda: None)) |
| 175 | + stack.enter_context( |
| 176 | + patch("torch.npu.empty_cache", lambda: None)) |
| 177 | + |
| 178 | + # mind-exploding: carefully manage the reference and memory. |
| 179 | + with torch.npu.graph(npugraph, pool=self.graph_pool): |
| 180 | + # `output` is managed by pytorch's npugraph pool |
| 181 | + output = entry.runnable(*args) |
| 182 | + if self.is_last_graph: |
| 183 | + # by converting it to weak ref, |
| 184 | + # the original `output` will immediately be released |
| 185 | + # to save memory. It is only safe to do this for |
| 186 | + # the last graph, because the output of the last graph |
| 187 | + # will not be used by any other cuda graph. |
| 188 | + output = weak_ref_tensors(output) |
| 189 | + |
| 190 | + # here we always use weak ref for the output |
| 191 | + # to save memory |
| 192 | + entry.output = weak_ref_tensors(output) |
| 193 | + entry.npugraph = npugraph |
| 194 | + |
| 195 | + compilation_counter.num_cudagraph_caputured += 1 |
| 196 | + |
| 197 | + # important: we need to return the output, rather than |
| 198 | + # the weak ref of the output, so that pytorch can correctly |
| 199 | + # manage the memory during cuda graph capture |
| 200 | + return output |
| 201 | + |
| 202 | + if self.is_debugging_mode: |
| 203 | + # check if the input addresses are the same |
| 204 | + new_input_addresses = [ |
| 205 | + x.data_ptr() for x in args if isinstance(x, torch.Tensor) |
| 206 | + ] |
| 207 | + assert new_input_addresses == entry.input_addresses, ( |
| 208 | + "Input addresses for aclgraphs are different during replay." |
| 209 | + f" Expected {entry.input_addresses}, got {new_input_addresses}" |
| 210 | + ) |
| 211 | + |
| 212 | + entry.npugraph.replay() |
| 213 | + return entry.output |
0 commit comments