Skip to content

Commit a510f03

Browse files
committed
[callback] Add a flag to implement host_callback in terms of io_callback.
The host_callbacks APIs are deprecated and will be removed. In order to help the transition to the new APIs, we add a flag (`JAX_HOST_CALLBACK_LEGACY`) that when set to `False` will use `io_callback` (and `pure_callback` and `jax.debug.callback`) to implement the host_callback APIs. See issue #20385 for more details. We change the tests to accomodate slightly different results when using the new callbacks. The tests that use `tap_with_device` and `call_with_device` are disabled when using the new callbacks.
1 parent 2512843 commit a510f03

File tree

6 files changed

+477
-178
lines changed

6 files changed

+477
-178
lines changed

CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ Remember to align the itemized text with the first line of an item within a list
5050
`spmd_axis_name` argument for expressing SPMD device-parallel computations.
5151
* The `jax.experimental.host_callback` module is deprecated.
5252
Use instead the [new JAX external callbacks](https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html).
53+
Added `JAX_HOST_CALLBACK_LEGACY` flag to assist in the transition to the
54+
new callbacks. See {jax-issue}`#20385` for a discussion.
5355
* Passing arguments to {func}`jax.numpy.array_equal` and {func}`jax.numpy.array_equiv`
5456
that cannot be converted to a JAX array now results in an exception.
5557
* The deprecated flag `jax_parallel_functions_output_gda` has been removed.
@@ -1451,7 +1453,7 @@ Changes:
14511453
special autodiff handling for hcb.id_tap and id_print.
14521454
From now on, only the primals are tapped. The old behavior can be
14531455
obtained (for a limited time) by setting the ``JAX_HOST_CALLBACK_AD_TRANSFORMS``
1454-
environment variable, or the ```--flax_host_callback_ad_transforms``` flag.
1456+
environment variable, or the ```--jax_host_callback_ad_transforms``` flag.
14551457
Additionally, added documentation for how to implement the old behavior
14561458
using JAX custom AD APIs ({jax-issue}`#8678`).
14571459
* Sorting now matches the behavior of NumPy for ``0.0`` and ``NaN`` regardless of the

