Skip to content

Commit 0ea22b7

Browse files
authored
Use a whitelist to restrict visibility in top-level jax namespace. (#2982)
* Use a whitelist to restrict visibility in top-level jax namespace. The goal of this change is to capture the way the world is (i.e., not break users), and separately we will work on fixing users to avoid accidentally-exported APIs.
1 parent 9f04d98 commit 0ea22b7

File tree

3 files changed

+76
-10
lines changed

3 files changed

+76
-10
lines changed

docs/notebooks/How_JAX_primitives_work.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -700,7 +700,7 @@
700700
"\n",
701701
"# Now we register the XLA compilation rule with JAX\n",
702702
"# TODO: for GPU? and TPU?\n",
703-
"from jax import xla\n",
703+
"from jax.interpreters import xla\n",
704704
"xla.backend_specific_translations['cpu'][multiply_add_p] = multiply_add_xla_translation"
705705
],
706706
"execution_count": 0,
@@ -876,7 +876,7 @@
876876
"colab": {}
877877
},
878878
"source": [
879-
"from jax import ad\n",
879+
"from jax.interpreters import ad\n",
880880
"\n",
881881
"\n",
882882
"@trace(\"multiply_add_value_and_jvp\")\n",
@@ -1529,7 +1529,7 @@
15291529
"colab": {}
15301530
},
15311531
"source": [
1532-
"from jax import batching\n",
1532+
"from jax.interpreters import batching\n",
15331533
"\n",
15341534
"\n",
15351535
"@trace(\"multiply_add_batch\")\n",

jax/__init__.py

Lines changed: 71 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,77 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import os
16-
os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '1')
17-
1815
from jax.version import __version__
19-
from jax.api import *
16+
from .config import config
17+
from .api import (
18+
ad, # TODO(phawkins): update users to avoid this.
19+
argnums_partial, # TODO(phawkins): update Haiku to not use this.
20+
checkpoint,
21+
curry, # TODO(phawkins): update users to avoid this.
22+
custom_gradient,
23+
custom_jvp,
24+
custom_vjp,
25+
custom_transforms,
26+
defjvp,
27+
defjvp_all,
28+
defvjp,
29+
defvjp_all,
30+
device_count,
31+
device_get,
32+
device_put,
33+
devices,
34+
disable_jit,
35+
eval_shape,
36+
flatten_fun_nokwargs, # TODO(phawkins): update users to avoid this.
37+
grad,
38+
hessian,
39+
host_count,
40+
host_id,
41+
host_ids,
42+
jacobian,
43+
jacfwd,
44+
jacrev,
45+
jit,
46+
jvp,
47+
local_device_count,
48+
local_devices,
49+
linearize,
50+
make_jaxpr,
51+
mask,
52+
partial, # TODO(phawkins): update callers to use functools.partial.
53+
pmap,
54+
pxla, # TODO(phawkins): update users to avoid this.
55+
remat,
56+
shapecheck,
57+
ShapedArray,
58+
ShapeDtypeStruct,
59+
soft_pmap,
60+
# TODO(phawkins): hide tree* functions from jax, update callers to use
61+
# jax.tree_util.
62+
treedef_is_leaf,
63+
tree_flatten,
64+
tree_leaves,
65+
tree_map,
66+
tree_multimap,
67+
tree_structure,
68+
tree_transpose,
69+
tree_unflatten,
70+
value_and_grad,
71+
vjp,
72+
vmap,
73+
xla, # TODO(phawkins): update users to avoid this.
74+
xla_computation,
75+
)
2076
from jax import nn
2177
from jax import random
22-
import jax.numpy as np # side-effecting import sets up operator overloads
78+
79+
# TODO(phawkins): remove the `np` name.
80+
import jax.numpy as np # side-effecting import sets up operator overloads
81+
82+
83+
def _init():
84+
import os
85+
os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '1')
86+
87+
_init()
88+
del _init

tests/polynomial_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
15+
from functools import partial
1616
import numpy as np
1717

1818
from absl.testing import absltest
1919
from absl.testing import parameterized
2020

2121
from jax import numpy as jnp
22-
from jax import test_util as jtu, jit, partial
22+
from jax import test_util as jtu, jit
2323

2424
from jax.config import config
2525
config.parse_flags_with_absl()

0 commit comments

Comments
 (0)