Skip to content

Add inspector numeric gap calculation between AOT and runtime intermediate outputs #11855

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions devtools/inspector/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ python_library(
"//executorch/devtools/etrecord:etrecord",
"//executorch/exir:lib",
"//executorch/devtools/inspector:intermediate_output_capturer",
"//executorch/devtools/inspector/numerical_comparator:lib",
],
)

Expand Down
52 changes: 52 additions & 0 deletions devtools/inspector/_inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
inflate_runtime_output,
is_debug_output,
is_inference_output_equal,
map_runtime_aot_intermediate_outputs,
ProgramOutput,
RESERVED_FRAMEWORK_EVENT_NAMES,
TimeScale,
Expand All @@ -63,6 +64,10 @@
from executorch.devtools.inspector._intermediate_output_capturer import (
IntermediateOutputCapturer,
)
from executorch.devtools.inspector.numerical_comparator import (
L1Comparator,
MSEComparator,
)
from executorch.exir import ExportedProgram


Expand Down Expand Up @@ -1337,3 +1342,50 @@ def get_exported_program(
if graph is None
else self._etrecord.graph_map.get(graph)
)

def calculate_numeric_gap(self, distance: str = "MSE") -> pd.DataFrame:
"""
Compares logged intermediate outputs from the exported graph (in ETRecord)
with runtime outputs (in ETDump) using a user-specific numerical comparator.

Args:
distance: the metrics the inspector will use for gap calculation. Should be one of "MSE", "L1" and "SNR".

Returns:
pd.DataFrame: A DataFrame listing corresponding operator outputs from
both stages and their computed numerical gaps.
"""
if self._aot_intermediate_outputs is None:
raise ValueError(
"The aot intermediate outputs is required but not populated."
)
mapping = map_runtime_aot_intermediate_outputs(
self._aot_intermediate_outputs, self._get_runtime_intermediate_outputs()
)
metric = distance.strip().upper()
if metric == "MSE":
comparator = MSEComparator()
elif metric == "L1":
comparator = L1Comparator()
else:
raise ValueError(f"Unsupported distance metric {distance!r}")

