Skip to content
This repository was archived by the owner on Feb 7, 2023. It is now read-only.

Commit 19f658b

Browse files
authored
Dead code elimination (#470)
1 parent c36bfef commit 19f658b

File tree

2 files changed

+43
-1
lines changed

2 files changed

+43
-1
lines changed

onnx_coreml/_transformers.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -772,3 +772,44 @@ def __call__(self, graph): # type: (Graph) -> Graph
772772
if node not in nodes_to_be_removed:
773773
transformed_nodes.append(node)
774774
return graph.create_graph(nodes=transformed_nodes)
775+
776+
class DeadCodeElimination(object):
777+
'''
778+
Removes nodes with unused outputs
779+
'''
780+
def __call__(self, graph): # type: (Graph) -> Graph
781+
input_names = [str(input_[0]) for input_ in graph.inputs]
782+
output_names = set([str(output_[0]) for output_ in graph.outputs])
783+
784+
nodes_to_be_removed = []
785+
use_set = set()
786+
787+
for node in graph.nodes:
788+
for _input in node.inputs:
789+
use_set.add(_input)
790+
791+
for node in graph.nodes:
792+
output_used = False
793+
for _output in node.outputs:
794+
if _output in output_names or _output in use_set:
795+
output_used = True
796+
break
797+
if not output_used:
798+
# Remove current node
799+
nodes_to_be_removed.append(node.name)
800+
for parent in node.parents:
801+
parent.children.remove(node)
802+
803+
transformed_nodes = []
804+
for node in graph.nodes:
805+
if node.name not in nodes_to_be_removed:
806+
transformed_nodes.append(node)
807+
808+
for _input in input_names:
809+
if _input not in use_set:
810+
for i in range(len(graph.inputs)):
811+
if graph.inputs[i][0] is _input:
812+
graph.inputs.remove(graph.inputs[i])
813+
break
814+
815+
return graph.create_graph(nodes=transformed_nodes)

onnx_coreml/converter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
ReshapeInitTensorFuser, BNBroadcastedMulFuser, BNBroadcastedAddFuser, \
2626
PixelShuffleFuser, OutputRenamer, AddModelInputsOutputs, \
2727
ConstantsToInitializers, ImageScalerRemover, ShapeOpRemover, ConstantRemover, \
28-
ConstantFillToInitializers, ReshapeTransposeReshape_pattern1, CastOpRemover
28+
ConstantFillToInitializers, ReshapeTransposeReshape_pattern1, CastOpRemover, DeadCodeElimination
2929

3030
from ._error_utils import ErrorHandling
3131
from .graph_viz import plot_graph # type: ignore
@@ -429,6 +429,7 @@ def __call__(self, graph):
429429
CastOpRemover(),
430430
ReshapeInitTensorFuser(),
431431
DropoutRemover(),
432+
DeadCodeElimination(),
432433
ConvAddFuser(),
433434
BNBroadcastedMulFuser(),
434435
BNBroadcastedAddFuser(),

0 commit comments

Comments
 (0)