jax/BUILD

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -997,7 +997,11 @@ pytype_library(
997997

998998
pytype_library(
999999
name = "experimental_host_callback",
1000-
srcs = ["experimental/host_callback.py"],
1000+
srcs = [
1001+
"experimental/__init__.py", # To support JAX_HOST_CALLBACK_LEGACY=False
1002+
"experimental/host_callback.py",
1003+
"experimental/x64_context.py", # To support JAX_HOST_CALLBACK_LEGACY=False
1004+
],
10011005
visibility = ["//visibility:public"],
10021006
deps = [
10031007
":jax",

jax/experimental/host_callback.py

Lines changed: 101 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
The host_callback APIs are deprecated as of March 20, 2024.
1818
The functionality is subsumed by the
1919
`new JAX external callbacks <https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html>`_
20+
See https://github.com/google/jax/issues/20385.
2021
2122
This module introduces the host callback functions :func:`call`,
2223
:func:`id_tap`, and :func:`id_print`, that send their arguments from the device
@@ -501,6 +502,7 @@ def power3_with_cotangents(x):
501502
from __future__ import annotations
502503

503504
import atexit
505+
import enum
504506
from collections.abc import Sequence
505507
import functools
506508
import itertools
@@ -510,13 +512,15 @@ def power3_with_cotangents(x):
510512
import traceback
511513
from typing import Any, Callable, cast
512514

515+
import jax
513516
from jax._src import api
514517
from jax._src import core
515518
from jax._src import config
516519
from jax import custom_derivatives
517520
from jax._src import dtypes
518521
from jax import lax
519522
from jax.experimental import pjit
523+
from jax.experimental import io_callback
520524
from jax._src.interpreters import ad, batching, pxla
521525
from jax._src.interpreters import mlir
522526
from jax._src.interpreters import partial_eval as pe
@@ -560,6 +564,15 @@ def power3_with_cotangents(x):
560564
'Has no effect on TPU, since only the outfeed mechanism is implemented.'
561565
)
562566
)
567+
_HOST_CALLBACK_LEGACY = config.DEFINE_bool(
568+
'jax_host_callback_legacy',
569+
config.bool_env('JAX_HOST_CALLBACK_LEGACY', True),
570+
help=(
571+
'Use old implementation of host_callback, documented in the module docstring.'
572+
'If False, use the jax.experimental.io_callback implementation. '
573+
'See https://github.com/google/jax/issues/20385.'
574+
)
575+
)
563576

564577
logger = logging.getLogger(__name__)
565578

@@ -591,20 +604,31 @@ def _raise_if_using_outfeed_with_pjrt_c_api(backend: xb.XlaBackend):
591604
XlaLocalClient = xla_client.Client
592605
DType = Any
593606

607+
class CallbackFlavor(enum.Enum):
608+
"""Specifies which flavor of callback to use under JAX_HOST_CALLBACK_LEGACY=False.
609+
610+
See https://github.com/google/jax/issues/20385.
611+
"""
612+
IO_CALLBACK = 1 # uses jax.experimental.io_callback
613+
PURE = 2 # uses jax.pure_callback
614+
DEBUG = 3 # uses jax.debug.callback, valid only when there are no results
615+
594616

595617
def _deprecated_id_tap(tap_func,
596618
arg,
597619
*,
598620
result=None,
599621
tap_with_device=False,
600622
device_index=0,
623+
callback_flavor=CallbackFlavor.IO_CALLBACK,
601624
**kwargs):
602625
"""Host-callback tap primitive, like identity function with a call to ``tap_func``.
603626
604627
.. warning::
605628
The host_callback APIs are deprecated as of March 20, 2024.
606629
The functionality is subsumed by the
607630
`new JAX external callbacks <https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html>`_
631+
See https://github.com/google/jax/issues/20385.
608632
609633
``id_tap`` behaves semantically like the identity function but has the
610634
side-effect that a user-defined Python function is called with the runtime
@@ -628,6 +652,9 @@ def _deprecated_id_tap(tap_func,
628652
device_index: specifies from which device the tap function is invoked in a
629653
SPMD program. Works only when using the outfeed implementation mechanism,
630654
i.e., does not work on CPU unless --jax_host_callback_outfeed=True.
655+
callback_flavor: if running with `JAX_HOST_CALLBACK_LEGACY=False` specifies
656+
the flavor of callback to use.
657+
See https://github.com/google/jax/issues/20385.
631658
632659
Returns:
633660
``arg``, or ``result`` if given.
@@ -660,7 +687,8 @@ def _deprecated_id_tap(tap_func,
660687
call_with_device=tap_with_device,
661688
result_shape=None,
662689
identity=True,
663-
device_index=device_index)
690+
device_index=device_index,
691+
callback_flavor=callback_flavor)
664692

665693
if result is not None:
666694
return result
@@ -675,13 +703,15 @@ def _deprecated_id_print(arg,
675703
device_index=0,
676704
output_stream=None,
677705
threshold=None,
706+
callback_flavor=CallbackFlavor.IO_CALLBACK,
678707
**kwargs):
679708
"""Like :func:`id_tap` with a printing tap function.
680709
681710
.. warning::
682711
The host_callback APIs are deprecated as of March 20, 2024.
683712
The functionality is subsumed by the
684713
`new JAX external callbacks <https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html>`_
714+
See https://github.com/google/jax/issues/20385.
685715
686716
On each invocation of the printing tap, the ``kwargs`` if present
687717
will be printed first (sorted by keys). Then arg will be printed,
@@ -697,6 +727,9 @@ def _deprecated_id_print(arg,
697727
built-in ``print``. The string will be passed as
698728
``output_stream.write(s)``.
699729
* ``threshold`` is passed to ``numpy.array2string``.
730+
* ``callback_flavor``: if running with `JAX_HOST_CALLBACK_LEGACY=False` specifies
731+
the flavor of callback to use.
732+
See https://github.com/google/jax/issues/20385.
700733
701734
For more details see the :mod:`jax.experimental.host_callback` module documentation.
702735
"""
@@ -708,19 +741,22 @@ def _deprecated_id_print(arg,
708741
arg,
709742
result=result,
710743
tap_with_device=tap_with_device,
711-
device_index=device_index)
744+
device_index=device_index,
745+
callback_flavor=callback_flavor)
712746

713747

714748
def _deprecated_call(callback_func: Callable, arg, *,
715749
result_shape=None,
716750
call_with_device=False,
717-
device_index=0):
751+
device_index=0,
752+
callback_flavor=CallbackFlavor.IO_CALLBACK):
718753
"""Make a call to the host, and expect a result.
719754
720755
.. warning::
721756
The host_callback APIs are deprecated as of March 20, 2024.
722757
The functionality is subsumed by the
723758
`new JAX external callbacks <https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html>`_
759+
See https://github.com/google/jax/issues/20385.
724760
725761
Args:
726762
callback_func: The Python function to invoke on the host as
@@ -748,14 +784,26 @@ def _deprecated_call(callback_func: Callable, arg, *,
748784
device_index: specifies from which device the tap function is invoked in a
749785
SPMD program. Works only when using the outfeed implementation mechanism,
750786
i.e., does not work on CPU unless --jax_host_callback_outfeed=True.
787+
callback_flavor: if running with `JAX_HOST_CALLBACK_LEGACY=False` specifies
788+
the flavor of callback to use.
789+
See https://github.com/google/jax/issues/20385.
790+
751791
Returns:
752792
the result of the ``callback_func`` invocation.
753793
754794
For more details see the :mod:`jax.experimental.host_callback` module documentation.
755795
"""
796+
if (not _HOST_CALLBACK_LEGACY.value and
797+
callback_flavor is CallbackFlavor.DEBUG and
798+
result_shape is not None):
799+
raise NotImplementedError(
800+
"When using JAX_HOST_CALLBACK_LEGACY=False you can use the `DEBUG` "
801+
"flavor of callback only when the `result_shape` is None. "
802+
"See https://github.com/google/jax/issues/20385."
803+
)
756804
return _call(callback_func, arg, result_shape=result_shape,
757805
call_with_device=call_with_device, identity=False,
758-
device_index=device_index)
806+
device_index=device_index, callback_flavor=callback_flavor)
759807

760808

761809
# We need the wrapper function to have hash and equality defined since it is
@@ -766,6 +814,11 @@ def __init__(self, callback_func, identity, call_with_device):
766814
self.callback_func = callback_func
767815
self.identity = identity
768816
self.call_with_device = call_with_device
817+
if not _HOST_CALLBACK_LEGACY.value and call_with_device:
818+
raise NotImplementedError(
819+
"When using JAX_HOST_CALLBACK_LEGACY=False, the host_callback APIs"
820+
" do not support `tap_with_device` and `call_with_device`. "
821+
"See https://github.com/google/jax/issues/20385.")
769822

770823
def __hash__(self):
771824
return hash((self.callback_func, self.identity, self.call_with_device))
@@ -775,7 +828,16 @@ def __eq__(self, other):
775828
self.identity == other.identity and
776829
self.call_with_device == other.call_with_device)
777830

778-
def __call__(self, arg, device, transforms):
831+
def __call__(self, *args, **kwargs):
832+
if _HOST_CALLBACK_LEGACY.value:
833+
return self._call_legacy(*args, **kwargs)
834+
else:
835+
if self.identity:
836+
# For id_tap, we pass empty transforms, for backwards compatibility
837+
return self.callback_func(args[0], ())
838+
return self.callback_func(*args, **kwargs)
839+
840+
def _call_legacy(self, arg, device, transforms):
779841
if self.identity:
780842
# For id_tap, we pass the transforms, for backwards compatibility
781843
if self.call_with_device:
@@ -797,14 +859,16 @@ def _call(callback_func: Callable,
797859
result_shape=None,
798860
call_with_device=False,
799861
device_index=0,
800-
identity=False):
801-
# Lazy initialization
802-
_initialize_outfeed_receiver(
803-
max_callback_queue_size_bytes=_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE.value)
862+
identity=False,
863+
callback_flavor=CallbackFlavor.IO_CALLBACK):
864+
if _HOST_CALLBACK_LEGACY.value:
865+
# Lazy initialization
866+
_initialize_outfeed_receiver(
867+
max_callback_queue_size_bytes=_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE.value)
804868
api.check_callable(callback_func)
805869
flat_args, arg_treedef = tree_util.tree_flatten(arg)
806-
for arg in flat_args:
807-
dispatch.check_arg(arg)
870+
for arg_ in flat_args:
871+
dispatch.check_arg(arg_)
808872
# See definition of outside_call_p for what parameters it takes
809873
params: dict[str, Any] = {}
810874
# TODO: wrap function
@@ -829,8 +893,27 @@ def _call(callback_func: Callable,
829893

830894
params["result_treedef"] = result_treedef
831895
params["flat_results_aval"] = tuple(flat_results_aval)
832-
flat_results = outside_call_p.bind(*flat_args, **params)
833-
return result_treedef.unflatten(flat_results) if not identity else arg_treedef.unflatten(flat_results)
896+
897+
if _HOST_CALLBACK_LEGACY.value:
898+
flat_results = outside_call_p.bind(*flat_args, **params)
899+
return result_treedef.unflatten(flat_results) if not identity else arg_treedef.unflatten(flat_results)
900+
else:
901+
callback_device = jax.local_devices()[device_index]
902+
sharding = jax.sharding.SingleDeviceSharding(callback_device)
903+
callback_func = _CallbackWrapper(callback_func, identity,
904+
call_with_device)
905+
if callback_flavor is CallbackFlavor.DEBUG:
906+
assert identity
907+
jax.debug.callback(callback_func, arg)
908+
return arg
909+
elif callback_flavor is CallbackFlavor.PURE:
910+
call_res = jax.pure_callback(callback_func, result_shape, arg,
911+
sharding=sharding)
912+
else:
913+
call_res = io_callback(callback_func, result_shape, arg,
914+
sharding=sharding,
915+
ordered=True)
916+
return call_res if not identity else arg
834917

835918

836919
# We need the lock for when we use the CustomCall implementation of callbacks.
@@ -855,7 +938,6 @@ def _print_tap_func(
855938
threshold: the value of numpy.array2string threshold parameter.
856939
**kwargs: all other keyword args are printed before printing `arg`.
857940
"""
858-
859941
def emit_str(s: str):
860942
if output_stream is not None:
861943
output_stream.write(s + "\n")
@@ -1844,6 +1926,10 @@ def _deprecated_barrier_wait(logging_name: str | None = None):
18441926
18451927
For more details see the :mod:`jax.experimental.host_callback` module documentation.
18461928
"""
1929+
if not _HOST_CALLBACK_LEGACY.value:
1930+
jax.effects_barrier()
1931+
return
1932+
18471933
logging_name = logging_name or ""
18481934
logger.debug("barrier_wait[%s]: start", logging_name)
18491935

@@ -1907,7 +1993,7 @@ def _deprecated_stop_outfeed_receiver():
19071993
_deprecation_msg = (
19081994
"The host_callback APIs are deprecated as of March 20, 2024. The functionality "
19091995
"is subsumed by the new JAX external callbacks. "
1910-
"See https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html.")
1996+
"See https://github.com/google/jax/issues/20385.")
19111997

19121998
_deprecations = {
19131999
# Added March 20, 2024

0 commit comments

Comments
 (0)