Skip to content

Commit 0a4d5c5

Browse files
committed
[ModelWrapper] add top-level checks for fork/join checks
1 parent 71ee780 commit 0a4d5c5

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

src/qonnx/core/modelwrapper.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -429,14 +429,24 @@ def is_fork_node(self, node):
429429
"""Checks if the given node is a fork, that is, the node has multiple
430430
direct successors"""
431431
direct_successors = self.find_direct_successors(node)
432-
is_fork = False if direct_successors is None else (len(direct_successors) > 1)
432+
# if the node output is also wired to a top-level output, it is still
433+
# a fork with only 1 direct successor
434+
if node.output[0] in [x.name for x in self.graph.output]:
435+
is_fork = False if direct_successors is None else (len(direct_successors) > 0)
436+
else:
437+
is_fork = False if direct_successors is None else (len(direct_successors) > 1)
433438
return is_fork
434439

435440
def is_join_node(self, node):
436441
"""Checks if the given node is a join, that is, the node has multiple
437442
direct predecessors"""
438443
direct_predecessors = self.find_direct_predecessors(node)
439-
is_join = False if direct_predecessors is None else (len(direct_predecessors) > 1)
444+
# if the node input is also wired to a top-level input, it is still
445+
# a fork with only 1 direct predecessor
446+
if node.input[0] in [x.name for x in self.graph.input]:
447+
is_join = False if direct_predecessors is None else (len(direct_predecessors) > 0)
448+
else:
449+
is_join = False if direct_predecessors is None else (len(direct_predecessors) > 1)
440450
return is_join
441451

442452
def get_all_tensor_names(self):

0 commit comments

Comments
 (0)