Skip to content

Commit 0791433

Browse files
Fix tuning failure after stripping equivalent nodes (#1290)
1 parent 3827826 commit 0791433

File tree

2 files changed

+3
-1
lines changed

2 files changed

+3
-1
lines changed

neural_compressor/adaptor/tf_utils/graph_rewriter/generic/insert_print_node.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,8 @@ def do_transformation(self):
151151
if post_node_names:
152152
for post_node_name in post_node_names:
153153
post_node = graph_info[post_node_name].node
154+
if each_node_name not in post_node.input:
155+
continue
154156
if post_node.op == 'FusedBatchNormV3':
155157
if "_print_identity" in \
156158
graph_info[Helper.node_name_from_input(post_node.name)].node.input[0]:

neural_compressor/adaptor/tf_utils/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,7 @@ def is_equivalent_input(input_tensor_list_1, input_tensor_list_2):
401401
for node_to_remove in nodes_to_remove:
402402
stripped_graph.remove_node(node_to_remove)
403403
return tf.compat.v1.graph_util.extract_sub_graph \
404-
(stripped_graph.dump_graph(), set(stripped_graph_node_names).intersection(output_node_names)), \
404+
(stripped_graph.dump_graph(), list(set(stripped_graph_node_names).intersection(output_node_names))), \
405405
replaced_nodes_type
406406

407407
# THIS API IS TO BE DEPRECATED!

0 commit comments

Comments
 (0)