Skip to content

Commit a7091bf

Browse files
authored
executorch quantier numeric debugging update for recent torchao changes
Differential Revision: D76842934 Pull Request resolved: #12173
1 parent 59e0476 commit a7091bf

File tree

1 file changed

+47
-76
lines changed

1 file changed

+47
-76
lines changed

backends/xnnpack/test/quantizer/test_pt2e_quantization.py

Lines changed: 47 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,8 @@
66

77
# pyre-unsafe
88

9-
import unittest
10-
119
from collections import Counter
12-
from typing import Dict, Tuple
10+
from typing import Tuple
1311

1412
import torch
1513
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
@@ -33,27 +31,26 @@
3331
from torch.testing._internal.common_utils import (
3432
instantiate_parametrized_tests,
3533
TemporaryFileName,
36-
TestCase,
3734
)
3835
from torchao.quantization.pt2e import (
3936
allow_exported_model_train_eval,
4037
compare_results,
41-
CUSTOM_KEY,
4238
extract_results_from_loggers,
43-
generate_numeric_debug_handle,
44-
NUMERIC_DEBUG_HANDLE_KEY,
39+
FROM_NODE_KEY,
4540
prepare_for_propagation_comparison,
4641
)
4742

48-
from torchao.quantization.pt2e.graph_utils import bfs_trace_with_node_process
4943
from torchao.quantization.pt2e.quantize_pt2e import (
5044
convert_pt2e,
5145
prepare_pt2e,
5246
prepare_qat_pt2e,
5347
)
5448
from torchao.quantization.pt2e.quantizer import ComposableQuantizer, Quantizer
5549
from torchao.quantization.pt2e.quantizer.embedding_quantizer import EmbeddingQuantizer
56-
from torchao.testing.pt2e.utils import PT2EQuantizationTestCase
50+
from torchao.testing.pt2e.utils import (
51+
PT2ENumericDebuggerTestCase,
52+
PT2EQuantizationTestCase,
53+
)
5754

5855

5956
class TestQuantizePT2E(PT2EQuantizationTestCase):
@@ -495,7 +492,8 @@ def forward(self, x):
495492
for n in m.graph.nodes:
496493
if n.op == "get_attr" and "frozen_param" in n.target:
497494
for key in n.meta:
498-
self.assertEqual(n.meta[key], weight_meta[key])
495+
if key != FROM_NODE_KEY:
496+
self.assertEqual(n.meta[key], weight_meta[key])
499497

500498
def test_reentrant(self) -> None:
501499
"""Test we can safely call quantization apis multiple times"""
@@ -725,76 +723,59 @@ def test_save_load(self) -> None:
725723
instantiate_parametrized_tests(TestQuantizePT2E)
726724

727725

728-
@unittest.skip("TODO: Reenable it after debug infrature finish update")
729-
class TestNumericDebugger(TestCase):
730-
def _extract_debug_handles(self, model) -> Dict[str, int]:
731-
debug_handle_map: Dict[str, int] = {}
732-
733-
def _extract_debug_handles_from_node(node: torch.fx.Node) -> None:
734-
nonlocal debug_handle_map
735-
if (
736-
CUSTOM_KEY in node.meta
737-
and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY]
738-
):
739-
debug_handle_map[str(node)] = node.meta[CUSTOM_KEY][
740-
NUMERIC_DEBUG_HANDLE_KEY
741-
]
742-
743-
bfs_trace_with_node_process(model, _extract_debug_handles_from_node)
744-
return debug_handle_map
745-
746-
def _assert_each_node_has_debug_handle(self, model) -> None:
747-
def _assert_node_has_debug_handle(node: torch.fx.Node) -> None:
748-
self.assertTrue(
749-
CUSTOM_KEY in node.meta
750-
and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY],
751-
f"Node {node} doesn't have debug handle",
752-
)
753-
754-
bfs_trace_with_node_process(model, _assert_node_has_debug_handle)
726+
class TestXNNPACKQuantizerNumericDebugger(PT2ENumericDebuggerTestCase):
755727

756-
def test_quantize_pt2e_preserve_handle(self) -> None:
728+
def test_quantize_pt2e_preserve_handle(self):
757729
m = TestHelperModules.Conv2dThenConv1d()
758730
example_inputs = m.example_inputs()
759731
ep = export_for_training(m, example_inputs, strict=True)
760-
generate_numeric_debug_handle(ep)
761732
m = ep.module()
762733

763734
quantizer = XNNPACKQuantizer().set_global(
764735
get_symmetric_quantization_config(is_per_channel=False)
765736
)
766-
m = prepare_pt2e(m, quantizer) # pyre-ignore[6]
767-
debug_handle_map = self._extract_debug_handles(m)
768-
res_counter = Counter(debug_handle_map.values())
769-
repeated_debug_handle_ids = [1, 2, 3]
770-
# 3 ids were repeated because we copy over the id from node to its output observer
737+
m = prepare_pt2e(m, quantizer)
738+
from_node_source_map = self._extract_from_node_source(m)
739+
node_name_equip_with_output_observer = [
740+
"conv2d",
741+
"conv1d",
742+
"squeeze",
743+
]
744+
res_counter = Counter(from_node_source_map.values())
745+
repeated_from_node_source = [
746+
from_node_source_map[n_name]
747+
for n_name in node_name_equip_with_output_observer
748+
]
749+
# 3 infos were repeated because we copy over the info from node to its output observer
771750
# torch.ops.aten.conv2d.default, torch.ops.aten.squeeze.dim and torch.ops.aten.conv1d.default
772-
for dh_id in repeated_debug_handle_ids:
773-
self.assertEqual(res_counter[dh_id], 2)
751+
for from_node_source in repeated_from_node_source:
752+
self.assertEqual(res_counter[from_node_source], 2)
774753

