|
| 1 | +from collections.abc import Sequence |
| 2 | +from typing import Any |
| 3 | + |
| 4 | +import torch |
| 5 | +import torch.fx as fx |
| 6 | +from functorch.compile import aot_function, make_boxed_func # type:ignore |
| 7 | +from torch.utils._pytree import tree_iter |
| 8 | +from torch.utils.checkpoint import CheckpointPolicy |
| 9 | +from torch_xla.experimental.stablehlo_custom_call import place_to_device, place_to_host |
| 10 | + |
| 11 | +from .remat_all import remat_all_partition_fn |
| 12 | + |
| 13 | + |
| 14 | +@torch.library.custom_op("xla::offload_name", mutates_args=()) |
| 15 | +def offload_name(t: torch.Tensor, name: str) -> torch.Tensor: |
| 16 | + """Given an input tensor, returns a named tensor for offloading selection. |
| 17 | +
|
| 18 | + `offload_name` is an identity function that associates the input |
| 19 | + tensor with `name`. It is primarily useful in conjunction with |
| 20 | + `remat_all_and_offload_these_inputs`, which will rematerialize |
| 21 | + intermediate activations and also offload inputs with the specified |
| 22 | + names to host memory, moving them back during the backward pass. |
| 23 | + """ |
| 24 | + if t is None: |
| 25 | + return None |
| 26 | + return t.clone() |
| 27 | + |
| 28 | + |
| 29 | +@offload_name.register_fake |
| 30 | +def _offload_name_fake(t: torch.Tensor, name: str) -> torch.Tensor: |
| 31 | + if t is None: |
| 32 | + return None |
| 33 | + return torch.empty_like(t) |
| 34 | + |
| 35 | + |
| 36 | +@offload_name.register_autograd |
| 37 | +def _offload_name_backward(ctx, grad): |
| 38 | + return grad, None |
| 39 | + |
| 40 | + |
| 41 | +def remat_all_and_offload_these_inputs( |
| 42 | + joint_module: fx.GraphModule, |
| 43 | + _joint_inputs, |
| 44 | + *, |
| 45 | + num_fwd_outputs, |
| 46 | + names_to_offload: Sequence[str], |
| 47 | +): |
| 48 | + """Partition the graph to rematerialize forward activations and offload |
| 49 | + forward inputs to host. |
| 50 | +
|
| 51 | + `remat_all_and_offload_these_inputs` will rematerialize (recompute) all |
| 52 | + intermediate activations in `joint_module`, and offload inputs with the |
| 53 | + specified names to host memory, moving them back during the backward pass. |
| 54 | + It transforms the joint graph into separate forward and backward graphs. |
| 55 | + """ |
| 56 | + input_device = next(iter(tree_iter(_joint_inputs))).device |
| 57 | + names_to_offload_set = set(names_to_offload) |
| 58 | + |
| 59 | + # Modify the module such that all `offload_name` tensors whose name match |
| 60 | + # `names_to_offload_set` must be saved during the forward pass. Then these |
| 61 | + # nodes will show up in the output of the `fwd` graph as additional |
| 62 | + # residuals. Later, we'll walk over the graph output to identify the nodes. |
| 63 | + for node in joint_module.graph.nodes: |
| 64 | + if ( |
| 65 | + tensor_name := _get_tensor_name_if_node_is_offload_name(node) |
| 66 | + ) and tensor_name in names_to_offload_set: |
| 67 | + # This trick is taken from https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/utils/checkpoint.py#L1290 |
| 68 | + node.meta["recompute"] = CheckpointPolicy.MUST_SAVE |
| 69 | + |
| 70 | + fwd, bwd = remat_all_partition_fn( |
| 71 | + joint_module, _joint_inputs, num_fwd_outputs=num_fwd_outputs |
| 72 | + ) |
| 73 | + with torch.device(input_device): |
| 74 | + fw_example_args = _make_arguments(fwd) |
| 75 | + bw_example_args = _make_arguments(bwd) |
| 76 | + |
| 77 | + fw_name_in_output_indices = _get_offload_name_to_fw_output_indices(fwd) |
| 78 | + bw_name_in_input_names = _get_offload_name_to_bw_input_names( |
| 79 | + bwd, num_fwd_outputs, fw_name_in_output_indices |
| 80 | + ) |
| 81 | + |
| 82 | + def _debug(msg): |
| 83 | + return f""" |
| 84 | +In the forward graph: |
| 85 | +{fwd.print_readable()} |
| 86 | +
|
| 87 | +In the backward graph: |
| 88 | +{bwd.print_readable()} |
| 89 | +
|
| 90 | +{msg} |
| 91 | + """ |
| 92 | + |
| 93 | + for name in names_to_offload_set: |
| 94 | + if name not in fw_name_in_output_indices: |
| 95 | + raise ValueError( |
| 96 | + _debug( |
| 97 | + f"Did not find {name} in fw_name_in_output_indices: {fw_name_in_output_indices}." |
| 98 | + ) |
| 99 | + ) |
| 100 | + if name not in bw_name_in_input_names: |
| 101 | + raise ValueError( |
| 102 | + _debug( |
| 103 | + f"Did not find {name} in bw_name_in_input_names: {bw_name_in_input_names}." |
| 104 | + ) |
| 105 | + ) |
| 106 | + |
| 107 | + with torch.no_grad(): |
| 108 | + |
| 109 | + def forward(**kwargs): |
| 110 | + out = fwd(**kwargs) |
| 111 | + indices_to_offload = set( |
| 112 | + [fw_name_in_output_indices[name] for name in names_to_offload_set] |
| 113 | + ) |
| 114 | + return tuple( |
| 115 | + place_to_host(v) if i in indices_to_offload else v for i, v in enumerate(out) |
| 116 | + ) |
| 117 | + |
| 118 | + def backward(**kwargs): |
| 119 | + arguments_to_move_back = set( |
| 120 | + [bw_name_in_input_names[name] for name in names_to_offload_set] |
| 121 | + ) |
| 122 | + kwargs = { |
| 123 | + k: place_to_device(v) if k in arguments_to_move_back else v |
| 124 | + for k, v in kwargs.items() |
| 125 | + } |
| 126 | + return bwd(**kwargs) |
| 127 | + |
| 128 | + # Use AOTAutograd to retrace forward and backward, thus incorporating |
| 129 | + # the offloading ops. |
| 130 | + graph = [None] |
| 131 | + |
| 132 | + def get_graph(g, _): |
| 133 | + graph[0] = g |
| 134 | + return make_boxed_func(g) |
| 135 | + |
| 136 | + _ = aot_function(forward, fw_compiler=get_graph)(**fw_example_args) |
| 137 | + aot_forward = graph[0] |
| 138 | + |
| 139 | + _ = aot_function(backward, fw_compiler=get_graph)(**bw_example_args) |
| 140 | + aot_backward = graph[0] |
| 141 | + |
| 142 | + return aot_forward, aot_backward |
| 143 | + |
| 144 | + |
| 145 | +def _make_arguments(gm: fx.GraphModule): |
| 146 | + """ |
| 147 | + Given a graph module, `make_arguments` returns a dictionary of example inputs |
| 148 | + that can be used as keyword arguments to call the graph module. |
| 149 | + """ |
| 150 | + example_args = {} |
| 151 | + for node in gm.graph.nodes: |
| 152 | + if node.op != "placeholder": |
| 153 | + continue |
| 154 | + if "tensor_meta" in node.meta: |
| 155 | + tensor_meta = node.meta["tensor_meta"] |
| 156 | + tensor = torch.zeros( |
| 157 | + tensor_meta.shape, |
| 158 | + dtype=tensor_meta.dtype, |
| 159 | + requires_grad=tensor_meta.requires_grad, |
| 160 | + ) |
| 161 | + example_args[node.name] = tensor |
| 162 | + return example_args |
| 163 | + |
| 164 | + |
| 165 | +def _get_offload_name_nodes(gm: torch.fx.GraphModule): |
| 166 | + """Build a dict from `offload_name` function call nodes to their names.""" |
| 167 | + named_nodes: dict[Any, str] = {} |
| 168 | + |
| 169 | + for node in gm.graph.nodes: |
| 170 | + if tensor_name := _get_tensor_name_if_node_is_offload_name(node): |
| 171 | + named_nodes[node] = tensor_name |
| 172 | + |
| 173 | + return named_nodes |
| 174 | + |
| 175 | + |
| 176 | +def _get_offload_name_to_fw_output_indices(gm: torch.fx.GraphModule): |
| 177 | + """Given a forward graph `gm`, build a dict from tensor names to their |
| 178 | + position in the forward graph outputs.""" |
| 179 | + |
| 180 | + named_nodes = _get_offload_name_nodes(gm) |
| 181 | + res: dict[str, int] = {} |
| 182 | + |
| 183 | + for node in gm.graph.nodes: |
| 184 | + if node.op == "output": |
| 185 | + assert len(node.args) <= 1 |
| 186 | + if len(node.args) == 0: |
| 187 | + continue |
| 188 | + for i, arg in enumerate(next(iter(node.args))): # type: ignore |
| 189 | + if arg in named_nodes: |
| 190 | + res[named_nodes[arg]] = i |
| 191 | + |
| 192 | + return res |
| 193 | + |
| 194 | + |
| 195 | +def _get_offload_name_to_bw_input_names( |
| 196 | + gm: torch.fx.GraphModule, |
| 197 | + num_fwd_outputs: int, |
| 198 | + offload_name_to_output_indices: dict[str, int], |
| 199 | +): |
| 200 | + """Given a backward graph `gm`, build a dict from tensor names to their |
| 201 | + corresponding keyword argument names in the backward graph inputs.""" |
| 202 | + |
| 203 | + res = {} |
| 204 | + placeholder_idx = 0 |
| 205 | + bw_input_idx_to_name = {} |
| 206 | + for k, v in offload_name_to_output_indices.items(): |
| 207 | + bw_input_idx_to_name[v - num_fwd_outputs] = k |
| 208 | + |
| 209 | + for node in gm.graph.nodes: |
| 210 | + if node.op == "placeholder": |
| 211 | + if placeholder_idx in bw_input_idx_to_name: |
| 212 | + res[bw_input_idx_to_name[placeholder_idx]] = node.target |
| 213 | + placeholder_idx += 1 |
| 214 | + |
| 215 | + return res |
| 216 | + |
| 217 | + |
| 218 | +def _get_tensor_name_if_node_is_offload_name(node: torch.fx.Node) -> str | None: |
| 219 | + """If the node is a call to the `offload_name` function, return the `name` string argument |
| 220 | + that was used to call the function. Otherwise, return None. |
| 221 | + """ |
| 222 | + if ( |
| 223 | + node.op == "call_function" |
| 224 | + and hasattr(node.target, "name") |
| 225 | + and node.target.name() == offload_name._qualname # type: ignore |
| 226 | + ): |
| 227 | + assert isinstance(node.args[1], str) |
| 228 | + return node.args[1] |
| 229 | + |
| 230 | + return None |
0 commit comments