Skip to content

Commit 3fcf2f7

Browse files
Fix ava-person-vehicle-detection-stage2-2_0_0 tuning failed (#1135)
1 parent 6663f7b commit 3fcf2f7

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

neural_compressor/adaptor/tf_utils/graph_rewriter/generic/fuse_decomposed_bn.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,13 @@ def do_transformation(self):
169169

170170
# Mul (input, Mul)
171171
input_data_op = node_from_map(input_node_map, data_scale_mul_op.input[0])
172+
# Workaround for model ava-person-vehicle-detection-stage2-2_0_0
173+
# FusedBatchNorm requires a 4D Tensor for input data, but MatMul only support 2D output.
174+
# Don't fuse the small ops to FusedBatchNorm when the upstream has MatMul.
175+
ancestor_input_data_op = node_from_map(input_node_map, input_data_op.input[0])
176+
if input_data_op.op == "MatMul" or ancestor_input_data_op.op == "MatMul":
177+
continue
178+
172179
scale_op = node_from_map(input_node_map, data_scale_mul_op.input[1])
173180

174181
if scale_op.op == "Rsqrt":
@@ -370,4 +377,4 @@ def get_const_dim_count(node_def):
370377
Number of dimensions for the Const node.
371378
"""
372379
const_value = values_from_const(node_def)
373-
return const_value.ndim
380+
return const_value.ndim

0 commit comments

Comments
 (0)