Skip to content

Commit efcba76

Browse files
Conchylicultorcopybara-github
authored andcommitted
Update parser/serialiser to support nested sequences
PiperOrigin-RevId: 278754939
1 parent c95ad10 commit efcba76

File tree

3 files changed

+442
-9
lines changed

3 files changed

+442
-9
lines changed

tensorflow_datasets/core/example_parser.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,25 @@ def parse_example(self, serialized_example):
6363
example: A nested `dict` of `tf.Tensor` values. The structure and tensors
6464
shape/dtype match the `example_specs` provided at construction.
6565
"""
66+
nested_feature_specs = self._build_feature_specs()
67+
68+
# Because of RaggedTensor specs, feature_specs can be a 2-level nested dict,
69+
# so have to wrap `tf.io.parse_single_example` between
70+
# `flatten_nest_dict`/`pack_as_nest_dict`.
71+
# {
72+
# 'video/image': tf.io.FixedLenSequenceFeature(...),
73+
# 'video/object/bbox': {
74+
# 'ragged_flat_values': tf.io.FixedLenSequenceFeature(...),
75+
# 'ragged_row_lengths_0', tf.io.FixedLenSequenceFeature(...),
76+
# },
77+
# }
78+
flat_feature_specs = utils.flatten_nest_dict(nested_feature_specs)
6679
example = tf.io.parse_single_example(
6780
serialized=serialized_example,
68-
features=self._build_feature_specs(),
81+
features=flat_feature_specs,
6982
)
83+
example = utils.pack_as_nest_dict(example, nested_feature_specs)
84+
7085
example = {
7186
k: _deserialize_single_field(example_data, tensor_info)
7287
for k, (example_data, tensor_info)
@@ -79,9 +94,12 @@ def parse_example(self, serialized_example):
7994

8095
def _deserialize_single_field(example_data, tensor_info):
8196
"""Reconstruct the serialized field."""
97+
# Ragged tensor case:
98+
if tensor_info.sequence_rank > 1:
99+
example_data = _dict_to_ragged(example_data, tensor_info)
82100

83101
# Restore shape if possible. TF Example flattened it.
84-
if tensor_info.shape.count(None) < 2:
102+
elif tensor_info.shape.count(None) < 2:
85103
shape = [-1 if i is None else i for i in tensor_info.shape]
86104
example_data = tf.reshape(example_data, shape)
87105

@@ -91,6 +109,17 @@ def _deserialize_single_field(example_data, tensor_info):
91109
return example_data
92110

93111

112+
def _dict_to_ragged(example_data, tensor_info):
113+
"""Reconstruct the ragged tensor from the row ids."""
114+
return tf.RaggedTensor.from_nested_row_lengths(
115+
flat_values=example_data["ragged_flat_values"],
116+
nested_row_lengths=[
117+
example_data["ragged_row_lengths_{}".format(k)]
118+
for k in range(tensor_info.sequence_rank - 1)
119+
],
120+
)
121+
122+
94123
def _to_tf_example_spec(tensor_info):
95124
"""Convert a `TensorInfo` into a feature proto object."""
96125
# Convert the dtype
@@ -126,6 +155,24 @@ def _to_tf_example_spec(tensor_info):
126155
allow_missing=True,
127156
default_value=tensor_info.default_value,
128157
)
158+
elif tensor_info.sequence_rank > 1: # RaggedTensor
159+
# Decoding here should match encoding from `_add_ragged_fields` in
160+
# `example_serializer.py`
161+
tf_specs = {
162+
"ragged_row_lengths_{}".format(k): tf.io.FixedLenSequenceFeature( # pylint: disable=g-complex-comprehension
163+
shape=(),
164+
dtype=tf.int64,
165+
allow_missing=True,
166+
)
167+
for k in range(tensor_info.sequence_rank - 1)
168+
}
169+
tf_specs["ragged_flat_values"] = tf.io.FixedLenSequenceFeature(
170+
shape=tensor_info.shape[tensor_info.sequence_rank:],
171+
dtype=dtype,
172+
allow_missing=True,
173+
default_value=tensor_info.default_value,
174+
)
175+
return tf_specs
129176
else:
130177
raise NotImplementedError(
131178
"Tensor with a unknown dimension not at the first position not "

tensorflow_datasets/core/example_serializer.py

Lines changed: 187 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,13 @@
1919
from __future__ import division
2020
from __future__ import print_function
2121

22+
import collections
2223
import numpy as np
2324
import six
2425
import tensorflow as tf
2526

2627
from tensorflow_datasets.core import utils
28+
from tensorflow_datasets.core.features import feature as feature_lib
2729

2830

2931
class ExampleSerializer(object):
@@ -63,21 +65,39 @@ def _dict_to_tf_example(example_dict, tensor_info_dict=None):
6365
tensor_info_dict: `dict` of `tfds.feature.TensorInfo` If given, perform
6466
additional checks on the example dict (check dtype, shape, number of
6567
fields...)
68+
69+
Returns:
70+
example_proto: `tf.train.Example`, the encoded example proto.
6671
"""
67-
def serialize_single_field(k, example_data, tensor_info):
72+
def run_with_reraise(fn, k, example_data, tensor_info):
6873
with utils.try_reraise(
6974
"Error while serializing feature {} ({}): ".format(k, tensor_info)):
70-
return _item_to_tf_feature(example_data, tensor_info)
75+
return fn(example_data, tensor_info)
7176

