Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions devtools/etrecord/_etrecord.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,38 @@ def _save_edge_dialect_program(
f"{base_name}_example_inputs", serialized_artifact.example_inputs
)

def add_extra_export_modules(
self,
extra_recorded_export_modules: Dict[
str,
Union[
ExportedProgram,
ExirExportedProgram,
EdgeProgramManager,
],
],
) -> None:
"""
Add extra export modules to the ETRecord after it has been created.

This method allows users to add more export modules they want to record
to an existing ETRecord instance. The modules will be added to the graph_map
and will be included when the ETRecord is saved.

Args:
extra_recorded_export_modules: A dictionary of graph modules with the key being
the user provided name and the value being the corresponding exported module.
The exported graph modules can be either the output of `torch.export()` or `exir.to_edge()`.
"""
if self.graph_map is None:
self.graph_map = {}

# Now self.graph_map is guaranteed to be non-None
graph_map = self.graph_map
for module_name, export_module in extra_recorded_export_modules.items():
_validate_module_name(module_name)
_add_module_to_graph_map(graph_map, module_name, export_module)


def _get_reference_outputs(
bundled_program: BundledProgram,
Expand Down
67 changes: 67 additions & 0 deletions devtools/etrecord/tests/etrecord_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,73 @@ def test_etrecord_generation_with_exported_program(self):
# Validate that export_graph_id matches the expected value
self.assertEqual(etrecord.export_graph_id, expected_graph_id)

def test_add_extra_export_modules(self):
"""Test add_extra_export_modules when ETRecord already has a graph_map."""
captured_output, edge_output, et_output = self.get_test_model()

# Create an ETRecord instance with existing graph_map
initial_graph_map = {
"existing_module/forward": captured_output.exported_program
}
etrecord = ETRecord(
exported_program=captured_output.exported_program,
export_graph_id=id(captured_output.exported_program.graph),
edge_dialect_program=edge_output.exported_program,
graph_map=initial_graph_map,
_debug_handle_map=et_output.debug_handle_map,
_delegate_map=et_output.delegate_map,
)

# Verify initial state
self.assertIsNotNone(etrecord.graph_map)
self.assertIn("existing_module/forward", etrecord.graph_map)

# Create additional module to add
f2 = models.BasicSinMax()
captured_output2 = exir.capture(
f2, f2.get_random_inputs(), exir.CaptureConfig()
)

extra_modules = {
"new_module": captured_output2.exported_program,
}

# Add extra export modules
etrecord.add_extra_export_modules(extra_modules)

# Verify both existing and new modules are present
self.assertIn("existing_module/forward", etrecord.graph_map)
self.assertIn("new_module/forward", etrecord.graph_map)

# Verify the modules are correctly stored
self.check_graph_closeness(
etrecord.graph_map["existing_module/forward"],
captured_output.exported_program.graph_module,
)
self.check_graph_closeness(
etrecord.graph_map["new_module/forward"],
captured_output2.exported_program.graph_module,
)

def test_add_extra_export_modules_reserved_name_validation(self):
"""Test that add_extra_export_modules validates reserved names."""
captured_output, edge_output, et_output = self.get_test_model()

etrecord = ETRecord(
exported_program=captured_output.exported_program,
export_graph_id=id(captured_output.exported_program.graph),
edge_dialect_program=edge_output.exported_program,
_debug_handle_map=et_output.debug_handle_map,
_delegate_map=et_output.delegate_map,
)

# Test that reserved names are rejected
for reserved_name in ETRecordReservedFileNames:
with self.assertRaises(RuntimeError):
etrecord.add_extra_export_modules(
{reserved_name: captured_output.exported_program}
)

def test_etrecord_class_constructor_and_save(self):
"""Test that ETRecord class constructor and save method work correctly."""
captured_output, edge_output, et_output = self.get_test_model()
Expand Down
Loading