17
17
import functools
18
18
import pickle
19
19
import re
20
+ import unittest
20
21
21
22
from absl .testing import absltest
22
23
from absl .testing import parameterized
25
26
from jax import tree_util
26
27
from jax import flatten_util
27
28
from jax ._src import test_util as jtu
29
+ from jax ._src .lib import xla_extension_version
28
30
from jax ._src .tree_util import prefix_errors , flatten_one_level
29
31
import jax .numpy as jnp
30
32
@@ -42,6 +44,19 @@ class ANamedTupleSubclass(ATuple):
42
44
tree_util .register_pytree_node (ATuple2 , lambda o : ((o .foo ,), o .bar ),
43
45
lambda bar , foo : ATuple2 (foo [0 ], bar ))
44
46
47
+ BadFlattenNonTuple = collections .namedtuple ("ATuple2" , ("foo" , "bar" ))
48
+ tree_util .register_pytree_node (BadFlattenNonTuple , lambda o : "hello" ,
49
+ lambda bar , foo : ATuple2 (foo [0 ], bar ))
50
+
51
+ BadFlattenBadArityTuple = collections .namedtuple ("ATuple2" , ("foo" , "bar" ))
52
+ tree_util .register_pytree_node (BadFlattenBadArityTuple , lambda o : (2 , 3 , 4 ),
53
+ lambda bar , foo : ATuple2 (foo [0 ], bar ))
54
+
55
+ BadFlattenNonIterableLeaves = collections .namedtuple ("ATuple2" , ("foo" , "bar" ))
56
+ tree_util .register_pytree_node (BadFlattenNonIterableLeaves , lambda o : (7 , 7 ),
57
+ lambda bar , foo : ATuple2 (foo [0 ], bar ))
58
+
59
+
45
60
class AnObject :
46
61
47
62
def __init__ (self , x , y , z ):
@@ -762,6 +777,37 @@ def testNamedTupleRegisteredWithoutKeysIsntTreatedAsLeaf(self):
762
777
leaves , _ = tree_util .tree_flatten_with_path (ATuple2 (1 , 'hi' ))
763
778
self .assertLen (leaves , 1 )
764
779
780
+ @unittest .skipIf (xla_extension_version < 247 , "Requires jaxlib>=0.4.26" )
781
+ def testBadFlattenNonTuple (self ):
782
+ t = BadFlattenNonTuple (3 , 4 )
783
+ with self .assertRaisesRegex (
784
+ ValueError ,
785
+ "The to_iterable function for a custom PyTree node should return a"
786
+ r" \(children, aux_data\) tuple, got 'hello'" ,
787
+ ):
788
+ tree_util .tree_flatten (t )
789
+
790
+ @unittest .skipIf (xla_extension_version < 247 , "Requires jaxlib>=0.4.26" )
791
+ def testBadFlattenBadArityTuple (self ):
792
+ t = BadFlattenBadArityTuple (3 , 4 )
793
+ with self .assertRaisesRegex (
794
+ ValueError ,
795
+ "The to_iterable function for a custom PyTree node should return a"
796
+ r" \(children, aux_data\) tuple, got \(2, 3, 4\)" ,
797
+ ):
798
+ tree_util .tree_flatten (t )
799
+
800
+ @unittest .skipIf (xla_extension_version < 247 , "Requires jaxlib>=0.4.26" )
801
+ def testBadFlattenNonIterableLeaves (self ):
802
+ t = BadFlattenNonIterableLeaves (3 , 4 )
803
+ with self .assertRaisesRegex (
804
+ ValueError ,
805
+ "The to_iterable function for a custom PyTree node should return a"
806
+ r" \(children, aux_data\) tuple where 'children' is iterable, got "
807
+ r"\(7, 7\)" ,
808
+ ):
809
+ tree_util .tree_flatten (t )
810
+
765
811
766
812
class StaticTest (parameterized .TestCase ):
767
813
0 commit comments