Skip to content

Commit d29acba

Browse files
author
jax authors
committed
Merge pull request #19958 from jakevdp:jax-tree
PiperOrigin-RevId: 610533158
2 parents 57e34e1 + cddee46 commit d29acba

24 files changed

+87
-102
lines changed

tests/api_test.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1042,7 +1042,7 @@ def f(x):
10421042
self.assertEqual(
10431043
obj.in_avals,
10441044
((core.ShapedArray([], expected_dtype, weak_type=True),), {}))
1045-
self.assertEqual(obj.in_tree, jax.tree_util.tree_flatten(((0,), {}))[1])
1045+
self.assertEqual(obj.in_tree, jax.tree.flatten(((0,), {}))[1])
10461046

10471047
def test_jit_lower_duck_typing(self):
10481048
f_jit = jit(lambda x: 2 * x)
@@ -2490,7 +2490,7 @@ def fun(x, y):
24902490
x = (jnp.ones(2), jnp.ones(2))
24912491
y = 3.
24922492
out_shape = api.eval_shape(fun, x, y)
2493-
out_shape = tree_util.tree_map(np.shape, out_shape)
2493+
out_shape = jax.tree.map(np.shape, out_shape)
24942494

24952495
self.assertEqual(out_shape, {'hi': (2,)})
24962496

@@ -3004,7 +3004,7 @@ def test_vmap_in_axes_tree_prefix_error(self):
30043004
ValueError,
30053005
"vmap in_axes specification must be a tree prefix of the corresponding "
30063006
r"value, got specification \(\[0\],\) for value tree "
3007-
+ re.escape(f"{tree_util.tree_structure((value_tree,))}."),
3007+
+ re.escape(f"{jax.tree.structure((value_tree,))}."),
30083008
lambda: api.vmap(lambda x: x, in_axes=([0],))(value_tree)
30093009
)
30103010

@@ -7013,8 +7013,8 @@ def foo_jvp(primals, tangents):
70137013
"must produce primal and tangent outputs "
70147014
"with equal container (pytree) structures, but got "
70157015
"{} and {} respectively.".format(
7016-
tree_util.tree_structure((1,)),
7017-
tree_util.tree_structure([1, 2]))
7016+
jax.tree.structure((1,)),
7017+
jax.tree.structure([1, 2]))
70187018
),
70197019
lambda: api.jvp(f, (2.,), (1.,)))
70207020

@@ -7729,9 +7729,9 @@ def _unpack(x):
77297729

77307730
def _vmap(fun):
77317731
def _fun(*args):
7732-
args = tree_util.tree_map(_pack, args)
7732+
args = jax.tree.map(_pack, args)
77337733
out = jax.vmap(fun)(*args)
7734-
out = tree_util.tree_map(_unpack, out)
7734+
out = jax.tree.map(_unpack, out)
77357735
return out
77367736
return _fun
77377737

@@ -8242,8 +8242,8 @@ def foo_bwd(_, g):
82428242
"and in particular must produce a tuple of length equal to the "
82438243
"number of arguments to the primal function, but got VJP output "
82448244
"structure {} for primal input structure {}.".format(
8245-
tree_util.tree_structure((1, 1)),
8246-
tree_util.tree_structure((1,)))
8245+
jax.tree.structure((1, 1)),
8246+
jax.tree.structure((1,)))
82478247
),
82488248
lambda: api.grad(f)(2.))
82498249

@@ -9017,9 +9017,9 @@ def _unpack(x):
90179017

90189018
def _vmap(fun):
90199019
def _fun(*args):
9020-
args = tree_util.tree_map(_pack, args)
9020+
args = jax.tree.map(_pack, args)
90219021
out = jax.vmap(fun)(*args)
9022-
out = tree_util.tree_map(_unpack, out)
9022+
out = jax.tree.map(_unpack, out)
90239023
return out
90249024
return _fun
90259025

@@ -9281,7 +9281,7 @@ def custom_transpose(example_out):
92819281
return _custom_transpose(out_type, example_out)
92829282
return partial(
92839283
_custom_transpose,
9284-
tree_util.tree_map(
9284+
jax.tree.map(
92859285
lambda x: core.get_aval(x).at_least_vspace(), example_out))
92869286

92879287

@@ -10139,21 +10139,21 @@ def rule(axis_size, in_batched, xs):
1013910139
f_jvp, in_axes=(None, 0), out_axes=(None, 0))(x, txs))
1014010140

1014110141
def test_tree(self):
10142-
tree_sin = partial(tree_util.tree_map, jnp.sin)
10143-
tree_cos = partial(tree_util.tree_map, jnp.cos)
10142+
tree_sin = partial(jax.tree.map, jnp.sin)
10143+
tree_cos = partial(jax.tree.map, jnp.cos)
1014410144

