Skip to content

Commit 1466826

Browse files
authored
Implemented Runtime Intermediate Output Extraction Based on Corresponding AOT Operators
Differential Revision: D77712318 Pull Request resolved: #12212
1 parent 6da7bde commit 1466826

File tree

3 files changed

+70
-3
lines changed

3 files changed

+70
-3
lines changed

devtools/inspector/_inspector_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -732,6 +732,19 @@ def map_runtime_aot_intermediate_outputs(
732732
# runtime follow the same format as aot, so it's safe to convert to tuple
733733
if isinstance(runtime_intermediate_output, list):
734734
runtime_intermediate_output = tuple(runtime_intermediate_output)
735+
736+
# Currently, runtime_intermediate_output logs all delegate call arguments.
737+
# Process here to extract only the outputs.
738+
if isinstance(aot_intermediate_output, tuple):
739+
# If both are sequences, slice runtime_intermediate_output to match the length of aot_intermediate_output
740+
if isinstance(runtime_intermediate_output, tuple):
741+
runtime_intermediate_output = runtime_intermediate_output[
742+
-len(aot_intermediate_output) :
743+
]
744+
# If aot_intermediate_output is not a sequence but runtime_intermediate_output is, get the last element
745+
elif isinstance(runtime_intermediate_output, tuple):
746+
runtime_intermediate_output = runtime_intermediate_output[-1]
747+
735748
# Create a mapping between runtime and aot
736749
aot_runtime_mapping[
737750
(aot_combined_debug_handle, aot_intermediate_output)

devtools/inspector/tests/inspector_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -571,15 +571,15 @@ def test_get_runtime_intermediate_outputs_and_op_names(self):
571571
self.assertIn((4,), runtime_outputs)
572572
self.assertIn((4,), op_names)
573573
self.assertTrue(
574-
torch.equal(runtime_outputs[(4,)][0], torch.tensor([4.0, 5.0, 6.0]))
574+
torch.allclose(runtime_outputs[(4,)][0], torch.tensor([4.0, 5.0, 6.0]))
575575
)
576576
self.assertEqual(op_names[(4,)], "op_3")
577577

578578
# Check that keys (5,) to (8,) are in the dictionary and have values of the correct size
579579
for key in range(5, 9):
580580
self.assertIn((key,), runtime_outputs)
581581
self.assertIn((key,), op_names)
582-
self.assertEqual(len(runtime_outputs[(key,)]), RAW_DATA_SIZE)
582+
self.assertEqual(runtime_outputs[(key,)][0].size(0), RAW_DATA_SIZE)
583583
self.assertEqual(op_names[(key,)], f"op_{key-1}")
584584

585585
def test_calculate_numeric_gap(self):
@@ -659,7 +659,7 @@ def _gen_random_float_list(self) -> List[float]:
659659
def _gen_random_runtime_output(
660660
self,
661661
) -> List[Union[None, List[torch.Tensor], bool, float, int, str, torch.Tensor]]:
662-
return list(torch.randn(RAW_DATA_SIZE))
662+
return [torch.randn(RAW_DATA_SIZE)]
663663

664664
def _gen_random_events(self) -> List[Event]:
665665
events = []

devtools/inspector/tests/inspector_utils_test.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,60 @@ def test_map_runtime_aot_intermediate_outputs_complex_chain(self):
343343
expected = {((1, 2, 3, 4, 5, 6), 300): ((2, 3, 4, 5, 6, 7), 350)}
344344
self.assertEqual(actual, expected)
345345

346+
def test_map_runtime_aot_intermediate_outputs_delegated(self):
347+
# Currently, runtime_intermediate_output logs all delegate call arguments
348+
# Test that the map function correctly extracted out the delegated outputs
349+
aot_intermediate_outputs = {
350+
(1, 2): torch.tensor([4, 5]),
351+
(3, 4): torch.tensor([10, 11, 12]),
352+
(5, 6): torch.tensor([13, 14, 15, 16, 17]),
353+
}
354+
runtime_intermediate_outputs = {
355+
(1, 2): [torch.tensor([1, 2, 3]), torch.tensor([4, 5])],
356+
(3, 4): [
357+
torch.tensor([6, 7, 8, 9]),
358+
torch.tensor(1),
359+
torch.tensor([10, 11, 12]),
360+
],
361+
(5, 6): [
362+
torch.tensor([1]),
363+
torch.tensor([2]),
364+
torch.tensor([13, 14, 15, 16, 17]),
365+
],
366+
}
367+
actual = map_runtime_aot_intermediate_outputs(
368+
aot_intermediate_outputs, runtime_intermediate_outputs
369+
)
370+
expected = {
371+
((1, 2), torch.tensor([4, 5])): ((1, 2), torch.tensor([4, 5])),
372+
((3, 4), torch.tensor([10, 11, 12])): ((3, 4), torch.tensor([10, 11, 12])),
373+
((5, 6), torch.tensor([13, 14, 15, 16, 17])): (
374+
(5, 6),
375+
torch.tensor([13, 14, 15, 16, 17]),
376+
),
377+
}
378+
self.assertEqual(len(actual), len(expected))
379+
380+
for (exp_aot_key, exp_aot_value), (
381+
exp_runtime_key,
382+
exp_runtime_value,
383+
) in expected.items():
384+
found = False
385+
for (act_aot_key, act_aot_value), (
386+
act_runtime_key,
387+
act_runtime_value,
388+
) in actual.items():
389+
if exp_aot_key == act_aot_key and torch.allclose(
390+
exp_aot_value, act_aot_value
391+
):
392+
found = True
393+
self.assertEqual(exp_runtime_key, act_runtime_key)
394+
self.assertTrue(
395+
torch.allclose(exp_runtime_value, act_runtime_value)
396+
)
397+
break
398+
self.assertTrue(found)
399+
346400
def test_convert_input_to_tensor_convertible_inputs(self):
347401
# Scalar -> tensor
348402
actual_output1 = convert_to_float_tensor(5)

0 commit comments

Comments
 (0)