rows = []
for (aot_debug_handle, aot_intermediate_output), (
runtime_debug_handle,
runtime_intermediate_output,
) in mapping.items():
if aot_intermediate_output is None or runtime_intermediate_output is None:
continue
rows.append(
{
"aot_debug_handle": aot_debug_handle,
"aot_intermediate_output": aot_intermediate_output,
"runtime_debug_handle": runtime_debug_handle,
"runtime_intermediate_output": runtime_intermediate_output,
"gap": comparator.compare(
aot_intermediate_output, runtime_intermediate_output
),
}
)
return pd.DataFrame(rows)
23 changes: 16 additions & 7 deletions devtools/inspector/_inspector_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import math
import sys
from collections.abc import Sequence
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, IO, List, Mapping, Optional, Tuple, TypeAlias, Union
Expand Down Expand Up @@ -676,17 +677,25 @@ def map_runtime_aot_intermediate_outputs(
# Map only if both AOT and runtime data are present.
if len(aot_list) != 0 and len(runtime_list) != 0:
# Combine aot debug handles into a single key
aot_combined_debug_handle, aot_output = (
aot_combined_debug_handle, aot_intermediate_output = (
_combine_overlapped_intermediate_outputs(aot_list)
)
# Combine runtime debug handles into a single key
runtime_combined_debug_handle, runtime_output = (
runtime_combined_debug_handle, runtime_intermediate_output = (
_combine_overlapped_intermediate_outputs(runtime_list)
)
# List can't be used as a key, so convert to tuple
if isinstance(aot_intermediate_output, list):
aot_intermediate_output = tuple(aot_intermediate_output)
# runtime follow the same format as aot, so it's safe to convert to tuple
if isinstance(runtime_intermediate_output, list):
runtime_intermediate_output = tuple(runtime_intermediate_output)
# Create a mapping between runtime and aot
aot_runtime_mapping[(aot_combined_debug_handle, aot_output)] = (
aot_runtime_mapping[
(aot_combined_debug_handle, aot_intermediate_output)
] = (
runtime_combined_debug_handle,
runtime_output,
runtime_intermediate_output,
)

return aot_runtime_mapping
Expand All @@ -698,7 +707,7 @@ def convert_to_float_tensor(input_data: Any) -> torch.Tensor:
This function handles the following types of input:
- Scalar (int or float): Converts to a tensor with a single element.
- Tensor: Converts to a float64 tensor on CPU.
- List of Tensors: Stacks the tensors into a single float64 tensor on CPU.
- Sequence of Tensors: Stacks the tensors into a single float64 tensor on CPU.
The resulting tensor is detached, moved to CPU, and cast to torch.float64.
Parameters:
input_data (Any): The input data to be converted to a tensor. It can be a scalar,
Expand All @@ -709,8 +718,8 @@ def convert_to_float_tensor(input_data: Any) -> torch.Tensor:
ValueError: If the input_data cannot be converted to a tensor.
"""
try:
# Check if the input is a list of tensors
if isinstance(input_data, list):
# Check if the input is a Sequence of tensors
if isinstance(input_data, Sequence):
input_tensor = torch.stack([convert_to_float_tensor(a) for a in input_data])
# Try to convert the input to a tensor
else:
Expand Down
4 changes: 2 additions & 2 deletions devtools/inspector/numerical_comparator/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ python_library(
srcs = ["l1_numerical_comparator.py"],
deps = [
"//executorch/devtools/inspector/numerical_comparator:numerical_comparator_base",
"//executorch/devtools/inspector:lib",
"//executorch/devtools/inspector:inspector_utils",
],
)

Expand All @@ -23,7 +23,7 @@ python_library(
srcs = ["mse_numerical_comparator.py"],
deps = [
"//executorch/devtools/inspector/numerical_comparator:numerical_comparator_base",
"//executorch/devtools/inspector:lib",
"//executorch/devtools/inspector:inspector_utils",
],
)

Expand Down
71 changes: 71 additions & 0 deletions devtools/inspector/tests/inspector_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

from unittest.mock import patch

import pandas as pd

import torch
import torch.fx

Expand Down Expand Up @@ -578,6 +580,75 @@ def test_get_runtime_intermediate_outputs(self):
self.assertIn((key,), runtime_outputs)
self.assertEqual(len(runtime_outputs[(key,)]), RAW_DATA_SIZE)

def test_calculate_numeric_gap(self):
# Create a context manager to patch functions called by Inspector.__init__
with patch.object(
_inspector, "parse_etrecord", return_value=None
), patch.object(
_inspector, "gen_etdump_object", return_value=None
), patch.object(
EventBlock, "_gen_from_etdump"
), patch.object(
_inspector, "gen_graphs_from_etrecord"
):
# Call the constructor of Inspector
inspector_instance = Inspector(
etdump_path=ETDUMP_PATH,
etrecord=ETRECORD_PATH,
)

aot_intermediate_outputs = {
(0,): torch.tensor([1.0, 2.0, 3.0]),
(1,): torch.tensor([4.0, 5.0, 6.0]),
}

runtime_intermediate_outputs = {
(0,): torch.tensor([2.0, 1.0, 4.0]),
(1,): torch.tensor([3.0, 6.0, 5.0]),
}

inspector_instance._aot_intermediate_outputs = aot_intermediate_outputs
inspector_instance._get_runtime_intermediate_outputs = (
lambda: runtime_intermediate_outputs
)

df = inspector_instance.calculate_numeric_gap(distance="L1")
self.assertIsInstance(df, pd.DataFrame)
self.assertEqual(len(df), 2)
cols = set(df.columns)
expected_cols = {
"aot_debug_handle",
"aot_intermediate_output",
"runtime_debug_handle",
"runtime_intermediate_output",
"gap",
}
self.assertEqual(cols, expected_cols)
founded_aot_debug_handle = set(df["aot_debug_handle"])
self.assertEqual(
founded_aot_debug_handle, set(aot_intermediate_outputs.keys())
)
for _, row in df.iterrows():
aot_debuh_handle = row["aot_debug_handle"]
# aot_intermediate_output should equal aot_intermediate_outputs[h]
self.assertTrue(
torch.allclose(
row["aot_intermediate_output"],
aot_intermediate_outputs[aot_debuh_handle],
)
)
# runtime_debug_hanlde equals aot_debug_handle at this case
self.assertEqual(row["runtime_debug_handle"], aot_debuh_handle)
# runtime_intermediate_output should equal runtime_intermediate_outputs[h]
self.assertTrue(
torch.allclose(
row["runtime_intermediate_output"],
runtime_intermediate_outputs[aot_debuh_handle],
)
)
# gap should equal 3.0
self.assertEqual(row["gap"], 3.0)

def _gen_random_float_list(self) -> List[float]:
return [random.uniform(0, 10) for _ in range(RAW_DATA_SIZE)]

Expand Down
Loading