What is the purpose of jax.tree_util.tree_all? #7169
-
Why was def tree_all(tree):
return all(tree_leaves(tree)) Its only use in the JAX library is in jax.test_util where it seems to be superfluous? def check_eq(xs, ys, err_msg=''):
assert_close = partial(_assert_numpy_allclose, err_msg=err_msg)
tree_all(tree_multimap(assert_close, xs, ys)) could be just def check_eq(xs, ys, err_msg=''):
assert_close = partial(_assert_numpy_allclose, err_msg=err_msg)
tree_multimap(assert_close, xs, ys) Unlike most of the other JAX functions ( Perhaps a better definition would be: def tree_all(tree: Any) -> bool | BooleanArray:
return tree_reduce(jnp.logical_and, tree, True) With this definition, a reasonable definition of def tree_allclose(x: T, y: T) -> bool | BooleanArray:
return tree_all(tree_multimap(jnp.allclose, x, y)) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
Folks like @froystig & @mattjj may remember more context on this, but it appears it has always been basically a convenience routine used in testing. You're probably right that it might be cleaner to remove it from the public API at this point, since it's so rarely used and so easily implemented. |
Beta Was this translation helpful? Give feedback.
tree_all
has been around since the initial commit of JAX on github: a30e858Folks like @froystig & @mattjj may remember more context on this, but it appears it has always been basically a convenience routine used in testing. You're probably right that it might be cleaner to remove it from the public API at this point, since it's so rarely used and so easily implemented.