File tree 2 files changed +10
-3
lines changed
neural_compressor/adaptor/tf_utils/graph_rewriter/generic
2 files changed +10
-3
lines changed Original file line number Diff line number Diff line change @@ -172,9 +172,12 @@ def do_transformation(self):
172
172
# Workaround for model ava-person-vehicle-detection-stage2-2_0_0
173
173
# FusedBatchNorm requires a 4D Tensor for input data, but MatMul only support 2D output.
174
174
# Don't fuse the small ops to FusedBatchNorm when the upstream has MatMul.
175
- ancestor_input_data_op = node_from_map (input_node_map , input_data_op .input [0 ])
176
- if input_data_op .op == "MatMul" or ancestor_input_data_op .op == "MatMul" :
175
+ if input_data_op .op == 'MatMul' :
177
176
continue
177
+ if input_data_op .input :
178
+ ancestor_input_data_op = node_from_map (input_node_map , input_data_op .input [0 ])
179
+ if ancestor_input_data_op .op == "MatMul" :
180
+ continue
178
181
179
182
scale_op = node_from_map (input_node_map , data_scale_mul_op .input [1 ])
180
183
Original file line number Diff line number Diff line change @@ -111,12 +111,16 @@ def get_optimized_model(self, itex_mode=False):
111
111
# Put FuseDecomposedBNOptimizer before GraphFoldConstantOptimizer
112
112
# The 'Sub' op in the small decomposed ops of BN will be converted to const by GraphFoldConstantOptimizer.
113
113
# Then the FuseDecomposedBNOptimizer can't fuse the small decomposed ops to BN.
114
- self ._tmp_graph_def = FuseDecomposedBNOptimizer (self ._tmp_graph_def ).do_transformation ()
114
+ if self .new_api :
115
+ self ._tmp_graph_def = FuseDecomposedBNOptimizer (self ._tmp_graph_def ).do_transformation ()
115
116
116
117
# disable fold constant for itex qdq mode
117
118
if not itex_mode :
118
119
self ._tmp_graph_def = GraphFoldConstantOptimizer (self ._tmp_graph_def ).do_transformation ()
119
120
121
+ if not self .new_api :
122
+ self ._tmp_graph_def = FuseDecomposedBNOptimizer (self ._tmp_graph_def ).do_transformation ()
123
+
120
124
self ._tmp_graph_def = FuseColumnWiseMulOptimizer (self ._tmp_graph_def ).do_transformation ()
121
125
122
126
self ._tmp_graph_def = StripUnusedNodesOptimizer (self ._tmp_graph_def ,
You can’t perform that action at this time.
0 commit comments