6
6
7
7
# pyre-unsafe
8
8
9
- import unittest
10
-
11
9
from collections import Counter
12
- from typing import Dict , Tuple
10
+ from typing import Tuple
13
11
14
12
import torch
15
13
from executorch .backends .xnnpack .quantizer .xnnpack_quantizer import (
33
31
from torch .testing ._internal .common_utils import (
34
32
instantiate_parametrized_tests ,
35
33
TemporaryFileName ,
36
- TestCase ,
37
34
)
38
35
from torchao .quantization .pt2e import (
39
36
allow_exported_model_train_eval ,
40
37
compare_results ,
41
- CUSTOM_KEY ,
42
38
extract_results_from_loggers ,
43
- generate_numeric_debug_handle ,
44
- NUMERIC_DEBUG_HANDLE_KEY ,
39
+ FROM_NODE_KEY ,
45
40
prepare_for_propagation_comparison ,
46
41
)
47
42
48
- from torchao .quantization .pt2e .graph_utils import bfs_trace_with_node_process
49
43
from torchao .quantization .pt2e .quantize_pt2e import (
50
44
convert_pt2e ,
51
45
prepare_pt2e ,
52
46
prepare_qat_pt2e ,
53
47
)
54
48
from torchao .quantization .pt2e .quantizer import ComposableQuantizer , Quantizer
55
49
from torchao .quantization .pt2e .quantizer .embedding_quantizer import EmbeddingQuantizer
56
- from torchao .testing .pt2e .utils import PT2EQuantizationTestCase
50
+ from torchao .testing .pt2e .utils import (
51
+ PT2ENumericDebuggerTestCase ,
52
+ PT2EQuantizationTestCase ,
53
+ )
57
54
58
55
59
56
class TestQuantizePT2E (PT2EQuantizationTestCase ):
@@ -495,7 +492,8 @@ def forward(self, x):
495
492
for n in m .graph .nodes :
496
493
if n .op == "get_attr" and "frozen_param" in n .target :
497
494
for key in n .meta :
498
- self .assertEqual (n .meta [key ], weight_meta [key ])
495
+ if key != FROM_NODE_KEY :
496
+ self .assertEqual (n .meta [key ], weight_meta [key ])
499
497
500
498
def test_reentrant (self ) -> None :
501
499
"""Test we can safely call quantization apis multiple times"""
@@ -725,76 +723,59 @@ def test_save_load(self) -> None:
725
723
instantiate_parametrized_tests (TestQuantizePT2E )
726
724
727
725
728
- @unittest .skip ("TODO: Reenable it after debug infrature finish update" )
729
- class TestNumericDebugger (TestCase ):
730
- def _extract_debug_handles (self , model ) -> Dict [str , int ]:
731
- debug_handle_map : Dict [str , int ] = {}
732
-
733
- def _extract_debug_handles_from_node (node : torch .fx .Node ) -> None :
734
- nonlocal debug_handle_map
735
- if (
736
- CUSTOM_KEY in node .meta
737
- and NUMERIC_DEBUG_HANDLE_KEY in node .meta [CUSTOM_KEY ]
738
- ):
739
- debug_handle_map [str (node )] = node .meta [CUSTOM_KEY ][
740
- NUMERIC_DEBUG_HANDLE_KEY
741
- ]
742
-
743
- bfs_trace_with_node_process (model , _extract_debug_handles_from_node )
744
- return debug_handle_map
745
-
746
- def _assert_each_node_has_debug_handle (self , model ) -> None :
747
- def _assert_node_has_debug_handle (node : torch .fx .Node ) -> None :
748
- self .assertTrue (
749
- CUSTOM_KEY in node .meta
750
- and NUMERIC_DEBUG_HANDLE_KEY in node .meta [CUSTOM_KEY ],
751
- f"Node { node } doesn't have debug handle" ,
752
- )
753
-
754
- bfs_trace_with_node_process (model , _assert_node_has_debug_handle )
726
+ class TestXNNPACKQuantizerNumericDebugger (PT2ENumericDebuggerTestCase ):
755
727
756
- def test_quantize_pt2e_preserve_handle (self ) -> None :
728
+ def test_quantize_pt2e_preserve_handle (self ):
757
729
m = TestHelperModules .Conv2dThenConv1d ()
758
730
example_inputs = m .example_inputs ()
759
731
ep = export_for_training (m , example_inputs , strict = True )
760
- generate_numeric_debug_handle (ep )
761
732
m = ep .module ()
762
733
763
734
quantizer = XNNPACKQuantizer ().set_global (
764
735
get_symmetric_quantization_config (is_per_channel = False )
765
736
)
766
- m = prepare_pt2e (m , quantizer ) # pyre-ignore[6]
767
- debug_handle_map = self ._extract_debug_handles (m )
768
- res_counter = Counter (debug_handle_map .values ())
769
- repeated_debug_handle_ids = [1 , 2 , 3 ]
770
- # 3 ids were repeated because we copy over the id from node to its output observer
737
+ m = prepare_pt2e (m , quantizer )
738
+ from_node_source_map = self ._extract_from_node_source (m )
739
+ node_name_equip_with_output_observer = [
740
+ "conv2d" ,
741
+ "conv1d" ,
742
+ "squeeze" ,
743
+ ]
744
+ res_counter = Counter (from_node_source_map .values ())
745
+ repeated_from_node_source = [
746
+ from_node_source_map [n_name ]
747
+ for n_name in node_name_equip_with_output_observer
748
+ ]
749
+ # 3 infos were repeated because we copy over the info from node to its output observer
771
750
# torch.ops.aten.conv2d.default, torch.ops.aten.squeeze.dim and torch.ops.aten.conv1d.default
772
- for dh_id in repeated_debug_handle_ids :
773
- self .assertEqual (res_counter [dh_id ], 2 )
751
+ for from_node_source in repeated_from_node_source :
752
+ self .assertEqual (res_counter [from_node_source ], 2 )
774
753
775
754
m (* example_inputs )
776
755
m = convert_pt2e (m )
777
- self ._assert_each_node_has_debug_handle (ep )
778
- debug_handle_map = self ._extract_debug_handles (m )
779
- res_counter = Counter (debug_handle_map .values ())
780
- # same set of ids where repeated, because we copy over the id from observer/fake_quant to
781
- # dequantize node
782
- repeated_debug_handle_ids = [1 , 2 , 3 ]
783
- for dh_id in repeated_debug_handle_ids :
784
- self .assertEqual (res_counter [dh_id ], 2 )
785
-
786
- def test_extract_results_from_loggers (self ) -> None :
756
+ self ._assert_each_node_has_from_node_source (m )
757
+ from_node_source_map = self ._extract_from_node_source (m )
758
+ res_counter = Counter (from_node_source_map .values ())
759
+ # same set of infos where repeated, because we copy over the info from observer/fake_quant to
760
+ # quantize/dequantize node
761
+ repeated_from_node_source = [
762
+ from_node_source_map [n_name ]
763
+ for n_name in node_name_equip_with_output_observer
764
+ ]
765
+ for from_node_source in repeated_from_node_source :
766
+ self .assertEqual (res_counter [from_node_source ], 3 )
767
+
768
+ def test_extract_results_from_loggers (self ):
787
769
m = TestHelperModules .Conv2dThenConv1d ()
788
770
example_inputs = m .example_inputs ()
789
771
ep = export_for_training (m , example_inputs , strict = True )
790
- generate_numeric_debug_handle (ep )
791
772
m = ep .module ()
792
- m_ref_logger = prepare_for_propagation_comparison (m ) # pyre-ignore[6]
773
+ m_ref_logger = prepare_for_propagation_comparison (m )
793
774
794
775
quantizer = XNNPACKQuantizer ().set_global (
795
776
get_symmetric_quantization_config (is_per_channel = False )
796
777
)
797
- m = prepare_pt2e (m , quantizer ) # pyre-ignore[6]
778
+ m = prepare_pt2e (m , quantizer )
798
779
m (* example_inputs )
799
780
m = convert_pt2e (m )
800
781
m_quant_logger = prepare_for_propagation_comparison (m )
@@ -803,29 +784,22 @@ def test_extract_results_from_loggers(self) -> None:
803
784
m_quant_logger (* example_inputs )
804
785
ref_results = extract_results_from_loggers (m_ref_logger )
805
786
quant_results = extract_results_from_loggers (m_quant_logger )
806
- comparison_results = compare_results (
807
- ref_results ,
808
- quant_results , # pyre-ignore[6]
809
- )
787
+ comparison_results = compare_results (ref_results , quant_results )
810
788
for node_summary in comparison_results .values ():
811
789
if len (node_summary .results ) > 0 :
812
- self .assertGreaterEqual (
813
- node_summary .results [0 ].sqnr ,
814
- 35 , # pyre-ignore[6]
815
- )
790
+ self .assertGreaterEqual (node_summary .results [0 ].sqnr , 35 )
816
791
817
- def test_extract_results_from_loggers_list_output (self ) -> None :
792
+ def test_extract_results_from_loggers_list_output (self ):
818
793
m = TestHelperModules .Conv2dWithSplit ()
819
794
example_inputs = m .example_inputs ()
820
795
ep = export_for_training (m , example_inputs , strict = True )
821
- generate_numeric_debug_handle (ep )
822
796
m = ep .module ()
823
- m_ref_logger = prepare_for_propagation_comparison (m ) # pyre-ignore[6]
797
+ m_ref_logger = prepare_for_propagation_comparison (m )
824
798
825
799
quantizer = XNNPACKQuantizer ().set_global (
826
800
get_symmetric_quantization_config (is_per_channel = False )
827
801
)
828
- m = prepare_pt2e (m , quantizer ) # pyre-ignore[6]
802
+ m = prepare_pt2e (m , quantizer )
829
803
m (* example_inputs )
830
804
m = convert_pt2e (m )
831
805
m_quant_logger = prepare_for_propagation_comparison (m )
@@ -834,15 +808,12 @@ def test_extract_results_from_loggers_list_output(self) -> None:
834
808
m_quant_logger (* example_inputs )
835
809
ref_results = extract_results_from_loggers (m_ref_logger )
836
810
quant_results = extract_results_from_loggers (m_quant_logger )
837
- comparison_results = compare_results (
838
- ref_results ,
839
- quant_results , # pyre-ignore[6]
840
- )
811
+ comparison_results = compare_results (ref_results , quant_results )
841
812
for node_summary in comparison_results .values ():
842
813
if len (node_summary .results ) > 0 :
843
814
sqnr = node_summary .results [0 ].sqnr
844
815
if isinstance (sqnr , list ):
845
816
for sqnr_i in sqnr :
846
817
self .assertGreaterEqual (sqnr_i , 35 )
847
818
else :
848
- self .assertGreaterEqual (sqnr , 35 ) # pyre-ignore[6]
819
+ self .assertGreaterEqual (sqnr , 35 )
0 commit comments