Skip to content

Commit 6e23c14

Browse files
superbobryjax authors
authored andcommitted
jax.debug.callback now passes arguments as jax.Arrays
Prior to this change the behavior in eager and under jax.jit was inconsistent >>> (lambda *args: jax.debug.callback(print, *args))([42]) [42] >>> jax.jit(lambda *args: jax.debug.callback(print, *args))([42]) [array(42, dtype=int32)] It was also inconsistent with other callback APIs, which cast the arguments to jax.Arrays. Closes #20627. PiperOrigin-RevId: 626461904
1 parent 32922f6 commit 6e23c14

File tree

4 files changed

+28
-16
lines changed

4 files changed

+28
-16
lines changed

CHANGELOG.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@ Remember to align the itemized text with the first line of an item within a list
1414
adopted by NumPy.
1515

1616
* Changes
17-
* {func}`jax.pure_callback` and {func}`jax.experimental.io_callback`
18-
now use {class}`jax.Array` instead of {class}`np.ndarray`. You can recover
19-
the old behavior by transforming the arguments via
20-
`jax.tree.map(np.asarray, args)` before passing them to the callback.
17+
* {func}`jax.pure_callback`, {func}`jax.experimental.io_callback`
18+
and {func}`jax.debug.callback` now use {class}`jax.Array` instead
19+
of {class}`np.ndarray`. You can recover the old behavior by transforming
20+
the arguments via `jax.tree.map(np.asarray, args)` before passing them
21+
to the callback.
2122
* `complex_arr.astype(bool)` now follows the same semantics as NumPy, returning
2223
False where `complex_arr` is equal to `0 + 0j`, and True otherwise.
2324
* Async dispatch expensive computations on the CPU backend. This only applies

jax/_src/debugging.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,15 @@
1818
import importlib.util
1919
from collections.abc import Sequence
2020
import functools
21+
import logging
2122
import string
2223
import sys
2324
from typing import Any, Callable, Union
2425
import weakref
2526

2627
import numpy as np
2728

29+
import jax
2830
import jax.numpy as jnp
2931
from jax import lax
3032

@@ -45,6 +47,8 @@
4547
from jax._src.sharding import Sharding
4648
from jax._src.sharding_impls import NamedSharding, parse_flatten_op_sharding
4749

50+
logger = logging.getLogger(__name__)
51+
4852
class DebugEffect(effects.Effect):
4953
__str__ = lambda self: "Debug"
5054
debug_effect = DebugEffect()
@@ -73,7 +77,14 @@ class OrderedDebugEffect(effects.Effect):
7377
def debug_callback_impl(*args, callback: Callable[..., Any],
7478
effect: DebugEffect):
7579
del effect
76-
callback(*args)
80+
cpu_device, *_ = jax.local_devices(backend="cpu")
81+
args = jax.device_put(args, cpu_device)
82+
with jax.default_device(cpu_device):
83+
try:
84+
callback(*args)
85+
except BaseException:
86+
logger.exception("jax.debug_callback failed")
87+
raise
7788
return ()
7889

7990
@debug_callback_p.def_effectful_abstract_eval

tests/debugger_test.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def f(x):
110110
return y
111111
expected = _format_multiline(r"""
112112
Entering jdb:
113-
(jdb) array(2., dtype=float32)
113+
(jdb) Array(2., dtype=float32)
114114
(jdb) """)
115115
f(jnp.array(2., jnp.float32))
116116
jax.effects_barrier()
@@ -126,7 +126,7 @@ def f(x):
126126
return y
127127
expected = _format_multiline(r"""
128128
Entering jdb:
129-
(jdb) (array(2., dtype=float32), array(3., dtype=float32))
129+
(jdb) (Array(2., dtype=float32), Array(3., dtype=float32))
130130
(jdb) """)
131131
f(jnp.array(2., jnp.float32))
132132
jax.effects_barrier()
@@ -196,7 +196,7 @@ def g\(x\):
196196
-> y = f\(x\)
197197
return jnp\.exp\(y\)
198198
.*
199-
\(jdb\) array\(2\., dtype=float32\)
199+
\(jdb\) Array\(2\., dtype=float32\)
200200
\(jdb\) > .*debugger_test\.py\([0-9]+\)
201201
def f\(x\):
202202
y = jnp\.sin\(x\)
@@ -225,9 +225,9 @@ def g(x):
225225
return jnp.exp(y)
226226
expected = _format_multiline(r"""
227227
Entering jdb:
228-
(jdb) array(3., dtype=float32)
228+
(jdb) Array(3., dtype=float32)
229229
(jdb) Entering jdb:
230-
(jdb) array(6., dtype=float32)
230+
(jdb) Array(6., dtype=float32)
231231
(jdb) """)
232232
g(jnp.array(2., jnp.float32))
233233
jax.effects_barrier()
@@ -249,9 +249,9 @@ def g(x):
249249
return jnp.exp(y)
250250
expected = _format_multiline(r"""
251251
Entering jdb:
252-
(jdb) array(1., dtype=float32)
252+
(jdb) Array(1., dtype=float32)
253253
(jdb) Entering jdb:
254-
(jdb) array(2., dtype=float32)
254+
(jdb) Array(2., dtype=float32)
255255
(jdb) """)
256256
g(jnp.arange(2., dtype=jnp.float32))
257257
jax.effects_barrier()
@@ -274,9 +274,9 @@ def g(x):
274274
return jnp.exp(y)
275275
expected = _format_multiline(r"""
276276
Entering jdb:
277-
\(jdb\) array\(.*, dtype=float32\)
277+
\(jdb\) Array\(.*, dtype=float32\)
278278
\(jdb\) Entering jdb:
279-
\(jdb\) array\(.*, dtype=float32\)
279+
\(jdb\) Array\(.*, dtype=float32\)
280280
\(jdb\) """)
281281
g(jnp.arange(2., dtype=jnp.float32))
282282
jax.effects_barrier()
@@ -302,7 +302,7 @@ def g(x):
302302
out_shardings=jax.sharding.PartitionSpec("dev"),
303303
)
304304
with jax.sharding.Mesh(np.array(jax.devices()), ["dev"]):
305-
arr = (1 + np.arange(8)).astype(np.int32)
305+
arr = (1 + jnp.arange(8)).astype(np.int32)
306306
expected = _format_multiline(r"""
307307
Entering jdb:
308308
\(jdb\) {}

tests/debugging_primitives_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def f(x):
170170
with jtu.capture_stdout() as output:
171171
f(np.array(2, np.int32))
172172
jax.effects_barrier()
173-
self.assertEqual(output(), f"x: {str(dict(foo=np.array(2, np.int32)))}\n")
173+
self.assertEqual(output(), f"x: {str(dict(foo=jnp.array(2, np.int32)))}\n")
174174

175175
def test_debug_print_should_use_default_layout(self):
176176
data = np.array(

0 commit comments

Comments
 (0)