diff --git a/keras2onnx/proto/tfcompat.py b/keras2onnx/proto/tfcompat.py index f82e5582..4d052f9e 100644 --- a/keras2onnx/proto/tfcompat.py +++ b/keras2onnx/proto/tfcompat.py @@ -10,9 +10,9 @@ def normalize_tensor_shape(tensor_shape): if is_tf2: - return [d for d in tensor_shape] - else: return [d.value for d in tensor_shape] + else: + return [d for d in tensor_shape] def dump_graph_into_tensorboard(tf_graph):