Skip to content

Commit 8a33c41

Browse files
Juntian777facebook-github-bot
authored andcommitted
Updated the comparison logic to handle sequences separately
Summary: Previously, the numerical comparators were designed to compare two inputs regardless of whether they were sequences, which involved stacking a list of tensors into one for comparison. The updated logic now restricts comparators to only compare two tensors at a time, with sequence handling managed externally. Differential Revision: D77893628
1 parent d952326 commit 8a33c41

File tree

4 files changed

+65
-24
lines changed

4 files changed

+65
-24
lines changed

devtools/inspector/_inspector.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from executorch.devtools.etrecord import ETRecord, parse_etrecord
4343
from executorch.devtools.inspector._inspector_utils import (
4444
calculate_time_scale_factor,
45+
compare_intermediate_outputs,
4546
create_debug_handle_to_op_node_mapping,
4647
DebugHandle,
4748
display_or_print_df,
@@ -1415,8 +1416,8 @@ def calculate_numeric_gap(self, distance: str = "MSE") -> pd.DataFrame:
14151416
runtime_debug_handle, runtime_debug_handle_to_op_name
14161417
),
14171418
"runtime_intermediate_output": runtime_intermediate_output,
1418-
"gap": comparator.compare(
1419-
aot_intermediate_output, runtime_intermediate_output
1419+
"gap": compare_intermediate_outputs(
1420+
aot_intermediate_output, runtime_intermediate_output, comparator
14201421
),
14211422
}
14221423
)

