Skip to content

JAX v0.6.0

Latest
Compare
Choose a tag to compare
@hawkinsp hawkinsp released this 17 Apr 00:04
· 953 commits to main since this release
  • Breaking changes

    • jax.numpy.array no longer accepts None. This behavior was
      deprecated since November 2023 and is now removed.
    • Removed the config.jax_data_dependent_tracing_fallback config option,
      which was added temporarily in v0.4.36 to allow users to opt out of the
      new "stackless" tracing machinery.
    • Removed the config.jax_eager_pmap config option.
    • Disallow the calling of lower and trace AOT APIs on the result
      of jax.jit if there have been subsequent wrappers applied.
      Previously this worked, but silently ignored the wrappers.
      The workaround is to apply jax.jit last among the wrappers,
      and similarly for jax.pmap.
      See #27873.
    • The cuda12_pip extra for jax has been removed; use pip install jax[cuda12]
      instead.
  • Changes

    • The minimum CuDNN version is v9.8.
    • JAX is now built using CUDA 12.8. All versions of CUDA 12.1 or newer remain
      supported.
    • JAX package extras are now updated to use dash instead of underscore to
      align with PEP 685. For instance, if you were previously using pip install jax[cuda12_local]
      to install JAX, run pip install jax[cuda12-local] instead.
    • jax.jit now requires fun to be passed by position, and additional
      arguments to be passed by keyword. Doing otherwise will result in a
      DeprecationWarning in v0.6.X, and an error in starting in v0.7.X.
  • Deprecations

    • jax.tree_util.build_tree is deprecated. Use jax.tree.unflatten
      instead.
    • Implemented host callback handlers for CPU and GPU devices using XLA's FFI
      and removed existing CPU/GPU handlers using XLA's custom call.
    • All APIs in jax.lib.xla_extension are now deprecated.
    • jax.interpreters.mlir.hlo and jax.interpreters.mlir.func_dialect,
      which were accidental exports, have been removed. If needed, they are
      available from jax.extend.mlir.
    • jax.interpreters.mlir.custom_call is deprecated. The APIs provided by
      jax.ffi should be used instead.
    • The deprecated use of jax.ffi.ffi_call with inline arguments is no
      longer supported. jax.ffi.ffi_call now unconditionally returns a
      callable.
    • The following exports in jax.lib.xla_client are deprecated:
      get_topology_for_devices, heap_profile, mlir_api_version, Client,
      CompileOptions, DeviceAssignment, Frame, HloSharding, OpSharding,
      Traceback.
    • The following internal APIs in jax.util are deprecated:
      HashableFunction, as_hashable_function, cache, safe_map, safe_zip,
      split_dict, split_list, split_list_checked, split_merge, subvals,
      toposort, unzip2, wrap_name, and wraps.
    • jax.dlpack.to_dlpack has been deprecated. You can usually pass a JAX
      Array directly to the from_dlpack function of another framework. If you
      need the functionality of to_dlpack, use the __dlpack__ attribute of an
      array.
    • jax.lax.infeed, jax.lax.infeed_p, jax.lax.outfeed, and
      jax.lax.outfeed_p are deprecated and will be removed in JAX v0.7.0.
    • Several previously-deprecated APIs have been removed, including:
      • From jax.lib.xla_client: ArrayImpl, FftType, PaddingType,
        PrimitiveType, XlaBuilder, dtype_to_etype,
        ops, register_custom_call_target, shape_from_pyval, Shape,
        XlaComputation.
      • From jax.lib.xla_extension: ArrayImpl, XlaRuntimeError.
      • From jax: jax.treedef_is_leaf, jax.tree_flatten, jax.tree_map,
        jax.tree_leaves, jax.tree_structure, jax.tree_transpose, and
        jax.tree_unflatten. Replacements can be found in jax.tree or
        jax.tree_util.
      • From jax.core: AxisSize, ClosedJaxpr, EvalTrace, InDBIdx, InputType,
        Jaxpr, JaxprEqn, Literal, MapPrimitive, OpaqueTraceState, OutDBIdx,
        Primitive, Token, TRACER_LEAK_DEBUGGER_WARNING, Var, concrete_aval,
        dedup_referents, escaped_tracer_error, extend_axis_env_nd, full_lower, get_referent, jaxpr_as_fun, join_effects, lattice_join,
        leaked_tracer_error, maybe_find_leaked_tracers, raise_to_shaped,
        raise_to_shaped_mappings, reset_trace_state, str_eqn_compact,
        substitute_vars_in_output_ty, typecompat, and used_axis_names_jaxpr. Most
        have no public replacement, though a few are available at jax.extend.core.
      • The vectorized argument to jax.pure_callback and
        jax.ffi.ffi_call. Use the vmap_method parameter instead.