Skip to content

Commit 5e7c62e

Browse files
pierrot0copybara-github
authored andcommitted
tfrecord writers: clearer error message when no examples are yielded or empty pcollection as input.
PiperOrigin-RevId: 289162141
1 parent 43ef317 commit 5e7c62e

File tree

2 files changed

+27
-0
lines changed

2 files changed

+27
-0
lines changed

tensorflow_datasets/core/tfrecords_writer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ def _get_shard_specs(num_examples, total_size, bucket_lengths, path):
103103

104104

105105
def _get_shard_boundaries(num_examples, number_of_shards):
106+
if num_examples == 0:
107+
raise AssertionError("No examples were yielded.")
106108
if num_examples < number_of_shards:
107109
raise AssertionError("num_examples ({}) < number_of_shards ({})".format(
108110
num_examples, number_of_shards))
@@ -306,6 +308,8 @@ def _get_boundaries_per_bucket_shard(self, shard_len_sizes):
306308
examples in the bucket and size is the total size in bytes of the
307309
elements in that bucket. Buckets with no elements are not mentioned.
308310
"""
311+
if not shard_len_sizes:
312+
raise AssertionError("Not a single example present in the PCollection!")
309313
total_num_examples = 0
310314
total_size = 0
311315
bucket2length = {}

tensorflow_datasets/core/tfrecords_writer_test.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,9 @@ def _read_records(path):
131131

132132
class WriterTest(testing.TestCase):
133133

134+
EMPTY_SPLIT_ERROR = 'No examples were yielded.'
135+
TOO_SMALL_SPLIT_ERROR = 'num_examples (1) < number_of_shards (2)'
136+
134137
@absltest.mock.patch.object(
135138
example_serializer, 'ExampleSerializer', testing.DummySerializer)
136139
def _write(self, to_write, path, salt=''):
@@ -174,9 +177,29 @@ def test_write_duplicated_keys(self):
174177
AssertionError, 'Two records share the same hashed key!'):
175178
self._write(to_write, path)
176179

180+
def test_empty_split(self):
181+
path = os.path.join(self.tmp_dir, 'foo.tfrecord')
182+
to_write = []
183+
with absltest.mock.patch.object(tfrecords_writer, '_get_number_shards',
184+
return_value=1):
185+
with self.assertRaisesWithPredicateMatch(
186+
AssertionError, self.EMPTY_SPLIT_ERROR):
187+
self._write(to_write, path)
188+
189+
def test_too_small_split(self):
190+
path = os.path.join(self.tmp_dir, 'foo.tfrecord')
191+
to_write = [(1, b'a')]
192+
with absltest.mock.patch.object(tfrecords_writer, '_get_number_shards',
193+
return_value=2):
194+
with self.assertRaisesWithPredicateMatch(
195+
AssertionError, self.TOO_SMALL_SPLIT_ERROR):
196+
self._write(to_write, path)
197+
177198

178199
class TfrecordsWriterBeamTest(WriterTest):
179200

201+
EMPTY_SPLIT_ERROR = 'Not a single example present in the PCollection!'
202+
180203
@absltest.mock.patch.object(
181204
example_serializer, 'ExampleSerializer', testing.DummySerializer)
182205
def _write(self, to_write, path, salt=''):

0 commit comments

Comments
 (0)