devtools/inspector/_inspector_utils.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -762,32 +762,31 @@ def convert_to_float_tensor(input_data: Any) -> torch.Tensor:
762762
This function handles the following types of input:
763763
- Scalar (int or float): Converts to a tensor with a single element.
764764
- Tensor: Converts to a float64 tensor on CPU.
765-
- Sequence of Tensors: Stacks the tensors into a single float64 tensor on CPU.
766765
The resulting tensor is detached, moved to CPU, and cast to torch.float64.
767766
Parameters:
768-
input_data (Any): The input data to be converted to a tensor. It can be a scalar,
769-
a tensor, or a list of tensors.
767+
input_data (Any): The input data to be converted to a tensor. It can be a scalar
768+
or a tensor.
770769
Returns:
771770
torch.Tensor: A tensor on CPU with dtype torch.float64.
772771
Raises:
773772
ValueError: If the input_data cannot be converted to a tensor.
773+
AssertionError: If the input_data is a Sequence.
774774
"""
775+
# Assert that the input is not a Sequence
776+
assert not isinstance(input_data, Sequence)
775777
try:
776-
# Check if the input is a Sequence of tensors
777-
if isinstance(input_data, Sequence):
778-
input_tensor = torch.stack([convert_to_float_tensor(a) for a in input_data])
779778
# Try to convert the input to a tensor
780-
else:
781-
input_tensor = torch.as_tensor(input_data, dtype=torch.float64)
779+
input_tensor = torch.as_tensor(input_data, dtype=torch.float64)
782780
except Exception as e:
783781
raise ValueError(
784782
f"Cannot convert value of type {type(input_data)} to a tensor: {e}"
785783
)
786-
input_tensor = input_tensor.detach().cpu().double()
787784

785+
input_tensor = input_tensor.detach().cpu().double()
788786
# Convert NaN to 0.0
789787
if torch.isnan(input_tensor).any():
790788
input_tensor = torch.nan_to_num(input_tensor)
789+
791790
return input_tensor
792791

793792

@@ -837,3 +836,33 @@ def find_op_names(
837836
result.append(op_name)
838837

839838
return result
839+
840+
841+
def compare_intermediate_outputs(a: Any, b: Any, comparator) -> List[float]:
842+
"""
843+
Compare two outputs, handling both sequence and non-sequence cases,
844+
and return a list of comparison results.
845+
Parameters:
846+
a: The first intermediate output to compare.
847+
b: The second intermediate output to compare.
848+
comparator: A comparator object with a `compare` method.
849+
Returns:
850+
List[float]: A list of comparison results.
851+
Raises:
852+
ValueError: If one input is a sequence and the other is not, or if sequences have different lengths.
853+
"""
854+
is_a_sequence = isinstance(a, Sequence)
855+
is_b_sequence = isinstance(b, Sequence)
856+
if is_a_sequence and is_b_sequence:
857+
# Ensure both sequences have the same length
858+
if len(a) != len(b):
859+
raise ValueError("Sequences must have the same length for comparison.")
860+
861+
# Compare each element in the sequences and return the list of results
862+
return [comparator.compare(x, y) for x, y in zip(a, b)]
863+
elif not is_a_sequence and not is_b_sequence:
864+
# Compare non-sequence items and return the result in a list
865+
return [comparator.compare(a, b)]
866+
else:
867+
# Raise an error if one is a sequence and the other is not
868+
raise ValueError("Both inputs must be sequences or both must be non-sequences.")

devtools/inspector/tests/inspector_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -651,7 +651,7 @@ def test_calculate_numeric_gap(self):
651651
)
652652
)
653653
# gap should equal 3.0
654-
self.assertEqual(row["gap"], 3.0)
654+
self.assertEqual(row["gap"][0], 3.0)
655655

656656
def _gen_random_float_list(self) -> List[float]:
657657
return [random.uniform(0, 10) for _ in range(RAW_DATA_SIZE)]

devtools/inspector/tests/inspector_utils_test.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
calculate_mse,
3030
calculate_snr,
3131
calculate_time_scale_factor,
32+
compare_intermediate_outputs,
3233
convert_to_float_tensor,
3334
create_debug_handle_to_op_node_mapping,
3435
EDGE_DIALECT_GRAPH_KEY,
@@ -42,6 +43,7 @@
4243
NodeFilter,
4344
TimeScale,
4445
)
46+
from executorch.devtools.inspector.numerical_comparator import L1Comparator
4547

4648

4749
class TestInspectorUtils(unittest.TestCase):
@@ -420,19 +422,10 @@ def test_convert_input_to_tensor_convertible_inputs(self):
420422
)
421423
self.assertEqual(actual_output2.device.type, "cpu")
422424

423-
# List of tensors -> stacked tensor float32 CPU
425+
# List of tensors -> AssertionError
424426
t_list = [torch.tensor([1, 2]), torch.tensor([2, 3]), torch.tensor([3, 4])]
425-
actual_output3 = convert_to_float_tensor(t_list)
426-
self.assertIsInstance(actual_output3, torch.Tensor)
427-
self.assertEqual(actual_output3.dtype, torch.float64)
428-
self.assertEqual(tuple(actual_output3.shape), (3, 2))
429-
self.assertTrue(
430-
torch.allclose(
431-
actual_output3,
432-
torch.tensor([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]], dtype=torch.float64),
433-
)
434-
)
435-
self.assertEqual(actual_output3.device.type, "cpu")
427+
with self.assertRaises(AssertionError):
428+
convert_to_float_tensor(t_list)
436429

437430
def test_convert_input_to_tensor_non_convertible_raises(self):
438431
class X:
@@ -566,6 +559,24 @@ def test_find_op_names_matching_handles(self):
566559
find_op_names(debug_handle, debug_handle_to_op_name), ["op1", "op2"]
567560
)
568561

562+
def test_compare_intermediate_outputs_sequences(self):
563+
a = [1.0, 2.0, 3.0]
564+
b = [1.0, 2.5, 3.5]
565+
result = compare_intermediate_outputs(a, b, L1Comparator())
566+
self.assertEqual(result, [0.0, 0.5, 0.5])
567+
568+
def test_compare_intermediate_outputs_diff_len_sequences(self):
569+
a = [1.0, 2.0]
570+
b = [1.0, 2.0, 3.0]
571+
with self.assertRaises(ValueError):
572+
compare_intermediate_outputs(a, b, L1Comparator())
573+
574+
def test_compare_intermediate_outputs_sequence_and_non_sequence(self):
575+
a = [1.0, 2.0]
576+
b = 1.0
577+
with self.assertRaises(ValueError):
578+
compare_intermediate_outputs(a, b, L1Comparator())
579+
569580

570581
def gen_mock_operator_graph_with_expected_map() -> (
571582
Tuple[OperatorGraph, Dict[int, OperatorNode]]

0 commit comments

Comments
 (0)