@@ -556,11 +556,13 @@ def _save_layer_input_and_outputs(self):
556556 if self ._is_collection_being_saved_for_step (CollectionKeys .LAYERS )
557557 else set ()
558558 )
559- if hasattr (tensor , "numpy" ):
560- self ._save_tensor_to_file (export_name , tensor .numpy (), input_collection )
561- else :
559+ t = tensor [0 ] if isinstance (tensor , list ) and len (tensor ) else tensor
560+ if hasattr (t , "numpy" ) is False :
562561 self .logger .warning ("cannot save layer values during forward pass with tf.function" )
563562 continue
563+ else :
564+ self ._save_tensor_to_file (export_name , tensor , input_collection )
565+
564566 # Save Output
565567 tensor = self .saved_layers [layer_name ].layer_output
566568 export_name = get_export_name_for_keras (layer_name , tensor_type = "output" , tensor = tensor )
@@ -570,8 +572,11 @@ def _save_layer_input_and_outputs(self):
570572 if self ._is_collection_being_saved_for_step (CollectionKeys .LAYERS )
571573 else set ()
572574 )
573- if hasattr (tensor , "numpy" ):
574- self ._save_tensor_to_file (export_name , tensor .numpy (), output_collection )
575+ t = tensor [0 ] if isinstance (tensor , list ) and len (tensor ) else tensor
576+ if hasattr (t , "numpy" ) is False :
577+ self .logger .warning ("cannot save layer values during forward pass with tf.function" )
578+ else :
579+ self ._save_tensor_to_file (export_name , tensor , output_collection )
575580
576581 def _save_tensors_post_step (self , batch , logs ):
577582 # some tensors available as value from within hook are saved here
0 commit comments