Skip to content

Commit c9df14e

Browse files
Reverts d7badca
PiperOrigin-RevId: 781796815
1 parent d7badca commit c9df14e

File tree

22 files changed

+166
-309
lines changed

22 files changed

+166
-309
lines changed

jax/_src/ad_checkpoint.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -448,14 +448,11 @@ def f_(*args):
448448
return f(*args, **kwargs)
449449

450450
debug_info = api_util.debug_info("saved_residuals", f, args, kwargs)
451-
out = api.make_jaxpr(lambda *args: api.linearize(f_, *args),
451+
out = api.make_jaxpr(lambda *args: api.linearize(f_, *args)[1],
452452
return_shape=True)(*in_leaves)
453453
assert isinstance(out, tuple)
454-
jaxpr_, out_shape_ = out
454+
jaxpr_, out_shape = out
455455
jaxpr = jaxpr_.jaxpr
456-
out_shape = out_shape_[1]
457-
num_res = tree_structure(out_shape).num_leaves
458-
jaxpr = jaxpr.replace(outvars=jaxpr.outvars[len(jaxpr.outvars) - num_res:])
459456
out_tree = lambda: tree_structure(out_shape)
460457
assert len(jaxpr.invars) == len(in_leaves)
461458
return _saved_residuals(jaxpr, debug_info.arg_names)

jax/_src/checkify.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,6 @@ def _reduce_any_error(error: Error):
472472
## check_p primitive
473473

474474
check_p = core.Primitive('check')
475-
check_p.is_effectful = lambda _: True # type: ignore
476475
check_p.multiple_results = True # zero results
477476

478477

@@ -842,12 +841,12 @@ def new_body_f(*c_consts_and_vals):
842841
c_consts, vals = split_list(c_consts_and_vals, [c_consts_num])
843842
out = body_f(*vals)
844843
# This checks if the next cond application will error
845-
lax.dce_sink(cond_f(*c_consts, *out))
844+
_ = cond_f(*c_consts, *out)
846845
return out
847846
new_body_f_ = lu.wrap_init(new_body_f, debug_info=body_jaxpr.jaxpr.debug_info)
848847
c_consts_avals = cond_jaxpr.in_avals[:c_consts_num]
849-
jaxpr, _, () = pe.trace_to_jaxpr_dynamic(
850-
new_body_f_, [*c_consts_avals, *body_jaxpr.in_avals])
848+
jaxpr, _, () = pe.trace_to_jaxpr_dynamic(new_body_f_, [*c_consts_avals,
849+
*body_jaxpr.in_avals])
851850
closed_jaxpr = pe.close_jaxpr(jaxpr)
852851
err_vals, err_tree = jtu.tree_flatten(error)
853852
err_vals = map(core.get_aval, err_vals)
@@ -1233,8 +1232,9 @@ def checked_fun(*args, **kwargs):
12331232
closed_f = lambda: f(*args, **kwargs)
12341233
# stage:
12351234
debug = api_util.debug_info("checkify", f, args, kwargs)
1236-
fun_, out_tree = api_util.flatten_fun(
1237-
lu.wrap_init(closed_f, debug_info=debug), in_tree)
1235+
fun_, out_tree = api_util.flatten_fun(lu.wrap_init(closed_f,
1236+
debug_info=debug),
1237+
in_tree)
12381238
jaxpr_, _, consts = pe.trace_to_jaxpr_dynamic(fun_, ())
12391239
jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr_))
12401240
# checkify:

jax/_src/core.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -605,10 +605,6 @@ def def_effectful_abstract_eval(self, effectful_abstract_eval):
605605
self.abstract_eval = effectful_abstract_eval
606606
return effectful_abstract_eval
607607

608-
def def_effectful_abstract_eval2(self, abstract_eval):
609-
self.abstract_eval = _generic_effectful_abstract_eval(abstract_eval)
610-
return abstract_eval
611-
612608
def def_bind_with_trace(self, bind_with_trace):
613609
self.bind_with_trace = bind_with_trace
614610
return bind_with_trace
@@ -633,19 +629,6 @@ def abstract_eval_(*args, **kwargs):
633629
return abstract_eval(*args, **kwargs), no_effects
634630
return abstract_eval_
635631

636-
class GenericEffect(Effect):
637-
pass
638-
generic_effect = GenericEffect()
639-
generic_effect_set = {generic_effect}
640-
effects.lowerable_effects.add_type(GenericEffect)
641-
effects.control_flow_allowed_effects.add_type(GenericEffect)
642-
effects.custom_derivatives_allowed_effects.add_type(GenericEffect)
643-
644-
def _generic_effectful_abstract_eval(abstract_eval):
645-
def abstract_eval_(*args, **kwargs):
646-
return abstract_eval(*args, **kwargs), generic_effect_set
647-
return abstract_eval_
648-
649632
# -------------------- lifting --------------------
650633

651634
# TODO(mattjj): replace this approach with a primitive-keyed table of rules

jax/_src/interpreters/ad.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def linearize_subtrace(_f: Callable, _store: lu.Store, _tag: core.TraceTag,
9393
*primals, **params):
9494
source_info = source_info_util.current()
9595
with core.take_current_trace() as parent_trace:
96-
tangent_trace = pe.DynamicJaxprTrace(debug_info, auto_dce=True)
96+
tangent_trace = pe.DynamicJaxprTrace(debug_info)
9797
tangent_trace.tag = _tag
9898
linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag=_tag)
9999
tracers = [LinearizeTracer(linearize_trace, p,
@@ -167,7 +167,7 @@ def _linearize_jaxpr(
167167
) -> tuple[core.ClosedJaxpr, int, Sequence[bool], Sequence[int | None], core.ClosedJaxpr]:
168168
dbg = jaxpr.jaxpr.debug_info
169169
primal_trace = pe.DynamicJaxprTrace(dbg)
170-
tangent_trace = pe.DynamicJaxprTrace(dbg, auto_dce=True)
170+
tangent_trace = pe.DynamicJaxprTrace(dbg)
171171
lin_trace = LinearizeTrace(primal_trace, tangent_trace)
172172
tangent_trace.tag = lin_trace.tag
173173

@@ -227,7 +227,7 @@ def direct_linearize(traceable: lu.WrappedFun,
227227
primals, kwargs, *, has_aux=False, tag=None):
228228
with core.take_current_trace() as parent_trace:
229229
source_info = source_info_util.current()
230-
tangent_trace = pe.DynamicJaxprTrace(traceable.debug_info, auto_dce=True)
230+
tangent_trace = pe.DynamicJaxprTrace(traceable.debug_info)
231231
tangents = [tangent_trace.new_arg(get_aval(p).to_tangent_aval(), source_info) for p in primals]
232232
tangents = [Zero.from_primal_value(t) if dtype(t) == float0 else t for t in tangents]
233233
linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag=tag)

0 commit comments

Comments
 (0)