Skip to content

Commit f759452

Browse files
hawkinspjax authors
authored andcommitted
[XLA:Python] Improve error checking for the return value of the to_iterable function of custom pytree nodes.
PiperOrigin-RevId: 617066587
1 parent 0e1b3e5 commit f759452

File tree

1 file changed

+46
-0
lines changed

1 file changed

+46
-0
lines changed

tests/tree_util_test.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import functools
1818
import pickle
1919
import re
20+
import unittest
2021

2122
from absl.testing import absltest
2223
from absl.testing import parameterized
@@ -25,6 +26,7 @@
2526
from jax import tree_util
2627
from jax import flatten_util
2728
from jax._src import test_util as jtu
29+
from jax._src.lib import xla_extension_version
2830
from jax._src.tree_util import prefix_errors, flatten_one_level
2931
import jax.numpy as jnp
3032

@@ -42,6 +44,19 @@ class ANamedTupleSubclass(ATuple):
4244
tree_util.register_pytree_node(ATuple2, lambda o: ((o.foo,), o.bar),
4345
lambda bar, foo: ATuple2(foo[0], bar))
4446

47+
BadFlattenNonTuple = collections.namedtuple("ATuple2", ("foo", "bar"))
48+
tree_util.register_pytree_node(BadFlattenNonTuple, lambda o: "hello",
49+
lambda bar, foo: ATuple2(foo[0], bar))
50+
51+
BadFlattenBadArityTuple = collections.namedtuple("ATuple2", ("foo", "bar"))
52+
tree_util.register_pytree_node(BadFlattenBadArityTuple, lambda o: (2, 3, 4),
53+
lambda bar, foo: ATuple2(foo[0], bar))
54+
55+
BadFlattenNonIterableLeaves = collections.namedtuple("ATuple2", ("foo", "bar"))
56+
tree_util.register_pytree_node(BadFlattenNonIterableLeaves, lambda o: (7, 7),
57+
lambda bar, foo: ATuple2(foo[0], bar))
58+
59+
4560
class AnObject:
4661

4762
def __init__(self, x, y, z):
@@ -762,6 +777,37 @@ def testNamedTupleRegisteredWithoutKeysIsntTreatedAsLeaf(self):
762777
leaves, _ = tree_util.tree_flatten_with_path(ATuple2(1, 'hi'))
763778
self.assertLen(leaves, 1)
764779

780+
@unittest.skipIf(xla_extension_version < 247, "Requires jaxlib>=0.4.26")
781+
def testBadFlattenNonTuple(self):
782+
t = BadFlattenNonTuple(3, 4)
783+
with self.assertRaisesRegex(
784+
ValueError,
785+
"The to_iterable function for a custom PyTree node should return a"
786+
r" \(children, aux_data\) tuple, got 'hello'",
787+
):
788+
tree_util.tree_flatten(t)
789+
790+
@unittest.skipIf(xla_extension_version < 247, "Requires jaxlib>=0.4.26")
791+
def testBadFlattenBadArityTuple(self):
792+
t = BadFlattenBadArityTuple(3, 4)
793+
with self.assertRaisesRegex(
794+
ValueError,
795+
"The to_iterable function for a custom PyTree node should return a"
796+
r" \(children, aux_data\) tuple, got \(2, 3, 4\)",
797+
):
798+
tree_util.tree_flatten(t)
799+
800+
@unittest.skipIf(xla_extension_version < 247, "Requires jaxlib>=0.4.26")
801+
def testBadFlattenNonIterableLeaves(self):
802+
t = BadFlattenNonIterableLeaves(3, 4)
803+
with self.assertRaisesRegex(
804+
ValueError,
805+
"The to_iterable function for a custom PyTree node should return a"
806+
r" \(children, aux_data\) tuple where 'children' is iterable, got "
807+
r"\(7, 7\)",
808+
):
809+
tree_util.tree_flatten(t)
810+
765811

766812
class StaticTest(parameterized.TestCase):
767813

0 commit comments

Comments
 (0)