Skip to content

Commit 321bc98

Browse files
Juntian777facebook-github-bot
authored andcommitted
Updated the comparison logic to handle sequences separately (#12251)
Summary: Pull Request resolved: #12251 Previously, the numerical comparators were designed to compare two inputs regardless of whether they were sequences, which involved stacking a list of tensors into one for comparison. The updated logic now restricts comparators to only compare two tensors at a time, with sequence handling managed externally. Reviewed By: Gasoonjia Differential Revision: D77893628
1 parent 264ac90 commit 321bc98

File tree

7 files changed

+67
-52
lines changed

7 files changed

+67
-52
lines changed

devtools/inspector/_inspector.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from executorch.devtools.etrecord import ETRecord, parse_etrecord
4343
from executorch.devtools.inspector._inspector_utils import (
4444
calculate_time_scale_factor,
45+
compare_intermediate_outputs,
4546
create_debug_handle_to_op_node_mapping,
4647
DebugHandle,
4748
display_or_print_df,
@@ -1415,8 +1416,8 @@ def calculate_numeric_gap(self, distance: str = "MSE") -> pd.DataFrame:
14151416
runtime_debug_handle, runtime_debug_handle_to_op_name
14161417
),
14171418
"runtime_intermediate_output": runtime_intermediate_output,
1418-
"gap": comparator.compare(
1419-
aot_intermediate_output, runtime_intermediate_output
1419+
"gap": compare_intermediate_outputs(
1420+
aot_intermediate_output, runtime_intermediate_output, comparator
14201421
),
14211422
}
14221423
)

