Skip to content

Commit 1263faf

Browse files
[export] Add a helper that lets us export multiple entrypoints from an nn.Module (#511)
I really hope that the `torch.export` folks consider this a real use case to solve for, but for now it seems like it is us and the executorch folks doing gross hacks like this to support export of multiple functions from a module stack. I went the extra mile on this a bit and went ahead to implement save/load on these stacks, which gives us the ability to do FX tracing once and then iterate on the result to compile. It should be an easier workflow (for dev) considering that Dynamo is so slow for real models. This is all bleeding edge stuff on Torch, but I need it for LLM so will deal with the sharp edge.
1 parent 828a47c commit 1263faf

File tree

3 files changed

+396
-0
lines changed

3 files changed

+396
-0
lines changed

core/shark_turbine/aot/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@
88
from .exporter import export
99

1010
from .builtins import *
11+
from .fx_programs import FxPrograms, FxProgramsBuilder

core/shark_turbine/aot/fx_programs.py

Lines changed: 325 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,325 @@
1+
# Copyright 2024 Advanced Micro Devices, Inc
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
"""Helper classes for assembling sets of FX modules that can be compiled.
8+
9+
This uses the `torch.export` machinery. However, it provides some extra
10+
services for handling multiple modules, save/load, and state management.
11+
"""
12+
13+
import json
14+
import os
15+
from pathlib import Path
16+
from typing import Any, Optional, Union
17+
18+
import functools
19+
20+
import torch
21+
import torch.nn as nn
22+
23+
# The dynamic_shapes support showed up in the Torch 2.3 timeframe.
24+
_supports_dynamic_shapes = hasattr(torch.export, "Dim")
25+
26+
27+
class FxPrograms:
28+
"""Represents a named set of ExportedPrograms.
29+
30+
This facility works around a design flaw in Torch where they conflated
31+
ExportedPrograms as representing a single entry-point while also having
32+
each instance persist its own state_dict and constants. How many times,
33+
in how many frameworks, do we have to fight this design flaw? Apparently
34+
once more.
35+
36+
This base class represents the set of programs, either loaded from storage
37+
or built live. The tricky part it is managing is to do all of this while
38+
aliasing state and captured constants. Having those be physically shared
39+
is an essential optimization.
40+
41+
In order to manage saving/loading of the set of things, we manually splice
42+
the state_dict and constants dict such that while saving, we only persist
43+
the first encountered instance of any reference. Any subsequent instances
44+
are replaced with a SharedStateTensor, which on load can be re-associated.
45+
46+
As this is primarily targeted at being able to decouple FX tracing from
47+
further manipulation (which for reasons unknown, is competing with the
48+
race of entropy to the heat death of the universe in terms of performance),
49+
we don't take a lot of pains to be optimized for distribution or storage of
50+
the resulting artifacts.
51+
52+
In the future, this same technique could be employed to elide parameters
53+
that we know we are going to resolve symbolically later, keeping them from
54+
being loaded and consuming memory during model export and compilation.
55+
56+
We have faith that in the fullness of time, the design flaws in Torch that
57+
require this kind of thing to exist will be resolved, and we then won't
58+
need this hack.
59+
"""
60+
61+
def __init__(self):
62+
self.programs: dict[str, torch.export.ExportedProgram] = {}
63+
64+
def save(self, path: Union[str, os.PathLike]) -> int:
65+
"""Saves the set of exported programs to a descriptor file.
66+
67+
Returns the number of tensors deduped (for debugging/testing).
68+
"""
69+
path = Path(path).resolve()
70+
71+
def permute_path(name):
72+
return path.parent / f"{path.stem}_{name}.pt2"
73+
74+
# Assemble descriptor.
75+
program_files = {name: str(permute_path(name)) for name in self.programs.keys()}
76+
descriptor = {
77+
"load_order": list(program_files.keys()),
78+
"program_files": program_files,
79+
}
80+
81+
# Accumulate shared state as we go.
82+
shared_state_dict: dict[str, Any] = {}
83+
shared_constants: dict[str, Any] = {}
84+
count_deduped = 0
85+
86+
# Save each.
87+
for program_name, ep in self.programs.items():
88+
# First validate the ep with normal rules, which we will then
89+
# disable since we are violating the spec.
90+
ep._validate()
91+
orig_state_dict = dict(ep.state_dict)
92+
constants_dict = _get_optional_constants(ep)
93+
orig_constants = dict(constants_dict)
94+
95+
try:
96+
# Now unmerge the state_dict and constants by knocking it up against
97+
# our running shared state dict.
98+
count_deduped += _sharify_state_dict(shared_state_dict, ep.state_dict)
99+
count_deduped += _sharify_state_dict(shared_constants, constants_dict)
100+
101+
# And save our hacked program.
102+
save_path = program_files[program_name]
103+
torch.export.save(ep, save_path)
104+
finally:
105+
ep.state_dict.clear()
106+
ep.state_dict.update(orig_state_dict)
107+
constants_dict.clear()
108+
constants_dict.update(orig_constants)
109+
110+
# Save the descriptor.
111+
with open(path, "wt") as f:
112+
json.dump(descriptor, f)
113+
return count_deduped
114+
115+
@staticmethod
116+
def load(path: Union[str, os.PathLike]) -> "FxPrograms":
117+
instance = FxPrograms()
118+
path = Path(path).resolve()
119+
with open(path, "rb") as f:
120+
descriptor = json.load(f)
121+
122+
shared_state_dict: dict[str, Any] = {}
123+
shared_constants: dict[str, Any] = {}
124+
125+
for program_name in descriptor["load_order"]:
126+
program_file_name = descriptor["program_files"][program_name]
127+
ep = torch.export.load(path.parent / program_file_name)
128+
_unsharify_state_dict(shared_state_dict, ep.state_dict)
129+
_unsharify_state_dict(shared_constants, _get_optional_constants(ep))
130+
instance.programs[program_name] = ep
131+
return instance
132+
133+
134+
class FxProgramsBuilder(FxPrograms):
135+
"""Builds a new set of exported programs that are all variations of the
136+
same root nn.Module.
137+
138+
This can be used to construct multi-entrypoint sets of ExportedPrograms
139+
in a way that alias information is preserved for lifted tensors.
140+
141+
Usage:
142+
143+
```
144+
class MyModule(nn.Module):
145+
...
146+
147+
fxb = FxProgramBuilder(MyModule())
148+
149+
@fxb.export_program(args=example_args)
150+
def entrypoint(m, x, y):
151+
return m.forward(x, y)
152+
153+
fxb.save("/some/path.json")
154+
```
155+
"""
156+
157+
def __init__(self, root_module: nn.Module):
158+
super().__init__()
159+
self.root_module = root_module
160+
161+
def export_program(
162+
fx_builder,
163+
f=None,
164+
*,
165+
args=None,
166+
kwargs=None,
167+
dynamic_shapes=None,
168+
name: Optional[str] = None,
169+
):
170+
if f is None:
171+
return functools.partial(
172+
fx_builder.export_program,
173+
args=args,
174+
kwargs=kwargs,
175+
dynamic_shapes=dynamic_shapes,
176+
name=name,
177+
)
178+
179+
if name is None:
180+
name = f.__name__
181+
if name in fx_builder.programs:
182+
raise ValueError(f"Attempt to export program '{name}' multiple times")
183+
184+
class LambdaModule(nn.Module):
185+
def __init__(self):
186+
super().__init__()
187+
self.add_module("root", fx_builder.root_module)
188+
189+
# Here we do a tricky thing: The free-function that we take has
190+
# signature:
191+
# def free_function(root_module, arg1, *, kwarg1)
192+
# Since the export machinery expects to be able to inspect and query
193+
# based on user-specified argument names ("arg1", "kwarg1" above),
194+
# we use the usual @functools.wraps to copy metadata. Because we wrap
195+
# it before adding it to the class, the first-arg of the free function
196+
# ("root_module" above) lines up with the usual "self" arg of a method
197+
# attached to a class. When instantiated and created, this synthetic
198+
# 'forward' method will inspect as only taking the user-specified
199+
# argument names (i.e. "arg1", "kwarg1") because the class machinery
200+
# swallowed the first, which is exactly the one we wanted to elide
201+
# from Dynamo's view anyway.
202+
# If we weren't doing this, we would need to munge the signature
203+
# descriptors to line up because the export machinery needs to see
204+
# the user-specified function arguments, not our "pseudo-self" root
205+
# module argument that we always pass.
206+
# Note that to keep Dynamo happy, we are careful to only access
207+
# names and attributes in the module tree (vs from the surrounding
208+
# closure, which goes down less well-trodden paths).
209+
@functools.wraps(f)
210+
def new_forward(self, *forward_args, **forward_kwargs):
211+
return f(self.root, *forward_args, **forward_kwargs)
212+
213+
setattr(LambdaModule, "forward", new_forward)
214+
lambda_module = LambdaModule()
215+
216+
# Export our franken-module.
217+
extra_kwargs = {}
218+
if dynamic_shapes:
219+
if not _supports_dynamic_shapes:
220+
raise ValueError(
221+
f"torch.export with dynamic_shapes= not supported for this version of torch"
222+
)
223+
extra_kwargs["dynamic_shapes"] = dynamic_shapes
224+
program = torch.export.export(
225+
lambda_module, args=args, kwargs=kwargs, **extra_kwargs
226+
)
227+
fx_builder.programs[name] = program
228+
return program
229+
230+
231+
class SharedStateTensor(torch.Tensor):
232+
"""A fake tensor that we shove into ExportedProgram state to share."""
233+
234+
@staticmethod
235+
def __new__(
236+
cls,
237+
size,
238+
dtype,
239+
shared_state_dict_key: str,
240+
is_param: bool,
241+
requires_grad=False,
242+
):
243+
# Using a meta tensor as the wrapped gives us shape and dtype
244+
# propagation.
245+
return torch.Tensor._make_subclass(
246+
cls,
247+
torch.empty(size, dtype=dtype, device="meta"),
248+
require_grad=requires_grad,
249+
)
250+
251+
def __init__(
252+
self,
253+
size,
254+
dtype,
255+
shared_state_dict_key: str,
256+
is_param: bool,
257+
requires_grad=False,
258+
):
259+
self.shared_state_dict_key = shared_state_dict_key
260+
# Magic attribute that makes isinstance(t, Parameter) True.
261+
# See torch.nn.Parameter.
262+
self._is_param = is_param
263+
264+
265+
def _create_shared_state_tensor(
266+
like: torch.Tensor, shared_state_dict_key: str
267+
) -> SharedStateTensor:
268+
t = SharedStateTensor(
269+
like.size(),
270+
like.dtype,
271+
shared_state_dict_key=shared_state_dict_key,
272+
is_param=isinstance(like, torch.nn.Parameter),
273+
requires_grad=like.requires_grad,
274+
)
275+
return t
276+
277+
278+
def _sharify_state_dict(shared_dict: dict, local_dict: dict) -> int:
279+
count_deduped = 0
280+
for key, local_value in local_dict.items():
281+
if not isinstance(local_value, torch.Tensor):
282+
continue
283+
if key in shared_dict:
284+
shared_value = shared_dict[key]
285+
assert (
286+
shared_value is local_value
287+
), f"State dict key collision results in different instances ({key})!"
288+
local_dict[key] = _create_shared_state_tensor(local_value, key)
289+
count_deduped += 1
290+
else:
291+
# Remember the original for the next time.
292+
shared_dict[key] = local_value
293+
return count_deduped
294+
295+
296+
def _unsharify_state_dict(shared_dict: dict, local_dict: dict):
297+
for key, local_value in local_dict.items():
298+
if not isinstance(local_value, torch.Tensor):
299+
continue
300+
if isinstance(local_value, SharedStateTensor):
301+
# Replace shared state tensor.
302+
shared_key = local_value.shared_state_dict_key
303+
try:
304+
shared_value = shared_dict[shared_key]
305+
except KeyError as e:
306+
raise KeyError(
307+
f"Shared tensor not found during deserialization. Corrupt metadata? "
308+
f"{shared_key}"
309+
)
310+
local_dict[key] = shared_value
311+
else:
312+
# Remember this one for later.
313+
shared_dict[key] = local_value
314+
315+
316+
def _get_optional_constants(ep: torch.export.ExportedProgram) -> dict[str, Any]:
317+
"""Constants showed up in early 2.3 timeframe.
318+
319+
Returns an empty dict if not supported.
320+
"""
321+
try:
322+
return ep.constants # type: ignore
323+
except AttributeError:
324+
assert torch.__version__ < "2.3.dev1", "Constants should be available"
325+
return dict()

0 commit comments

Comments
 (0)