Skip to content
28 changes: 27 additions & 1 deletion devtools/etrecord/_etrecord.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,22 @@ def __init__(
_reference_outputs: Optional[Dict[str, List[ProgramOutput]]] = None,
_representative_inputs: Optional[List[ProgramInput]] = None,
):
"""
Please do not construct an ETRecord object directly.
If you want to create an ETRecord for logging AOT information to further analysis, please mark `generate_etrecord`
as True in your export api, and get the ETRecord object from the `ExecutorchProgramManager`.
For exmaple:
```python
exported_program = torch.export.export(model, inputs)
edge_program = to_edge_transform_and_lower(exported_program, generate_etrecord=True)
executorch_program = edge_program.to_executorch()
etrecord = executorch_program.get_etrecord()
```
If user need to create an ETRecord manually, please use the `create_etrecord` function.
"""

self.exported_program = exported_program
self.export_graph_id = export_graph_id
self.edge_dialect_program = edge_dialect_program
Expand All @@ -81,15 +97,25 @@ def __init__(

def save(self, path: Union[str, os.PathLike, BinaryIO, IO[bytes]]) -> None:
"""
Serialize and save the ETRecord to the specified path.
Serialize and save the ETRecord to the specified path for use in Inspector. The ETRecord
should contains at least edge dialect program and executorch program information for further
analysis, otherwise it will raise an exception.
Args:
path: Path where the ETRecord file will be saved to.
Raises:
RuntimeError: If the ETRecord does not contain essential information for Inpector.
"""
if isinstance(path, (str, os.PathLike)):
# pyre-ignore[6]: In call `os.fspath`, for 1st positional argument, expected `str` but got `Union[PathLike[typing.Any], str]`
path = os.fspath(path)

if not (self.edge_dialect_program and self._debug_handle_map):
raise RuntimeError(
"ETRecord must contain edge dialect program and executorch program to be saved"
)

etrecord_zip = ZipFile(path, "w")

try:
Expand Down
29 changes: 29 additions & 0 deletions devtools/etrecord/tests/etrecord_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1462,3 +1462,32 @@ def test_update_apis_and_save_parse(self):
custom_outputs["forward"], parsed_etrecord._reference_outputs["forward"]
):
self.assertTrue(torch.equal(expected[0], actual[0]))

def test_save_missing_essential_info(self):
def expected_runtime_error(etrecord, etrecord_path):
with self.assertRaises(RuntimeError) as context:
etrecord.save(etrecord_path)

self.assertIn(
"ETRecord must contain edge dialect program and executorch program to be saved",
str(context.exception),
)

"""Test that save raises RuntimeError when essential info is missing."""
_, edge_output, et_output = self.get_test_model()

etrecord = ETRecord()

with tempfile.TemporaryDirectory() as tmpdirname:
etrecord_path = tmpdirname + "/etrecord_no_edge.bin"

expected_runtime_error(etrecord, etrecord_path)
etrecord.add_edge_dialect_program(edge_output)

# Should raise runtime error due to missing executorch program related info
expected_runtime_error(etrecord, etrecord_path)

etrecord.add_executorch_program(et_output)

# All essential components are now present, so save should succeed
etrecord.save(etrecord_path)
Loading