Skip to content

Commit 4c41c12

Browse files
author
jax authors
committed
Merge pull request #20514 from gnecula:callback_cache
PiperOrigin-RevId: 621160168
2 parents 431015a + bff24c6 commit 4c41c12

File tree

2 files changed

+41
-16
lines changed

2 files changed

+41
-16
lines changed

jax/_src/callback.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@
1414
"""Module for JAX callbacks."""
1515
from __future__ import annotations
1616

17+
import dataclasses
1718
from collections.abc import Sequence
1819
import logging
1920
import functools
2021
from typing import Any, Callable
2122

2223
import numpy as np
2324

25+
import jax
2426
from jax._src import core
2527
from jax._src import dispatch
2628
from jax._src import dtypes
@@ -46,10 +48,27 @@
4648
map, unsafe_map = util.safe_map, map
4749

4850

51+
@dataclasses.dataclass(frozen=True)
52+
class _FlatCallback:
53+
"""A Python function callable with flat arguments and results.
54+
55+
An instance of this class is used as a parameter for the callback primitives.
56+
We prefer it to an anonymous flattened function because it produces
57+
equal objects when we call the same Python function with the same argument
58+
structure.
59+
"""
60+
callback_func: Callable[..., Any]
61+
in_tree: tree_util.PyTreeDef # (args, kwargs) pytree for `callback_func`.
62+
63+
def __call__(self, *flat_args: jax.Array) -> Sequence[jax.Array]:
64+
args, kwargs = tree_util.tree_unflatten(self.in_tree, flat_args)
65+
return tree_util.tree_leaves(self.callback_func(*args, **kwargs))
66+
67+
4968
def pure_callback_impl(
5069
*args,
5170
result_avals,
52-
callback: Callable[..., Any],
71+
callback: _FlatCallback,
5372
sharding: SingleDeviceSharding | None,
5473
vectorized: bool,
5574
):
@@ -68,7 +87,7 @@ def pure_callback_impl(
6887
@pure_callback_p.def_abstract_eval
6988
def pure_callback_abstract_eval(
7089
*avals,
71-
callback: Callable[..., Any],
90+
callback: _FlatCallback,
7291
result_avals,
7392
sharding: SingleDeviceSharding | None,
7493
vectorized: bool,
@@ -100,7 +119,7 @@ def pure_callback_batching_rule(
100119
args,
101120
dims,
102121
*,
103-
callback,
122+
callback: _FlatCallback,
104123
sharding: SingleDeviceSharding | None,
105124
vectorized: bool,
106125
result_avals: Sequence[core.ShapedArray],
@@ -193,7 +212,7 @@ def _callback_op_sharding(axis_context, sharding: SingleDeviceSharding | None):
193212

194213

195214
def pure_callback_lowering(
196-
ctx, *args, callback, sharding: SingleDeviceSharding | None, **params
215+
ctx, *args, callback: _FlatCallback, sharding: SingleDeviceSharding | None, **params
197216
):
198217
def _callback(*flat_args):
199218
return tuple(
@@ -265,18 +284,14 @@ def pure_callback(
265284
266285
.. _External Callbacks: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html
267286
"""
268-
def _flat_callback(*flat_args):
269-
args, kwargs = tree_util.tree_unflatten(in_tree, flat_args)
270-
return tree_util.tree_leaves(callback(*args, **kwargs))
271-
272287
flat_args, in_tree = tree_util.tree_flatten((args, kwargs))
273288
tree_util.tree_map(_check_shape_dtype, result_shape_dtypes)
274289
result_avals = tree_util.tree_map(
275290
lambda x: core.ShapedArray(x.shape, x.dtype), result_shape_dtypes)
276291
flat_result_avals, out_tree = tree_util.tree_flatten(result_avals)
277292
out_flat = pure_callback_p.bind(
278293
*flat_args,
279-
callback=_flat_callback,
294+
callback=_FlatCallback(callback, in_tree),
280295
result_avals=tuple(flat_result_avals),
281296
sharding=sharding,
282297
vectorized=vectorized,
@@ -378,7 +393,7 @@ class OrderedIOEffect(effects.Effect):
378393
def io_callback_impl(
379394
*args,
380395
result_avals,
381-
callback: Callable[..., Any],
396+
callback: _FlatCallback,
382397
sharding: SingleDeviceSharding | None,
383398
ordered: bool,
384399
):
@@ -397,7 +412,7 @@ def io_callback_impl(
397412
@io_callback_p.def_effectful_abstract_eval
398413
def io_callback_abstract_eval(
399414
*avals,
400-
callback: Callable[..., Any],
415+
callback: _FlatCallback,
401416
result_avals,
402417
sharding: SingleDeviceSharding | None,
403418
ordered: bool,
@@ -516,10 +531,6 @@ def io_callback(
516531
517532
.. _External Callbacks: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html
518533
"""
519-
def _flat_callback(*flat_args):
520-
args, kwargs = tree_util.tree_unflatten(in_tree, flat_args)
521-
return tree_util.tree_leaves(callback(*args, **kwargs))
522-
523534
flat_args, in_tree = tree_util.tree_flatten((args, kwargs))
524535
tree_util.tree_map(_check_shape_dtype, result_shape_dtypes)
525536
flat_shape_dtypes, out_tree = tree_util.tree_flatten(result_shape_dtypes)
@@ -528,7 +539,7 @@ def _flat_callback(*flat_args):
528539
flat_args = map(core.raise_as_much_as_possible, flat_args)
529540
out_flat = io_callback_p.bind(
530541
*flat_args,
531-
callback=_flat_callback,
542+
callback=_FlatCallback(callback, in_tree),
532543
result_avals=tuple(flat_result_avals),
533544
sharding=sharding,
534545
ordered=ordered,

tests/python_callback_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,20 @@ def f(x):
586586
self.assertIn(f"jax.{api_name} failed", output)
587587
self.assertIn("Traceback (most recent call last)", output)
588588

589+
@with_pure_and_io_callbacks
590+
def test_compilation_caching(self, *, callback):
591+
def f_outside(x):
592+
return 2 * x
593+
594+
def fun(x):
595+
return callback(f_outside, x, x)
596+
597+
x = np.arange(6, dtype=np.int32).reshape((2, 3))
598+
with jtu.count_primitive_compiles() as count:
599+
for _ in range(3):
600+
self.assertAllClose(2 * x, fun(x))
601+
self.assertEqual(count[0], 1)
602+
589603

590604
class PureCallbackTest(jtu.JaxTestCase):
591605

0 commit comments

Comments
 (0)