Skip to content

Commit c7cf7fb

Browse files
committed
[array api] fix deprecation to support old import pattern
1 parent aa9e1e4 commit c7cf7fb

File tree

3 files changed

+33
-19
lines changed

3 files changed

+33
-19
lines changed

jax/experimental/__init__.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,20 +25,3 @@
2525
from jax._src.earray import (
2626
EArray as EArray
2727
)
28-
29-
from jax import numpy as _array_api
30-
31-
32-
_deprecations = {
33-
# Deprecated 01 Aug 2024
34-
"array_api": (
35-
"jax.experimental.array_api import is no longer required as of JAX v0.4.32; "
36-
"jax.numpy supports the array API by default.",
37-
_array_api
38-
),
39-
}
40-
41-
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
42-
__getattr__ = _deprecation_getattr(__name__, _deprecations)
43-
del _deprecation_getattr
44-
del _array_api
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Copyright 2024 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Note: import <name> as <name> is required for names to be exported.
16+
# See PEP 484 & https://github.com/google/jax/issues/7570
17+
18+
import sys as _sys
19+
import warnings as _warnings
20+
21+
import jax.numpy as _array_api
22+
23+
_warnings.warn(
24+
"jax.experimental.array_api import is no longer required as of JAX v0.4.32; "
25+
"jax.numpy supports the array API by default.",
26+
DeprecationWarning, stacklevel=2
27+
)
28+
29+
_sys.modules['jax.experimental.array_api'] = _array_api
30+
31+
del _array_api, _sys, _warnings

tests/array_api_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,8 +249,8 @@ def test_array_namespace_method(self):
249249
def test_deprecated_import(self):
250250
msg = "jax.experimental.array_api import is no longer required"
251251
with self.assertWarnsRegex(DeprecationWarning, msg):
252-
from jax.experimental import array_api
253-
self.assertIs(array_api, ARRAY_API_NAMESPACE)
252+
import jax.experimental.array_api as nx
253+
self.assertIs(nx, ARRAY_API_NAMESPACE)
254254

255255

256256
class ArrayAPIInspectionUtilsTest(jtu.JaxTestCase):

0 commit comments

Comments
 (0)