Skip to content

Commit da6993c

Browse files
Conchylicultorcopybara-github
authored andcommitted
Add support for nested tfds.features.Sequences
PiperOrigin-RevId: 278916413
1 parent c8439a3 commit da6993c

File tree

7 files changed

+384
-16
lines changed

7 files changed

+384
-16
lines changed

tensorflow_datasets/core/dataset_builder_test.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,5 +489,91 @@ def test_supervised_keys(self):
489489

490490

491491

492+
493+
class NestedSequenceBuilder(dataset_builder.GeneratorBasedBuilder):
494+
"""Dataset containing nested sequences."""
495+
496+
VERSION = utils.Version("0.0.1")
497+
498+
def _info(self):
499+
return dataset_info.DatasetInfo(
500+
builder=self,
501+
features=features.FeaturesDict({
502+
"frames": features.Sequence({
503+
"coordinates": features.Sequence(
504+
features.Tensor(shape=(2,), dtype=tf.int32)
505+
),
506+
}),
507+
}),
508+
)
509+
510+
def _split_generators(self, dl_manager):
511+
# Split the 30 examples from the generator into 2 train shards and 1 test
512+
# shard.
513+
del dl_manager
514+
return [
515+
splits_lib.SplitGenerator(
516+
name=splits_lib.Split.TRAIN,
517+
gen_kwargs={},
518+
),
519+
]
520+
521+
def _generate_examples(self):
522+
ex0 = [
523+
[[0, 1], [2, 3], [4, 5]],
524+
[],
525+
[[6, 7]]
526+
]
527+
ex1 = []
528+
ex2 = [
529+
[[10, 11]],
530+
[[12, 13], [14, 15]],
531+
]
532+
for i, ex in enumerate([ex0, ex1, ex2]):
533+
yield i, {"frames": {"coordinates": ex}}
534+
535+
536+
class NestedSequenceBuilderTest(testing.TestCase):
537+
"""Test of the NestedSequenceBuilder."""
538+
539+
@testing.run_in_graph_and_eager_modes()
540+
def test_nested_sequence(self):
541+
with testing.tmp_dir(self.get_temp_dir()) as tmp_dir:
542+
ds_train, ds_info = registered.load(
543+
name="nested_sequence_builder",
544+
data_dir=tmp_dir,
545+
split="train",
546+
with_info=True,
547+
shuffle_files=False)
548+
ex0, ex1, ex2 = [
549+
ex["frames"]["coordinates"]
550+
for ex in dataset_utils.as_numpy(ds_train)
551+
]
552+
self.assertAllEqual(ex0, tf.ragged.constant([
553+
[[0, 1], [2, 3], [4, 5]],
554+
[],
555+
[[6, 7]],
556+
], inner_shape=(2,)))
557+
self.assertAllEqual(ex1, tf.ragged.constant([], ragged_rank=1))
558+
self.assertAllEqual(ex2, tf.ragged.constant([
559+
[[10, 11]],
560+
[[12, 13], [14, 15]],
561+
], inner_shape=(2,)))
562+
563+
self.assertEqual(
564+
ds_info.features.dtype,
565+
{"frames": {"coordinates": tf.int32}},
566+
)
567+
self.assertEqual(
568+
ds_info.features.shape,
569+
{"frames": {"coordinates": (None, None, 2)}},
570+
)
571+
nested_tensor_info = ds_info.features.get_tensor_info()
572+
self.assertEqual(
573+
nested_tensor_info["frames"]["coordinates"].sequence_rank,
574+
2,
575+
)
576+
577+
492578
if __name__ == "__main__":
493579
testing.test_main()

tensorflow_datasets/core/features/feature.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,25 @@ def decode_batch_example(self, tfexample_data):
331331
name='sequence_decode',
332332
)
333333

334+
def decode_ragged_example(self, tfexample_data):
335+
"""Decode nested features from a tf.RaggedTensor.
336+
337+
This function is used to decode features wrapped in nested
338+
`tfds.features.Sequence()`.
339+
By default, this function apply `decode_batch_example` on the flat values
340+
of the ragged tensor. For optimization, features can
341+
overwrite this method to apply a custom batch decoding.
342+
343+
Args:
344+
tfexample_data: `tf.RaggedTensor` inputs containing the nested encoded
345+
examples.
346+
347+
Returns:
348+
tensor_data: The decoded `tf.RaggedTensor` or dictionary of tensor,
349+
output of the tf.data.Dataset object
350+
"""
351+
return tf.ragged.map_flat_values(self.decode_batch_example, tfexample_data)
352+
334353
def _flatten(self, x):
335354
"""Flatten the input dict into a list of values.
336355
@@ -509,6 +528,11 @@ def decode_batch_example(self, example_data):
509528
# Overwrite the `tf.map_fn`, decoding is a no-op
510529
return self.decode_example(example_data)
511530

531+
def decode_ragged_example(self, example_data):
532+
"""See base class for details."""
533+
# Overwrite the `tf.map_fn`, decoding is a no-op
534+
return self.decode_example(example_data)
535+
512536
def encode_example(self, example_data):
513537
"""See base class for details."""
514538
np_dtype = np.dtype(self.dtype.as_numpy_dtype)

tensorflow_datasets/core/features/sequence_feature.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -144,15 +144,20 @@ def _build_empty_np(serialized_info):
144144
for sequence_elem in sequence_elements
145145
]
146146

147-
# Then merge the elements back together
147+
# Then convert back list[nested dict] => nested dict[list]
148148
def _stack_nested(sequence_elements):
149+
"""Recursivelly stack the tensors from the same dict field."""
149150
if isinstance(sequence_elements[0], dict):
150151
return {
151152
# Stack along the first dimension
152153
k: _stack_nested(sub_sequence)
153154
for k, sub_sequence in utils.zip_dict(*sequence_elements)
154155
}
155-
return stack_arrays(*sequence_elements)
156+
# Note: As each field can be a nested ragged list, we don't check here
157+
# that all elements from the list have matching dtype/shape.
158+
# Checking is done in `example_serializer` when elements
159+
# are converted to numpy array and stacked togethers.
160+
return list(sequence_elements)
156161

157162
return _stack_nested(sequence_elements)
158163

@@ -203,14 +208,7 @@ def __repr__(self):
203208
return '{}({})'.format(type(self).__name__, inner_feature_repr)
204209

205210

206-
def stack_arrays(*elems):
207-
if isinstance(elems[0], np.ndarray):
208-
return np.stack(elems)
209-
else:
210-
return [e for e in elems]
211-
212-
213-
def np_to_list(elem):
211+
def _np_to_list(elem):
214212
"""Returns list from list, tuple or ndarray."""
215213
if isinstance(elem, list):
216214
return elem
@@ -227,7 +225,7 @@ def np_to_list(elem):
227225
def _transpose_dict_list(dict_list):
228226
"""Transpose a nested dict[list] into a list[nested dict]."""
229227
# 1. Unstack numpy arrays into list
230-
dict_list = utils.map_nested(np_to_list, dict_list, dict_only=True)
228+
dict_list = utils.map_nested(_np_to_list, dict_list, dict_only=True)
231229

232230
# 2. Extract the sequence length (and ensure the length is constant for all
233231
# elements)

0 commit comments

Comments
 (0)