@@ -131,6 +131,9 @@ def _read_records(path):
131
131
132
132
class WriterTest (testing .TestCase ):
133
133
134
+ EMPTY_SPLIT_ERROR = 'No examples were yielded.'
135
+ TOO_SMALL_SPLIT_ERROR = 'num_examples (1) < number_of_shards (2)'
136
+
134
137
@absltest .mock .patch .object (
135
138
example_serializer , 'ExampleSerializer' , testing .DummySerializer )
136
139
def _write (self , to_write , path , salt = '' ):
@@ -174,9 +177,29 @@ def test_write_duplicated_keys(self):
174
177
AssertionError , 'Two records share the same hashed key!' ):
175
178
self ._write (to_write , path )
176
179
180
+ def test_empty_split (self ):
181
+ path = os .path .join (self .tmp_dir , 'foo.tfrecord' )
182
+ to_write = []
183
+ with absltest .mock .patch .object (tfrecords_writer , '_get_number_shards' ,
184
+ return_value = 1 ):
185
+ with self .assertRaisesWithPredicateMatch (
186
+ AssertionError , self .EMPTY_SPLIT_ERROR ):
187
+ self ._write (to_write , path )
188
+
189
+ def test_too_small_split (self ):
190
+ path = os .path .join (self .tmp_dir , 'foo.tfrecord' )
191
+ to_write = [(1 , b'a' )]
192
+ with absltest .mock .patch .object (tfrecords_writer , '_get_number_shards' ,
193
+ return_value = 2 ):
194
+ with self .assertRaisesWithPredicateMatch (
195
+ AssertionError , self .TOO_SMALL_SPLIT_ERROR ):
196
+ self ._write (to_write , path )
197
+
177
198
178
199
class TfrecordsWriterBeamTest (WriterTest ):
179
200
201
+ EMPTY_SPLIT_ERROR = 'Not a single example present in the PCollection!'
202
+
180
203
@absltest .mock .patch .object (
181
204
example_serializer , 'ExampleSerializer' , testing .DummySerializer )
182
205
def _write (self , to_write , path , salt = '' ):
0 commit comments