Skip to content

Commit 64a6e1a

Browse files
authored
Introduce two FX graph partitioners: remat_all and remat_all_and_offload_these_inputs (#151)
* Introduce two FX graph partitioners: remat_all and remat_all_and_offload_these_inputs The main purpose is to be able to add activation checkpointing and host offloading of layer inputs to scan, which will be added later. And add unit tests that they work as intended. * Add test for offloading * Address some comments * Address comments more * Address comments x3
1 parent 3c59038 commit 64a6e1a

File tree

4 files changed

+664
-0
lines changed

4 files changed

+664
-0
lines changed
Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
from collections.abc import Sequence
2+
from typing import Any
3+
4+
import torch
5+
import torch.fx as fx
6+
from functorch.compile import aot_function, make_boxed_func # type:ignore
7+
from torch.utils._pytree import tree_iter
8+
from torch.utils.checkpoint import CheckpointPolicy
9+
from torch_xla.experimental.stablehlo_custom_call import place_to_device, place_to_host
10+
11+
from .remat_all import remat_all_partition_fn
12+
13+
14+
@torch.library.custom_op("xla::offload_name", mutates_args=())
15+
def offload_name(t: torch.Tensor, name: str) -> torch.Tensor:
16+
"""Given an input tensor, returns a named tensor for offloading selection.
17+
18+
`offload_name` is an identity function that associates the input
19+
tensor with `name`. It is primarily useful in conjunction with
20+
`remat_all_and_offload_these_inputs`, which will rematerialize
21+
intermediate activations and also offload inputs with the specified
22+
names to host memory, moving them back during the backward pass.
23+
"""
24+
if t is None:
25+
return None
26+
return t.clone()
27+
28+
29+
@offload_name.register_fake
30+
def _offload_name_fake(t: torch.Tensor, name: str) -> torch.Tensor:
31+
if t is None:
32+
return None
33+
return torch.empty_like(t)
34+
35+
36+
@offload_name.register_autograd
37+
def _offload_name_backward(ctx, grad):
38+
return grad, None
39+
40+
41+
def remat_all_and_offload_these_inputs(
42+
joint_module: fx.GraphModule,
43+
_joint_inputs,
44+
*,
45+
num_fwd_outputs,
46+
names_to_offload: Sequence[str],
47+
):
48+
"""Partition the graph to rematerialize forward activations and offload
49+
forward inputs to host.
50+
51+
`remat_all_and_offload_these_inputs` will rematerialize (recompute) all
52+
intermediate activations in `joint_module`, and offload inputs with the
53+
specified names to host memory, moving them back during the backward pass.
54+
It transforms the joint graph into separate forward and backward graphs.
55+
"""
56+
input_device = next(iter(tree_iter(_joint_inputs))).device
57+
names_to_offload_set = set(names_to_offload)
58+
59+
# Modify the module such that all `offload_name` tensors whose name match
60+
# `names_to_offload_set` must be saved during the forward pass. Then these
61+
# nodes will show up in the output of the `fwd` graph as additional
62+
# residuals. Later, we'll walk over the graph output to identify the nodes.
63+
for node in joint_module.graph.nodes:
64+
if (
65+
tensor_name := _get_tensor_name_if_node_is_offload_name(node)
66+
) and tensor_name in names_to_offload_set:
67+
# This trick is taken from https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/utils/checkpoint.py#L1290
68+
node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
69+
70+
fwd, bwd = remat_all_partition_fn(
71+
joint_module, _joint_inputs, num_fwd_outputs=num_fwd_outputs
72+
)
73+
with torch.device(input_device):
74+
fw_example_args = _make_arguments(fwd)
75+
bw_example_args = _make_arguments(bwd)
76+
77+
fw_name_in_output_indices = _get_offload_name_to_fw_output_indices(fwd)
78+
bw_name_in_input_names = _get_offload_name_to_bw_input_names(
79+
bwd, num_fwd_outputs, fw_name_in_output_indices
80+
)
81+
82+
def _debug(msg):
83+
return f"""
84+
In the forward graph:
85+
{fwd.print_readable()}
86+
87+
In the backward graph:
88+
{bwd.print_readable()}
89+
90+
{msg}
91+
"""
92+
93+
for name in names_to_offload_set:
94+
if name not in fw_name_in_output_indices:
95+
raise ValueError(
96+
_debug(
97+
f"Did not find {name} in fw_name_in_output_indices: {fw_name_in_output_indices}."
98+
)
99+
)
100+
if name not in bw_name_in_input_names:
101+
raise ValueError(
102+
_debug(
103+
f"Did not find {name} in bw_name_in_input_names: {bw_name_in_input_names}."
104+
)
105+
)
106+
107+
with torch.no_grad():
108+
109+
def forward(**kwargs):
110+
out = fwd(**kwargs)
111+
indices_to_offload = set(
112+
[fw_name_in_output_indices[name] for name in names_to_offload_set]
113+
)
114+
return tuple(
115+
place_to_host(v) if i in indices_to_offload else v for i, v in enumerate(out)
116+
)
117+
118+
def backward(**kwargs):
119+
arguments_to_move_back = set(
120+
[bw_name_in_input_names[name] for name in names_to_offload_set]
121+
)
122+
kwargs = {
123+
k: place_to_device(v) if k in arguments_to_move_back else v
124+
for k, v in kwargs.items()
125+
}
126+
return bwd(**kwargs)
127+
128+
# Use AOTAutograd to retrace forward and backward, thus incorporating
129+
# the offloading ops.
130+
graph = [None]
131+
132+
def get_graph(g, _):
133+
graph[0] = g
134+
return make_boxed_func(g)
135+
136+
_ = aot_function(forward, fw_compiler=get_graph)(**fw_example_args)
137+
aot_forward = graph[0]
138+
139+
_ = aot_function(backward, fw_compiler=get_graph)(**bw_example_args)
140+
aot_backward = graph[0]
141+
142+
return aot_forward, aot_backward
143+
144+
145+
def _make_arguments(gm: fx.GraphModule):
146+
"""
147+
Given a graph module, `make_arguments` returns a dictionary of example inputs
148+
that can be used as keyword arguments to call the graph module.
149+
"""
150+
example_args = {}
151+
for node in gm.graph.nodes:
152+
if node.op != "placeholder":
153+
continue
154+
if "tensor_meta" in node.meta:
155+
tensor_meta = node.meta["tensor_meta"]
156+
tensor = torch.zeros(
157+
tensor_meta.shape,
158+
dtype=tensor_meta.dtype,
159+
requires_grad=tensor_meta.requires_grad,
160+
)
161+
example_args[node.name] = tensor
162+
return example_args
163+
164+
165+
def _get_offload_name_nodes(gm: torch.fx.GraphModule):
166+
"""Build a dict from `offload_name` function call nodes to their names."""
167+
named_nodes: dict[Any, str] = {}
168+
169+
for node in gm.graph.nodes:
170+
if tensor_name := _get_tensor_name_if_node_is_offload_name(node):
171+
named_nodes[node] = tensor_name
172+
173+
return named_nodes
174+
175+
176+
def _get_offload_name_to_fw_output_indices(gm: torch.fx.GraphModule):
177+
"""Given a forward graph `gm`, build a dict from tensor names to their
178+
position in the forward graph outputs."""
179+
180+
named_nodes = _get_offload_name_nodes(gm)
181+
res: dict[str, int] = {}
182+
183+
for node in gm.graph.nodes:
184+
if node.op == "output":
185+
assert len(node.args) <= 1
186+
if len(node.args) == 0:
187+
continue
188+
for i, arg in enumerate(next(iter(node.args))): # type: ignore
189+
if arg in named_nodes:
190+
res[named_nodes[arg]] = i
191+
192+
return res
193+
194+
195+
def _get_offload_name_to_bw_input_names(
196+
gm: torch.fx.GraphModule,
197+
num_fwd_outputs: int,
198+
offload_name_to_output_indices: dict[str, int],
199+
):
200+
"""Given a backward graph `gm`, build a dict from tensor names to their
201+
corresponding keyword argument names in the backward graph inputs."""
202+
203+
res = {}
204+
placeholder_idx = 0
205+
bw_input_idx_to_name = {}
206+
for k, v in offload_name_to_output_indices.items():
207+
bw_input_idx_to_name[v - num_fwd_outputs] = k
208+
209+
for node in gm.graph.nodes:
210+
if node.op == "placeholder":
211+
if placeholder_idx in bw_input_idx_to_name:
212+
res[bw_input_idx_to_name[placeholder_idx]] = node.target
213+
placeholder_idx += 1
214+
215+
return res
216+
217+
218+
def _get_tensor_name_if_node_is_offload_name(node: torch.fx.Node) -> str | None:
219+
"""If the node is a call to the `offload_name` function, return the `name` string argument
220+
that was used to call the function. Otherwise, return None.
221+
"""
222+
if (
223+
node.op == "call_function"
224+
and hasattr(node.target, "name")
225+
and node.target.name() == offload_name._qualname # type: ignore
226+
):
227+
assert isinstance(node.args[1], str)
228+
return node.args[1]
229+
230+
return None
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import torch._functorch.config
2+
import torch.fx
3+
from functorch.compile import min_cut_rematerialization_partition
4+
from torch.utils.checkpoint import CheckpointPolicy
5+
6+
7+
def remat_all_partition_fn(
8+
joint_module: torch.fx.GraphModule,
9+
_joint_inputs,
10+
*,
11+
num_fwd_outputs,
12+
):
13+
"""
14+
remat_all_partition_fn is a graph partition function that closely matches the
15+
default behavior of `torch.utils.checkpoint`, which is to discard all intermediate
16+
activations and recompute all of them during the backward pass.
17+
"""
18+
# Mark anything that does not have a policy as MUST_RECOMPUTE
19+
for node in joint_module.graph.nodes:
20+
if _is_call(node) and "recompute" not in node.meta:
21+
# This trick is taken from https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/utils/checkpoint.py#L1290
22+
node.meta["recompute"] = CheckpointPolicy.MUST_RECOMPUTE
23+
24+
# min_cut_rematerialization_partition checks the graph ID to handle multiple
25+
# graphs at once. We only have one graph so this can simply be 0.
26+
node.meta["ac_graph_id"] = 0
27+
28+
return min_cut_rematerialization_partition(
29+
joint_module, _joint_inputs, num_fwd_outputs=num_fwd_outputs
30+
)
31+
32+
33+
def _is_call(node: torch.fx.Node):
34+
# See documentation here: https://pytorch.org/docs/stable/fx.html
35+
match node.op:
36+
case "call_function" | "call_method" | "call_module":
37+
return True
38+
case _:
39+
return False

0 commit comments

Comments
 (0)