Skip to content

Commit fdbee31

Browse files
hawkinspjax authors
authored andcommitted
Make JAX tests that check for errors from dict key comparators in pytrees more relaxed, in preparation for openxla/xla#9529.
PiperOrigin-RevId: 610819296
1 parent e00149c commit fdbee31

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

tests/api_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1329,7 +1329,9 @@ class E(enum.Enum):
13291329
def f(d) -> float:
13301330
return d[E.A]
13311331

1332-
with self.assertRaisesRegex(TypeError, "'<' not supported.*"):
1332+
with self.assertRaisesRegex(
1333+
(TypeError, ValueError),
1334+
"('<' not supported|Comparator raised exception).*"):
13331335
f({E.A: 1.0, E.B: 2.0})
13341336

13351337
def test_jit_static_argnums_requires_type_equality(self):

tests/tree_util_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -582,7 +582,9 @@ def testPickleRoundTrip(self, tree):
582582

583583
def testDictKeysSortable(self):
584584
d = {"a": 1, 2: "b"}
585-
with self.assertRaisesRegex(TypeError, "'<' not supported"):
585+
with self.assertRaisesRegex(
586+
(TypeError, ValueError),
587+
"('<' not supported|Comparator raised exception).*"):
586588
_, _ = tree_util.tree_flatten(d)
587589

588590
def testFlattenDictKeyOrder(self):

0 commit comments

Comments
 (0)