Skip to content

Commit 3f77980

Browse files
Conchylicultorcopybara-github
authored andcommitted
Better Tensor singleton repr
PiperOrigin-RevId: 297170026
1 parent af68c5d commit 3f77980

File tree

4 files changed

+70
-12
lines changed

4 files changed

+70
-12
lines changed

tensorflow_datasets/core/features/feature.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,7 @@ class Tensor(FeatureConnector):
516516
@api_utils.disallow_positional_args
517517
def __init__(self, shape, dtype):
518518
"""Construct a Tensor feature."""
519-
self._shape = shape
519+
self._shape = tuple(shape)
520520
self._dtype = dtype
521521

522522
def get_tensor_info(self):
@@ -544,3 +544,25 @@ def encode_example(self, example_data):
544544
example_data.dtype, np_dtype))
545545
utils.assert_shape_match(example_data.shape, self._shape)
546546
return example_data
547+
548+
549+
def get_inner_feature_repr(feature):
550+
"""Utils which returns the object which should get printed in __repr__.
551+
552+
This is used in container features (Sequence, FeatureDict) to print scalar
553+
Tensor in a less verbose way `Sequence(tf.int32)` rather than
554+
`Sequence(Tensor(shape=(), dtype=tf.in32))`.
555+
556+
Args:
557+
feature: The feature to dispaly
558+
559+
Returns:
560+
Either the feature or it's inner value.
561+
"""
562+
# We only print `tf.int32` rather than `Tensor(shape=(), dtype=tf.int32)`
563+
# * For the base `Tensor` class (and not subclass).
564+
# * When shape is scalar (explicit check to avoid trigger when `shape=None`).
565+
if type(feature) == Tensor and feature.shape == (): # pylint: disable=unidiomatic-typecheck,g-explicit-bool-comparison
566+
return repr(feature.dtype)
567+
else:
568+
return repr(feature)

tensorflow_datasets/core/features/features_dict.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,8 @@ def __repr__(self):
141141
lines = ['{}({{'.format(type(self).__name__)]
142142
# Add indentation
143143
for key, feature in sorted(list(self._feature_dict.items())):
144-
all_sub_lines = '\'{}\': {},'.format(key, feature)
144+
feature_repr = feature_lib.get_inner_feature_repr(feature)
145+
all_sub_lines = '\'{}\': {},'.format(key, feature_repr)
145146
lines.extend(' ' + l for l in all_sub_lines.split('\n'))
146147
lines.append('})')
147148
return '\n'.join(lines)

tensorflow_datasets/core/features/features_test.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from __future__ import division
2323
from __future__ import print_function
2424

25+
import textwrap
2526
import numpy as np
2627
import tensorflow.compat.v2 as tf
2728
from tensorflow_datasets import testing
@@ -206,7 +207,16 @@ def test_feature__repr__(self):
206207
'label': features_lib.Sequence(label),
207208
})
208209

209-
self.assertEqual(repr(feature_dict), FEATURE_STR)
210+
self.assertEqual(
211+
repr(feature_dict),
212+
textwrap.dedent("""\
213+
FeaturesDict({
214+
'label': Sequence(ClassLabel(shape=(), dtype=tf.int64, num_classes=2)),
215+
'metadata': Sequence({
216+
'frame': Image(shape=(32, 32, 3), dtype=tf.uint8),
217+
}),
218+
})"""),
219+
)
210220

211221
def test_feature_save_load_metadata_slashes(self):
212222
with testing.tmp_dir() as data_dir:
@@ -218,14 +228,6 @@ def test_feature_save_load_metadata_slashes(self):
218228
fd.load_metadata(data_dir)
219229

220230

221-
FEATURE_STR = """FeaturesDict({
222-
'label': Sequence(ClassLabel(shape=(), dtype=tf.int64, num_classes=2)),
223-
'metadata': Sequence({
224-
'frame': Image(shape=(32, 32, 3), dtype=tf.uint8),
225-
}),
226-
})"""
227-
228-
229231
class FeatureTensorTest(testing.FeatureExpectationsTestCase):
230232

231233
def test_shape_static(self):
@@ -395,6 +397,39 @@ def test_string(self):
395397
],
396398
)
397399

400+
def test_repr_tensor(self):
401+
402+
# Top level Tensor is printed expanded
403+
self.assertEqual(
404+
repr(features_lib.Tensor(shape=(), dtype=tf.int32)),
405+
'Tensor(shape=(), dtype=tf.int32)',
406+
)
407+
408+
# Sequences colapse tensor repr
409+
self.assertEqual(
410+
repr(features_lib.Sequence(tf.int32)),
411+
'Sequence(tf.int32)',
412+
)
413+
414+
class ChildTensor(features_lib.Tensor):
415+
pass
416+
417+
self.assertEqual(
418+
repr(features_lib.FeaturesDict({
419+
'colapsed': features_lib.Tensor(shape=(), dtype=tf.int32),
420+
# Tensor with defined shape are printed expanded
421+
'noncolapsed': features_lib.Tensor(shape=(1,), dtype=tf.int32),
422+
# Tensor inherited are expanded
423+
'child': ChildTensor(shape=(), dtype=tf.int32),
424+
})),
425+
textwrap.dedent("""\
426+
FeaturesDict({
427+
'child': ChildTensor(shape=(), dtype=tf.int32),
428+
'colapsed': tf.int32,
429+
'noncolapsed': Tensor(shape=(1,), dtype=tf.int32),
430+
})"""),
431+
)
432+
398433

399434
if __name__ == '__main__':
400435
testing.test_main()

tensorflow_datasets/core/features/sequence_feature.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def __setstate__(self, state):
201201

202202
def __repr__(self):
203203
"""Display the feature."""
204-
inner_feature_repr = repr(self._feature)
204+
inner_feature_repr = feature_lib.get_inner_feature_repr(self._feature)
205205
if inner_feature_repr.startswith('FeaturesDict('):
206206
# Minor formatting cleaning: 'Sequence(FeaturesDict({' => 'Sequence({'
207207
inner_feature_repr = inner_feature_repr[len('FeaturesDict('):-len(')')]

0 commit comments

Comments
 (0)