File tree Expand file tree Collapse file tree 1 file changed +12
-2
lines changed Expand file tree Collapse file tree 1 file changed +12
-2
lines changed Original file line number Diff line number Diff line change @@ -429,14 +429,24 @@ def is_fork_node(self, node):
429
429
"""Checks if the given node is a fork, that is, the node has multiple
430
430
direct successors"""
431
431
direct_successors = self .find_direct_successors (node )
432
- is_fork = False if direct_successors is None else (len (direct_successors ) > 1 )
432
+ # if the node output is also wired to a top-level output, it is still
433
+ # a fork with only 1 direct successor
434
+ if node .output [0 ] in [x .name for x in self .graph .output ]:
435
+ is_fork = False if direct_successors is None else (len (direct_successors ) > 0 )
436
+ else :
437
+ is_fork = False if direct_successors is None else (len (direct_successors ) > 1 )
433
438
return is_fork
434
439
435
440
def is_join_node (self , node ):
436
441
"""Checks if the given node is a join, that is, the node has multiple
437
442
direct predecessors"""
438
443
direct_predecessors = self .find_direct_predecessors (node )
439
- is_join = False if direct_predecessors is None else (len (direct_predecessors ) > 1 )
444
+ # if the node input is also wired to a top-level input, it is still
445
+ # a fork with only 1 direct predecessor
446
+ if node .input [0 ] in [x .name for x in self .graph .input ]:
447
+ is_join = False if direct_predecessors is None else (len (direct_predecessors ) > 0 )
448
+ else :
449
+ is_join = False if direct_predecessors is None else (len (direct_predecessors ) > 1 )
440
450
return is_join
441
451
442
452
def get_all_tensor_names (self ):
You can’t perform that action at this time.
0 commit comments