Skip to content

Commit 4939127

Browse files
lvliang-intelchensuyue
authored andcommitted
efficientnet_b0 tuning accuracy regression fix (#1140)
(cherry picked from commit c509883)
1 parent 507e3e2 commit 4939127

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

neural_compressor/adaptor/tf_utils/graph_rewriter/generic/fuse_decomposed_bn.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,9 +172,12 @@ def do_transformation(self):
172172
# Workaround for model ava-person-vehicle-detection-stage2-2_0_0
173173
# FusedBatchNorm requires a 4D Tensor for input data, but MatMul only support 2D output.
174174
# Don't fuse the small ops to FusedBatchNorm when the upstream has MatMul.
175-
ancestor_input_data_op = node_from_map(input_node_map, input_data_op.input[0])
176-
if input_data_op.op == "MatMul" or ancestor_input_data_op.op == "MatMul":
175+
if input_data_op.op == 'MatMul':
177176
continue
177+
if input_data_op.input:
178+
ancestor_input_data_op = node_from_map(input_node_map, input_data_op.input[0])
179+
if ancestor_input_data_op.op == "MatMul":
180+
continue
178181

179182
scale_op = node_from_map(input_node_map, data_scale_mul_op.input[1])
180183

neural_compressor/adaptor/tf_utils/graph_rewriter/generic/pre_optimize.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,12 +111,16 @@ def get_optimized_model(self, itex_mode=False):
111111
# Put FuseDecomposedBNOptimizer before GraphFoldConstantOptimizer
112112
# The 'Sub' op in the small decomposed ops of BN will be converted to const by GraphFoldConstantOptimizer.
113113
# Then the FuseDecomposedBNOptimizer can't fuse the small decomposed ops to BN.
114-
self._tmp_graph_def = FuseDecomposedBNOptimizer(self._tmp_graph_def).do_transformation()
114+
if self.new_api:
115+
self._tmp_graph_def = FuseDecomposedBNOptimizer(self._tmp_graph_def).do_transformation()
115116

116117
# disable fold constant for itex qdq mode
117118
if not itex_mode:
118119
self._tmp_graph_def = GraphFoldConstantOptimizer(self._tmp_graph_def).do_transformation()
119120

121+
if not self.new_api:
122+
self._tmp_graph_def = FuseDecomposedBNOptimizer(self._tmp_graph_def).do_transformation()
123+
120124
self._tmp_graph_def = FuseColumnWiseMulOptimizer(self._tmp_graph_def).do_transformation()
121125

122126
self._tmp_graph_def = StripUnusedNodesOptimizer(self._tmp_graph_def,

0 commit comments

Comments
 (0)