File tree Expand file tree Collapse file tree 3 files changed +76
-10
lines changed Expand file tree Collapse file tree 3 files changed +76
-10
lines changed Original file line number Diff line number Diff line change 700
700
" \n " ,
701
701
" # Now we register the XLA compilation rule with JAX\n " ,
702
702
" # TODO: for GPU? and TPU?\n " ,
703
- " from jax import xla\n " ,
703
+ " from jax.interpreters import xla\n " ,
704
704
" xla.backend_specific_translations['cpu'][multiply_add_p] = multiply_add_xla_translation"
705
705
],
706
706
"execution_count" : 0 ,
876
876
"colab" : {}
877
877
},
878
878
"source" : [
879
- " from jax import ad\n " ,
879
+ " from jax.interpreters import ad\n " ,
880
880
" \n " ,
881
881
" \n " ,
882
882
" @trace(\" multiply_add_value_and_jvp\" )\n " ,
1529
1529
"colab" : {}
1530
1530
},
1531
1531
"source" : [
1532
- " from jax import batching\n " ,
1532
+ " from jax.interpreters import batching\n " ,
1533
1533
" \n " ,
1534
1534
" \n " ,
1535
1535
" @trace(\" multiply_add_batch\" )\n " ,
Original file line number Diff line number Diff line change 12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
- import os
16
- os .environ .setdefault ('TF_CPP_MIN_LOG_LEVEL' , '1' )
17
-
18
15
from jax .version import __version__
19
- from jax .api import *
16
+ from .config import config
17
+ from .api import (
18
+ ad , # TODO(phawkins): update users to avoid this.
19
+ argnums_partial , # TODO(phawkins): update Haiku to not use this.
20
+ checkpoint ,
21
+ curry , # TODO(phawkins): update users to avoid this.
22
+ custom_gradient ,
23
+ custom_jvp ,
24
+ custom_vjp ,
25
+ custom_transforms ,
26
+ defjvp ,
27
+ defjvp_all ,
28
+ defvjp ,
29
+ defvjp_all ,
30
+ device_count ,
31
+ device_get ,
32
+ device_put ,
33
+ devices ,
34
+ disable_jit ,
35
+ eval_shape ,
36
+ flatten_fun_nokwargs , # TODO(phawkins): update users to avoid this.
37
+ grad ,
38
+ hessian ,
39
+ host_count ,
40
+ host_id ,
41
+ host_ids ,
42
+ jacobian ,
43
+ jacfwd ,
44
+ jacrev ,
45
+ jit ,
46
+ jvp ,
47
+ local_device_count ,
48
+ local_devices ,
49
+ linearize ,
50
+ make_jaxpr ,
51
+ mask ,
52
+ partial , # TODO(phawkins): update callers to use functools.partial.
53
+ pmap ,
54
+ pxla , # TODO(phawkins): update users to avoid this.
55
+ remat ,
56
+ shapecheck ,
57
+ ShapedArray ,
58
+ ShapeDtypeStruct ,
59
+ soft_pmap ,
60
+ # TODO(phawkins): hide tree* functions from jax, update callers to use
61
+ # jax.tree_util.
62
+ treedef_is_leaf ,
63
+ tree_flatten ,
64
+ tree_leaves ,
65
+ tree_map ,
66
+ tree_multimap ,
67
+ tree_structure ,
68
+ tree_transpose ,
69
+ tree_unflatten ,
70
+ value_and_grad ,
71
+ vjp ,
72
+ vmap ,
73
+ xla , # TODO(phawkins): update users to avoid this.
74
+ xla_computation ,
75
+ )
20
76
from jax import nn
21
77
from jax import random
22
- import jax .numpy as np # side-effecting import sets up operator overloads
78
+
79
+ # TODO(phawkins): remove the `np` name.
80
+ import jax .numpy as np # side-effecting import sets up operator overloads
81
+
82
+
83
+ def _init ():
84
+ import os
85
+ os .environ .setdefault ('TF_CPP_MIN_LOG_LEVEL' , '1' )
86
+
87
+ _init ()
88
+ del _init
Original file line number Diff line number Diff line change 12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
-
15
+ from functools import partial
16
16
import numpy as np
17
17
18
18
from absl .testing import absltest
19
19
from absl .testing import parameterized
20
20
21
21
from jax import numpy as jnp
22
- from jax import test_util as jtu , jit , partial
22
+ from jax import test_util as jtu , jit
23
23
24
24
from jax .config import config
25
25
config .parse_flags_with_absl ()
You can’t perform that action at this time.
0 commit comments