Skip to content

Commit 154403c

Browse files
Jake VanderPlasjax authors
authored andcommitted
Finalize deprecations of jax.interpreters.ad config & source_info_util
These have been raising a DeprecationWarning since JAX 0.4.19, released 2023 Oct 19. I've left the undefined symbols in place for now, as they will raise an informative AttributeError. PiperOrigin-RevId: 616931120
1 parent bc363de commit 154403c

File tree

2 files changed

+6
-17
lines changed

2 files changed

+6
-17
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ Remember to align the itemized text with the first line of an item within a list
1818
that cannot be converted to a JAX array now results in an exception.
1919
* The deprecated flag `jax_parallel_functions_output_gda` has been removed.
2020
This flag was long deprecated and did nothing; its use was a no-op.
21+
* The previously-deprecated imports `jax.interpreters.ad.config` and
22+
`jax.interpreters.ad.source_info_util` have now been removed. Use `jax.config`
23+
and `jax.extend.source_info_util` instead.
2124

2225
## jaxlib 0.4.26
2326

jax/interpreters/ad.py

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -73,32 +73,18 @@
7373
zeros_like_p as zeros_like_p,
7474
)
7575

76-
from jax import config as _deprecated_config
77-
from jax._src import source_info_util as _deprecated_source_info_util
7876
_deprecations = {
79-
# Added Oct 13, 2023:
77+
# Finalized Mar 18, 2024; remove after June 18, 2024
8078
"config": (
8179
"jax.interpreters.ad.config is deprecated. Use jax.config directly.",
82-
_deprecated_config,
80+
None,
8381
),
8482
"source_info_util": (
8583
"jax.interpreters.ad.source_info_util is deprecated. Use jax.extend.source_info_util.",
86-
_deprecated_source_info_util,
84+
None,
8785
),
8886
}
8987

90-
import typing
91-
if typing.TYPE_CHECKING:
92-
config = _deprecated_config
93-
source_info_util = _deprecated_source_info_util
94-
else:
95-
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
96-
__getattr__ = _deprecation_getattr(__name__, _deprecations)
97-
del _deprecation_getattr
98-
del typing
99-
del _deprecated_config
100-
del _deprecated_source_info_util
101-
10288
def backward_pass(jaxpr, reduce_axes, transform_stack,
10389
consts, primals_in, cotangents_in):
10490
if reduce_axes:

0 commit comments

Comments
 (0)