File tree Expand file tree Collapse file tree 5 files changed +31
-0
lines changed Expand file tree Collapse file tree 5 files changed +31
-0
lines changed Original file line number Diff line number Diff line change @@ -1024,6 +1024,7 @@ pytype_strict_library(
1024
1024
visibility = [":jax_extend_users" ],
1025
1025
deps = [
1026
1026
"//jax/extend" ,
1027
+ "//jax/extend:backend" ,
1027
1028
"//jax/extend:core" ,
1028
1029
"//jax/extend:linear_util" ,
1029
1030
"//jax/extend:random" ,
Original file line number Diff line number Diff line change @@ -26,6 +26,7 @@ pytype_strict_library(
26
26
name = "extend" ,
27
27
srcs = ["__init__.py" ],
28
28
deps = [
29
+ ":backend" ,
29
30
":core" ,
30
31
":linear_util" ,
31
32
":random" ,
@@ -45,6 +46,12 @@ pytype_strict_library(
45
46
deps = ["//jax:core" ],
46
47
)
47
48
49
+ pytype_strict_library (
50
+ name = "backend" ,
51
+ srcs = ["backend.py" ],
52
+ deps = ["//jax" ],
53
+ )
54
+
48
55
pytype_strict_library (
49
56
name = "random" ,
50
57
srcs = ["random.py" ],
Original file line number Diff line number Diff line change 29
29
"""
30
30
31
31
from jax .extend import (
32
+ backend as backend ,
32
33
core as core ,
33
34
linear_util as linear_util ,
34
35
random as random ,
Original file line number Diff line number Diff line change
1
+ # Copyright 2024 The JAX Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Note: import <name> as <name> is required for names to be exported.
16
+ # See PEP 484 & https://github.com/google/jax/issues/7570
17
+
18
+ from jax ._src .api import (
19
+ clear_backends as clear_backends ,
20
+ )
Original file line number Diff line number Diff line change 18
18
import jax .extend as jex
19
19
import jax .numpy as jnp
20
20
21
+ from jax ._src import api
21
22
from jax ._src import abstract_arrays
22
23
from jax ._src import linear_util
23
24
from jax ._src import prng
@@ -39,6 +40,7 @@ def test_symbols(self):
39
40
self .assertIs (jex .random .unsafe_rbg_prng_impl , prng .unsafe_rbg_prng_impl )
40
41
41
42
# Assume these are tested elsewhere, only check equivalence
43
+ self .assertIs (jex .backend .clear_backends , api .clear_backends )
42
44
self .assertIs (jex .core .array_types , abstract_arrays .array_types )
43
45
self .assertIs (jex .linear_util .StoreException , linear_util .StoreException )
44
46
self .assertIs (jex .linear_util .WrappedFun , linear_util .WrappedFun )
You can’t perform that action at this time.
0 commit comments