Skip to content

Commit 5532e55

Browse files
hawkinspjax authors
authored andcommitted
[XLA:Python] Add a C++ implementation of flatten_one_level.
Also add a copy of the default registry that doesn't have None registered as a leaf, which is slightly faster than using an is_leaf function. This is mostly just doing an old TODO. PiperOrigin-RevId: 617988496
1 parent 05e61ed commit 5532e55

File tree

2 files changed

+118
-64
lines changed

2 files changed

+118
-64
lines changed

jax/_src/pjit.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -363,11 +363,11 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any,
363363

364364
user_specified_in_shardings = (in_shardings is not None and
365365
not is_unspecified(in_shardings))
366-
is_none = lambda x: x is None
367-
in_shardings_leaves, in_shardings_treedef = tree_flatten(
368-
in_shardings, is_leaf=is_none)
369-
out_shardings_leaves, out_shardings_treedef = tree_flatten(
370-
out_shardings, is_leaf=is_none)
366+
none_leaf_registry = tree_util.none_leaf_registry
367+
in_shardings_leaves, in_shardings_treedef = none_leaf_registry.flatten(
368+
in_shardings)
369+
out_shardings_leaves, out_shardings_treedef = none_leaf_registry.flatten(
370+
out_shardings)
371371

372372
fun_sourceinfo = api_util.fun_sourceinfo(fun)
373373
fun_signature = api_util.fun_signature(fun)

jax/_src/tree_util.py

Lines changed: 113 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
from jax._src import traceback_util
2727
from jax._src.lib import pytree
28+
from jax._src.lib import xla_extension_version
2829
from jax._src.util import safe_zip
2930
from jax._src.util import unzip2
3031

@@ -44,6 +45,13 @@
4445
default_registry.__module__ = __name__
4546
default_registry.__name__ = "default_registry"
4647

