Skip to content

Commit e61510d

Browse files
authored
Improve tf.RaggedTensor support in DataAdapters. (#21170)
Previously, only 2D Tensorflow ragged tensors were supported. This adds support for any rank. Also added tests for ragged tensors with `GeneratorDataAdapter`.
1 parent 44a655b commit e61510d

File tree

2 files changed

+37
-2
lines changed

2 files changed

+37
-2
lines changed

keras/src/trainers/data_adapters/data_adapter_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,12 @@ def get_single_tensor_spec(*tensors):
168168

169169
dtype = backend.standardize_dtype(x.dtype)
170170
if isinstance(x, tf.RaggedTensor):
171-
return tf.RaggedTensorSpec(shape=shape, dtype=dtype)
171+
return tf.RaggedTensorSpec(
172+
shape=shape,
173+
dtype=dtype,
174+
ragged_rank=x.ragged_rank,
175+
row_splits_dtype=x.row_splits.dtype,
176+
)
172177
if (
173178
isinstance(x, tf.SparseTensor)
174179
or is_scipy_sparse(x)

keras/src/trainers/data_adapters/generator_data_adapter_test.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def generator():
166166
not backend.SUPPORTS_SPARSE_TENSORS,
167167
reason="Backend does not support sparse tensors",
168168
)
169-
def test_scipy_sparse_tensors(self, generator_type):
169+
def test_sparse_tensors(self, generator_type):
170170
if generator_type == "tf":
171171
x = tf.SparseTensor([[0, 0], [1, 2]], [1.0, 2.0], (2, 4))
172172
y = tf.SparseTensor([[0, 0], [1, 1]], [3.0, 4.0], (2, 2))
@@ -197,3 +197,33 @@ def generate():
197197
self.assertIsInstance(by, expected_class)
198198
self.assertEqual(bx.shape, (2, 4))
199199
self.assertEqual(by.shape, (2, 2))
200+
201+
@pytest.mark.skipif(
202+
not backend.SUPPORTS_RAGGED_TENSORS,
203+
reason="Backend does not support ragged tensors",
204+
)
205+
def test_ragged_tensors(self):
206+
x = tf.ragged.constant(
207+
[[[0.0, 1.0]], [[2.0, 3.0], [4.0, 5.0]]], ragged_rank=1
208+
)
209+
y = tf.ragged.constant(
210+
[[[0.0, 1.0]], [[0.0, 1.0], [0.0, 1.0]]], ragged_rank=1
211+
)
212+
213+
def generate():
214+
for _ in range(4):
215+
yield x, y
216+
217+
adapter = generator_data_adapter.GeneratorDataAdapter(generate())
218+
219+
if backend.backend() == "tensorflow":
220+
it = adapter.get_tf_dataset()
221+
expected_class = tf.RaggedTensor
222+
223+
for batch in it:
224+
self.assertEqual(len(batch), 2)
225+
bx, by = batch
226+
self.assertIsInstance(bx, expected_class)
227+
self.assertIsInstance(by, expected_class)
228+
self.assertEqual(bx.shape, (2, None, 2))
229+
self.assertEqual(by.shape, (2, None, 2))

0 commit comments

Comments
 (0)