1014510145
x, xs = jnp.array(1.), jnp.arange(3)
1014610146
x = (x, [x + 1, x + 2], [x + 3], x + 4)
1014710147
xs = (xs, [xs + 1, xs + 2], [xs + 3], xs + 4)
10148-
in_batched_ref = tree_util.tree_map(lambda _: True, x)
10148+
in_batched_ref = jax.tree.map(lambda _: True, x)
1014910149

1015010150
@jax.custom_batching.custom_vmap
1015110151
def f(xs): return tree_sin(xs)
1015210152

1015310153
@f.def_vmap
1015410154
def rule(axis_size, in_batched, xs):
1015510155
self.assertEqual(in_batched, [in_batched_ref])
10156-
sz, = {z.shape[0] for z in tree_util.tree_leaves(xs)}
10156+
sz, = {z.shape[0] for z in jax.tree.leaves(xs)}
1015710157
self.assertEqual(axis_size, sz)
1015810158
return tree_cos(xs), in_batched[0]
1015910159

@@ -10163,21 +10163,21 @@ def rule(axis_size, in_batched, xs):
1016310163
self.assertAllClose(ys, tree_cos(xs))
1016410164

1016510165
def test_tree_with_nones(self):
10166-
tree_sin = partial(tree_util.tree_map, jnp.sin)
10167-
tree_cos = partial(tree_util.tree_map, jnp.cos)
10166+
tree_sin = partial(jax.tree.map, jnp.sin)
10167+
tree_cos = partial(jax.tree.map, jnp.cos)
1016810168

1016910169
x, xs = jnp.array(1.), jnp.arange(3)
1017010170
x = (x, [x + 1, None], [x + 3], None)
1017110171
xs = (xs, [xs + 1, None], [xs + 3], None)
10172-
in_batched_ref = tree_util.tree_map(lambda _: True, x)
10172+
in_batched_ref = jax.tree.map(lambda _: True, x)
1017310173

1017410174
@jax.custom_batching.custom_vmap
1017510175
def f(xs): return tree_sin(xs)
1017610176

1017710177
@f.def_vmap
1017810178
def rule(axis_size, in_batched, xs):
1017910179
self.assertEqual(in_batched, [in_batched_ref])
10180-
sz, = {z.shape[0] for z in tree_util.tree_leaves(xs)}
10180+
sz, = {z.shape[0] for z in jax.tree.leaves(xs)}
1018110181
self.assertEqual(axis_size, sz)
1018210182
return tree_cos(xs), in_batched[0]
1018310183

tests/core_test.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from jax import jvp, linearize, vjp, jit, make_jaxpr
2929
from jax.api_util import flatten_fun_nokwargs
3030
from jax import config
31-
from jax.tree_util import tree_flatten, tree_unflatten, tree_map, tree_reduce
3231

3332
from jax._src import core
3433
from jax._src import linear_util as lu
@@ -49,17 +48,17 @@ def call(f, *args):
4948

5049
@util.curry
5150
def core_call(f, *args):
52-
args, in_tree = tree_flatten(args)
51+
args, in_tree = jax.tree.flatten(args)
5352
f, out_tree = flatten_fun_nokwargs(lu.wrap_init(f), in_tree)
5453
out = core.call_p.bind(f, *args)
55-
return tree_unflatten(out_tree(), out)
54+
return jax.tree.unflatten(out_tree(), out)
5655

5756
@util.curry
5857
def core_closed_call(f, *args):
59-
args, in_tree = tree_flatten(args)
58+
args, in_tree = jax.tree.flatten(args)
6059
f, out_tree = flatten_fun_nokwargs(lu.wrap_init(f), in_tree)
6160
out = core.closed_call_p.bind(f, *args)
62-
return tree_unflatten(out_tree(), out)
61+
return jax.tree.unflatten(out_tree(), out)
6362

6463
def simple_fun(x, y):
6564
return jnp.sin(x * y)
@@ -175,24 +174,24 @@ def test_tree_map(self):
175174
zs = ({'a': 11}, [22, 33])
176175

177176
f = lambda x, y: x + y
178-
assert tree_map(f, xs, ys) == zs
177+
assert jax.tree.map(f, xs, ys) == zs
179178
try:
180-
tree_map(f, xs, ys_bad)
179+
jax.tree.map(f, xs, ys_bad)
181180
assert False
182181
except (TypeError, ValueError):
183182
pass
184183

185184
def test_tree_flatten(self):
186-
flat, _ = tree_flatten(({'a': 1}, [2, 3], 4))
185+
flat, _ = jax.tree.flatten(({'a': 1}, [2, 3], 4))
187186
assert flat == [1, 2, 3, 4]
188187