48+
# A copy of the default registry, where None is a leaf.
49+
none_leaf_registry = pytree.PyTreeRegistry(
50+
enable_none=False, enable_tuple=True, enable_namedtuple=True,
51+
enable_list=True, enable_dict=True)
52+
none_leaf_registry.__module__ = __name__
53+
none_leaf_registry.__name__ = "none_leaf_registry"
54+
4755
# A special, internal pytree registry that includes everything in
4856
# `default_registry`, plus internal Python-defined types that we want
4957
# to teach the fast dispatch path ("C++ dispatch") how to flatten and
@@ -242,6 +250,7 @@ def register_pytree_node(nodetype: type[T],
242250
``nodetype``.
243251
"""
244252
default_registry.register_node(nodetype, flatten_func, unflatten_func)
253+
none_leaf_registry.register_node(nodetype, flatten_func, unflatten_func)
245254
dispatch_registry.register_node(nodetype, flatten_func, unflatten_func)
246255
_registry[nodetype] = _RegistryEntry(flatten_func, unflatten_func)
247256

@@ -374,21 +383,9 @@ def tree_transpose(outer_treedef: PyTreeDef, inner_treedef: PyTreeDef | None,
374383

375384
def _replace_nones(sentinel, tree):
376385
"""Replaces ``None`` in ``tree`` with ``sentinel``."""
377-
if tree is None:
378-
return sentinel
379-
else:
380-
handler = _registry.get(type(tree))
381-
if handler:
382-
children, metadata = handler.to_iter(tree)
383-
proc_children = [_replace_nones(sentinel, child) for child in children]
384-
return handler.from_iter(metadata, proc_children)
385-
elif isinstance(tree, tuple) and hasattr(tree, "_fields"):
386-
# handle namedtuple as a special case, based on heuristic
387-
children = iter(tree)
388-
proc_children = [_replace_nones(sentinel, child) for child in children]
389-
return type(tree)(*proc_children)
390-
else:
391-
return tree
386+
leaves, treedef = none_leaf_registry.flatten(tree)
387+
leaves = map(lambda x: sentinel if x is None else x, leaves)
388+
return treedef.unflatten(leaves)
392389

393390

394391
no_initializer = object()
@@ -586,29 +583,50 @@ def broadcast_prefix(prefix_tree: Any, full_tree: Any,
586583
tree_map(add_leaves, prefix_tree, full_tree, is_leaf=is_leaf)
587584
return result
588585

589-
def flatten_one_level(pytree: Any) -> tuple[list[Any], Hashable]:
590-
"""Flatten the given pytree node by one level.
586+
if xla_extension_version >= 248:
587+
def flatten_one_level(pytree: Any) -> tuple[Iterable[Any], Hashable]:
588+
"""Flatten the given pytree node by one level.
591589
592-
Args:
593-
pytree: A valid pytree node, either built-in or registered via
594-
``register_pytree_node`` or ``register_pytree_with_keys``.
590+
Args:
591+
pytree: A valid pytree node, either built-in or registered via
592+
``register_pytree_node`` or ``register_pytree_with_keys``.
595593
596-
Returns:
597-
A pair of the pytree's flattened children and its hashable metadata.
594+
Returns:
595+
A pair of the pytree's flattened children and its hashable metadata.
598596
599-
Raises:
600-
ValueError: If the given pytree is not a built-in or registered container
601-
via ``register_pytree_node`` or ``register_pytree_with_keys``.
602-
"""
603-
handler = _registry.get(type(pytree))
604-
if handler:
605-
children, meta = handler.to_iter(pytree)
606-
return list(children), meta
607-
elif isinstance(pytree, tuple) and hasattr(pytree, '_fields'):
608-
# handle namedtuple as a special case, based on heuristic
609-
return [getattr(pytree, s) for s in pytree._fields], None
610-
else:
611-
raise ValueError(f"can't tree-flatten type: {type(pytree)}")
597+
Raises:
598+
ValueError: If the given pytree is not a built-in or registered container
599+
via ``register_pytree_node`` or ``register_pytree_with_keys``.
600+
"""
601+
out = default_registry.flatten_one_level(pytree)
602+
if out is None:
603+
raise ValueError(f"can't tree-flatten type: {type(pytree)}")
604+
else:
605+
return out
606+
else:
607+
def flatten_one_level(pytree: Any) -> tuple[Iterable[Any], Hashable]:
608+
"""Flatten the given pytree node by one level.
609+
610+
Args:
611+
pytree: A valid pytree node, either built-in or registered via
612+
``register_pytree_node`` or ``register_pytree_with_keys``.
613+
614+
Returns:
615+
A pair of the pytree's flattened children and its hashable metadata.
616+
617+
Raises:
618+
ValueError: If the given pytree is not a built-in or registered container
619+
via ``register_pytree_node`` or ``register_pytree_with_keys``.
620+
"""
621+
handler = _registry.get(type(pytree))
622+
if handler:
623+
children, meta = handler.to_iter(pytree)
624+
return list(children), meta
625+
elif isinstance(pytree, tuple) and hasattr(pytree, '_fields'):
626+
# handle namedtuple as a special case, based on heuristic
627+
return [getattr(pytree, s) for s in pytree._fields], None
628+
else:
629+
raise ValueError(f"can't tree-flatten type: {type(pytree)}")
612630

613631
def prefix_errors(prefix_tree: Any, full_tree: Any,
614632
is_leaf: Callable[[Any], bool] | None = None,
@@ -659,6 +677,8 @@ def _equality_errors(path, t1, t2, is_leaf):
659677
return # no more errors to find
660678
t1_children, t1_meta = flatten_one_level(t1)
661679
t2_children, t2_meta = flatten_one_level(t2)
680+
t1_children = tuple(t1_children)
681+
t2_children = tuple(t2_children)
662682
t1_keys, t2_keys = _child_keys(t1), _child_keys(t2)
663683
try:
664684
diff = ' '.join(repr(k.key) for k in
@@ -905,32 +925,64 @@ def generate_key_paths(
905925

906926

907927
# The overall logic should be same as PyTreeDef::FlattenIntoImpl
908-
def _generate_key_paths_(
909-
key_path: KeyPath,
910-
tree: Any,
911-
is_leaf: Callable[[Any], bool] | None = None,
912-
) -> Iterable[tuple[KeyPath, Any]]:
913-
if is_leaf and is_leaf(tree):
914-
yield key_path, tree
915-
return
916-
key_handler = _registry_with_keypaths.get(type(tree))
917-
handler = _registry.get(type(tree))
918-
if key_handler:
919-
key_children, _ = key_handler.flatten_with_keys(tree)
920-
for k, c in key_children:
921-
yield from _generate_key_paths_((*key_path, k), c, is_leaf)
922-
elif handler:
923-
children, _ = handler.to_iter(tree)
924-
for i, c in enumerate(children):
928+
if xla_extension_version >= 248:
929+
def _generate_key_paths_(
930+
key_path: KeyPath,
931+
tree: Any,
932+
is_leaf: Callable[[Any], bool] | None = None,
933+
) -> Iterable[tuple[KeyPath, Any]]:
934+
if is_leaf and is_leaf(tree):
935+
yield key_path, tree
936+
return
937+
key_handler = _registry_with_keypaths.get(type(tree))
938+
if key_handler:
939+
key_children, _ = key_handler.flatten_with_keys(tree)
940+
for k, c in key_children:
941+
yield from _generate_key_paths_((*key_path, k), c, is_leaf)
942+
return
943+
944+
flat = default_registry.flatten_one_level(tree)
945+
if flat is None:
946+
yield key_path, tree # strict leaf type
947+
return
948+
949+
if (isinstance(tree, tuple) and hasattr(tree, '_fields') and
950+
flat[1] == type(tree)):
951+
# handle namedtuple as a special case, based on heuristic
952+
key_children = [(GetAttrKey(s), getattr(tree, s)) for s in tree._fields]
953+
for k, c in key_children:
954+
yield from _generate_key_paths_((*key_path, k), c, is_leaf)
955+
return
956+
957+
for i, c in enumerate(flat[0]):
925958
k = FlattenedIndexKey(i)
926959
yield from _generate_key_paths_((*key_path, k), c, is_leaf)
927-
elif isinstance(tree, tuple) and hasattr(tree, '_fields'):
928-
# handle namedtuple as a special case, based on heuristic
929-
key_children = [(GetAttrKey(s), getattr(tree, s)) for s in tree._fields]
930-
for k, c in key_children:
931-
yield from _generate_key_paths_((*key_path, k), c, is_leaf)
932-
else:
933-
yield key_path, tree # strict leaf type
960+
else:
961+
def _generate_key_paths_(
962+
key_path: KeyPath,
963+
tree: Any,
964+
is_leaf: Callable[[Any], bool] | None = None,
965+
) -> Iterable[tuple[KeyPath, Any]]:
966+
if is_leaf and is_leaf(tree):
967+
yield key_path, tree
968+
return
969+
key_handler = _registry_with_keypaths.get(type(tree))
970+
if key_handler:
971+
key_children, _ = key_handler.flatten_with_keys(tree)
972+
for k, c in key_children:
973+
yield from _generate_key_paths_((*key_path, k), c, is_leaf)
974+
elif handler := _registry.get(type(tree)):
975+
children, _ = handler.to_iter(tree)
976+
for i, c in enumerate(children):
977+
k = FlattenedIndexKey(i)
978+
yield from _generate_key_paths_((*key_path, k), c, is_leaf)
979+
elif isinstance(tree, tuple) and hasattr(tree, '_fields'):
980+
# handle namedtuple as a special case, based on heuristic
981+
key_children = [(GetAttrKey(s), getattr(tree, s)) for s in tree._fields]
982+
for k, c in key_children:
983+
yield from _generate_key_paths_((*key_path, k), c, is_leaf)
984+
else:
985+
yield key_path, tree # strict leaf type
934986

935987

936988
def tree_map_with_path(f: Callable[..., Any],
@@ -1001,6 +1053,8 @@ def _prefix_error(
10011053
# point, and because prefix_tree is not a leaf, each can be flattened once:
10021054
prefix_tree_children, prefix_tree_meta = flatten_one_level(prefix_tree)
10031055
full_tree_children, full_tree_meta = flatten_one_level(full_tree)
1056+
prefix_tree_children = tuple(prefix_tree_children)
1057+
full_tree_children = tuple(full_tree_children)
10041058
prefix_tree_keys = _child_keys(prefix_tree)
10051059
full_tree_keys = _child_keys(full_tree)
10061060
# First we check special case types (list and tuple, though if they were

0 commit comments

Comments
 (0)