775754
m(*example_inputs)
776755
m = convert_pt2e(m)
777-
self._assert_each_node_has_debug_handle(ep)
778-
debug_handle_map = self._extract_debug_handles(m)
779-
res_counter = Counter(debug_handle_map.values())
780-
# same set of ids where repeated, because we copy over the id from observer/fake_quant to
781-
# dequantize node
782-
repeated_debug_handle_ids = [1, 2, 3]
783-
for dh_id in repeated_debug_handle_ids:
784-
self.assertEqual(res_counter[dh_id], 2)
785-
786-
def test_extract_results_from_loggers(self) -> None:
756+
self._assert_each_node_has_from_node_source(m)
757+
from_node_source_map = self._extract_from_node_source(m)
758+
res_counter = Counter(from_node_source_map.values())
759+
# same set of infos where repeated, because we copy over the info from observer/fake_quant to
760+
# quantize/dequantize node
761+
repeated_from_node_source = [
762+
from_node_source_map[n_name]
763+
for n_name in node_name_equip_with_output_observer
764+
]
765+
for from_node_source in repeated_from_node_source:
766+
self.assertEqual(res_counter[from_node_source], 3)
767+
768+
def test_extract_results_from_loggers(self):
787769
m = TestHelperModules.Conv2dThenConv1d()
788770
example_inputs = m.example_inputs()
789771
ep = export_for_training(m, example_inputs, strict=True)
790-
generate_numeric_debug_handle(ep)
791772
m = ep.module()
792-
m_ref_logger = prepare_for_propagation_comparison(m) # pyre-ignore[6]
773+
m_ref_logger = prepare_for_propagation_comparison(m)
793774

794775
quantizer = XNNPACKQuantizer().set_global(
795776
get_symmetric_quantization_config(is_per_channel=False)
796777
)
797-
m = prepare_pt2e(m, quantizer) # pyre-ignore[6]
778+
m = prepare_pt2e(m, quantizer)
798779
m(*example_inputs)
799780
m = convert_pt2e(m)
800781
m_quant_logger = prepare_for_propagation_comparison(m)
@@ -803,29 +784,22 @@ def test_extract_results_from_loggers(self) -> None:
803784
m_quant_logger(*example_inputs)
804785
ref_results = extract_results_from_loggers(m_ref_logger)
805786
quant_results = extract_results_from_loggers(m_quant_logger)
806-
comparison_results = compare_results(
807-
ref_results,
808-
quant_results, # pyre-ignore[6]
809-
)
787+
comparison_results = compare_results(ref_results, quant_results)
810788
for node_summary in comparison_results.values():
811789
if len(node_summary.results) > 0:
812-
self.assertGreaterEqual(
813-
node_summary.results[0].sqnr,
814-
35, # pyre-ignore[6]
815-
)
790+
self.assertGreaterEqual(node_summary.results[0].sqnr, 35)
816791

817-
def test_extract_results_from_loggers_list_output(self) -> None:
792+
def test_extract_results_from_loggers_list_output(self):
818793
m = TestHelperModules.Conv2dWithSplit()
819794
example_inputs = m.example_inputs()
820795
ep = export_for_training(m, example_inputs, strict=True)
821-
generate_numeric_debug_handle(ep)
822796
m = ep.module()
823-
m_ref_logger = prepare_for_propagation_comparison(m) # pyre-ignore[6]
797+
m_ref_logger = prepare_for_propagation_comparison(m)
824798

825799
quantizer = XNNPACKQuantizer().set_global(
826800
get_symmetric_quantization_config(is_per_channel=False)
827801
)
828-
m = prepare_pt2e(m, quantizer) # pyre-ignore[6]
802+
m = prepare_pt2e(m, quantizer)
829803
m(*example_inputs)
830804
m = convert_pt2e(m)
831805
m_quant_logger = prepare_for_propagation_comparison(m)
@@ -834,15 +808,12 @@ def test_extract_results_from_loggers_list_output(self) -> None:
834808
m_quant_logger(*example_inputs)
835809
ref_results = extract_results_from_loggers(m_ref_logger)
836810
quant_results = extract_results_from_loggers(m_quant_logger)
837-
comparison_results = compare_results(
838-
ref_results,
839-
quant_results, # pyre-ignore[6]
840-
)
811+
comparison_results = compare_results(ref_results, quant_results)
841812
for node_summary in comparison_results.values():
842813
if len(node_summary.results) > 0:
843814
sqnr = node_summary.results[0].sqnr
844815
if isinstance(sqnr, list):
845816
for sqnr_i in sqnr:
846817
self.assertGreaterEqual(sqnr_i, 35)
847818
else:
848-
self.assertGreaterEqual(sqnr, 35) # pyre-ignore[6]
819+
self.assertGreaterEqual(sqnr, 35)

0 commit comments

Comments
 (0)