7277
if tensor_info_dict:
73-
example_dict = {
74-
k: serialize_single_field(k, example_data, tensor_info)
78+
# Add the RaggedTensor fields for the nested sequences
79+
# Nested sequences are encoded as {'flat_values':, 'row_lengths':}, so need
80+
# to flatten the example nested dict again.
81+
# Ex:
82+
# Input: {'objects/tokens': [[0, 1, 2], [], [3, 4]]}
83+
# Output: {
84+
# 'objects/tokens/flat_values': [0, 1, 2, 3, 4],
85+
# 'objects/tokens/row_lengths_0': [3, 0, 2],
86+
# }
87+
example_dict = utils.flatten_nest_dict({
88+
k: run_with_reraise(_add_ragged_fields, k, example_data, tensor_info)
7589
for k, (example_data, tensor_info)
7690
in utils.zip_dict(example_dict, tensor_info_dict)
91+
})
92+
example_dict = {
93+
k: run_with_reraise(_item_to_tf_feature, k, item, tensor_info)
94+
for k, (item, tensor_info) in example_dict.items()
7795
}
7896
else:
97+
# TODO(epot): The following code is only executed in tests and could be
98+
# cleanned-up, as TensorInfo is always passed to _item_to_tf_feature.
7999
example_dict = {
80-
k: serialize_single_field(k, example_data, None)
100+
k: run_with_reraise(_item_to_tf_feature, k, example_data, None)
81101
for k, example_data in example_dict.items()
82102
}
83103

@@ -88,18 +108,31 @@ def _is_string(item):
88108
"""Check if the object contains string or bytes."""
89109
if isinstance(item, (six.binary_type, six.string_types)):
90110
return True
91-
elif (isinstance(item, (tuple, list)) and
92-
all(isinstance(x, (six.binary_type, six.string_types)) for x in item)):
111+
elif (isinstance(item, (tuple, list)) and all(_is_string(x) for x in item)):
93112
return True
94113
elif (isinstance(item, np.ndarray) and # binary or unicode
95114
(item.dtype.kind in ("U", "S") or item.dtype == object)):
96115
return True
97116
return False
98117

99118

119+
def _item_to_np_array(item, dtype, shape):
120+
"""Single item to a np.array."""
121+
original_item = item
122+
item = np.array(item, dtype=dtype.as_numpy_dtype)
123+
utils.assert_shape_match(item.shape, shape)
124+
if dtype == tf.string and not _is_string(original_item):
125+
raise ValueError(
126+
"Unsuported value: {}\nCould not convert to bytes list.".format(item))
127+
return item
128+
129+
100130
def _item_to_tf_feature(item, tensor_info=None):
101131
"""Single item to a tf.train.Feature."""
102132
v = item
133+
# TODO(epot): tensor_info is only None for file_format_adapter tests.
134+
# tensor_info could be made required to cleanup some of the following code,
135+
# for instance by re-using _item_to_np_array.
103136
if not tensor_info and isinstance(v, (list, tuple)) and not v:
104137
raise ValueError(
105138
"Received an empty list value, so is unable to infer the "
@@ -146,3 +179,150 @@ def _item_to_tf_feature(item, tensor_info=None):
146179
"This may indicate that one of the FeatureConnectors received an "
147180
"unsupported value as input.".format(repr(v), repr(type(v)))
148181
)
182+
183+
184+
RaggedExtraction = collections.namedtuple("RaggedExtraction", [
185+
"nested_list",
186+
"flat_values",
187+
"nested_row_lengths",
188+
"curr_ragged_rank",
189+
"tensor_info",
190+
])
191+
192+
193+
def _add_ragged_fields(example_data, tensor_info):
194+
"""Optionally convert the ragged data into flat/row_lengths fields.
195+
196+
Example:
197+
198+
```
199+
example_data = [
200+
[1, 2, 3],
201+
[],
202+
[4, 5]
203+
]
204+
tensor_info = TensorInfo(shape=(None, None,), sequence_rank=2, ...)
205+
out = _add_ragged_fields(example_data, tensor_info)
206+
out == {
207+
'ragged_flat_values': ([0, 1, 2, 3, 4, 5], TensorInfo(shape=(), ...)),
208+
'ragged_row_length_0': ([3, 0, 2], TensorInfo(shape=(None,), ...))
209+
}
210+
```
211+
212+
If `example_data` isn't ragged, `example_data` and `tensor_info` are
213+
forwarded as-is.
214+
215+
Args:
216+
example_data: Data to optionally convert to ragged data.
217+
tensor_info: TensorInfo associated with the given data.
218+
219+
Returns:
220+
A tuple(example_data, tensor_info) if the tensor isn't ragged, or a dict of
221+
tuple(example_data, tensor_info) if the tensor is ragged.
222+
"""
223+
# Step 1: Extract the ragged tensor info
224+
if tensor_info.sequence_rank:
225+
# If the input is ragged, extract the nested values.
226+
# 1-level sequences are converted as numpy and stacked.
227+
# If the sequence is empty, a np.empty(shape=(0, ...)) array is returned.
228+
example_data, nested_row_lengths = _extract_ragged_attributes(
229+
example_data, tensor_info)
230+
231+
# Step 2: Format the ragged tensor data as dict
232+
# No sequence or 1-level sequence, forward the data.
233+
# Could eventually handle multi-level sequences with static lengths
234+
# in a smarter way.
235+
if tensor_info.sequence_rank < 2:
236+
return (example_data, tensor_info)
237+
# Multiple level sequence:
238+
else:
239+
tensor_info_length = feature_lib.TensorInfo(shape=(None,), dtype=tf.int64)
240+
ragged_attr_dict = {
241+
"ragged_row_lengths_{}".format(i): (length, tensor_info_length)
242+
for i, length in enumerate(nested_row_lengths)
243+
}
244+
tensor_info_flat = feature_lib.TensorInfo(
245+
shape=(None,) + tensor_info.shape[tensor_info.sequence_rank:],
246+
dtype=tensor_info.dtype,
247+
)
248+
ragged_attr_dict["ragged_flat_values"] = (example_data, tensor_info_flat)
249+
return ragged_attr_dict
250+
251+
252+
def _extract_ragged_attributes(nested_list, tensor_info):
253+
"""Extract the values for the tf.RaggedTensor __init__.
254+
255+
This extract the ragged tensor attributes which allow reconstruct the
256+
ragged tensor with `tf.RaggedTensor.from_nested_row_lengths`.
257+
258+
Args:
259+
nested_list: A nested list containing the ragged tensor values
260+
tensor_info: The specs of the ragged tensor
261+
262+
Returns:
263+
flat_values: The flatten values of the ragged tensor. All values from each
264+
list will be converted to np.array and stacked together.
265+
nested_row_lengths: The row lengths for each ragged dimensions.
266+
"""
267+
assert tensor_info.sequence_rank, "{} is not ragged.".format(tensor_info)
268+
269+
flat_values = []
270+
nested_row_lengths = [[] for _ in range(tensor_info.sequence_rank)]
271+
# Reccursivelly append to `flat_values`, `nested_row_lengths`
272+
_fill_ragged_attribute(RaggedExtraction(
273+
nested_list=nested_list,
274+
flat_values=flat_values,
275+
nested_row_lengths=nested_row_lengths,
276+
curr_ragged_rank=0,
277+
tensor_info=tensor_info,
278+
))
279+
if not flat_values: # The full sequence is empty
280+
flat_values = np.empty(
281+
shape=(0,) + tensor_info.shape[tensor_info.sequence_rank:],
282+
dtype=tensor_info.dtype.as_numpy_dtype,
283+
)
284+
else: # Otherwise, merge all flat values together, some might be empty
285+
flat_values = np.stack(flat_values)
286+
return flat_values, nested_row_lengths[1:]
287+
288+
289+
def _fill_ragged_attribute(ext):
290+
"""Recurse the nested_list from the given RaggedExtraction.
291+
292+
Args:
293+
ext: RaggedExtraction tuple containing the input/outputs
294+
295+
Returns:
296+
None, the function mutate instead `ext.nested_row_lengths` and
297+
`ext.flat_values` lists.
298+
"""
299+
# Register the current sequence length.
300+
# Could be 0 in case of empty list or an np.empty(shape=(0, ...)).
301+
curr_sequence_length = len(ext.nested_list)
302+
ext.nested_row_lengths[ext.curr_ragged_rank].append(curr_sequence_length)
303+
# Sanity check if sequence is static, but should have been catched before
304+
# by `Sequence.encode_example`
305+
expected_sequence_length = ext.tensor_info.shape[ext.curr_ragged_rank]
306+
if (expected_sequence_length is not None and
307+
expected_sequence_length != curr_sequence_length):
308+
raise ValueError(
309+
"Received length {} do not match the expected one {} from {}.".format(
310+
curr_sequence_length, expected_sequence_length, ext.tensor_info))
311+
312+
if ext.curr_ragged_rank < ext.tensor_info.sequence_rank - 1:
313+
# If there are additional Sequence dimension, recurse 1 level deeper.
314+
for sub_list in ext.nested_list:
315+
_fill_ragged_attribute(ext._replace(
316+
nested_list=sub_list,
317+
curr_ragged_rank=ext.curr_ragged_rank + 1,
318+
))
319+
else:
320+
# Otherwise, we reached the max level deep, so add the current items
321+
for item in ext.nested_list:
322+
item = _item_to_np_array( # Normalize the item
323+
item,
324+
dtype=ext.tensor_info.dtype,
325+
# We only check the non-ragged shape
326+
shape=ext.tensor_info.shape[ext.tensor_info.sequence_rank:],
327+
)
328+
ext.flat_values.append(item)

0 commit comments

Comments
 (0)