Skip to content

Commit a91699a

Browse files
Conchylicultorcopybara-github
authored andcommitted
Minor DatasetInfo print formatting change
PiperOrigin-RevId: 258087427
1 parent 501ce7c commit a91699a

File tree

3 files changed

+54
-20
lines changed

3 files changed

+54
-20
lines changed

tensorflow_datasets/core/dataset_info.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
import json
4040
import os
4141
import posixpath
42-
import pprint
4342
import tempfile
4443

4544
from absl import logging
@@ -72,7 +71,7 @@
7271
total_num_examples={total_num_examples},
7372
splits={splits},
7473
supervised_keys={supervised_keys},
75-
citation='{citation}',
74+
citation={citation},
7675
redistribution_info={redistribution_info},
7776
)
7877
"""
@@ -408,19 +407,13 @@ def initialize_from_bucket(self):
408407
gcs_utils.download_gcs_file(fname, out_fname)
409408
self.read_from_directory(tmp_dir)
410409

411-
def __str__(self):
412-
splits_pprint = "{\n %s\n }" % (
413-
pprint.pformat(
414-
{k: self.splits[k] for k in sorted(list(self.splits.keys()))},
415-
indent=8, width=1)[1:-1])
416-
features_dict = self.features
417-
features_pprint = "%s({\n %s\n }" % (
418-
type(features_dict).__name__,
419-
pprint.pformat({
420-
k: features_dict[k] for k in sorted(list(features_dict.keys()))
421-
}, indent=8, width=1)[1:-1])
422-
citation_pprint = '"""\n%s\n """' % "\n".join(
423-
[u" " * 8 + line for line in self.citation.split(u"\n")])
410+
def __repr__(self):
411+
splits_pprint = _indent("\n".join(["{"] + [
412+
" '{}': {},".format(k, split.num_examples)
413+
for k, split in sorted(self.splits.items())
414+
] + ["}"]))
415+
features_pprint = _indent(repr(self.features))
416+
citation_pprint = _indent('"""{}"""'.format(self.citation.strip()))
424417
return INFO_STR.format(
425418
name=self.name,
426419
version=self.version,
@@ -431,7 +424,14 @@ def __str__(self):
431424
citation=citation_pprint,
432425
urls=self.urls,
433426
supervised_keys=self.supervised_keys,
434-
redistribution_info=self.redistribution_info)
427+
# Proto add a \n that we strip.
428+
redistribution_info=str(self.redistribution_info).strip())
429+
430+
431+
def _indent(content):
432+
"""Add indentation to all lines except the first."""
433+
lines = content.split("\n")
434+
return "\n".join([lines[0]] + [" " + l for l in lines[1:]])
435435

436436
#
437437
#

tensorflow_datasets/core/dataset_info_test.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import os
2424
import tempfile
2525
import numpy as np
26+
import six
2627
import tensorflow as tf
2728
from tensorflow_datasets import testing
2829
from tensorflow_datasets.core import dataset_info
@@ -76,6 +77,7 @@ def setUpClass(cls):
7677

7778
@classmethod
7879
def tearDownClass(cls):
80+
super(DatasetInfoTest, cls).tearDownClass()
7981
testing.rm_tmp_dir(cls._tfds_tmp_dir)
8082

8183
def test_undefined_dir(self):
@@ -106,8 +108,8 @@ def test_reading(self):
106108
self.assertTrue(len(split_dict), 2)
107109

108110
# Assert on what they are
109-
self.assertTrue("train" in split_dict)
110-
self.assertTrue("test" in split_dict)
111+
self.assertIn("train", split_dict)
112+
self.assertIn("test", split_dict)
111113

112114
# Assert that this is computed correctly.
113115
self.assertEqual(40, info.splits.total_num_examples)
@@ -127,7 +129,8 @@ def test_writing(self):
127129
mnist_builder = mnist.MNIST(
128130
data_dir=tempfile.mkdtemp(dir=self.get_temp_dir()))
129131

130-
info = dataset_info.DatasetInfo(builder=mnist_builder)
132+
info = dataset_info.DatasetInfo(
133+
builder=mnist_builder, features=mnist_builder.info.features)
131134
info.read_from_directory(_INFO_DIR)
132135

133136
# Read the json file into a string.
@@ -152,6 +155,10 @@ def test_writing(self):
152155
# Assert correct license was written.
153156
self.assertEqual(existing_json["redistributionInfo"]["license"], license_)
154157

158+
if six.PY3:
159+
# Only test on Python 3 to avoid u'' formatting issues
160+
self.assertEqual(repr(info), INFO_STR)
161+
155162
def test_restore_after_modification(self):
156163
# Create a DatasetInfo
157164
info = dataset_info.DatasetInfo(
@@ -296,5 +303,32 @@ def test_updates_on_bucket_info(self):
296303
self.assertEqual(2, len(info.as_proto.schema.feature))
297304

298305

306+
INFO_STR = """tfds.core.DatasetInfo(
307+
name='mnist',
308+
version=1.0.0,
309+
description='The MNIST database of handwritten digits.',
310+
urls=['https://storage.googleapis.com/cvdf-datasets/mnist/'],
311+
features=FeaturesDict({
312+
'image': Image(shape=(28, 28, 1), dtype=tf.uint8),
313+
'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10),
314+
}),
315+
total_num_examples=40,
316+
splits={
317+
'test': 20,
318+
'train': 20,
319+
},
320+
supervised_keys=('image', 'label'),
321+
citation=\"\"\"@article{lecun2010mnist,
322+
title={MNIST handwritten digit database},
323+
author={LeCun, Yann and Cortes, Corinna and Burges, CJ},
324+
journal={ATT Labs [Online]. Available: http://yann. lecun. com/exdb/mnist},
325+
volume={2},
326+
year={2010}
327+
}\"\"\",
328+
redistribution_info=license: "test license",
329+
)
330+
"""
331+
332+
299333
if __name__ == "__main__":
300334
testing.test_main()

tensorflow_datasets/core/splits.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class SplitInfo(object):
3636

3737
@property
3838
def num_examples(self):
39-
return self.statistics.num_examples
39+
return int(self.statistics.num_examples)
4040

4141
def __repr__(self):
4242
num_examples = self.num_examples or "unknown"

0 commit comments

Comments
 (0)