19
19
from __future__ import division
20
20
from __future__ import print_function
21
21
22
+ import collections
22
23
import numpy as np
23
24
import six
24
25
import tensorflow as tf
25
26
26
27
from tensorflow_datasets .core import utils
28
+ from tensorflow_datasets .core .features import feature as feature_lib
27
29
28
30
29
31
class ExampleSerializer (object ):
@@ -63,21 +65,39 @@ def _dict_to_tf_example(example_dict, tensor_info_dict=None):
63
65
tensor_info_dict: `dict` of `tfds.feature.TensorInfo` If given, perform
64
66
additional checks on the example dict (check dtype, shape, number of
65
67
fields...)
68
+
69
+ Returns:
70
+ example_proto: `tf.train.Example`, the encoded example proto.
66
71
"""
67
- def serialize_single_field ( k , example_data , tensor_info ):
72
+ def run_with_reraise ( fn , k , example_data , tensor_info ):
68
73
with utils .try_reraise (
69
74
"Error while serializing feature {} ({}): " .format (k , tensor_info )):
70
- return _item_to_tf_feature (example_data , tensor_info )
75
+ return fn (example_data , tensor_info )
71
76
72
77
if tensor_info_dict :
73
- example_dict = {
74
- k : serialize_single_field (k , example_data , tensor_info )
78
+ # Add the RaggedTensor fields for the nested sequences
79
+ # Nested sequences are encoded as {'flat_values':, 'row_lengths':}, so need
80
+ # to flatten the example nested dict again.
81
+ # Ex:
82
+ # Input: {'objects/tokens': [[0, 1, 2], [], [3, 4]]}
83
+ # Output: {
84
+ # 'objects/tokens/flat_values': [0, 1, 2, 3, 4],
85
+ # 'objects/tokens/row_lengths_0': [3, 0, 2],
86
+ # }
87
+ example_dict = utils .flatten_nest_dict ({
88
+ k : run_with_reraise (_add_ragged_fields , k , example_data , tensor_info )
75
89
for k , (example_data , tensor_info )
76
90
in utils .zip_dict (example_dict , tensor_info_dict )
91
+ })
92
+ example_dict = {
93
+ k : run_with_reraise (_item_to_tf_feature , k , item , tensor_info )
94
+ for k , (item , tensor_info ) in example_dict .items ()
77
95
}
78
96
else :
97
+ # TODO(epot): The following code is only executed in tests and could be
98
+ # cleanned-up, as TensorInfo is always passed to _item_to_tf_feature.
79
99
example_dict = {
80
- k : serialize_single_field ( k , example_data , None )
100
+ k : run_with_reraise ( _item_to_tf_feature , k , example_data , None )
81
101
for k , example_data in example_dict .items ()
82
102
}
83
103
@@ -88,18 +108,31 @@ def _is_string(item):
88
108
"""Check if the object contains string or bytes."""
89
109
if isinstance (item , (six .binary_type , six .string_types )):
90
110
return True
91
- elif (isinstance (item , (tuple , list )) and
92
- all (isinstance (x , (six .binary_type , six .string_types )) for x in item )):
111
+ elif (isinstance (item , (tuple , list )) and all (_is_string (x ) for x in item )):
93
112
return True
94
113
elif (isinstance (item , np .ndarray ) and # binary or unicode
95
114
(item .dtype .kind in ("U" , "S" ) or item .dtype == object )):
96
115
return True
97
116
return False
98
117
99
118
119
+ def _item_to_np_array (item , dtype , shape ):
120
+ """Single item to a np.array."""
121
+ original_item = item
122
+ item = np .array (item , dtype = dtype .as_numpy_dtype )
123
+ utils .assert_shape_match (item .shape , shape )
124
+ if dtype == tf .string and not _is_string (original_item ):
125
+ raise ValueError (
126
+ "Unsuported value: {}\n Could not convert to bytes list." .format (item ))
127
+ return item
128
+
129
+
100
130
def _item_to_tf_feature (item , tensor_info = None ):
101
131
"""Single item to a tf.train.Feature."""
102
132
v = item
133
+ # TODO(epot): tensor_info is only None for file_format_adapter tests.
134
+ # tensor_info could be made required to cleanup some of the following code,
135
+ # for instance by re-using _item_to_np_array.
103
136
if not tensor_info and isinstance (v , (list , tuple )) and not v :
104
137
raise ValueError (
105
138
"Received an empty list value, so is unable to infer the "
@@ -146,3 +179,150 @@ def _item_to_tf_feature(item, tensor_info=None):
146
179
"This may indicate that one of the FeatureConnectors received an "
147
180
"unsupported value as input." .format (repr (v ), repr (type (v )))
148
181
)
182
+
183
+
184
+ RaggedExtraction = collections .namedtuple ("RaggedExtraction" , [
185
+ "nested_list" ,
186
+ "flat_values" ,
187
+ "nested_row_lengths" ,
188
+ "curr_ragged_rank" ,
189
+ "tensor_info" ,
190
+ ])
191
+
192
+
193
+ def _add_ragged_fields (example_data , tensor_info ):
194
+ """Optionally convert the ragged data into flat/row_lengths fields.
195
+
196
+ Example:
197
+
198
+ ```
199
+ example_data = [
200
+ [1, 2, 3],
201
+ [],
202
+ [4, 5]
203
+ ]
204
+ tensor_info = TensorInfo(shape=(None, None,), sequence_rank=2, ...)
205
+ out = _add_ragged_fields(example_data, tensor_info)
206
+ out == {
207
+ 'ragged_flat_values': ([0, 1, 2, 3, 4, 5], TensorInfo(shape=(), ...)),
208
+ 'ragged_row_length_0': ([3, 0, 2], TensorInfo(shape=(None,), ...))
209
+ }
210
+ ```
211
+
212
+ If `example_data` isn't ragged, `example_data` and `tensor_info` are
213
+ forwarded as-is.
214
+
215
+ Args:
216
+ example_data: Data to optionally convert to ragged data.
217
+ tensor_info: TensorInfo associated with the given data.
218
+
219
+ Returns:
220
+ A tuple(example_data, tensor_info) if the tensor isn't ragged, or a dict of
221
+ tuple(example_data, tensor_info) if the tensor is ragged.
222
+ """
223
+ # Step 1: Extract the ragged tensor info
224
+ if tensor_info .sequence_rank :
225
+ # If the input is ragged, extract the nested values.
226
+ # 1-level sequences are converted as numpy and stacked.
227
+ # If the sequence is empty, a np.empty(shape=(0, ...)) array is returned.
228
+ example_data , nested_row_lengths = _extract_ragged_attributes (
229
+ example_data , tensor_info )
230
+
231
+ # Step 2: Format the ragged tensor data as dict
232
+ # No sequence or 1-level sequence, forward the data.
233
+ # Could eventually handle multi-level sequences with static lengths
234
+ # in a smarter way.
235
+ if tensor_info .sequence_rank < 2 :
236
+ return (example_data , tensor_info )
237
+ # Multiple level sequence:
238
+ else :
239
+ tensor_info_length = feature_lib .TensorInfo (shape = (None ,), dtype = tf .int64 )
240
+ ragged_attr_dict = {
241
+ "ragged_row_lengths_{}" .format (i ): (length , tensor_info_length )
242
+ for i , length in enumerate (nested_row_lengths )
243
+ }
244
+ tensor_info_flat = feature_lib .TensorInfo (
245
+ shape = (None ,) + tensor_info .shape [tensor_info .sequence_rank :],
246
+ dtype = tensor_info .dtype ,
247
+ )
248
+ ragged_attr_dict ["ragged_flat_values" ] = (example_data , tensor_info_flat )
249
+ return ragged_attr_dict
250
+
251
+
252
+ def _extract_ragged_attributes (nested_list , tensor_info ):
253
+ """Extract the values for the tf.RaggedTensor __init__.
254
+
255
+ This extract the ragged tensor attributes which allow reconstruct the
256
+ ragged tensor with `tf.RaggedTensor.from_nested_row_lengths`.
257
+
258
+ Args:
259
+ nested_list: A nested list containing the ragged tensor values
260
+ tensor_info: The specs of the ragged tensor
261
+
262
+ Returns:
263
+ flat_values: The flatten values of the ragged tensor. All values from each
264
+ list will be converted to np.array and stacked together.
265
+ nested_row_lengths: The row lengths for each ragged dimensions.
266
+ """
267
+ assert tensor_info .sequence_rank , "{} is not ragged." .format (tensor_info )
268
+
269
+ flat_values = []
270
+ nested_row_lengths = [[] for _ in range (tensor_info .sequence_rank )]
271
+ # Reccursivelly append to `flat_values`, `nested_row_lengths`
272
+ _fill_ragged_attribute (RaggedExtraction (
273
+ nested_list = nested_list ,
274
+ flat_values = flat_values ,
275
+ nested_row_lengths = nested_row_lengths ,
276
+ curr_ragged_rank = 0 ,
277
+ tensor_info = tensor_info ,
278
+ ))
279
+ if not flat_values : # The full sequence is empty
280
+ flat_values = np .empty (
281
+ shape = (0 ,) + tensor_info .shape [tensor_info .sequence_rank :],
282
+ dtype = tensor_info .dtype .as_numpy_dtype ,
283
+ )
284
+ else : # Otherwise, merge all flat values together, some might be empty
285
+ flat_values = np .stack (flat_values )
286
+ return flat_values , nested_row_lengths [1 :]
287
+
288
+
289
+ def _fill_ragged_attribute (ext ):
290
+ """Recurse the nested_list from the given RaggedExtraction.
291
+
292
+ Args:
293
+ ext: RaggedExtraction tuple containing the input/outputs
294
+
295
+ Returns:
296
+ None, the function mutate instead `ext.nested_row_lengths` and
297
+ `ext.flat_values` lists.
298
+ """
299
+ # Register the current sequence length.
300
+ # Could be 0 in case of empty list or an np.empty(shape=(0, ...)).
301
+ curr_sequence_length = len (ext .nested_list )
302
+ ext .nested_row_lengths [ext .curr_ragged_rank ].append (curr_sequence_length )
303
+ # Sanity check if sequence is static, but should have been catched before
304
+ # by `Sequence.encode_example`
305
+ expected_sequence_length = ext .tensor_info .shape [ext .curr_ragged_rank ]
306
+ if (expected_sequence_length is not None and
307
+ expected_sequence_length != curr_sequence_length ):
308
+ raise ValueError (
309
+ "Received length {} do not match the expected one {} from {}." .format (
310
+ curr_sequence_length , expected_sequence_length , ext .tensor_info ))
311
+
312
+ if ext .curr_ragged_rank < ext .tensor_info .sequence_rank - 1 :
313
+ # If there are additional Sequence dimension, recurse 1 level deeper.
314
+ for sub_list in ext .nested_list :
315
+ _fill_ragged_attribute (ext ._replace (
316
+ nested_list = sub_list ,
317
+ curr_ragged_rank = ext .curr_ragged_rank + 1 ,
318
+ ))
319
+ else :
320
+ # Otherwise, we reached the max level deep, so add the current items
321
+ for item in ext .nested_list :
322
+ item = _item_to_np_array ( # Normalize the item
323
+ item ,
324
+ dtype = ext .tensor_info .dtype ,
325
+ # We only check the non-ragged shape
326
+ shape = ext .tensor_info .shape [ext .tensor_info .sequence_rank :],
327
+ )
328
+ ext .flat_values .append (item )
0 commit comments