Skip to content

Commit 8902694

Browse files
committed
Fix RemoveIdentityOps not correctly handling ops following fork-nodes
1 parent cadd6b2 commit 8902694

File tree

1 file changed

+34
-11
lines changed

1 file changed

+34
-11
lines changed

src/qonnx/transformation/remove.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,8 @@
2525
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
2626
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2727
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28-
29-
3028
import numpy as np
29+
import warnings
3130

3231
from qonnx.core.modelwrapper import ModelWrapper
3332
from qonnx.transformation.base import Transformation
@@ -58,21 +57,45 @@ def apply(self, model: ModelWrapper):
5857

5958

6059
def remove_node_and_rewire(model, node):
60+
# Currently cannot remove and rewire join-nodes, probably not necessary to
61+
# support this
62+
if model.is_join_node(node):
63+
# Log this as a warning, so the user is aware of this, there might be
64+
# somthing wrong or some checks missing at the caller site
65+
warnings.warn(
66+
"Tried to remove join-node operation: Currently not supported"
67+
)
68+
# Exit the function here without doing anything
69+
return
70+
# We already know that node is not a join-node, thus to rewire, we only need
71+
# to check the single producer
6172
producer = model.find_producer(node.input[0])
62-
if producer is not None:
63-
# wire output tensor to
64-
# output of producer node
73+
# If there is a producer which is not a fork-node, rewiring is simple
74+
if producer is not None and not model.is_fork_node(producer):
75+
# Rewire by skipping the node, letting the producer directly feed the
76+
# nodes output.
77+
# TODO: Check whether this already covers fork-node identities?
6578
producer.output[0] = node.output[0]
79+
# If there is no producer or the producer forks, rewiring is a bit more
80+
# complicated
6681
else:
67-
# node is first in graph
82+
# Now it depends on the successor nodes to rewire their inputs
6883
successors = model.find_direct_successors(node)
84+
# Singular node detached from the rest of the graph?
6985
assert successors is not None, "Whole graph is one node."
70-
for succ in successors:
71-
for i, s_inp in enumerate(succ.input):
86+
# We need to rewire the input of each successor to not detach parts of
87+
# the graph
88+
for successor in successors:
89+
# Find the inputs of the successor which are produced by the node to
90+
# be removed
91+
for i, s_inp in enumerate(successor.input):
92+
# Note: This might happen multiple times?
7293
if s_inp == node.output[0]:
73-
# rewire successor's input directly to graph input
74-
succ.input[i] = node.input[0]
75-
# remove node
94+
# Rewire successor's input directly to nodes input
95+
# Note: Node may not be a join-node, but there is probably
96+
# no such thing as join-node identity anyway
97+
successor.input[i] = node.input[0]
98+
# Remove node
7699
model.graph.node.remove(node)
77100

78101

0 commit comments

Comments
 (0)