diff --git a/devtools/inspector/TARGETS b/devtools/inspector/TARGETS index d8d6c20fb20..0712bdf1f9a 100644 --- a/devtools/inspector/TARGETS +++ b/devtools/inspector/TARGETS @@ -19,6 +19,7 @@ python_library( "//executorch/devtools/etrecord:etrecord", "//executorch/exir:lib", "//executorch/devtools/inspector:intermediate_output_capturer", + "//executorch/devtools/inspector/numerical_comparator:lib", ], ) diff --git a/devtools/inspector/_inspector.py b/devtools/inspector/_inspector.py index 199a740737a..dfff3d0818e 100644 --- a/devtools/inspector/_inspector.py +++ b/devtools/inspector/_inspector.py @@ -55,6 +55,7 @@ inflate_runtime_output, is_debug_output, is_inference_output_equal, + map_runtime_aot_intermediate_outputs, ProgramOutput, RESERVED_FRAMEWORK_EVENT_NAMES, TimeScale, @@ -63,6 +64,10 @@ from executorch.devtools.inspector._intermediate_output_capturer import ( IntermediateOutputCapturer, ) +from executorch.devtools.inspector.numerical_comparator import ( + L1Comparator, + MSEComparator, +) from executorch.exir import ExportedProgram @@ -1337,3 +1342,50 @@ def get_exported_program( if graph is None else self._etrecord.graph_map.get(graph) ) + + def calculate_numeric_gap(self, distance: str = "MSE") -> pd.DataFrame: + """ + Compares logged intermediate outputs from the exported graph (in ETRecord) + with runtime outputs (in ETDump) using a user-specific numerical comparator. + + Args: + distance: the metrics the inspector will use for gap calculation. Should be one of "MSE", "L1" and "SNR". + + Returns: + pd.DataFrame: A DataFrame listing corresponding operator outputs from + both stages and their computed numerical gaps. + """ + if self._aot_intermediate_outputs is None: + raise ValueError( + "The aot intermediate outputs is required but not populated." + ) + mapping = map_runtime_aot_intermediate_outputs( + self._aot_intermediate_outputs, self._get_runtime_intermediate_outputs() + ) + metric = distance.strip().upper() + if metric == "MSE": + comparator = MSEComparator() + elif metric == "L1": + comparator = L1Comparator() + else: + raise ValueError(f"Unsupported distance metric {distance!r}") + + rows = [] + for (aot_debug_handle, aot_intermediate_output), ( + runtime_debug_handle, + runtime_intermediate_output, + ) in mapping.items(): + if aot_intermediate_output is None or runtime_intermediate_output is None: + continue + rows.append( + { + "aot_debug_handle": aot_debug_handle, + "aot_intermediate_output": aot_intermediate_output, + "runtime_debug_handle": runtime_debug_handle, + "runtime_intermediate_output": runtime_intermediate_output, + "gap": comparator.compare( + aot_intermediate_output, runtime_intermediate_output + ), + } + ) + return pd.DataFrame(rows) diff --git a/devtools/inspector/_inspector_utils.py b/devtools/inspector/_inspector_utils.py index 4d468adbccb..21d627d4eba 100644 --- a/devtools/inspector/_inspector_utils.py +++ b/devtools/inspector/_inspector_utils.py @@ -8,6 +8,7 @@ import math import sys +from collections.abc import Sequence from dataclasses import dataclass from enum import Enum from typing import Any, Dict, IO, List, Mapping, Optional, Tuple, TypeAlias, Union @@ -676,17 +677,25 @@ def map_runtime_aot_intermediate_outputs( # Map only if both AOT and runtime data are present. if len(aot_list) != 0 and len(runtime_list) != 0: # Combine aot debug handles into a single key - aot_combined_debug_handle, aot_output = ( + aot_combined_debug_handle, aot_intermediate_output = ( _combine_overlapped_intermediate_outputs(aot_list) ) # Combine runtime debug handles into a single key - runtime_combined_debug_handle, runtime_output = ( + runtime_combined_debug_handle, runtime_intermediate_output = ( _combine_overlapped_intermediate_outputs(runtime_list) ) + # List can't be used as a key, so convert to tuple + if isinstance(aot_intermediate_output, list): + aot_intermediate_output = tuple(aot_intermediate_output) + # runtime follow the same format as aot, so it's safe to convert to tuple + if isinstance(runtime_intermediate_output, list): + runtime_intermediate_output = tuple(runtime_intermediate_output) # Create a mapping between runtime and aot - aot_runtime_mapping[(aot_combined_debug_handle, aot_output)] = ( + aot_runtime_mapping[ + (aot_combined_debug_handle, aot_intermediate_output) + ] = ( runtime_combined_debug_handle, - runtime_output, + runtime_intermediate_output, ) return aot_runtime_mapping @@ -698,7 +707,7 @@ def convert_to_float_tensor(input_data: Any) -> torch.Tensor: This function handles the following types of input: - Scalar (int or float): Converts to a tensor with a single element. - Tensor: Converts to a float64 tensor on CPU. - - List of Tensors: Stacks the tensors into a single float64 tensor on CPU. + - Sequence of Tensors: Stacks the tensors into a single float64 tensor on CPU. The resulting tensor is detached, moved to CPU, and cast to torch.float64. Parameters: input_data (Any): The input data to be converted to a tensor. It can be a scalar, @@ -709,8 +718,8 @@ def convert_to_float_tensor(input_data: Any) -> torch.Tensor: ValueError: If the input_data cannot be converted to a tensor. """ try: - # Check if the input is a list of tensors - if isinstance(input_data, list): + # Check if the input is a Sequence of tensors + if isinstance(input_data, Sequence): input_tensor = torch.stack([convert_to_float_tensor(a) for a in input_data]) # Try to convert the input to a tensor else: diff --git a/devtools/inspector/numerical_comparator/TARGETS b/devtools/inspector/numerical_comparator/TARGETS index 65fa3a53853..1c0fc8abb85 100644 --- a/devtools/inspector/numerical_comparator/TARGETS +++ b/devtools/inspector/numerical_comparator/TARGETS @@ -14,7 +14,7 @@ python_library( srcs = ["l1_numerical_comparator.py"], deps = [ "//executorch/devtools/inspector/numerical_comparator:numerical_comparator_base", - "//executorch/devtools/inspector:lib", + "//executorch/devtools/inspector:inspector_utils", ], ) @@ -23,7 +23,7 @@ python_library( srcs = ["mse_numerical_comparator.py"], deps = [ "//executorch/devtools/inspector/numerical_comparator:numerical_comparator_base", - "//executorch/devtools/inspector:lib", + "//executorch/devtools/inspector:inspector_utils", ], ) diff --git a/devtools/inspector/tests/inspector_test.py b/devtools/inspector/tests/inspector_test.py index b96a694b581..1460dbd46a2 100644 --- a/devtools/inspector/tests/inspector_test.py +++ b/devtools/inspector/tests/inspector_test.py @@ -17,6 +17,8 @@ from unittest.mock import patch +import pandas as pd + import torch import torch.fx @@ -578,6 +580,75 @@ def test_get_runtime_intermediate_outputs(self): self.assertIn((key,), runtime_outputs) self.assertEqual(len(runtime_outputs[(key,)]), RAW_DATA_SIZE) + def test_calculate_numeric_gap(self): + # Create a context manager to patch functions called by Inspector.__init__ + with patch.object( + _inspector, "parse_etrecord", return_value=None + ), patch.object( + _inspector, "gen_etdump_object", return_value=None + ), patch.object( + EventBlock, "_gen_from_etdump" + ), patch.object( + _inspector, "gen_graphs_from_etrecord" + ): + # Call the constructor of Inspector + inspector_instance = Inspector( + etdump_path=ETDUMP_PATH, + etrecord=ETRECORD_PATH, + ) + + aot_intermediate_outputs = { + (0,): torch.tensor([1.0, 2.0, 3.0]), + (1,): torch.tensor([4.0, 5.0, 6.0]), + } + + runtime_intermediate_outputs = { + (0,): torch.tensor([2.0, 1.0, 4.0]), + (1,): torch.tensor([3.0, 6.0, 5.0]), + } + + inspector_instance._aot_intermediate_outputs = aot_intermediate_outputs + inspector_instance._get_runtime_intermediate_outputs = ( + lambda: runtime_intermediate_outputs + ) + + df = inspector_instance.calculate_numeric_gap(distance="L1") + self.assertIsInstance(df, pd.DataFrame) + self.assertEqual(len(df), 2) + cols = set(df.columns) + expected_cols = { + "aot_debug_handle", + "aot_intermediate_output", + "runtime_debug_handle", + "runtime_intermediate_output", + "gap", + } + self.assertEqual(cols, expected_cols) + founded_aot_debug_handle = set(df["aot_debug_handle"]) + self.assertEqual( + founded_aot_debug_handle, set(aot_intermediate_outputs.keys()) + ) + for _, row in df.iterrows(): + aot_debuh_handle = row["aot_debug_handle"] + # aot_intermediate_output should equal aot_intermediate_outputs[h] + self.assertTrue( + torch.allclose( + row["aot_intermediate_output"], + aot_intermediate_outputs[aot_debuh_handle], + ) + ) + # runtime_debug_hanlde equals aot_debug_handle at this case + self.assertEqual(row["runtime_debug_handle"], aot_debuh_handle) + # runtime_intermediate_output should equal runtime_intermediate_outputs[h] + self.assertTrue( + torch.allclose( + row["runtime_intermediate_output"], + runtime_intermediate_outputs[aot_debuh_handle], + ) + ) + # gap should equal 3.0 + self.assertEqual(row["gap"], 3.0) + def _gen_random_float_list(self) -> List[float]: return [random.uniform(0, 10) for _ in range(RAW_DATA_SIZE)]