devtools/inspector/_inspector_utils.py

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -762,32 +762,29 @@ def convert_to_float_tensor(input_data: Any) -> torch.Tensor:
762762
This function handles the following types of input:
763763
- Scalar (int or float): Converts to a tensor with a single element.
764764
- Tensor: Converts to a float64 tensor on CPU.
765-
- Sequence of Tensors: Stacks the tensors into a single float64 tensor on CPU.
766765
The resulting tensor is detached, moved to CPU, and cast to torch.float64.
767766
Parameters:
768-
input_data (Any): The input data to be converted to a tensor. It can be a scalar,
769-
a tensor, or a list of tensors.
767+
input_data (Any): The input data to be converted to a tensor. It can be a scalar
768+
or a tensor.
770769
Returns:
771770
torch.Tensor: A tensor on CPU with dtype torch.float64.
772-
Raises:
773-
ValueError: If the input_data cannot be converted to a tensor.
771+
Raises error if the input is not a scalar or a tensor
774772
"""
773+
# Assert that the input is not a Sequence
774+
assert not isinstance(input_data, Sequence)
775775
try:
776-
# Check if the input is a Sequence of tensors
777-
if isinstance(input_data, Sequence):
778-
input_tensor = torch.stack([convert_to_float_tensor(a) for a in input_data])
779776
# Try to convert the input to a tensor
780-
else:
781-
input_tensor = torch.as_tensor(input_data, dtype=torch.float64)
777+
input_tensor = torch.as_tensor(input_data, dtype=torch.float64)
782778
except Exception as e:
783779
raise ValueError(
784780
f"Cannot convert value of type {type(input_data)} to a tensor: {e}"
785781
)
786-
input_tensor = input_tensor.detach().cpu().double()
787782

783+
input_tensor = input_tensor.detach().cpu().double()
788784
# Convert NaN to 0.0
789785
if torch.isnan(input_tensor).any():
790786
input_tensor = torch.nan_to_num(input_tensor)
787+
791788
return input_tensor
792789

793790

@@ -837,3 +834,33 @@ def find_op_names(
837834
result.append(op_name)
838835

839836
return result
837+
838+
839+
def compare_intermediate_outputs(a: Any, b: Any, comparator) -> List[float]:
840+
"""
841+
Compare two outputs, handling both sequence and non-sequence cases,
842+
and return a list of comparison results.
843+
Parameters:
844+
a: The first intermediate output to compare.
845+
b: The second intermediate output to compare.
846+
comparator: A comparator object with a `compare` method.
847+
Returns:
848+
List[float]: A list of comparison results.
849+
Raises:
850+
ValueError: If one input is a sequence and the other is not, or if sequences have different lengths.
851+
"""
852+
is_a_sequence = isinstance(a, Sequence)
853+
is_b_sequence = isinstance(b, Sequence)
854+
if is_a_sequence and is_b_sequence:
855+
# Ensure both sequences have the same length
856+
if len(a) != len(b):
857+
raise ValueError("Sequences must have the same length for comparison.")
858+
859+
# Compare each element in the sequences and return the list of results
860+
return [comparator.compare(x, y) for x, y in zip(a, b)]
861+
elif not is_a_sequence and not is_b_sequence:
862+
# Compare non-sequence items and return the result in a list
863+
return [comparator.compare(a, b)]
864+
else:
865+
# Raise an error if one is a sequence and the other is not
866+
raise ValueError("Both inputs must be sequences or both must be non-sequences.")

devtools/inspector/tests/inspector_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -636,22 +636,22 @@ def test_calculate_numeric_gap(self):
636636
for i, row in df.iterrows():
637637
# Dummpy key to get the expected aot/runtime internmediate outputs
638638
key = (i,)
639-
# aot_intermediate_output should equal aot_intermediate_outputs[h]
639+
# aot_intermediate_output should equal aot_intermediate_outputs[key]
640640
self.assertTrue(
641641
torch.allclose(
642642
row["aot_intermediate_output"],
643643
aot_intermediate_outputs[key],
644644
)
645645
)
646-
# runtime_intermediate_output should equal runtime_intermediate_outputs[h]
646+
# runtime_intermediate_output should equal runtime_intermediate_outputs[key]
647647
self.assertTrue(
648648
torch.allclose(
649649
row["runtime_intermediate_output"],
650650
runtime_intermediate_outputs[key],
651651
)
652652
)
653653
# gap should equal 3.0
654-
self.assertEqual(row["gap"], 3.0)
654+
self.assertEqual(row["gap"][0], 3.0)
655655

656656
def _gen_random_float_list(self) -> List[float]:
657657
return [random.uniform(0, 10) for _ in range(RAW_DATA_SIZE)]

devtools/inspector/tests/inspector_utils_test.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
calculate_mse,
3030
calculate_snr,
3131
calculate_time_scale_factor,
32+
compare_intermediate_outputs,
3233
convert_to_float_tensor,
3334
create_debug_handle_to_op_node_mapping,
3435
EDGE_DIALECT_GRAPH_KEY,
@@ -42,6 +43,7 @@
4243
NodeFilter,
4344
TimeScale,
4445
)
46+
from executorch.devtools.inspector.numerical_comparator import L1Comparator
4547

4648

4749
class TestInspectorUtils(unittest.TestCase):
@@ -420,19 +422,10 @@ def test_convert_input_to_tensor_convertible_inputs(self):
420422
)
421423
self.assertEqual(actual_output2.device.type, "cpu")
422424

423-
# List of tensors -> stacked tensor float32 CPU
425+
# List of tensors -> AssertionError
424426
t_list = [torch.tensor([1, 2]), torch.tensor([2, 3]), torch.tensor([3, 4])]
425-
actual_output3 = convert_to_float_tensor(t_list)
426-
self.assertIsInstance(actual_output3, torch.Tensor)
427-
self.assertEqual(actual_output3.dtype, torch.float64)
428-
self.assertEqual(tuple(actual_output3.shape), (3, 2))
429-
self.assertTrue(
430-
torch.allclose(
431-
actual_output3,
432-
torch.tensor([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]], dtype=torch.float64),
433-
)
434-
)
435-
self.assertEqual(actual_output3.device.type, "cpu")
427+
with self.assertRaises(AssertionError):
428+
convert_to_float_tensor(t_list)
436429

437430
def test_convert_input_to_tensor_non_convertible_raises(self):
438431
class X:
@@ -566,6 +559,24 @@ def test_find_op_names_matching_handles(self):
566559
find_op_names(debug_handle, debug_handle_to_op_name), ["op1", "op2"]
567560
)
568561

562+
def test_compare_intermediate_outputs_sequences(self):
563+
a = [1.0, 2.0, 3.0]
564+
b = [1.0, 2.5, 3.5]
565+
result = compare_intermediate_outputs(a, b, L1Comparator())
566+
self.assertEqual(result, [0.0, 0.5, 0.5])
567+
568+
def test_compare_intermediate_outputs_diff_len_sequences(self):
569+
a = [1.0, 2.0]
570+
b = [1.0, 2.0, 3.0]
571+
with self.assertRaises(ValueError):
572+
compare_intermediate_outputs(a, b, L1Comparator())
573+
574+
def test_compare_intermediate_outputs_sequence_and_non_sequence(self):
575+
a = [1.0, 2.0]
576+
b = 1.0
577+
with self.assertRaises(ValueError):
578+
compare_intermediate_outputs(a, b, L1Comparator())
579+
569580

570581
def gen_mock_operator_graph_with_expected_map() -> (
571582
Tuple[OperatorGraph, Dict[int, OperatorNode]]

devtools/inspector/tests/l1_comparator_test.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,3 @@ def test_2D_tensors(self):
4747
expected = 14.0
4848
result = self.l1_comparator.compare(a, b)
4949
self.assertAlmostEqual(result, expected)
50-
51-
def test_list_of_tensors(self):
52-
a = [torch.tensor([2, 4]), torch.tensor([5, 2])]
53-
b = [torch.tensor([1, 2]), torch.tensor([3, 5])]
54-
expected = 8.0
55-
result = self.l1_comparator.compare(a, b)
56-
self.assertAlmostEqual(result, expected)

devtools/inspector/tests/mse_comparator_test.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,3 @@ def test_2D_tensors(self):
4747
expected = (9.0 + 49.0 + 9.0 + 36.0) / 4.0
4848
result = self.mse_comparator.compare(a, b)
4949
self.assertAlmostEqual(result, expected)
50-
51-
def test_list_of_tensors(self):
52-
a = [torch.tensor([2, 4]), torch.tensor([15, 2])]
53-
b = [torch.tensor([1, 2]), torch.tensor([9, 5])]
54-
expected = (1.0 + 4.0 + 36.0 + 9.0) / 4.0
55-
result = self.mse_comparator.compare(a, b)
56-
self.assertAlmostEqual(result, expected)

devtools/inspector/tests/snr_comparator_test.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,3 @@ def test_2D_tensors(self):
5050
expected = 10 * math.log10(37.25 / 17.0)
5151
result = self.snr_comparator.compare(a, b)
5252
self.assertAlmostEqual(result, expected)
53-
54-
def test_list_of_tensors(self):
55-
# original_power = mean(4, 16, 25, 4]) = 12.25
56-
# error = a - b = [1, 2, 2, -3] squared = [1, 4, 4, 9] mean = 18/4 = 4.5
57-
# SNR = 10 * log10(37.25/17.0)
58-
a = [torch.tensor([2, 4]), torch.tensor([5, 2])]
59-
b = [torch.tensor([1, 2]), torch.tensor([3, 5])]
60-
expected = 10 * math.log10(12.25 / 4.5)
61-
result = self.snr_comparator.compare(a, b)
62-
self.assertAlmostEqual(result, expected)

0 commit comments

Comments
 (0)