Skip to content

Commit 5ab8036

Browse files
Conchylicultorcopybara-github
authored andcommitted
Add test support for ragged tensor
PiperOrigin-RevId: 257251800
1 parent 37eb94b commit 5ab8036

File tree

2 files changed

+126
-1
lines changed

2 files changed

+126
-1
lines changed
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# coding=utf-8
2+
# Copyright 2019 The TensorFlow Datasets Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Test utils for tensorflow RaggedTensors.
17+
18+
Copied from the tensorflow/python/ops/ragged/ragged_test_util.py
19+
20+
TODO(epot): Delete this with the next TF public release.
21+
"""
22+
23+
from __future__ import absolute_import
24+
from __future__ import division
25+
from __future__ import print_function
26+
27+
import numpy as np
28+
29+
import tensorflow as tf
30+
31+
32+
class RaggedTensorTestCase(tf.test.TestCase):
33+
"""Base class for RaggedTensor test cases."""
34+
35+
def _GetPyList(self, a):
36+
"""Converts a to a nested python list."""
37+
if isinstance(a, tf.RaggedTensor):
38+
return self.evaluate(a).to_list()
39+
elif isinstance(a, tf.Tensor):
40+
a = self.evaluate(a)
41+
return a.tolist() if isinstance(a, np.ndarray) else a
42+
elif isinstance(a, np.ndarray):
43+
return a.tolist()
44+
elif isinstance(a, tf.ragged.RaggedTensorValue):
45+
return a.to_list()
46+
else:
47+
return np.array(a).tolist()
48+
49+
def assertRaggedEqual(self, a, b):
50+
"""Asserts that two potentially ragged tensors are equal."""
51+
a_list = self._GetPyList(a)
52+
b_list = self._GetPyList(b)
53+
self.assertEqual(a_list, b_list)
54+
55+
if not (isinstance(a, (list, tuple)) or isinstance(b, (list, tuple))):
56+
a_ragged_rank = a.ragged_rank if is_ragged(a) else 0
57+
b_ragged_rank = b.ragged_rank if is_ragged(b) else 0
58+
self.assertEqual(a_ragged_rank, b_ragged_rank)
59+
60+
def assertRaggedAlmostEqual(self, a, b, places=7):
61+
a_list = self._GetPyList(a)
62+
b_list = self._GetPyList(b)
63+
self.assertNestedListAlmostEqual(a_list, b_list, places, context='value')
64+
65+
if not (isinstance(a, (list, tuple)) or isinstance(b, (list, tuple))):
66+
a_ragged_rank = a.ragged_rank if is_ragged(a) else 0
67+
b_ragged_rank = b.ragged_rank if is_ragged(b) else 0
68+
self.assertEqual(a_ragged_rank, b_ragged_rank)
69+
70+
def assertNestedListAlmostEqual(self, a, b, places=7, context='value'):
71+
self.assertEqual(type(a), type(b))
72+
if isinstance(a, (list, tuple)):
73+
self.assertLen(a, len(b), 'Length differs for %s' % context)
74+
for i in range(len(a)):
75+
self.assertNestedListAlmostEqual(a[i], b[i], places,
76+
'%s[%s]' % (context, i))
77+
else:
78+
self.assertAlmostEqual(
79+
a, b, places,
80+
'%s != %s within %s places at %s' % (a, b, places, context))
81+
82+
def eval_to_list(self, tensor):
83+
value = self.evaluate(tensor)
84+
if is_ragged(value):
85+
return value.to_list()
86+
elif isinstance(value, np.ndarray):
87+
return value.tolist()
88+
else:
89+
return value
90+
91+
def _eval_tensor(self, tensor):
92+
if is_ragged(tensor):
93+
return tf.ragged.RaggedTensorValue(
94+
self._eval_tensor(tensor.values),
95+
self._eval_tensor(tensor.row_splits))
96+
else:
97+
return tf.test.TestCase._eval_tensor(self, tensor)
98+
99+
@staticmethod
100+
def _normalize_pylist(item):
101+
"""Convert all (possibly nested) np.arrays contained in item to list."""
102+
# convert np.arrays in current level to list
103+
if np.ndim(item) == 0:
104+
return item
105+
level = (x.tolist() if isinstance(x, np.ndarray) else x for x in item)
106+
_normalize = RaggedTensorTestCase._normalize_pylist # pylint: disable=invalid-name
107+
return [_normalize(el) if np.ndim(el) != 0 else el for el in level]
108+
109+
110+
def is_ragged(value):
111+
"""Returns true if `value` is a ragged tensor or ragged tensor value."""
112+
return isinstance(value, (tf.RaggedTensor, tf.ragged.RaggedTensorValue))

tensorflow_datasets/testing/test_utils.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from tensorflow_datasets.core import file_format_adapter
3838
from tensorflow_datasets.core import splits
3939
from tensorflow_datasets.core import utils
40+
from tensorflow_datasets.testing import ragged_test_util
4041
from tensorflow_datasets.testing import test_case
4142

4243

@@ -84,7 +85,7 @@ def __init__(
8485
self.raise_msg = raise_msg
8586

8687

87-
class SubTestCase(test_case.TestCase):
88+
class SubTestCase(ragged_test_util.RaggedTensorTestCase, test_case.TestCase):
8889
"""Adds subTest() context manager to the TestCase if supported.
8990
9091
Note: To use this feature, make sure you call super() in setUpClass to
@@ -108,6 +109,18 @@ def _subTest(self, test_str):
108109
yield
109110
self._sub_test_stack.pop()
110111

112+
def assertAllEqual(self, d1, d2):
113+
"""Same as assertAllEqual but with RaggedTensor support."""
114+
# TODO(epot): This function as well as RaggedTensorTestCase could be
115+
# removed once tf.test.TestCase support RaggedTensor
116+
if any(isinstance(d, tf.RaggedTensor) for d in (d1, d2)):
117+
d1, d2 = [ # Required to support list of np.array
118+
d if isinstance(d, tf.RaggedTensor) else tf.ragged.constant(d)
119+
for d in (d1, d2)
120+
]
121+
return self.assertRaggedEqual(d1, d2)
122+
return super(SubTestCase, self).assertAllEqual(d1, d2)
123+
111124
def assertAllEqualNested(self, d1, d2):
112125
"""Same as assertAllEqual but compatible with nested dict."""
113126
if isinstance(d1, dict):

0 commit comments

Comments
 (0)