Skip to content

Commit 93acc3a

Browse files
committed
Merge branch 'fix/remove_identity_ops' of https://github.com/iksnagreb/qonnx into iksnagreb-fix/remove_identity_ops
2 parents e02f701 + c7b3590 commit 93acc3a

File tree

1 file changed

+32
-11
lines changed

1 file changed

+32
-11
lines changed

src/qonnx/transformation/remove.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,8 @@
2525
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
2626
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2727
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28-
29-
3028
import numpy as np
29+
import warnings
3130

3231
from qonnx.core.modelwrapper import ModelWrapper
3332
from qonnx.transformation.base import Transformation
@@ -58,21 +57,43 @@ def apply(self, model: ModelWrapper):
5857

5958

6059
def remove_node_and_rewire(model, node):
60+
# Currently cannot remove and rewire join-nodes, probably not necessary to
61+
# support this
62+
if model.is_join_node(node):
63+
# Log this as a warning, so the user is aware of this, there might be
64+
# somthing wrong or some checks missing at the caller site
65+
warnings.warn("Removing join-node operation is currently not supported")
66+
# Exit the function here without doing anything
67+
return
68+
# We already know that node is not a join-node, thus to rewire, we only need
69+
# to check the single producer
6170
producer = model.find_producer(node.input[0])
62-
if producer is not None:
63-
# wire output tensor to
64-
# output of producer node
71+
# If there is a producer which is not a fork-node, rewiring is simple
72+
if producer is not None and not model.is_fork_node(producer):
73+
# Rewire by skipping the node, letting the producer directly feed the
74+
# nodes output.
75+
# TODO: Check whether this already covers fork-node identities?
6576
producer.output[0] = node.output[0]
77+
# If there is no producer or the producer forks, rewiring is a bit more
78+
# complicated
6679
else:
67-
# node is first in graph
80+
# Now it depends on the successor nodes to rewire their inputs
6881
successors = model.find_direct_successors(node)
82+
# Singular node detached from the rest of the graph?
6983
assert successors is not None, "Whole graph is one node."
70-
for succ in successors:
71-
for i, s_inp in enumerate(succ.input):
84+
# We need to rewire the input of each successor to not detach parts of
85+
# the graph
86+
for successor in successors:
87+
# Find the inputs of the successor which are produced by the node to
88+
# be removed
89+
for i, s_inp in enumerate(successor.input):
90+
# Note: This might happen multiple times?
7291
if s_inp == node.output[0]:
73-
# rewire successor's input directly to graph input
74-
succ.input[i] = node.input[0]
75-
# remove node
92+
# Rewire successor's input directly to nodes input
93+
# Note: Node may not be a join-node, but there is probably
94+
# no such thing as join-node identity anyway
95+
successor.input[i] = node.input[0]
96+
# Remove node
7697
model.graph.node.remove(node)
7798

7899

0 commit comments

Comments
 (0)