File tree Expand file tree Collapse file tree 3 files changed +33
-19
lines changed Expand file tree Collapse file tree 3 files changed +33
-19
lines changed Original file line number Diff line number Diff line change 25
25
from jax ._src .earray import (
26
26
EArray as EArray
27
27
)
28
-
29
- from jax import numpy as _array_api
30
-
31
-
32
- _deprecations = {
33
- # Deprecated 01 Aug 2024
34
- "array_api" : (
35
- "jax.experimental.array_api import is no longer required as of JAX v0.4.32; "
36
- "jax.numpy supports the array API by default." ,
37
- _array_api
38
- ),
39
- }
40
-
41
- from jax ._src .deprecations import deprecation_getattr as _deprecation_getattr
42
- __getattr__ = _deprecation_getattr (__name__ , _deprecations )
43
- del _deprecation_getattr
44
- del _array_api
Original file line number Diff line number Diff line change
1
+ # Copyright 2024 The JAX Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Note: import <name> as <name> is required for names to be exported.
16
+ # See PEP 484 & https://github.com/google/jax/issues/7570
17
+
18
+ import sys as _sys
19
+ import warnings as _warnings
20
+
21
+ import jax .numpy as _array_api
22
+
23
+ _warnings .warn (
24
+ "jax.experimental.array_api import is no longer required as of JAX v0.4.32; "
25
+ "jax.numpy supports the array API by default." ,
26
+ DeprecationWarning , stacklevel = 2
27
+ )
28
+
29
+ _sys .modules ['jax.experimental.array_api' ] = _array_api
30
+
31
+ del _array_api , _sys , _warnings
Original file line number Diff line number Diff line change @@ -249,8 +249,8 @@ def test_array_namespace_method(self):
249
249
def test_deprecated_import (self ):
250
250
msg = "jax.experimental.array_api import is no longer required"
251
251
with self .assertWarnsRegex (DeprecationWarning , msg ):
252
- from jax .experimental import array_api
253
- self .assertIs (array_api , ARRAY_API_NAMESPACE )
252
+ import jax .experimental . array_api as nx
253
+ self .assertIs (nx , ARRAY_API_NAMESPACE )
254
254
255
255
256
256
class ArrayAPIInspectionUtilsTest (jtu .JaxTestCase ):
You can’t perform that action at this time.
0 commit comments