189188
def test_tree_unflatten(self):
190189
tree = [(1, 2), {"roy": (3, [4, 5, ()])}]
191-
flat, treedef = tree_flatten(tree)
190+
flat, treedef = jax.tree.flatten(tree)
192191
assert flat == [1, 2, 3, 4, 5]
193-
tree2 = tree_unflatten(treedef, flat)
194-
nodes_equal = tree_map(operator.eq, tree, tree2)
195-
assert tree_reduce(operator.and_, nodes_equal)
192+
tree2 = jax.tree.unflatten(treedef, flat)
193+
nodes_equal = jax.tree.map(operator.eq, tree, tree2)
194+
assert jax.tree.reduce(operator.and_, nodes_equal)
196195

197196
@jtu.sample_product(
198197
dtype=[*jtu.dtypes.all, object, [('i', 'i4'), ('f', 'f4')]]

tests/custom_linear_solve_test.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from jax import lax
2626
from jax.ad_checkpoint import checkpoint
2727
from jax._src import test_util as jtu
28-
from jax import tree_util
2928
import jax.numpy as jnp # scan tests use numpy
3029
import jax.scipy as jsp
3130

@@ -160,7 +159,7 @@ def linear_solve_aux(a, b):
160159
# vmap test
161160
c = rng.randn(3, 2)
162161
expected = jnp.linalg.solve(a, c)
163-
expected_aux = tree_util.tree_map(partial(np.repeat, repeats=2), array_aux)
162+
expected_aux = jax.tree.map(partial(np.repeat, repeats=2), array_aux)
164163
actual_vmap, vmap_aux = jax.vmap(linear_solve_aux, (None, 1), -1)(a, c)
165164

166165
self.assertAllClose(expected, actual_vmap)
@@ -473,7 +472,7 @@ def solve(mv, b):
473472
return mv(b), aux
474473

475474
def solve_aux(x):
476-
matvec = lambda y: tree_util.tree_map(partial(jnp.dot, A), y)
475+
matvec = lambda y: jax.tree.map(partial(jnp.dot, A), y)
477476
return lax.custom_linear_solve(matvec, (x, x), solve, solve, symmetric=True, has_aux=True)
478477

479478
rng = self.rng()

tests/custom_root_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import jax
2323
from jax import lax
2424
from jax._src import test_util as jtu
25-
from jax import tree_util
2625
import jax.numpy as jnp # scan tests use numpy
2726
import jax.scipy as jsp
2827

@@ -227,7 +226,7 @@ def pos_def_solve(g, b):
227226
expected_fwd_val = expected_fwd(a, b)
228227
self.assertAllClose(fwd_val, expected_fwd_val, rtol={np.float32: 5E-6, np.float64: 5E-12})
229228

230-
jtu.check_close(fwd_aux, tree_util.tree_map(jnp.zeros_like, fwd_aux))
229+
jtu.check_close(fwd_aux, jax.tree.map(jnp.zeros_like, fwd_aux))
231230

232231
def test_custom_root_errors(self):
233232
with self.assertRaisesRegex(TypeError, re.escape("f() output pytree")):

tests/export_harnesses_multi_platform_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def export_and_compare_to_native(
155155
if device.platform in skip_run_on_platforms:
156156
logging.info("Skipping running on %s", device)
157157
continue
158-
device_args = jax.tree_util.tree_map(
158+
device_args = jax.tree.map(
159159
lambda x: jax.device_put(x, device), args
160160
)
161161
logging.info("Running harness natively on %s", device)

tests/export_test.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import jax
2626
from jax import lax
2727
from jax import numpy as jnp
28-
from jax import tree_util
2928
from jax.experimental import export
3029
from jax.experimental.export import _export
3130
from jax.experimental import pjit
@@ -185,7 +184,7 @@ def my_fun(x):
185184
self.assertEqual("my_fun", exp.fun_name)
186185
self.assertEqual((export.default_lowering_platform(),),
187186
exp.lowering_platforms)
188-
self.assertEqual(tree_util.tree_flatten(((1,), {}))[1], exp.in_tree)
187+
self.assertEqual(jax.tree.flatten(((1,), {}))[1], exp.in_tree)
189188
self.assertEqual((core.ShapedArray((4,), dtype=np.float32),), exp.in_avals)
190189
self.assertEqual((core.ShapedArray((4,), dtype=np.float32),), exp.out_avals)
191190

@@ -201,9 +200,9 @@ def f(a_b_pair, *, a, b):
201200
self.assertEqual(exp.lowering_platforms, ("cpu",))
202201
args = ((a, b),)
203202
kwargs = dict(a=a, b=b)
204-
self.assertEqual(exp.in_tree, tree_util.tree_flatten((args, kwargs))[1])
203+
self.assertEqual(exp.in_tree, jax.tree.flatten((args, kwargs))[1])
205204
self.assertEqual(exp.in_avals, (a_aval, b_aval, a_aval, b_aval))
206-
self.assertEqual(exp.out_tree, tree_util.tree_flatten(f(*args, **kwargs))[1])
205+
self.assertEqual(exp.out_tree, jax.tree.flatten(f(*args, **kwargs))[1])
207206
self.assertEqual(exp.out_avals, (a_aval, b_aval, a_aval, b_aval, a_aval, b_aval))
208207

209208
def test_basic(self):

tests/filecheck/jax_filecheck_helpers.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,15 @@
1515
# Helpers for writing JAX filecheck tests.
1616

1717
import jax
18-
import jax.tree_util as tree_util
1918
import numpy as np
2019

2120
def print_ir(*prototypes):
2221
def lower(f):
2322
"""Prints the MLIR IR that results from lowering `f`.
2423
2524
The arguments to `f` are taken to be arrays shaped like `prototypes`."""
26-
inputs = tree_util.tree_map(np.array, prototypes)
27-
flat_inputs, _ = tree_util.tree_flatten(inputs)
25+
inputs = jax.tree.map(np.array, prototypes)
26+
flat_inputs, _ = jax.tree.flatten(inputs)
2827
shape_strs = " ".join([f"{x.dtype.name}[{','.join(map(str, x.shape))}]"
2928
for x in flat_inputs])
3029
name = f.func.__name__ if hasattr(f, "func") else f.__name__

tests/host_callback_test.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
from jax import dtypes
3535
from jax import lax
3636
from jax import numpy as jnp
37-
from jax import tree_util
3837
from jax.experimental import host_callback as hcb
3938
from jax.experimental import pjit
4039
from jax.sharding import PartitionSpec as P
@@ -887,7 +886,7 @@ def func(x):
887886
# making the Jaxpr does not print anything
888887
hcb.barrier_wait()
889888

890-
treedef = tree_util.tree_structure(arg)
889+
treedef = jax.tree.structure(arg)
891890
assertMultiLineStrippedEqual(
892891
self, f"""
893892
{{ lambda ; a:f32[]. let
@@ -1027,7 +1026,7 @@ def make_ct(res):
10271026
return res
10281027
ct_dtype = core.primal_dtype_to_tangent_dtype(res_dtype)
10291028
return np.ones(np.shape(res), dtype=ct_dtype)
1030-
cts = tree_util.tree_map(make_ct, res_f_of_args)
1029+
cts = jax.tree.map(make_ct, res_f_of_args)
10311030
def f_vjp(args, cts):
10321031
res, pullback = jax.vjp(f, *args)
10331032
return pullback(cts)

tests/infeed_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def f(x):
7272

7373
device = jax.local_devices()[0]
7474
# We must transfer the flattened data, as a tuple!!!
75-
flat_to_infeed, _ = jax.tree_util.tree_flatten(to_infeed)
75+
flat_to_infeed, _ = jax.tree.flatten(to_infeed)
7676
device.transfer_to_infeed(tuple(flat_to_infeed))
7777
self.assertAllClose(f(x), to_infeed)
7878

tests/jet_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ class JetTest(jtu.JaxTestCase):
5858
def check_jet(self, fun, primals, series, atol=1e-5, rtol=1e-5,
5959
check_dtypes=True):
6060
# Convert to jax arrays to ensure dtype canonicalization.
61-
primals = jax.tree_util.tree_map(jnp.asarray, primals)
62-
series = jax.tree_util.tree_map(jnp.asarray, series)
61+
primals = jax.tree.map(jnp.asarray, primals)
62+
series = jax.tree.map(jnp.asarray, series)
6363

6464
y, terms = jet(fun, primals, series)
6565
expected_y, expected_terms = jvp_taylor(fun, primals, series)
@@ -73,8 +73,8 @@ def check_jet(self, fun, primals, series, atol=1e-5, rtol=1e-5,
7373
def check_jet_finite(self, fun, primals, series, atol=1e-5, rtol=1e-5,
7474
check_dtypes=True):
7575
# Convert to jax arrays to ensure dtype canonicalization.
76-
primals = jax.tree_util.tree_map(jnp.asarray, primals)
77-
series = jax.tree_util.tree_map(jnp.asarray, series)
76+
primals = jax.tree.map(jnp.asarray, primals)
77+
series = jax.tree.map(jnp.asarray, series)
7878

7979
y, terms = jet(fun, primals, series)
8080
expected_y, expected_terms = jvp_taylor(fun, primals, series)

0 commit comments

Comments
 (0)