Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 167 additions & 0 deletions devtools/etrecord/tests/etrecord_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
ETRecordReservedFileNames,
)
from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge
from executorch.exir.program._program import to_edge_transform_and_lower
from torch.export import export


Expand Down Expand Up @@ -52,6 +53,21 @@ def assert_etrecord_is_empty(self, etrecord: ETRecord) -> None:
self.assert_etrecord_has_no_executorch_program(etrecord)
self.assertIsNone(etrecord.graph_map)

def assert_legal_etrecord_in_edge_program(self, etrecord: ETRecord) -> None:
"""Assert that ETRecord has all expected data after to_edge_transform_and_lower() or to_edge() stage"""
self.assertIsNotNone(etrecord.exported_program)
self.assertIsNotNone(etrecord.export_graph_id)
self.assertIsNotNone(etrecord.edge_dialect_program)
self.assert_etrecord_has_no_executorch_program(etrecord)

def assert_etrecord_saveable(self, etrecord: ETRecord) -> None:
"""Assert ETRecord contains all essential information for saving"""
self.assertIsNotNone(etrecord.exported_program)
self.assertIsNotNone(etrecord.export_graph_id)
self.assertIsNotNone(etrecord.edge_dialect_program)
self.assertIsNotNone(etrecord._debug_handle_map)
self.assertIsNotNone(etrecord._delegate_map)

def get_test_model(self):
f = models.BasicSinMax()
captured_output = exir.capture(f, f.get_random_inputs(), exir.CaptureConfig())
Expand Down Expand Up @@ -275,6 +291,157 @@ def test_etrecord_generation_with_exported_program(self):
# Validate that export_graph_id matches the expected value
self.assertEqual(etrecord.export_graph_id, expected_graph_id)

def test_to_edge_transform_and_lower_with_etrecord_generation(self):
"""Test that to_edge_transform_and_lower generates ETRecord correctly."""
f = models.BasicSinMax()
aten_program = export(f, f.get_random_inputs(), strict=True)

# Test with generate_etrecord=True
edge_manager = to_edge_transform_and_lower(
aten_program,
generate_etrecord=True,
)

# Verify that ETRecord was generated and attached
self.assertIsNotNone(edge_manager._etrecord)
etrecord = edge_manager._etrecord
self.assert_legal_etrecord_in_edge_program(etrecord)

# Verify the exported program matches the input
self.check_graph_closeness(
etrecord.exported_program,
aten_program.graph_module,
)
self.assertEqual(
etrecord.export_graph_id,
id(aten_program.graph),
)

# Verify the edge dialect program matches the edge manager
self.check_graph_closeness(
etrecord.edge_dialect_program,
edge_manager.exported_program().graph_module,
)

def test_to_edge_transform_and_lower_without_etrecord_generation(self):
"""Test that to_edge_transform_and_lower works correctly without ETRecord generation."""
f = models.BasicSinMax()
aten_program = export(f, f.get_random_inputs(), strict=True)

# Test with generate_etrecord=False (default)
edge_manager = to_edge_transform_and_lower(aten_program)

# Verify that no ETRecord was generated
self.assertIsNone(edge_manager._etrecord)

# Verify that the edge manager still works correctly
self.assertIsNotNone(edge_manager.exported_program())

def test_get_etrecord_from_executorch_program_manager(self):
"""Test getting ETRecord from ExecutorchProgramManager using get_etrecord() method."""
f = models.BasicSinMax()
aten_program = export(f, f.get_random_inputs(), strict=True)

# Generate edge manager with ETRecord
edge_manager = to_edge_transform_and_lower(
aten_program,
generate_etrecord=True,
)

# Convert to executorch
et_manager = edge_manager.to_executorch()

# Test get_etrecord method
etrecord = et_manager.get_etrecord()
self.assertIsNotNone(etrecord)
self.assert_etrecord_saveable(etrecord)

# Verify the data matches the original input
self.check_graph_closeness(
etrecord.exported_program,
aten_program.graph_module,
)
self.assertEqual(
etrecord.export_graph_id,
id(aten_program.graph),
)

# Verify the executorch program data matches
# ETRecord stores data directly (not JSON serialized), so compare with original data
self.assertEqual(etrecord._debug_handle_map, et_manager.debug_handle_map)
self.assertEqual(etrecord._delegate_map, et_manager.delegate_map)

def test_get_etrecord_from_executorch_program_manager_without_generation(self):
"""Test getting ETRecord from ExecutorchProgramManager when ETRecord was not generated."""
f = models.BasicSinMax()
aten_program = export(f, f.get_random_inputs(), strict=True)

# Generate edge manager without ETRecord
edge_manager = to_edge_transform_and_lower(aten_program)

# Verify no ETRecord on edge manager
self.assertIsNone(edge_manager._etrecord)

# Convert to executorch
et_manager = edge_manager.to_executorch()

# Verify no ETRecord on executorch manager
self.assertIsNone(et_manager._etrecord)

# Test get_etrecord method should raise RuntimeError
with self.assertRaises(RuntimeError) as context:
et_manager.get_etrecord()

self.assertIn("ETRecord was not generated", str(context.exception))

def test_to_edge_transform_and_lower_etrecord_save_and_parse(self):
"""Test that ETRecord generated by to_edge_transform_and_lower can be saved and parsed."""
f = models.BasicSinMax()
aten_program = export(f, f.get_random_inputs(), strict=True)

# Generate edge manager with ETRecord
edge_manager = to_edge_transform_and_lower(
aten_program,
generate_etrecord=True,
)

