|
25 | 25 |
|
26 | 26 | from jax._src import traceback_util
|
27 | 27 | from jax._src.lib import pytree
|
| 28 | +from jax._src.lib import xla_extension_version |
28 | 29 | from jax._src.util import safe_zip
|
29 | 30 | from jax._src.util import unzip2
|
30 | 31 |
|
|
44 | 45 | default_registry.__module__ = __name__
|
45 | 46 | default_registry.__name__ = "default_registry"
|
46 | 47 |
|
| 48 | +# A copy of the default registry, where None is a leaf. |
| 49 | +none_leaf_registry = pytree.PyTreeRegistry( |
| 50 | + enable_none=False, enable_tuple=True, enable_namedtuple=True, |
| 51 | + enable_list=True, enable_dict=True) |
| 52 | +none_leaf_registry.__module__ = __name__ |
| 53 | +none_leaf_registry.__name__ = "none_leaf_registry" |
| 54 | + |
47 | 55 | # A special, internal pytree registry that includes everything in
|
48 | 56 | # `default_registry`, plus internal Python-defined types that we want
|
49 | 57 | # to teach the fast dispatch path ("C++ dispatch") how to flatten and
|
@@ -242,6 +250,7 @@ def register_pytree_node(nodetype: type[T],
|
242 | 250 | ``nodetype``.
|
243 | 251 | """
|
244 | 252 | default_registry.register_node(nodetype, flatten_func, unflatten_func)
|
| 253 | + none_leaf_registry.register_node(nodetype, flatten_func, unflatten_func) |
245 | 254 | dispatch_registry.register_node(nodetype, flatten_func, unflatten_func)
|
246 | 255 | _registry[nodetype] = _RegistryEntry(flatten_func, unflatten_func)
|
247 | 256 |
|
@@ -374,21 +383,9 @@ def tree_transpose(outer_treedef: PyTreeDef, inner_treedef: PyTreeDef | None,
|
374 | 383 |
|
375 | 384 | def _replace_nones(sentinel, tree):
|
376 | 385 | """Replaces ``None`` in ``tree`` with ``sentinel``."""
|
377 |
| - if tree is None: |
378 |
| - return sentinel |
379 |
| - else: |
380 |
| - handler = _registry.get(type(tree)) |
381 |
| - if handler: |
382 |
| - children, metadata = handler.to_iter(tree) |
383 |
| - proc_children = [_replace_nones(sentinel, child) for child in children] |
384 |
| - return handler.from_iter(metadata, proc_children) |
385 |
| - elif isinstance(tree, tuple) and hasattr(tree, "_fields"): |
386 |
| - # handle namedtuple as a special case, based on heuristic |
387 |
| - children = iter(tree) |
388 |
| - proc_children = [_replace_nones(sentinel, child) for child in children] |
389 |
| - return type(tree)(*proc_children) |
390 |
| - else: |
391 |
| - return tree |
| 386 | + leaves, treedef = none_leaf_registry.flatten(tree) |
| 387 | + leaves = map(lambda x: sentinel if x is None else x, leaves) |
| 388 | + return treedef.unflatten(leaves) |
392 | 389 |
|
393 | 390 |
|
394 | 391 | no_initializer = object()
|
@@ -586,29 +583,50 @@ def broadcast_prefix(prefix_tree: Any, full_tree: Any,
|
586 | 583 | tree_map(add_leaves, prefix_tree, full_tree, is_leaf=is_leaf)
|
587 | 584 | return result
|
588 | 585 |
|
589 |
| -def flatten_one_level(pytree: Any) -> tuple[list[Any], Hashable]: |
590 |
| - """Flatten the given pytree node by one level. |
| 586 | +if xla_extension_version >= 248: |
| 587 | + def flatten_one_level(pytree: Any) -> tuple[Iterable[Any], Hashable]: |
| 588 | + """Flatten the given pytree node by one level. |
591 | 589 |
|
592 |
| - Args: |
593 |
| - pytree: A valid pytree node, either built-in or registered via |
594 |
| - ``register_pytree_node`` or ``register_pytree_with_keys``. |
| 590 | + Args: |
| 591 | + pytree: A valid pytree node, either built-in or registered via |
| 592 | + ``register_pytree_node`` or ``register_pytree_with_keys``. |
595 | 593 |
|
596 |
| - Returns: |
597 |
| - A pair of the pytree's flattened children and its hashable metadata. |
| 594 | + Returns: |
| 595 | + A pair of the pytree's flattened children and its hashable metadata. |
598 | 596 |
|
599 |
| - Raises: |
600 |
| - ValueError: If the given pytree is not a built-in or registered container |
601 |
| - via ``register_pytree_node`` or ``register_pytree_with_keys``. |
602 |
| - """ |
603 |
| - handler = _registry.get(type(pytree)) |
604 |
| - if handler: |
605 |
| - children, meta = handler.to_iter(pytree) |
606 |
| - return list(children), meta |
607 |
| - elif isinstance(pytree, tuple) and hasattr(pytree, '_fields'): |
608 |
| - # handle namedtuple as a special case, based on heuristic |
609 |
| - return [getattr(pytree, s) for s in pytree._fields], None |
610 |
| - else: |
611 |
| - raise ValueError(f"can't tree-flatten type: {type(pytree)}") |
| 597 | + Raises: |
| 598 | + ValueError: If the given pytree is not a built-in or registered container |
| 599 | + via ``register_pytree_node`` or ``register_pytree_with_keys``. |
| 600 | + """ |
| 601 | + out = default_registry.flatten_one_level(pytree) |
| 602 | + if out is None: |
| 603 | + raise ValueError(f"can't tree-flatten type: {type(pytree)}") |
| 604 | + else: |
| 605 | + return out |
| 606 | +else: |
| 607 | + def flatten_one_level(pytree: Any) -> tuple[Iterable[Any], Hashable]: |
| 608 | + """Flatten the given pytree node by one level. |
| 609 | +
|
| 610 | + Args: |
| 611 | + pytree: A valid pytree node, either built-in or registered via |
| 612 | + ``register_pytree_node`` or ``register_pytree_with_keys``. |
| 613 | +
|
| 614 | + Returns: |
| 615 | + A pair of the pytree's flattened children and its hashable metadata. |
| 616 | +
|
| 617 | + Raises: |
| 618 | + ValueError: If the given pytree is not a built-in or registered container |
| 619 | + via ``register_pytree_node`` or ``register_pytree_with_keys``. |
| 620 | + """ |
| 621 | + handler = _registry.get(type(pytree)) |
| 622 | + if handler: |
| 623 | + children, meta = handler.to_iter(pytree) |
| 624 | + return list(children), meta |
| 625 | + elif isinstance(pytree, tuple) and hasattr(pytree, '_fields'): |
| 626 | + # handle namedtuple as a special case, based on heuristic |
| 627 | + return [getattr(pytree, s) for s in pytree._fields], None |
| 628 | + else: |
| 629 | + raise ValueError(f"can't tree-flatten type: {type(pytree)}") |
612 | 630 |
|
613 | 631 | def prefix_errors(prefix_tree: Any, full_tree: Any,
|
614 | 632 | is_leaf: Callable[[Any], bool] | None = None,
|
@@ -659,6 +677,8 @@ def _equality_errors(path, t1, t2, is_leaf):
|
659 | 677 | return # no more errors to find
|
660 | 678 | t1_children, t1_meta = flatten_one_level(t1)
|
661 | 679 | t2_children, t2_meta = flatten_one_level(t2)
|
| 680 | + t1_children = tuple(t1_children) |
| 681 | + t2_children = tuple(t2_children) |
662 | 682 | t1_keys, t2_keys = _child_keys(t1), _child_keys(t2)
|
663 | 683 | try:
|
664 | 684 | diff = ' '.join(repr(k.key) for k in
|
@@ -905,32 +925,64 @@ def generate_key_paths(
|
905 | 925 |
|
906 | 926 |
|
907 | 927 | # The overall logic should be same as PyTreeDef::FlattenIntoImpl
|
908 |
| -def _generate_key_paths_( |
909 |
| - key_path: KeyPath, |
910 |
| - tree: Any, |
911 |
| - is_leaf: Callable[[Any], bool] | None = None, |
912 |
| -) -> Iterable[tuple[KeyPath, Any]]: |
913 |
| - if is_leaf and is_leaf(tree): |
914 |
| - yield key_path, tree |
915 |
| - return |
916 |
| - key_handler = _registry_with_keypaths.get(type(tree)) |
917 |
| - handler = _registry.get(type(tree)) |
918 |
| - if key_handler: |
919 |
| - key_children, _ = key_handler.flatten_with_keys(tree) |
920 |
| - for k, c in key_children: |
921 |
| - yield from _generate_key_paths_((*key_path, k), c, is_leaf) |
922 |
| - elif handler: |
923 |
| - children, _ = handler.to_iter(tree) |
924 |
| - for i, c in enumerate(children): |
| 928 | +if xla_extension_version >= 248: |
| 929 | + def _generate_key_paths_( |
| 930 | + key_path: KeyPath, |
| 931 | + tree: Any, |
| 932 | + is_leaf: Callable[[Any], bool] | None = None, |
| 933 | + ) -> Iterable[tuple[KeyPath, Any]]: |
| 934 | + if is_leaf and is_leaf(tree): |
| 935 | + yield key_path, tree |
| 936 | + return |
| 937 | + key_handler = _registry_with_keypaths.get(type(tree)) |
| 938 | + if key_handler: |
| 939 | + key_children, _ = key_handler.flatten_with_keys(tree) |
| 940 | + for k, c in key_children: |
| 941 | + yield from _generate_key_paths_((*key_path, k), c, is_leaf) |
| 942 | + return |
| 943 | + |
| 944 | + flat = default_registry.flatten_one_level(tree) |
| 945 | + if flat is None: |
| 946 | + yield key_path, tree # strict leaf type |
| 947 | + return |
| 948 | + |
| 949 | + if (isinstance(tree, tuple) and hasattr(tree, '_fields') and |
| 950 | + flat[1] == type(tree)): |
| 951 | + # handle namedtuple as a special case, based on heuristic |
| 952 | + key_children = [(GetAttrKey(s), getattr(tree, s)) for s in tree._fields] |
| 953 | + for k, c in key_children: |
| 954 | + yield from _generate_key_paths_((*key_path, k), c, is_leaf) |
| 955 | + return |
| 956 | + |
| 957 | + for i, c in enumerate(flat[0]): |
925 | 958 | k = FlattenedIndexKey(i)
|
926 | 959 | yield from _generate_key_paths_((*key_path, k), c, is_leaf)
|
927 |
| - elif isinstance(tree, tuple) and hasattr(tree, '_fields'): |
928 |
| - # handle namedtuple as a special case, based on heuristic |
929 |
| - key_children = [(GetAttrKey(s), getattr(tree, s)) for s in tree._fields] |
930 |
| - for k, c in key_children: |
931 |
| - yield from _generate_key_paths_((*key_path, k), c, is_leaf) |
932 |
| - else: |
933 |
| - yield key_path, tree # strict leaf type |
| 960 | +else: |
| 961 | + def _generate_key_paths_( |
| 962 | + key_path: KeyPath, |
| 963 | + tree: Any, |
| 964 | + is_leaf: Callable[[Any], bool] | None = None, |
| 965 | + ) -> Iterable[tuple[KeyPath, Any]]: |
| 966 | + if is_leaf and is_leaf(tree): |
| 967 | + yield key_path, tree |
| 968 | + return |
| 969 | + key_handler = _registry_with_keypaths.get(type(tree)) |
| 970 | + if key_handler: |
| 971 | + key_children, _ = key_handler.flatten_with_keys(tree) |
| 972 | + for k, c in key_children: |
| 973 | + yield from _generate_key_paths_((*key_path, k), c, is_leaf) |
| 974 | + elif handler := _registry.get(type(tree)): |
| 975 | + children, _ = handler.to_iter(tree) |
| 976 | + for i, c in enumerate(children): |
| 977 | + k = FlattenedIndexKey(i) |
| 978 | + yield from _generate_key_paths_((*key_path, k), c, is_leaf) |
| 979 | + elif isinstance(tree, tuple) and hasattr(tree, '_fields'): |
| 980 | + # handle namedtuple as a special case, based on heuristic |
| 981 | + key_children = [(GetAttrKey(s), getattr(tree, s)) for s in tree._fields] |
| 982 | + for k, c in key_children: |
| 983 | + yield from _generate_key_paths_((*key_path, k), c, is_leaf) |
| 984 | + else: |
| 985 | + yield key_path, tree # strict leaf type |
934 | 986 |
|
935 | 987 |
|
936 | 988 | def tree_map_with_path(f: Callable[..., Any],
|
@@ -1001,6 +1053,8 @@ def _prefix_error(
|
1001 | 1053 | # point, and because prefix_tree is not a leaf, each can be flattened once:
|
1002 | 1054 | prefix_tree_children, prefix_tree_meta = flatten_one_level(prefix_tree)
|
1003 | 1055 | full_tree_children, full_tree_meta = flatten_one_level(full_tree)
|
| 1056 | + prefix_tree_children = tuple(prefix_tree_children) |
| 1057 | + full_tree_children = tuple(full_tree_children) |
1004 | 1058 | prefix_tree_keys = _child_keys(prefix_tree)
|
1005 | 1059 | full_tree_keys = _child_keys(full_tree)
|
1006 | 1060 | # First we check special case types (list and tuple, though if they were
|
|
0 commit comments