From 0735510a6ce1fb58834e72e0c1fd0fbe57b025eb Mon Sep 17 00:00:00 2001 From: Juntian Liu Date: Mon, 23 Jun 2025 10:23:57 -0700 Subject: [PATCH] Add inspector numeric gap calculation between AOT and runtime intermediate outputs Summary: This PR introduces a method to calculate the numeric gap between logged intermediate outputs from an exported graph and runtime outputs. The method currently supports MSE and L1 distance metrics for comparison. It maps corresponding intermediate outputs from both stages and computes the numerical gaps, returning the results in a pandas DataFrame. This enhancement aids in identifying discrepancies between AOT intermediate outputs and actual intermediate outputs during runtime. Reviewed By: Gasoonjia Differential Revision: D76831086 --- devtools/inspector/TARGETS | 1 + devtools/inspector/_inspector.py | 52 ++++++++++++++ devtools/inspector/_inspector_utils.py | 23 ++++-- .../inspector/numerical_comparator/TARGETS | 4 +- devtools/inspector/tests/inspector_test.py | 71 +++++++++++++++++++ 5 files changed, 142 insertions(+), 9 deletions(-) diff --git a/devtools/inspector/TARGETS b/devtools/inspector/TARGETS index d8d6c20fb20..0712bdf1f9a 100644 --- a/devtools/inspector/TARGETS +++ b/devtools/inspector/TARGETS @@ -19,6 +19,7 @@ python_library( "//executorch/devtools/etrecord:etrecord", "//executorch/exir:lib", "//executorch/devtools/inspector:intermediate_output_capturer", + "//executorch/devtools/inspector/numerical_comparator:lib", ], ) diff --git a/devtools/inspector/_inspector.py b/devtools/inspector/_inspector.py index 199a740737a..dfff3d0818e 100644 --- a/devtools/inspector/_inspector.py +++ b/devtools/inspector/_inspector.py @@ -55,6 +55,7 @@ inflate_runtime_output, is_debug_output, is_inference_output_equal, + map_runtime_aot_intermediate_outputs, ProgramOutput, RESERVED_FRAMEWORK_EVENT_NAMES, TimeScale, @@ -63,6 +64,10 @@ from executorch.devtools.inspector._intermediate_output_capturer import ( IntermediateOutputCapturer, ) +from executorch.devtools.inspector.numerical_comparator import ( + L1Comparator, + MSEComparator, +) from executorch.exir import ExportedProgram @@ -1337,3 +1342,50 @@ def get_exported_program( if graph is None else self._etrecord.graph_map.get(graph) ) + + def calculate_numeric_gap(self, distance: str = "MSE") -> pd.DataFrame: + """ + Compares logged intermediate outputs from the exported graph (in ETRecord) + with runtime outputs (in ETDump) using a user-specific numerical comparator. + + Args: + distance: the metrics the inspector will use for gap calculation. Should be one of "MSE", "L1" and "SNR". + + Returns: + pd.DataFrame: A DataFrame listing corresponding operator outputs from + both stages and their computed numerical gaps. + """ + if self._aot_intermediate_outputs is None: + raise ValueError( + "The aot intermediate outputs is required but not populated." + ) + mapping = map_runtime_aot_intermediate_outputs( + self._aot_intermediate_outputs, self._get_runtime_intermediate_outputs() + ) + metric = distance.strip().upper() + if metric == "MSE": + comparator = MSEComparator() + elif metric == "L1": + comparator = L1Comparator() + else: + raise ValueError(f"Unsupported distance metric {distance!r}") + + rows = [] + for (aot_debug_handle, aot_intermediate_output), ( + runtime_debug_handle, + runtime_intermediate_output, + ) in mapping.items(): + if aot_intermediate_output is None or runtime_intermediate_output is None: + continue + rows.append( + { + "aot_debug_handle": aot_debug_handle, + "aot_intermediate_output": aot_intermediate_output, + "runtime_debug_handle": runtime_debug_handle, + "runtime_intermediate_output": runtime_intermediate_output, + "gap": comparator.compare( + aot_intermediate_output, runtime_intermediate_output + ), + } + ) + return pd.DataFrame(rows) diff --git a/devtools/inspector/_inspector_utils.py b/devtools/inspector/_inspector_utils.py index 4d468adbccb..21d627d4eba 100644 --- a/devtools/inspector/_inspector_utils.py +++ b/devtools/inspector/_inspector_utils.py @@ -8,6 +8,7 @@ import math import sys +from collections.abc import Sequence from dataclasses import dataclass from enum import Enum from typing import Any, Dict, IO, List, Mapping, Optional, Tuple, TypeAlias, Union @@ -676,17 +677,25 @@ def map_runtime_aot_intermediate_outputs( # Map only if both AOT and runtime data are present. if len(aot_list) != 0 and len(runtime_list) != 0: # Combine aot debug handles into a single key - aot_combined_debug_handle, aot_output = ( + aot_combined_debug_handle, aot_intermediate_output = ( _combine_overlapped_intermediate_outputs(aot_list) ) # Combine runtime debug handles into a single key - runtime_combined_debug_handle, runtime_output = ( + runtime_combined_debug_handle, runtime_intermediate_output = ( _combine_overlapped_intermediate_outputs(runtime_list) ) + # List can't be used as a key, so convert to tuple + if isinstance(aot_intermediate_output, list): + aot_intermediate_output = tuple(aot_intermediate_output) + # runtime follow the same format as aot, so it's safe to convert to tuple + if isinstance(runtime_intermediate_output, list): + runtime_intermediate_output = tuple(runtime_intermediate_output) # Create a mapping between runtime and aot - aot_runtime_mapping[(aot_combined_debug_handle, aot_output)] = ( + aot_runtime_mapping[ + (aot_combined_debug_handle, aot_intermediate_output) + ] = ( runtime_combined_debug_handle, - runtime_output, + runtime_intermediate_output, ) return aot_runtime_mapping @@ -698,7 +707,7 @@ def convert_to_float_tensor(input_data: Any) -> torch.Tensor: This function handles the following types of input: - Scalar (int or float): Converts to a tensor with a single element. - Tensor: Converts to a float64 tensor on CPU. - - List of Tensors: Stacks the tensors into a single float64 tensor on CPU. + - Sequence of Tensors: Stacks the tensors into a single float64 tensor on CPU. The resulting tensor is detached, moved to CPU, and cast to torch.float64. Parameters: input_data (Any): The input data to be converted to a tensor. It can be a scalar, @@ -709,8 +718,8 @@ def convert_to_float_tensor(input_data: Any) -> torch.Tensor: ValueError: If the input_data cannot be converted to a tensor. """ try: - # Check if the input is a list of tensors - if isinstance(input_data, list): + # Check if the input is a Sequence of tensors + if isinstance(input_data, Sequence): input_tensor = torch.stack([convert_to_float_tensor(a) for a in input_data]) # Try to convert the input to a tensor else: diff --git a/devtools/inspector/numerical_comparator/TARGETS b/devtools/inspector/numerical_comparator/TARGETS index 65fa3a53853..1c0fc8abb85 100644 --- a/devtools/inspector/numerical_comparator/TARGETS +++ b/devtools/inspector/numerical_comparator/TARGETS @@ -14,7 +14,7 @@ python_library( srcs = ["l1_numerical_comparator.py"], deps = [ "//executorch/devtools/inspector/numerical_comparator:numerical_comparator_base", - "//executorch/devtools/inspector:lib", + "//executorch/devtools/inspector:inspector_utils", ], ) @@ -23,7 +23,7 @@ python_library( srcs = ["mse_numerical_comparator.py"], deps = [ "//executorch/devtools/inspector/numerical_comparator:numerical_comparator_base", - "//executorch/devtools/inspector:lib", + "//executorch/devtools/inspector:inspector_utils", ], ) diff --git a/devtools/inspector/tests/inspector_test.py b/devtools/inspector/tests/inspector_test.py index b96a694b581..1460dbd46a2 100644 --- a/devtools/inspector/tests/inspector_test.py +++ b/devtools/inspector/tests/inspector_test.py @@ -17,6 +17,8 @@ from unittest.mock import patch +import pandas as pd + import torch import torch.fx @@ -578,6 +580,75 @@ def test_get_runtime_intermediate_outputs(self): self.assertIn((key,), runtime_outputs) self.assertEqual(len(runtime_outputs[(key,)]), RAW_DATA_SIZE) + def test_calculate_numeric_gap(self): + # Create a context manager to patch functions called by Inspector.__init__ + with patch.object( + _inspector, "parse_etrecord", return_value=None + ), patch.object( + _inspector, "gen_etdump_object", return_value=None + ), patch.object( + EventBlock, "_gen_from_etdump" + ), patch.object( + _inspector, "gen_graphs_from_etrecord" + ): + # Call the constructor of Inspector + inspector_instance = Inspector( + etdump_path=ETDUMP_PATH, + etrecord=ETRECORD_PATH, + ) + + aot_intermediate_outputs = { + (0,): torch.tensor([1.0, 2.0, 3.0]), + (1,): torch.tensor([4.0, 5.0, 6.0]), + } + + runtime_intermediate_outputs = { + (0,): torch.tensor([2.0, 1.0, 4.0]), + (1,): torch.tensor([3.0, 6.0, 5.0]), + } + + inspector_instance._aot_intermediate_outputs = aot_intermediate_outputs + inspector_instance._get_runtime_intermediate_outputs = ( + lambda: runtime_intermediate_outputs + ) + + df = inspector_instance.calculate_numeric_gap(distance="L1") + self.assertIsInstance(df, pd.DataFrame) + self.assertEqual(len(df), 2) + cols = set(df.columns) + expected_cols = { + "aot_debug_handle", + "aot_intermediate_output", + "runtime_debug_handle", + "runtime_intermediate_output", + "gap", + } + self.assertEqual(cols, expected_cols) + founded_aot_debug_handle = set(df["aot_debug_handle"]) + self.assertEqual( + founded_aot_debug_handle, set(aot_intermediate_outputs.keys()) + ) + for _, row in df.iterrows(): + aot_debuh_handle = row["aot_debug_handle"] + # aot_intermediate_output should equal aot_intermediate_outputs[h] + self.assertTrue( + torch.allclose( + row["aot_intermediate_output"], + aot_intermediate_outputs[aot_debuh_handle], + ) + ) + # runtime_debug_hanlde equals aot_debug_handle at this case + self.assertEqual(row["runtime_debug_handle"], aot_debuh_handle) + # runtime_intermediate_output should equal runtime_intermediate_outputs[h] + self.assertTrue( + torch.allclose( + row["runtime_intermediate_output"], + runtime_intermediate_outputs[aot_debuh_handle], + ) + ) + # gap should equal 3.0 + self.assertEqual(row["gap"], 3.0) + def _gen_random_float_list(self) -> List[float]: return [random.uniform(0, 10) for _ in range(RAW_DATA_SIZE)]