# Convert to executorch to get complete ETRecord
et_manager = edge_manager.to_executorch()
etrecord = et_manager.get_etrecord()

with tempfile.TemporaryDirectory() as tmpdirname:
etrecord_path = tmpdirname + "/etrecord_flow2.bin"

etrecord.save(etrecord_path)

# Parse ETRecord back and verify
parsed_etrecord = parse_etrecord(etrecord_path)

# Validate that all components are preserved
# Note: Skip graph structure comparison due to transformation differences
self.check_graph_closeness(
etrecord.exported_program, parsed_etrecord.exported_program
)
self.check_graph_closeness(
etrecord.edge_dialect_program, parsed_etrecord.edge_dialect_program
)

# Validate executorch program data
self.assertEqual(
parsed_etrecord._debug_handle_map,
json.loads(json.dumps(et_manager.debug_handle_map)),
)
self.assertEqual(
parsed_etrecord._delegate_map,
json.loads(json.dumps(et_manager.delegate_map)),
)

# Validate export graph id
self.assertEqual(
parsed_etrecord.export_graph_id,
id(aten_program.graph),
)

def test_add_extra_export_modules(self):
"""Test add_extra_export_modules when ETRecord already has a graph_map."""
captured_output, edge_output, et_output = self.get_test_model()
Expand Down
48 changes: 46 additions & 2 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,15 @@ def _copy_module(new_prog, new_gm):
setattr(new_prog, node.target, t)


def _create_empty_etrecord():
# Import etrecord at runtime to resolve cyclic dependencies (program -> etrecord -> program).
# This also ensures that etrecord-related packages do not affect the export flow.
# @manual
from executorch.devtools.etrecord import ETRecord

return ETRecord()


def lift_constant_tensor_pass(ep):
"""
Takes an ExportedProgram and returns the ExportedProgram modified in-place,
Expand Down Expand Up @@ -1103,6 +1112,7 @@ def _gen_edge_manager_for_partitioners(
aten_programs: Dict[str, ExportedProgram],
config: EdgeCompileConfig,
constant_methods: Optional[Dict[str, Any]],
generate_etrecord: Optional[bool] = False,
) -> "EdgeProgramManager":
"""
Generates EdgeProgramManager for subsequent lowering to the
Expand Down Expand Up @@ -1179,6 +1189,13 @@ def _gen_edge_manager_for_partitioners(
config,
list(set().union(*ops_set_to_not_decompose_by_program.values())),
)

if generate_etrecord:
etrecord = _create_empty_etrecord()
etrecord.add_exported_program(aten_programs)
etrecord.add_edge_dialect_program(copy.deepcopy(edge_manager))
edge_manager._etrecord = etrecord

return edge_manager


Expand Down Expand Up @@ -1220,6 +1237,7 @@ def to_edge_transform_and_lower( # noqa: C901
] = None,
constant_methods: Optional[Dict[str, Any]] = None,
compile_config: Optional[EdgeCompileConfig] = None,
generate_etrecord: bool = False,
) -> "EdgeProgramManager":
"""
:func:`to_edge_transform_and_lower` constructs an EdgeProgramManager from a set of
Expand Down Expand Up @@ -1260,6 +1278,8 @@ def to_edge_transform_and_lower( # noqa: C901
compile_config: An optional argument used to provide greater control over the
transformation to edge dialect process.
generate_etrecord: An optional argument used to generate an etrecord for debugging purposes.
Returns:
EdgeProgramManager
"""
Expand All @@ -1279,7 +1299,7 @@ def to_edge_transform_and_lower( # noqa: C901
partitioner, aten_programs
)
edge_manager = _gen_edge_manager_for_partitioners(
partitioner, aten_programs, config, constant_methods
partitioner, aten_programs, config, constant_methods, generate_etrecord
)

if transform_passes is not None:
Expand Down Expand Up @@ -1447,6 +1467,8 @@ def __init__(
program, self._named_data_store
)

self._etrecord = None

@property
def methods(self) -> Set[str]:
"""
Expand Down Expand Up @@ -1643,13 +1665,19 @@ def to_executorch(
_copy_module(program.graph_module, new_gm)
execution_programs[name] = program

return ExecutorchProgramManager(
et_pm = ExecutorchProgramManager(
execution_programs,
self._config_methods,
config,
self._named_data_store.get_named_data_store_output(),
)

if self._etrecord is not None:
self._etrecord.add_executorch_program(et_pm)
et_pm._etrecord = self._etrecord

return et_pm


class ExecutorchProgramManager:
"""
Expand Down Expand Up @@ -1713,6 +1741,7 @@ def __init__(
self._named_data,
)
self._buffer: Optional[bytes] = None
self._etrecord = None

@property
def methods(self) -> Set[str]:
Expand Down Expand Up @@ -1785,6 +1814,21 @@ def buffer(self) -> bytes:
self._buffer = bytes(self._pte_data)
return self._buffer

def get_etrecord(self):
"""
Get the generated ETRecord if etrecord generation was enabled.
Returns:
ETRecord object if generation was enabled, None otherwise
Raises:
RuntimeError: if ETRecord object was not generated.
"""

if self._etrecord is None:
raise RuntimeError("ETRecord was not generated")
return self._etrecord

def write_to_file(self, open_file: io.BufferedIOBase) -> None:
"""
Writes the serialized ExecuTorch binary to the file at `open_file`. Prefer to use this over
Expand Down
Loading