17
17
The host_callback APIs are deprecated as of March 20, 2024.
18
18
The functionality is subsumed by the
19
19
`new JAX external callbacks <https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html>`_
20
+ See https://github.com/google/jax/issues/20385.
20
21
21
22
This module introduces the host callback functions :func:`call`,
22
23
:func:`id_tap`, and :func:`id_print`, that send their arguments from the device
@@ -501,6 +502,7 @@ def power3_with_cotangents(x):
501
502
from __future__ import annotations
502
503
503
504
import atexit
505
+ import enum
504
506
from collections .abc import Sequence
505
507
import functools
506
508
import itertools
@@ -510,13 +512,15 @@ def power3_with_cotangents(x):
510
512
import traceback
511
513
from typing import Any , Callable , cast
512
514
515
+ import jax
513
516
from jax ._src import api
514
517
from jax ._src import core
515
518
from jax ._src import config
516
519
from jax import custom_derivatives
517
520
from jax ._src import dtypes
518
521
from jax import lax
519
522
from jax .experimental import pjit
523
+ from jax .experimental import io_callback
520
524
from jax ._src .interpreters import ad , batching , pxla
521
525
from jax ._src .interpreters import mlir
522
526
from jax ._src .interpreters import partial_eval as pe
@@ -560,6 +564,15 @@ def power3_with_cotangents(x):
560
564
'Has no effect on TPU, since only the outfeed mechanism is implemented.'
561
565
)
562
566
)
567
+ _HOST_CALLBACK_LEGACY = config .DEFINE_bool (
568
+ 'jax_host_callback_legacy' ,
569
+ config .bool_env ('JAX_HOST_CALLBACK_LEGACY' , True ),
570
+ help = (
571
+ 'Use old implementation of host_callback, documented in the module docstring.'
572
+ 'If False, use the jax.experimental.io_callback implementation. '
573
+ 'See https://github.com/google/jax/issues/20385.'
574
+ )
575
+ )
563
576
564
577
logger = logging .getLogger (__name__ )
565
578
@@ -591,20 +604,31 @@ def _raise_if_using_outfeed_with_pjrt_c_api(backend: xb.XlaBackend):
591
604
XlaLocalClient = xla_client .Client
592
605
DType = Any
593
606
607
+ class CallbackFlavor (enum .Enum ):
608
+ """Specifies which flavor of callback to use under JAX_HOST_CALLBACK_LEGACY=False.
609
+
610
+ See https://github.com/google/jax/issues/20385.
611
+ """
612
+ IO_CALLBACK = 1 # uses jax.experimental.io_callback
613
+ PURE = 2 # uses jax.pure_callback
614
+ DEBUG = 3 # uses jax.debug.callback, valid only when there are no results
615
+
594
616
595
617
def _deprecated_id_tap (tap_func ,
596
618
arg ,
597
619
* ,
598
620
result = None ,
599
621
tap_with_device = False ,
600
622
device_index = 0 ,
623
+ callback_flavor = CallbackFlavor .IO_CALLBACK ,
601
624
** kwargs ):
602
625
"""Host-callback tap primitive, like identity function with a call to ``tap_func``.
603
626
604
627
.. warning::
605
628
The host_callback APIs are deprecated as of March 20, 2024.
606
629
The functionality is subsumed by the
607
630
`new JAX external callbacks <https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html>`_
631
+ See https://github.com/google/jax/issues/20385.
608
632
609
633
``id_tap`` behaves semantically like the identity function but has the
610
634
side-effect that a user-defined Python function is called with the runtime
@@ -628,6 +652,9 @@ def _deprecated_id_tap(tap_func,
628
652
device_index: specifies from which device the tap function is invoked in a
629
653
SPMD program. Works only when using the outfeed implementation mechanism,
630
654
i.e., does not work on CPU unless --jax_host_callback_outfeed=True.
655
+ callback_flavor: if running with `JAX_HOST_CALLBACK_LEGACY=False` specifies
656
+ the flavor of callback to use.
657
+ See https://github.com/google/jax/issues/20385.
631
658
632
659
Returns:
633
660
``arg``, or ``result`` if given.
@@ -660,7 +687,8 @@ def _deprecated_id_tap(tap_func,
660
687
call_with_device = tap_with_device ,
661
688
result_shape = None ,
662
689
identity = True ,
663
- device_index = device_index )
690
+ device_index = device_index ,
691
+ callback_flavor = callback_flavor )
664
692
665
693
if result is not None :
666
694
return result
@@ -675,13 +703,15 @@ def _deprecated_id_print(arg,
675
703
device_index = 0 ,
676
704
output_stream = None ,
677
705
threshold = None ,
706
+ callback_flavor = CallbackFlavor .IO_CALLBACK ,
678
707
** kwargs ):
679
708
"""Like :func:`id_tap` with a printing tap function.
680
709
681
710
.. warning::
682
711
The host_callback APIs are deprecated as of March 20, 2024.
683
712
The functionality is subsumed by the
684
713
`new JAX external callbacks <https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html>`_
714
+ See https://github.com/google/jax/issues/20385.
685
715
686
716
On each invocation of the printing tap, the ``kwargs`` if present
687
717
will be printed first (sorted by keys). Then arg will be printed,
@@ -697,6 +727,9 @@ def _deprecated_id_print(arg,
697
727
built-in ``print``. The string will be passed as
698
728
``output_stream.write(s)``.
699
729
* ``threshold`` is passed to ``numpy.array2string``.
730
+ * ``callback_flavor``: if running with `JAX_HOST_CALLBACK_LEGACY=False` specifies
731
+ the flavor of callback to use.
732
+ See https://github.com/google/jax/issues/20385.
700
733
701
734
For more details see the :mod:`jax.experimental.host_callback` module documentation.
702
735
"""
@@ -708,19 +741,22 @@ def _deprecated_id_print(arg,
708
741
arg ,
709
742
result = result ,
710
743
tap_with_device = tap_with_device ,
711
- device_index = device_index )
744
+ device_index = device_index ,
745
+ callback_flavor = callback_flavor )
712
746
713
747
714
748
def _deprecated_call (callback_func : Callable , arg , * ,
715
749
result_shape = None ,
716
750
call_with_device = False ,
717
- device_index = 0 ):
751
+ device_index = 0 ,
752
+ callback_flavor = CallbackFlavor .IO_CALLBACK ):
718
753
"""Make a call to the host, and expect a result.
719
754
720
755
.. warning::
721
756
The host_callback APIs are deprecated as of March 20, 2024.
722
757
The functionality is subsumed by the
723
758
`new JAX external callbacks <https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html>`_
759
+ See https://github.com/google/jax/issues/20385.
724
760
725
761
Args:
726
762
callback_func: The Python function to invoke on the host as
@@ -748,14 +784,26 @@ def _deprecated_call(callback_func: Callable, arg, *,
748
784
device_index: specifies from which device the tap function is invoked in a
749
785
SPMD program. Works only when using the outfeed implementation mechanism,
750
786
i.e., does not work on CPU unless --jax_host_callback_outfeed=True.
787
+ callback_flavor: if running with `JAX_HOST_CALLBACK_LEGACY=False` specifies
788
+ the flavor of callback to use.
789
+ See https://github.com/google/jax/issues/20385.
790
+
751
791
Returns:
752
792
the result of the ``callback_func`` invocation.
753
793
754
794
For more details see the :mod:`jax.experimental.host_callback` module documentation.
755
795
"""
796
+ if (not _HOST_CALLBACK_LEGACY .value and
797
+ callback_flavor is CallbackFlavor .DEBUG and
798
+ result_shape is not None ):
799
+ raise NotImplementedError (
800
+ "When using JAX_HOST_CALLBACK_LEGACY=False you can use the `DEBUG` "
801
+ "flavor of callback only when the `result_shape` is None. "
802
+ "See https://github.com/google/jax/issues/20385."
803
+ )
756
804
return _call (callback_func , arg , result_shape = result_shape ,
757
805
call_with_device = call_with_device , identity = False ,
758
- device_index = device_index )
806
+ device_index = device_index , callback_flavor = callback_flavor )
759
807
760
808
761
809
# We need the wrapper function to have hash and equality defined since it is
@@ -766,6 +814,11 @@ def __init__(self, callback_func, identity, call_with_device):
766
814
self .callback_func = callback_func
767
815
self .identity = identity
768
816
self .call_with_device = call_with_device
817
+ if not _HOST_CALLBACK_LEGACY .value and call_with_device :
818
+ raise NotImplementedError (
819
+ "When using JAX_HOST_CALLBACK_LEGACY=False, the host_callback APIs"
820
+ " do not support `tap_with_device` and `call_with_device`. "
821
+ "See https://github.com/google/jax/issues/20385." )
769
822
770
823
def __hash__ (self ):
771
824
return hash ((self .callback_func , self .identity , self .call_with_device ))
@@ -775,7 +828,16 @@ def __eq__(self, other):
775
828
self .identity == other .identity and
776
829
self .call_with_device == other .call_with_device )
777
830
778
- def __call__ (self , arg , device , transforms ):
831
+ def __call__ (self , * args , ** kwargs ):
832
+ if _HOST_CALLBACK_LEGACY .value :
833
+ return self ._call_legacy (* args , ** kwargs )
834
+ else :
835
+ if self .identity :
836
+ # For id_tap, we pass empty transforms, for backwards compatibility
837
+ return self .callback_func (args [0 ], ())
838
+ return self .callback_func (* args , ** kwargs )
839
+
840
+ def _call_legacy (self , arg , device , transforms ):
779
841
if self .identity :
780
842
# For id_tap, we pass the transforms, for backwards compatibility
781
843
if self .call_with_device :
@@ -797,14 +859,16 @@ def _call(callback_func: Callable,
797
859
result_shape = None ,
798
860
call_with_device = False ,
799
861
device_index = 0 ,
800
- identity = False ):
801
- # Lazy initialization
802
- _initialize_outfeed_receiver (
803
- max_callback_queue_size_bytes = _HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE .value )
862
+ identity = False ,
863
+ callback_flavor = CallbackFlavor .IO_CALLBACK ):
864
+ if _HOST_CALLBACK_LEGACY .value :
865
+ # Lazy initialization
866
+ _initialize_outfeed_receiver (
867
+ max_callback_queue_size_bytes = _HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE .value )
804
868
api .check_callable (callback_func )
805
869
flat_args , arg_treedef = tree_util .tree_flatten (arg )
806
- for arg in flat_args :
807
- dispatch .check_arg (arg )
870
+ for arg_ in flat_args :
871
+ dispatch .check_arg (arg_ )
808
872
# See definition of outside_call_p for what parameters it takes
809
873
params : dict [str , Any ] = {}
810
874
# TODO: wrap function
@@ -829,8 +893,27 @@ def _call(callback_func: Callable,
829
893
830
894
params ["result_treedef" ] = result_treedef
831
895
params ["flat_results_aval" ] = tuple (flat_results_aval )
832
- flat_results = outside_call_p .bind (* flat_args , ** params )
833
- return result_treedef .unflatten (flat_results ) if not identity else arg_treedef .unflatten (flat_results )
896
+
897
+ if _HOST_CALLBACK_LEGACY .value :
898
+ flat_results = outside_call_p .bind (* flat_args , ** params )
899
+ return result_treedef .unflatten (flat_results ) if not identity else arg_treedef .unflatten (flat_results )
900
+ else :
901
+ callback_device = jax .local_devices ()[device_index ]
902
+ sharding = jax .sharding .SingleDeviceSharding (callback_device )
903
+ callback_func = _CallbackWrapper (callback_func , identity ,
904
+ call_with_device )
905
+ if callback_flavor is CallbackFlavor .DEBUG :
906
+ assert identity
907
+ jax .debug .callback (callback_func , arg )
908
+ return arg
909
+ elif callback_flavor is CallbackFlavor .PURE :
910
+ call_res = jax .pure_callback (callback_func , result_shape , arg ,
911
+ sharding = sharding )
912
+ else :
913
+ call_res = io_callback (callback_func , result_shape , arg ,
914
+ sharding = sharding ,
915
+ ordered = True )
916
+ return call_res if not identity else arg
834
917
835
918
836
919
# We need the lock for when we use the CustomCall implementation of callbacks.
@@ -855,7 +938,6 @@ def _print_tap_func(
855
938
threshold: the value of numpy.array2string threshold parameter.
856
939
**kwargs: all other keyword args are printed before printing `arg`.
857
940
"""
858
-
859
941
def emit_str (s : str ):
860
942
if output_stream is not None :
861
943
output_stream .write (s + "\n " )
@@ -1844,6 +1926,10 @@ def _deprecated_barrier_wait(logging_name: str | None = None):
1844
1926
1845
1927
For more details see the :mod:`jax.experimental.host_callback` module documentation.
1846
1928
"""
1929
+ if not _HOST_CALLBACK_LEGACY .value :
1930
+ jax .effects_barrier ()
1931
+ return
1932
+
1847
1933
logging_name = logging_name or ""
1848
1934
logger .debug ("barrier_wait[%s]: start" , logging_name )
1849
1935
@@ -1907,7 +1993,7 @@ def _deprecated_stop_outfeed_receiver():
1907
1993
_deprecation_msg = (
1908
1994
"The host_callback APIs are deprecated as of March 20, 2024. The functionality "
1909
1995
"is subsumed by the new JAX external callbacks. "
1910
- "See https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html ." )
1996
+ "See https://github.com/google/jax/issues/20385 ." )
1911
1997
1912
1998
_deprecations = {
1913
1999
# Added March 20, 2024
0 commit comments