JAX v0.6.1
-
New features:
- Added
jax.lax.axis_size
which returns the size of the mapped axis
given its name.
- Added
-
Changes
- Additional checking for the versions of CUDA package dependencies was
reenabled, having been accidentally disabled in a previous release. - JAX nightly packages are now published to artifact registry. To install
these packages, see the JAX installation guide. jax.sharding.PartitionSpec
no longer inherits from a tuple.jax.ShapeDtypeStruct
is immutable now. Please use.update
method to
update yourShapeDtypeStruct
instead of doing in-place updates.
- Additional checking for the versions of CUDA package dependencies was
-
Deprecations
jax.custom_derivatives.custom_jvp_call_jaxpr_p
is deprecated, and will be
removed in JAX v0.7.0.