@@ -346,16 +346,19 @@ def find_producer(self, tensor_name):
346
346
return x
347
347
return None
348
348
349
- def find_upstream (self , tensor_name , finder_fxn ):
349
+ def find_upstream (self , tensor_name , finder_fxn , keep_if_not_found = False ):
350
350
"""Follow the producer chain upstream, calling finder_fxn on each upstream
351
351
node until it returns True or there are no nodes left. Returns the list
352
- of nodes visited, or None if finder_fxn did not return True."""
352
+ of nodes visited, or None if finder_fxn did not return True. If
353
+ keep_if_not_found is specified, returns the list of nodes visited, even
354
+ if finder_fxn never returned True, i.e., if the search terminated at an
355
+ input or initializer."""
353
356
visit_list = []
354
357
current_tensor = tensor_name
355
358
while True :
356
359
current_producer = self .find_producer (current_tensor )
357
360
if current_producer is None :
358
- return []
361
+ return visit_list if keep_if_not_found else []
359
362
else :
360
363
found = finder_fxn (current_producer )
361
364
visit_list .append (current_producer )
@@ -364,7 +367,7 @@ def find_upstream(self, tensor_name, finder_fxn):
364
367
elif len (current_producer .input ) > 0 :
365
368
current_tensor = current_producer .input [0 ]
366
369
else :
367
- return None
370
+ return visit_list if keep_if_not_found else None
368
371
369
372
def find_consumer (self , tensor_name ):
370
373
"""Finds and returns the node that consumes the tensor with given name.
0 commit comments