Skip to content

Commit 94321b7

Browse files
committed
Add option to find_upstream to keep nodes visited even if not found
1 parent cadd6b2 commit 94321b7

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

src/qonnx/core/modelwrapper.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -346,16 +346,19 @@ def find_producer(self, tensor_name):
346346
return x
347347
return None
348348

349-
def find_upstream(self, tensor_name, finder_fxn):
349+
def find_upstream(self, tensor_name, finder_fxn, keep_if_not_found=False):
350350
"""Follow the producer chain upstream, calling finder_fxn on each upstream
351351
node until it returns True or there are no nodes left. Returns the list
352-
of nodes visited, or None if finder_fxn did not return True."""
352+
of nodes visited, or None if finder_fxn did not return True. If
353+
keep_if_not_found is specified, returns the list of nodes visited, even
354+
if finder_fxn never returned True, i.e., if the search terminated at an
355+
input or initializer."""
353356
visit_list = []
354357
current_tensor = tensor_name
355358
while True:
356359
current_producer = self.find_producer(current_tensor)
357360
if current_producer is None:
358-
return []
361+
return visit_list if keep_if_not_found else []
359362
else:
360363
found = finder_fxn(current_producer)
361364
visit_list.append(current_producer)
@@ -364,7 +367,7 @@ def find_upstream(self, tensor_name, finder_fxn):
364367
elif len(current_producer.input) > 0:
365368
current_tensor = current_producer.input[0]
366369
else:
367-
return None
370+
return visit_list if keep_if_not_found else None
368371

369372
def find_consumer(self, tensor_name):
370373
"""Finds and returns the node that consumes the tensor with given name.

0 commit comments

Comments
 (0)