Skip to content

Commit d7badca

Browse files
Merge pull request #30062 from jax-ml:auto-dce-dynamic-jaxpr-trace
PiperOrigin-RevId: 781744820
2 parents 7a06933 + 8980220 commit d7badca

File tree

22 files changed

+309
-166
lines changed

22 files changed

+309
-166
lines changed

jax/_src/ad_checkpoint.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -448,11 +448,14 @@ def f_(*args):
448448
return f(*args, **kwargs)
449449

450450
debug_info = api_util.debug_info("saved_residuals", f, args, kwargs)
451-
out = api.make_jaxpr(lambda *args: api.linearize(f_, *args)[1],
451+
out = api.make_jaxpr(lambda *args: api.linearize(f_, *args),
452452
return_shape=True)(*in_leaves)
453453
assert isinstance(out, tuple)
454-
jaxpr_, out_shape = out
454+
jaxpr_, out_shape_ = out
455455
jaxpr = jaxpr_.jaxpr
456+
out_shape = out_shape_[1]
457+
num_res = tree_structure(out_shape).num_leaves
458+
jaxpr = jaxpr.replace(outvars=jaxpr.outvars[len(jaxpr.outvars) - num_res:])
456459
out_tree = lambda: tree_structure(out_shape)
457460
assert len(jaxpr.invars) == len(in_leaves)
458461
return _saved_residuals(jaxpr, debug_info.arg_names)

jax/_src/checkify.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,7 @@ def _reduce_any_error(error: Error):
472472
## check_p primitive
473473

474474
check_p = core.Primitive('check')
475+
check_p.is_effectful = lambda _: True # type: ignore
475476
check_p.multiple_results = True # zero results
476477

477478

@@ -841,12 +842,12 @@ def new_body_f(*c_consts_and_vals):
841842
c_consts, vals = split_list(c_consts_and_vals, [c_consts_num])
842843
out = body_f(*vals)
843844
# This checks if the next cond application will error
844-
_ = cond_f(*c_consts, *out)
845+
lax.dce_sink(cond_f(*c_consts, *out))
845846
return out
846847
new_body_f_ = lu.wrap_init(new_body_f, debug_info=body_jaxpr.jaxpr.debug_info)
847848
c_consts_avals = cond_jaxpr.in_avals[:c_consts_num]
848-
jaxpr, _, () = pe.trace_to_jaxpr_dynamic(new_body_f_, [*c_consts_avals,
849-
*body_jaxpr.in_avals])
849+
jaxpr, _, () = pe.trace_to_jaxpr_dynamic(
850+
new_body_f_, [*c_consts_avals, *body_jaxpr.in_avals])
850851
closed_jaxpr = pe.close_jaxpr(jaxpr)
851852
err_vals, err_tree = jtu.tree_flatten(error)
852853
err_vals = map(core.get_aval, err_vals)
@@ -1232,9 +1233,8 @@ def checked_fun(*args, **kwargs):
12321233
closed_f = lambda: f(*args, **kwargs)
12331234
# stage:
12341235
debug = api_util.debug_info("checkify", f, args, kwargs)
1235-
fun_, out_tree = api_util.flatten_fun(lu.wrap_init(closed_f,
1236-
debug_info=debug),
1237-
in_tree)
1236+
fun_, out_tree = api_util.flatten_fun(
1237+
lu.wrap_init(closed_f, debug_info=debug), in_tree)
12381238
jaxpr_, _, consts = pe.trace_to_jaxpr_dynamic(fun_, ())
12391239
jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr_))
12401240
# checkify:

jax/_src/core.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -605,6 +605,10 @@ def def_effectful_abstract_eval(self, effectful_abstract_eval):
605605
self.abstract_eval = effectful_abstract_eval
606606
return effectful_abstract_eval
607607

608+
def def_effectful_abstract_eval2(self, abstract_eval):
609+
self.abstract_eval = _generic_effectful_abstract_eval(abstract_eval)
610+
return abstract_eval
611+
608612
def def_bind_with_trace(self, bind_with_trace):
609613
self.bind_with_trace = bind_with_trace
610614
return bind_with_trace
@@ -629,6 +633,19 @@ def abstract_eval_(*args, **kwargs):
629633
return abstract_eval(*args, **kwargs), no_effects
630634
return abstract_eval_
631635

636+
class GenericEffect(Effect):
637+
pass
638+
generic_effect = GenericEffect()
639+
generic_effect_set = {generic_effect}
640+
effects.lowerable_effects.add_type(GenericEffect)
641+
effects.control_flow_allowed_effects.add_type(GenericEffect)
642+
effects.custom_derivatives_allowed_effects.add_type(GenericEffect)
643+
644+
def _generic_effectful_abstract_eval(abstract_eval):
645+
def abstract_eval_(*args, **kwargs):
646+
return abstract_eval(*args, **kwargs), generic_effect_set
647+
return abstract_eval_
648+
632649
# -------------------- lifting --------------------
633650

634651
# TODO(mattjj): replace this approach with a primitive-keyed table of rules

jax/_src/interpreters/ad.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def linearize_subtrace(_f: Callable, _store: lu.Store, _tag: core.TraceTag,
9393
*primals, **params):
9494
source_info = source_info_util.current()
9595
with core.take_current_trace() as parent_trace:
96-
tangent_trace = pe.DynamicJaxprTrace(debug_info)
96+
tangent_trace = pe.DynamicJaxprTrace(debug_info, auto_dce=True)
9797
tangent_trace.tag = _tag
9898
linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag=_tag)
9999
tracers = [LinearizeTracer(linearize_trace, p,
@@ -167,7 +167,7 @@ def _linearize_jaxpr(
167167
) -> tuple[core.ClosedJaxpr, int, Sequence[bool], Sequence[int | None], core.ClosedJaxpr]:
168168
dbg = jaxpr.jaxpr.debug_info
169169
primal_trace = pe.DynamicJaxprTrace(dbg)
170-
tangent_trace = pe.DynamicJaxprTrace(dbg)
170+
tangent_trace = pe.DynamicJaxprTrace(dbg, auto_dce=True)
171171
lin_trace = LinearizeTrace(primal_trace, tangent_trace)
172172
tangent_trace.tag = lin_trace.tag
173173

@@ -227,7 +227,7 @@ def direct_linearize(traceable: lu.WrappedFun,
227227
primals, kwargs, *, has_aux=False, tag=None):
228228
with core.take_current_trace() as parent_trace:
229229
source_info = source_info_util.current()
230-
tangent_trace = pe.DynamicJaxprTrace(traceable.debug_info)
230+
tangent_trace = pe.DynamicJaxprTrace(traceable.debug_info, auto_dce=True)
231231
tangents = [tangent_trace.new_arg(get_aval(p).to_tangent_aval(), source_info) for p in primals]
232232
tangents = [Zero.from_primal_value(t) if dtype(t) == float0 else t for t in tangents]
233233
linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag=tag)

0 commit comments

Comments
 (0)