Skip to content

Commit 32922f6

Browse files
superbobryjax authors
authored andcommitted
jax.debug.callback now requires a Callable[..., None]
This makes the "return value is ignored" behavior explicit in the type. PiperOrigin-RevId: 626430448
1 parent b2375fa commit 32922f6

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

jax/_src/debugging.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@ class OrderedDebugEffect(effects.Effect):
7373
def debug_callback_impl(*args, callback: Callable[..., Any],
7474
effect: DebugEffect):
7575
del effect
76-
return callback(*args)
76+
callback(*args)
77+
return ()
7778

7879
@debug_callback_p.def_effectful_abstract_eval
7980
def debug_callback_abstract_eval(*flat_avals, callback: Callable[..., Any],
@@ -136,13 +137,13 @@ def debug_callback_lowering(ctx, *args, effect, callback, **params):
136137
sharding = None
137138

138139
def _callback(*flat_args):
139-
return tuple(
140-
debug_callback_p.impl(
141-
*flat_args, effect=effect, callback=callback, **params))
140+
debug_callback_p.impl(
141+
*flat_args, effect=effect, callback=callback, **params)
142+
return ()
142143
if effects.ordered_effects.contains(effect):
143144
token = ctx.tokens_in.get(effect)[0]
144145
result, token, _ = mlir.emit_python_callback(
145-
ctx, _callback, token, list(args), ctx.avals_in, ctx.avals_out, True)
146+
ctx, _callback, token, list(args), ctx.avals_in, ctx.avals_out, has_side_effect=True)
146147
ctx.set_tokens_out(mlir.TokenSet({effect: (token,)}))
147148
else:
148149
result, token, _ = mlir.emit_python_callback(
@@ -187,7 +188,7 @@ def _debug_callback_partial_eval_custom(saveable, unks_in, inst_in, eqn):
187188
pe.partial_eval_jaxpr_custom_rules[debug_callback_p] = (
188189
_debug_callback_partial_eval_custom)
189190

190-
def debug_callback(callback: Callable[..., Any], *args: Any,
191+
def debug_callback(callback: Callable[..., None], *args: Any,
191192
ordered: bool = False, **kwargs: Any) -> None:
192193
"""Calls a stageable Python callback.
193194
@@ -206,7 +207,7 @@ def debug_callback(callback: Callable[..., Any], *args: Any,
206207
of the computation are duplicated or dropped.
207208
208209
Args:
209-
callback: A Python callable. Its return value will be ignored.
210+
callback: A Python callable returning None.
210211
*args: The positional arguments to the callback.
211212
ordered: A keyword only argument used to indicate whether or not the
212213
staged out computation will enforce ordering of this callback w.r.t.
@@ -231,7 +232,7 @@ def debug_callback(callback: Callable[..., Any], *args: Any,
231232
def _flat_callback(*flat_args):
232233
args, kwargs = tree_util.tree_unflatten(in_tree, flat_args)
233234
callback(*args, **kwargs)
234-
return []
235+
return ()
235236
debug_callback_p.bind(*flat_args, callback=_flat_callback, effect=effect)
236237

237238
class _DebugPrintFormatChecker(string.Formatter):

0 commit comments

Comments
 (0)