Skip to content

Commit d95df14

Browse files
author
jax authors
committed
Merge pull request #20314 from jakevdp:dep-maps
PiperOrigin-RevId: 617212897
2 parents e5a16a0 + 84e49bd commit d95df14

11 files changed

+21
-19
lines changed

jax/experimental/jax2tf/jax2tf.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
from jax import numpy as jnp
3737
from jax import tree_util
3838
from jax import sharding
39-
from jax.experimental import maps
4039
from jax.experimental import export
4140
from jax.experimental.export import _export
4241
from jax.experimental.export import _shape_poly
@@ -54,6 +53,8 @@
5453
from jax._src import linear_util as lu
5554
from jax._src import op_shardings
5655
from jax._src import sharding_impls
56+
from jax._src import maps
57+
from jax._src import mesh
5758
from jax._src import pjit
5859
from jax._src import prng
5960
from jax._src import random as random_internal
@@ -3503,7 +3504,7 @@ def _pjit(*args: TfVal,
35033504
jaxpr: core.ClosedJaxpr,
35043505
in_shardings: Sequence[sharding.XLACompatibleSharding],
35053506
out_shardings: Sequence[sharding.XLACompatibleSharding],
3506-
resource_env: maps.ResourceEnv,
3507+
resource_env: mesh.ResourceEnv,
35073508
donated_invars,
35083509
name: str,
35093510
keep_unused: bool,
@@ -3535,7 +3536,7 @@ def _pjit(*args: TfVal,
35353536

35363537
def _pjit_sharding_constraint(arg: TfVal, *,
35373538
sharding: sharding.XLACompatibleSharding,
3538-
resource_env: maps.ResourceEnv,
3539+
resource_env: mesh.ResourceEnv,
35393540
_in_avals: Sequence[core.ShapedArray],
35403541
_out_aval: core.ShapedArray,
35413542
**kwargs) -> TfVal:

jax/experimental/jax2tf/tests/jax2tf_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,13 @@
3434
from jax import sharding
3535
from jax._src import config
3636
from jax._src import core
37+
from jax._src.maps import xmap
3738
from jax._src import source_info_util
3839
from jax._src import test_util as jtu
3940
from jax._src import xla_bridge as xb
4041
from jax.experimental import jax2tf
4142
from jax.experimental import export
4243
from jax.experimental.jax2tf.tests import tf_test_util
43-
from jax.experimental.maps import xmap
4444
from jax.experimental.shard_map import shard_map
4545
from jax.experimental import pjit
4646
from jax.sharding import PartitionSpec as P

jax/experimental/jax2tf/tests/sharding_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,12 @@
3434
from jax._src import compiler
3535
from jax._src import config
3636
from jax._src import maps
37+
from jax._src.maps import xmap
3738
from jax._src import test_util as jtu
3839
from jax._src import xla_bridge
3940
from jax import lax
4041
from jax.experimental import jax2tf
4142
from jax.experimental import pjit
42-
from jax.experimental.maps import xmap
4343
from jax.experimental.shard_map import shard_map
4444
from jax.sharding import NamedSharding
4545
from jax.sharding import Mesh

tests/compilation_cache_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,11 @@
3333
from jax._src import compiler
3434
from jax._src import config
3535
from jax._src import distributed
36+
from jax._src.maps import xmap
3637
from jax._src import monitoring
3738
from jax._src import test_util as jtu
3839
from jax._src import xla_bridge
3940
from jax._src.lib import xla_client
40-
from jax.experimental.maps import xmap
4141
from jax.experimental.pjit import pjit
4242
from jax.sharding import PartitionSpec as P
4343
import numpy as np

tests/debug_nans_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
from jax._src import api
2424
from jax._src import test_util as jtu
2525
from jax import numpy as jnp
26-
from jax.experimental import pjit, maps
26+
from jax.experimental import pjit
27+
from jax._src.maps import xmap
2728

2829
from jax import config
2930
config.parse_flags_with_absl()
@@ -125,7 +126,7 @@ def testPmapNoNaN(self):
125126
@jtu.ignore_warning(message=".*is an experimental.*")
126127
def testXmap(self):
127128

128-
f = maps.xmap(
129+
f = xmap(
129130
lambda x: 0. / x,
130131
in_axes=["i"],
131132
out_axes=["i"],

tests/debugging_primitives_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,13 @@
2020
import jax
2121
from jax import lax
2222
from jax import config
23-
from jax.experimental import maps
2423
from jax.experimental import pjit
2524
from jax.interpreters import pxla
2625
from jax._src import ad_checkpoint
2726
from jax._src import debugging
2827
from jax._src import dispatch
2928
from jax._src import test_util as jtu
29+
from jax._src.maps import xmap
3030
import jax.numpy as jnp
3131
import numpy as np
3232

@@ -795,7 +795,7 @@ def foo(x):
795795
idx = lax.axis_index('foo')
796796
debug_print("{idx}: {x}", idx=idx, x=x)
797797
return jnp.mean(x, axis=['foo'])
798-
out = maps.xmap(foo, in_axes=['foo'], out_axes=[...])(x)
798+
out = xmap(foo, in_axes=['foo'], out_axes=[...])(x)
799799
debug_print("Out: {}", out)
800800
return out
801801
mesh = jax.sharding.Mesh(np.array(jax.devices()), ['dev'])
@@ -813,8 +813,8 @@ def foo(x):
813813
def test_unordered_print_with_xmap(self):
814814
def f(x):
815815
debug_print("{}", x, ordered=False)
816-
f = maps.xmap(f, in_axes=['a'], out_axes=None, backend='cpu',
817-
axis_resources={'a': 'dev'})
816+
f = xmap(f, in_axes=['a'], out_axes=None, backend='cpu',
817+
axis_resources={'a': 'dev'})
818818
with jax.sharding.Mesh(np.array(jax.devices()), ['dev']):
819819
with jtu.capture_stdout() as output:
820820
f(np.arange(40))

tests/jaxpr_effects_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import jax
2020
import jax.numpy as jnp
2121
from jax import lax
22-
from jax.experimental import maps
2322
from jax.experimental import pjit
2423
from jax._src import ad_checkpoint
2524
from jax._src import dispatch
@@ -32,6 +31,7 @@
3231
from jax._src.interpreters import ad
3332
from jax._src.interpreters import mlir
3433
from jax._src.interpreters import partial_eval as pe
34+
from jax._src.maps import xmap
3535
import numpy as np
3636

3737
config.parse_flags_with_absl()
@@ -275,7 +275,7 @@ def f(x):
275275
effect_p.bind(effect=foo_effect)
276276
effect_p.bind(effect=bar_effect)
277277
return x
278-
f = maps.xmap(f, in_axes=['a'], out_axes=['a'])
278+
f = xmap(f, in_axes=['a'], out_axes=['a'])
279279
jaxpr = jax.make_jaxpr(f)(jnp.arange(jax.local_device_count()))
280280
self.assertSetEqual(jaxpr.effects, {foo_effect, bar_effect})
281281

tests/lax_control_flow_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,12 @@
3535
from jax._src import test_util as jtu
3636
from jax import tree_util
3737
from jax._src.util import unzip2
38-
from jax.experimental import maps
3938
from jax.ad_checkpoint import checkpoint as new_checkpoint, checkpoint_policies
4039
import jax.numpy as jnp # scan tests use numpy
4140
import jax.scipy as jsp
4241
from jax._src.lax import control_flow as lax_control_flow
4342
from jax._src.lax.control_flow import for_loop
43+
from jax._src.maps import xmap
4444

4545
from jax import config
4646
config.parse_flags_with_absl()
@@ -2712,7 +2712,7 @@ def body(carry):
27122712
i, x = carry
27132713
return i + 1, x + lax.psum(y, 'b')
27142714
return lax.while_loop(cond, body, (0, z))[1]
2715-
maps.xmap(f, axis_sizes=dict(a=2, b=10), out_axes=(['a']), in_axes={})(1.)
2715+
xmap(f, axis_sizes=dict(a=2, b=10), out_axes=(['a']), in_axes={})(1.)
27162716

27172717
def test_while_loop_fixed_point_with_batched_pred_and_consts(self):
27182718
def f(i, x):

tests/pjit_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
from jax.lax import with_sharding_constraint
4242
from jax._src import prng
4343
from jax.sharding import PartitionSpec as P
44-
from jax.experimental.maps import xmap
4544
from jax.experimental import multihost_utils
4645
from jax.experimental.custom_partitioning import custom_partitioning
4746
from jax._src import array
@@ -52,6 +51,7 @@
5251
AUTO, UNSPECIFIED, NamedSharding, GSPMDSharding, PositionalSharding,
5352
SingleDeviceSharding, parse_flatten_op_sharding)
5453
import jax._src.pjit as pjit_lib
54+
from jax._src.maps import xmap
5555
from jax._src.pjit import pjit, pjit_p
5656
from jax._src import mesh as mesh_lib
5757
from jax._src.interpreters import pxla

tests/python_callback_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,13 @@
2727
from jax._src import core
2828
from jax._src import dispatch
2929
from jax._src import maps
30+
from jax._src.maps import xmap
3031
from jax._src import test_util as jtu
3132
from jax._src import util
3233
from jax._src.lib import xla_client
3334
from jax._src.lib import xla_extension_version
3435
from jax.experimental import io_callback
3536
from jax.experimental import pjit
36-
from jax.experimental.maps import xmap
3737
from jax.experimental.shard_map import shard_map
3838
import jax.numpy as jnp
3939
from jax.sharding import Mesh

0 commit comments

Comments
 (0)