Skip to content

Commit 1cef1d9

Browse files
yueshengysjax authors
authored andcommitted
jax.clear_backends() is not doing what it is intended to do, users should try to avoid using it.
We decide to move it into `jax.extend`. This CL is the first step which adds a new module `jax.extend.backend`. PiperOrigin-RevId: 615934218
1 parent 7578e10 commit 1cef1d9

File tree

5 files changed

+31
-0
lines changed

5 files changed

+31
-0
lines changed

jax/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1024,6 +1024,7 @@ pytype_strict_library(
10241024
visibility = [":jax_extend_users"],
10251025
deps = [
10261026
"//jax/extend",
1027+
"//jax/extend:backend",
10271028
"//jax/extend:core",
10281029
"//jax/extend:linear_util",
10291030
"//jax/extend:random",

jax/extend/BUILD

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ pytype_strict_library(
2626
name = "extend",
2727
srcs = ["__init__.py"],
2828
deps = [
29+
":backend",
2930
":core",
3031
":linear_util",
3132
":random",
@@ -45,6 +46,12 @@ pytype_strict_library(
4546
deps = ["//jax:core"],
4647
)
4748

49+
pytype_strict_library(
50+
name = "backend",
51+
srcs = ["backend.py"],
52+
deps = ["//jax"],
53+
)
54+
4855
pytype_strict_library(
4956
name = "random",
5057
srcs = ["random.py"],

jax/extend/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
"""
3030

3131
from jax.extend import (
32+
backend as backend,
3233
core as core,
3334
linear_util as linear_util,
3435
random as random,

jax/extend/backend.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright 2024 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Note: import <name> as <name> is required for names to be exported.
16+
# See PEP 484 & https://github.com/google/jax/issues/7570
17+
18+
from jax._src.api import (
19+
clear_backends as clear_backends,
20+
)

tests/extend_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import jax.extend as jex
1919
import jax.numpy as jnp
2020

21+
from jax._src import api
2122
from jax._src import abstract_arrays
2223
from jax._src import linear_util
2324
from jax._src import prng
@@ -39,6 +40,7 @@ def test_symbols(self):
3940
self.assertIs(jex.random.unsafe_rbg_prng_impl, prng.unsafe_rbg_prng_impl)
4041

4142
# Assume these are tested elsewhere, only check equivalence
43+
self.assertIs(jex.backend.clear_backends, api.clear_backends)
4244
self.assertIs(jex.core.array_types, abstract_arrays.array_types)
4345
self.assertIs(jex.linear_util.StoreException, linear_util.StoreException)
4446
self.assertIs(jex.linear_util.WrappedFun, linear_util.WrappedFun)

0 commit comments

Comments
 (0)