Skip to content

Commit 930aaa5

Browse files
superbobryjax authors
authored andcommitted
Deprecated the jax.experimental.maps submodule
PiperOrigin-RevId: 614082251
1 parent 0302e4c commit 930aaa5

File tree

3 files changed

+28
-2
lines changed

3 files changed

+28
-2
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ Remember to align the itemized text with the first line of an item within a list
1111
* Deprecations
1212
* {func}`jax.tree_map` is deprecated; use `jax.tree.map` instead, or for backward
1313
compatibility with older JAX versions, use {func}`jax.tree_util.tree_map`.
14+
* The `jax.experimental.maps` module and `jax.experimental.maps.xmap` are
15+
deprecated. Use `jax.experimental.shard_map` or `jax.vmap` with the
16+
`spmd_axis_name` argument for expressing SPMD device-parallel computations.
1417
* Passing arguments to {func}`jax.numpy.array_equal` and {func}`jax.numpy.array_equiv`
1518
that cannot be converted to a JAX array now results in an exception.
1619

jax/experimental/maps.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,40 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import warnings
16+
17+
from jax._src import deprecations
1518
from jax._src.maps import (
1619
AxisName as AxisName,
1720
ResourceSet as ResourceSet,
1821
SerialLoop as SerialLoop,
22+
_prepare_axes as _prepare_axes,
1923
make_xmap_callable as make_xmap_callable,
2024
serial_loop as serial_loop,
21-
xmap as xmap,
2225
xmap_p as xmap_p,
23-
_prepare_axes as _prepare_axes,
26+
xmap as xmap,
2427
)
2528
from jax._src.mesh import (
2629
EMPTY_ENV as EMPTY_ENV,
2730
ResourceEnv as ResourceEnv,
2831
thread_resources as thread_resources,
2932
)
33+
34+
# Added March 7, 2024.
35+
_msg = (
36+
"jax.experimental.maps and jax.experimental.maps.xmap are deprecated and"
37+
" will be removed in a future release. Use jax.experimental.shard_map or"
38+
" jax.vmap with the spmd_axis_name argument for expressing SPMD"
39+
" device-parallel computations. Please file an issue on"
40+
" https://github.com/google/jax/issues if neither"
41+
" jax.experimental.shard_map nor jax.vmap are suitable for your use case."
42+
)
43+
44+
deprecations.register("jax.experimental.maps", "maps-module")
45+
46+
if deprecations.is_accelerated("jax.experimental.maps", "maps-module"):
47+
raise ImportError(_msg)
48+
else:
49+
warnings.warn(_msg, DeprecationWarning, stacklevel=2)
50+
51+
del deprecations, warnings, _msg

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ filterwarnings = [
8484
"ignore:Special cases found for .* but none were parsed.*:UserWarning",
8585
# end array_api_tests-related warnings
8686
"ignore:jax.extend.mlir.dialects.mhlo is deprecated.*:DeprecationWarning",
87+
"ignore:jax.experimental.maps and .* are deprecated.*:DeprecationWarning",
8788
]
8889
doctest_optionflags = [
8990
"NUMBER",

0 commit comments

Comments
 (0)