@@ -73,7 +73,8 @@ class OrderedDebugEffect(effects.Effect):
73
73
def debug_callback_impl (* args , callback : Callable [..., Any ],
74
74
effect : DebugEffect ):
75
75
del effect
76
- return callback (* args )
76
+ callback (* args )
77
+ return ()
77
78
78
79
@debug_callback_p .def_effectful_abstract_eval
79
80
def debug_callback_abstract_eval (* flat_avals , callback : Callable [..., Any ],
@@ -136,13 +137,13 @@ def debug_callback_lowering(ctx, *args, effect, callback, **params):
136
137
sharding = None
137
138
138
139
def _callback (* flat_args ):
139
- return tuple (
140
- debug_callback_p . impl (
141
- * flat_args , effect = effect , callback = callback , ** params ) )
140
+ debug_callback_p . impl (
141
+ * flat_args , effect = effect , callback = callback , ** params )
142
+ return ( )
142
143
if effects .ordered_effects .contains (effect ):
143
144
token = ctx .tokens_in .get (effect )[0 ]
144
145
result , token , _ = mlir .emit_python_callback (
145
- ctx , _callback , token , list (args ), ctx .avals_in , ctx .avals_out , True )
146
+ ctx , _callback , token , list (args ), ctx .avals_in , ctx .avals_out , has_side_effect = True )
146
147
ctx .set_tokens_out (mlir .TokenSet ({effect : (token ,)}))
147
148
else :
148
149
result , token , _ = mlir .emit_python_callback (
@@ -187,7 +188,7 @@ def _debug_callback_partial_eval_custom(saveable, unks_in, inst_in, eqn):
187
188
pe .partial_eval_jaxpr_custom_rules [debug_callback_p ] = (
188
189
_debug_callback_partial_eval_custom )
189
190
190
- def debug_callback (callback : Callable [..., Any ], * args : Any ,
191
+ def debug_callback (callback : Callable [..., None ], * args : Any ,
191
192
ordered : bool = False , ** kwargs : Any ) -> None :
192
193
"""Calls a stageable Python callback.
193
194
@@ -206,7 +207,7 @@ def debug_callback(callback: Callable[..., Any], *args: Any,
206
207
of the computation are duplicated or dropped.
207
208
208
209
Args:
209
- callback: A Python callable. Its return value will be ignored .
210
+ callback: A Python callable returning None .
210
211
*args: The positional arguments to the callback.
211
212
ordered: A keyword only argument used to indicate whether or not the
212
213
staged out computation will enforce ordering of this callback w.r.t.
@@ -231,7 +232,7 @@ def debug_callback(callback: Callable[..., Any], *args: Any,
231
232
def _flat_callback (* flat_args ):
232
233
args , kwargs = tree_util .tree_unflatten (in_tree , flat_args )
233
234
callback (* args , ** kwargs )
234
- return []
235
+ return ()
235
236
debug_callback_p .bind (* flat_args , callback = _flat_callback , effect = effect )
236
237
237
238
class _DebugPrintFormatChecker (string .Formatter ):
0 commit comments