14
14
"""Module for JAX callbacks."""
15
15
from __future__ import annotations
16
16
17
+ import dataclasses
17
18
from collections .abc import Sequence
18
19
import logging
19
20
import functools
20
21
from typing import Any , Callable
21
22
22
23
import numpy as np
23
24
25
+ import jax
24
26
from jax ._src import core
25
27
from jax ._src import dispatch
26
28
from jax ._src import dtypes
46
48
map , unsafe_map = util .safe_map , map
47
49
48
50
51
+ @dataclasses .dataclass (frozen = True )
52
+ class _FlatCallback :
53
+ """A Python function callable with flat arguments and results.
54
+
55
+ An instance of this class is used as a parameter for the callback primitives.
56
+ We prefer it to an anonymous flattened function because it produces
57
+ equal objects when we call the same Python function with the same argument
58
+ structure.
59
+ """
60
+ callback_func : Callable [..., Any ]
61
+ in_tree : tree_util .PyTreeDef # (args, kwargs) pytree for `callback_func`.
62
+
63
+ def __call__ (self , * flat_args : jax .Array ) -> Sequence [jax .Array ]:
64
+ args , kwargs = tree_util .tree_unflatten (self .in_tree , flat_args )
65
+ return tree_util .tree_leaves (self .callback_func (* args , ** kwargs ))
66
+
67
+
49
68
def pure_callback_impl (
50
69
* args ,
51
70
result_avals ,
52
- callback : Callable [..., Any ] ,
71
+ callback : _FlatCallback ,
53
72
sharding : SingleDeviceSharding | None ,
54
73
vectorized : bool ,
55
74
):
@@ -68,7 +87,7 @@ def pure_callback_impl(
68
87
@pure_callback_p .def_abstract_eval
69
88
def pure_callback_abstract_eval (
70
89
* avals ,
71
- callback : Callable [..., Any ] ,
90
+ callback : _FlatCallback ,
72
91
result_avals ,
73
92
sharding : SingleDeviceSharding | None ,
74
93
vectorized : bool ,
@@ -100,7 +119,7 @@ def pure_callback_batching_rule(
100
119
args ,
101
120
dims ,
102
121
* ,
103
- callback ,
122
+ callback : _FlatCallback ,
104
123
sharding : SingleDeviceSharding | None ,
105
124
vectorized : bool ,
106
125
result_avals : Sequence [core .ShapedArray ],
@@ -193,7 +212,7 @@ def _callback_op_sharding(axis_context, sharding: SingleDeviceSharding | None):
193
212
194
213
195
214
def pure_callback_lowering (
196
- ctx , * args , callback , sharding : SingleDeviceSharding | None , ** params
215
+ ctx , * args , callback : _FlatCallback , sharding : SingleDeviceSharding | None , ** params
197
216
):
198
217
def _callback (* flat_args ):
199
218
return tuple (
@@ -265,18 +284,14 @@ def pure_callback(
265
284
266
285
.. _External Callbacks: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html
267
286
"""
268
- def _flat_callback (* flat_args ):
269
- args , kwargs = tree_util .tree_unflatten (in_tree , flat_args )
270
- return tree_util .tree_leaves (callback (* args , ** kwargs ))
271
-
272
287
flat_args , in_tree = tree_util .tree_flatten ((args , kwargs ))
273
288
tree_util .tree_map (_check_shape_dtype , result_shape_dtypes )
274
289
result_avals = tree_util .tree_map (
275
290
lambda x : core .ShapedArray (x .shape , x .dtype ), result_shape_dtypes )
276
291
flat_result_avals , out_tree = tree_util .tree_flatten (result_avals )
277
292
out_flat = pure_callback_p .bind (
278
293
* flat_args ,
279
- callback = _flat_callback ,
294
+ callback = _FlatCallback ( callback , in_tree ) ,
280
295
result_avals = tuple (flat_result_avals ),
281
296
sharding = sharding ,
282
297
vectorized = vectorized ,
@@ -378,7 +393,7 @@ class OrderedIOEffect(effects.Effect):
378
393
def io_callback_impl (
379
394
* args ,
380
395
result_avals ,
381
- callback : Callable [..., Any ] ,
396
+ callback : _FlatCallback ,
382
397
sharding : SingleDeviceSharding | None ,
383
398
ordered : bool ,
384
399
):
@@ -397,7 +412,7 @@ def io_callback_impl(
397
412
@io_callback_p .def_effectful_abstract_eval
398
413
def io_callback_abstract_eval (
399
414
* avals ,
400
- callback : Callable [..., Any ] ,
415
+ callback : _FlatCallback ,
401
416
result_avals ,
402
417
sharding : SingleDeviceSharding | None ,
403
418
ordered : bool ,
@@ -516,10 +531,6 @@ def io_callback(
516
531
517
532
.. _External Callbacks: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html
518
533
"""
519
- def _flat_callback (* flat_args ):
520
- args , kwargs = tree_util .tree_unflatten (in_tree , flat_args )
521
- return tree_util .tree_leaves (callback (* args , ** kwargs ))
522
-
523
534
flat_args , in_tree = tree_util .tree_flatten ((args , kwargs ))
524
535
tree_util .tree_map (_check_shape_dtype , result_shape_dtypes )
525
536
flat_shape_dtypes , out_tree = tree_util .tree_flatten (result_shape_dtypes )
@@ -528,7 +539,7 @@ def _flat_callback(*flat_args):
528
539
flat_args = map (core .raise_as_much_as_possible , flat_args )
529
540
out_flat = io_callback_p .bind (
530
541
* flat_args ,
531
- callback = _flat_callback ,
542
+ callback = _FlatCallback ( callback , in_tree ) ,
532
543
result_avals = tuple (flat_result_avals ),
533
544
sharding = sharding ,
534
545
ordered = ordered ,
0 commit comments