Skip to content

Commit d866288

Browse files
Jake VanderPlasjax authors
authored andcommitted
Register maps module deprecation outside of module
PiperOrigin-RevId: 617194807
1 parent b48aec5 commit d866288

File tree

2 files changed

+2
-3
lines changed

2 files changed

+2
-3
lines changed

jax/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@
179179
# TODO(jakevdp): remove this when jax/config.py is removed.
180180
from jax._src.deprecations import register as _register_deprecation
181181
_register_deprecation("jax.config", "config-module")
182+
_register_deprecation("jax.experimental", "maps-module")
182183
del _register_deprecation
183184

184185
_deprecations = {

jax/experimental/maps.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,7 @@
4141
" jax.experimental.shard_map nor jax.vmap are suitable for your use case."
4242
)
4343

44-
deprecations.register("jax.experimental.maps", "maps-module")
45-
46-
if deprecations.is_accelerated("jax.experimental.maps", "maps-module"):
44+
if deprecations.is_accelerated("jax.experimental", "maps-module"):
4745
raise ImportError(_msg)
4846
else:
4947
warnings.warn(_msg, DeprecationWarning, stacklevel=2)

0 commit comments

Comments
 (0)