Skip to content

Commit c95ad10

Browse files
Conchylicultorcopybara-github
authored andcommitted
Add tf.RaggedTensor support for tfds.as_numpy
PiperOrigin-RevId: 278750600
1 parent ec93f31 commit c95ad10

File tree

2 files changed

+57
-2
lines changed

2 files changed

+57
-2
lines changed

tensorflow_datasets/core/dataset_utils.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def _build_ds_from_instruction(instruction, ds_from_file_fn):
159159
def _eager_dataset_iterator(dataset):
160160
for item in dataset:
161161
flat = tf.nest.flatten(item)
162-
flat = [el.numpy() for el in flat]
162+
flat = [t if isinstance(t, tf.RaggedTensor) else t.numpy() for t in flat]
163163
yield tf.nest.pack_sequence_as(item, flat)
164164

165165

@@ -184,6 +184,13 @@ def as_numpy(dataset, graph=None):
184184
`as_numpy` converts a possibly nested structure of `tf.data.Dataset`s
185185
and `tf.Tensor`s to iterables of NumPy arrays and NumPy arrays, respectively.
186186
187+
Note that because TensorFlow has support for ragged tensors and NumPy has
188+
no equivalent representation,
189+
[`tf.RaggedTensor`s](https://www.tensorflow.org/api_docs/python/tf/RaggedTensor)
190+
are left as-is for the user to deal with them (e.g. using `to_list()`).
191+
In TF 1 (i.e. graph mode), `tf.RaggedTensor`s are returned as
192+
`tf.ragged.RaggedTensorValue`s.
193+
187194
Args:
188195
dataset: a possibly nested structure of `tf.data.Dataset`s and/or
189196
`tf.Tensor`s.
@@ -204,7 +211,9 @@ def as_numpy(dataset, graph=None):
204211
for ds_el in flat_ds:
205212
types = [type(el) for el in flat_ds]
206213
types = tf.nest.pack_sequence_as(nested_ds, types)
207-
if not (isinstance(ds_el, tf.Tensor) or tf_compat.is_dataset(ds_el)):
214+
if not (
215+
isinstance(ds_el, (tf.Tensor, tf.RaggedTensor)) or
216+
tf_compat.is_dataset(ds_el)):
208217
raise ValueError("Arguments to as_numpy must be tf.Tensors or "
209218
"tf.data.Datasets. Got: %s" % types)
210219

@@ -213,6 +222,8 @@ def as_numpy(dataset, graph=None):
213222
for ds_el in flat_ds:
214223
if isinstance(ds_el, tf.Tensor):
215224
np_el = ds_el.numpy()
225+
elif isinstance(ds_el, tf.RaggedTensor):
226+
np_el = ds_el
216227
elif tf_compat.is_dataset(ds_el):
217228
np_el = _eager_dataset_iterator(ds_el)
218229
else:

tensorflow_datasets/core/dataset_utils_test.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,50 @@ def test_tensors_match(self):
135135
# and target may not match
136136
self.assertAllEqual(ds["a"], ds["b"])
137137

138+
@testing.run_in_graph_and_eager_modes()
139+
def test_ragged_tensors(self):
140+
rt = tf.ragged.constant([
141+
[1, 2, 3],
142+
[],
143+
[4, 5],
144+
])
145+
rt = dataset_utils.as_numpy(rt)
146+
147+
if not tf.executing_eagerly():
148+
# Output of `sess.run(rt)` is a `RaggedTensorValue` object
149+
self.assertIsInstance(rt, tf.compat.v1.ragged.RaggedTensorValue)
150+
else:
151+
self.assertIsInstance(rt, tf.RaggedTensor)
152+
153+
self.assertAllEqual(rt, tf.ragged.constant([
154+
[1, 2, 3],
155+
[],
156+
[4, 5],
157+
]))
158+
159+
@testing.run_in_graph_and_eager_modes()
160+
def test_ragged_tensors_ds(self):
161+
def _gen_ragged_tensors():
162+
# Yield the (flat_values, rowids)
163+
yield ([0, 1, 2, 3], [0, 0, 0, 2]) # ex0
164+
yield ([], []) # ex1
165+
yield ([4, 5, 6], [0, 1, 1]) # ex2
166+
ds = tf.data.Dataset.from_generator(
167+
_gen_ragged_tensors,
168+
output_types=(tf.int64, tf.int64),
169+
output_shapes=((None,), (None,))
170+
)
171+
ds = ds.map(tf.RaggedTensor.from_value_rowids)
172+
173+
rt0, rt1, rt2 = list(dataset_utils.as_numpy(ds))
174+
self.assertAllEqual(rt0, [
175+
[0, 1, 2],
176+
[],
177+
[3,],
178+
])
179+
self.assertAllEqual(rt1, [])
180+
self.assertAllEqual(rt2, [[4], [5, 6]])
181+
138182

139183
class DatasetOffsetTest(testing.TestCase):
140184
"""Test that the offset functions are working properly."""

0 commit comments

Comments
 (0)