Skip to content

Commit 147c363

Browse files
yueshengysjax authors
authored andcommitted
Deprecate jax.clear_backends.
`jax.clear_backends` does not necessarily do what its name suggests and can lead to unexpected consequences, e.g., it will not destroy existing backends and release corresponding owned resources. Use `jax.clear_caches` if you only want to clean up compilation caches. For backward compatibilty or you really need to switch/reinitialize the default backend, use `jax.extend.backend.clear_backends`. PiperOrigin-RevId: 616946337
1 parent 154403c commit 147c363

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@ Remember to align the itemized text with the first line of an item within a list
1111
* Deprecations & Removals
1212
* {func}`jax.tree_map` is deprecated; use `jax.tree.map` instead, or for backward
1313
compatibility with older JAX versions, use {func}`jax.tree_util.tree_map`.
14+
* {func}`jax.clear_backends` is deprecated as it does not necessarily do what
15+
its name suggests and can lead to unexpected consequences, e.g., it will not
16+
destroy existing backends and release corresponding owned resources. Use
17+
{func}`jax.clear_caches` if you only want to clean up compilation caches.
18+
For backward compatibility or you really need to switch/reinitialize the
19+
default backend, use {func}`jax.extend.backend.clear_backends`.
1420
* The `jax.experimental.maps` module and `jax.experimental.maps.xmap` are
1521
deprecated. Use `jax.experimental.shard_map` or `jax.vmap` with the
1622
`spmd_axis_name` argument for expressing SPMD device-parallel computations.

jax/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@
8181
from jax._src.api import block_until_ready as block_until_ready
8282
from jax._src.ad_checkpoint import checkpoint_wrapper as checkpoint
8383
from jax._src.ad_checkpoint import checkpoint_policies as checkpoint_policies
84-
from jax._src.api import clear_backends as clear_backends
84+
from jax._src.api import clear_backends as _deprecated_clear_backends
8585
from jax._src.api import clear_caches as clear_caches
8686
from jax._src.custom_derivatives import closure_convert as closure_convert
8787
from jax._src.custom_derivatives import custom_gradient as custom_gradient
@@ -218,10 +218,16 @@
218218
"or jax.tree_util.tree_map (any JAX version).",
219219
_deprecated_tree_map
220220
),
221+
# Added Mar 18, 2024
222+
"clear_backends": (
223+
"jax.clear_backends is deprecated.",
224+
_deprecated_clear_backends
225+
),
221226
}
222227

223228
import typing as _typing
224229
if _typing.TYPE_CHECKING:
230+
from jax._src.api import clear_backends as clear_backends
225231
from jax._src.tree_util import treedef_is_leaf as treedef_is_leaf
226232
from jax._src.tree_util import tree_flatten as tree_flatten
227233
from jax._src.tree_util import tree_leaves as tree_leaves

0 commit comments

Comments
 (0)