Skip to content

Commit 07440f4

Browse files
yashk2810hawkinsp
authored andcommitted
Prepare for JAX release 0.5.1
1 parent b0cfcb8 commit 07440f4

File tree

3 files changed

+9
-4
lines changed

3 files changed

+9
-4
lines changed

CHANGELOG.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ Remember to align the itemized text with the first line of an item within a list
1414
When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.md.
1515
-->
1616

17-
## Unreleased
17+
## jax 0.5.1 (Feb 24, 2025)
1818

1919
* New Features
2020
* Added an experimental {func}`jax.experimental.custom_dce.custom_dce`
@@ -47,6 +47,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
4747
See https://github.com/jax-ml/jax/issues/26480 for more detail.
4848

4949
* Bug fixes
50+
* TPU runtime startup and shutdown time should be significantly improved on
51+
TPU v5e and newer (from around 17s to around 8s). If not already set, you may
52+
need to enable transparent hugepages in your VM image
53+
(`sudo sh -c 'echo always > /sys/kernel/mm/transparent_hugepage/enabled'`).
54+
We hope to improve this further in future releases.
5055
* Persistent compilation cache no longer writes access time file if
5156
JAX_COMPILATION_CACHE_MAX_SIZE is unset or set to -1, i.e. if the LRU
5257
eviction policy isn't enabled. This should improve performance when using

jax/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def make_release_tree(self, base_dir, files):
146146

147147

148148
__version__ = _get_version_string()
149-
_minimum_jaxlib_version = "0.5.0"
149+
_minimum_jaxlib_version = "0.5.1"
150150

151151
def _version_as_tuple(version_str):
152152
return tuple(int(i) for i in version_str.split(".") if i.isdigit())

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@
1919

2020
project_name = 'jax'
2121

22-
_current_jaxlib_version = '0.5.0'
22+
_current_jaxlib_version = '0.5.1'
2323
# The following should be updated after each new jaxlib release.
2424
_latest_jaxlib_version_on_pypi = '0.5.0'
2525

26-
_libtpu_version = '0.0.8'
26+
_libtpu_version = '0.0.10'
2727

2828
def load_version_module(pkg_path):
2929
spec = importlib.util.spec_from_file_location(

0 commit comments

Comments
 (0)