-
Breaking changes
jax.numpy.array
no longer acceptsNone
. This behavior was
deprecated since November 2023 and is now removed.- Removed the
config.jax_data_dependent_tracing_fallback
config option,
which was added temporarily in v0.4.36 to allow users to opt out of the
new "stackless" tracing machinery. - Removed the
config.jax_eager_pmap
config option. - Disallow the calling of
lower
andtrace
AOT APIs on the result
ofjax.jit
if there have been subsequent wrappers applied.
Previously this worked, but silently ignored the wrappers.
The workaround is to applyjax.jit
last among the wrappers,
and similarly forjax.pmap
.
See#27873
. - The
cuda12_pip
extra forjax
has been removed; usepip install jax[cuda12]
instead.
-
Changes
- The minimum CuDNN version is v9.8.
- JAX is now built using CUDA 12.8. All versions of CUDA 12.1 or newer remain
supported. - JAX package extras are now updated to use dash instead of underscore to
align with PEP 685. For instance, if you were previously usingpip install jax[cuda12_local]
to install JAX, runpip install jax[cuda12-local]
instead. jax.jit
now requiresfun
to be passed by position, and additional
arguments to be passed by keyword. Doing otherwise will result in a
DeprecationWarning in v0.6.X, and an error in starting in v0.7.X.
-
Deprecations
jax.tree_util.build_tree
is deprecated. Usejax.tree.unflatten
instead.- Implemented host callback handlers for CPU and GPU devices using XLA's FFI
and removed existing CPU/GPU handlers using XLA's custom call. - All APIs in
jax.lib.xla_extension
are now deprecated. jax.interpreters.mlir.hlo
andjax.interpreters.mlir.func_dialect
,
which were accidental exports, have been removed. If needed, they are
available fromjax.extend.mlir
.jax.interpreters.mlir.custom_call
is deprecated. The APIs provided by
jax.ffi
should be used instead.- The deprecated use of
jax.ffi.ffi_call
with inline arguments is no
longer supported.jax.ffi.ffi_call
now unconditionally returns a
callable. - The following exports in
jax.lib.xla_client
are deprecated:
get_topology_for_devices
,heap_profile
,mlir_api_version
,Client
,
CompileOptions
,DeviceAssignment
,Frame
,HloSharding
,OpSharding
,
Traceback
. - The following internal APIs in
jax.util
are deprecated:
HashableFunction
,as_hashable_function
,cache
,safe_map
,safe_zip
,
split_dict
,split_list
,split_list_checked
,split_merge
,subvals
,
toposort
,unzip2
,wrap_name
, andwraps
. jax.dlpack.to_dlpack
has been deprecated. You can usually pass a JAX
Array
directly to thefrom_dlpack
function of another framework. If you
need the functionality ofto_dlpack
, use the__dlpack__
attribute of an
array.jax.lax.infeed
,jax.lax.infeed_p
,jax.lax.outfeed
, and
jax.lax.outfeed_p
are deprecated and will be removed in JAX v0.7.0.- Several previously-deprecated APIs have been removed, including:
- From
jax.lib.xla_client
:ArrayImpl
,FftType
,PaddingType
,
PrimitiveType
,XlaBuilder
,dtype_to_etype
,
ops
,register_custom_call_target
,shape_from_pyval
,Shape
,
XlaComputation
. - From
jax.lib.xla_extension
:ArrayImpl
,XlaRuntimeError
. - From
jax
:jax.treedef_is_leaf
,jax.tree_flatten
,jax.tree_map
,
jax.tree_leaves
,jax.tree_structure
,jax.tree_transpose
, and
jax.tree_unflatten
. Replacements can be found injax.tree
or
jax.tree_util
. - From
jax.core
:AxisSize
,ClosedJaxpr
,EvalTrace
,InDBIdx
,InputType
,
Jaxpr
,JaxprEqn
,Literal
,MapPrimitive
,OpaqueTraceState
,OutDBIdx
,
Primitive
,Token
,TRACER_LEAK_DEBUGGER_WARNING
,Var
,concrete_aval
,
dedup_referents
,escaped_tracer_error
,extend_axis_env_nd
,full_lower
,get_referent
,jaxpr_as_fun
,join_effects
,lattice_join
,
leaked_tracer_error
,maybe_find_leaked_tracers
,raise_to_shaped
,
raise_to_shaped_mappings
,reset_trace_state
,str_eqn_compact
,
substitute_vars_in_output_ty
,typecompat
, andused_axis_names_jaxpr
. Most
have no public replacement, though a few are available atjax.extend.core
. - The
vectorized
argument tojax.pure_callback
and
jax.ffi.ffi_call
. Use thevmap_method
parameter instead.
- From