Skip to content

Commit 8980220

Browse files
committed
set auto_dce=False by default, enable in direct linearize
small other fixes
1 parent ff7ff60 commit 8980220

File tree

11 files changed

+53
-20
lines changed

11 files changed

+53
-20
lines changed

jax/_src/api.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2396,6 +2396,7 @@ def make_jaxpr(
23962396
{ lambda ; a:f32[]. let
23972397
b:f32[] = cos a
23982398
c:f32[] = sin a
2399+
_:f32[] = sin b
23992400
d:f32[] = cos b
24002401
e:f32[] = mul 1.0:f32[] d
24012402
f:f32[] = neg e

jax/_src/core.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,7 @@ def def_effectful_abstract_eval(self, effectful_abstract_eval):
606606
return effectful_abstract_eval
607607

608608
def def_effectful_abstract_eval2(self, abstract_eval):
609-
self.abstract_eval = _generic_abstract_eval(abstract_eval)
609+
self.abstract_eval = _generic_effectful_abstract_eval(abstract_eval)
610610
return abstract_eval
611611

612612
def def_bind_with_trace(self, bind_with_trace):
@@ -639,8 +639,9 @@ class GenericEffect(Effect):
639639
generic_effect_set = {generic_effect}
640640
effects.lowerable_effects.add_type(GenericEffect)
641641
effects.control_flow_allowed_effects.add_type(GenericEffect)
642+
effects.custom_derivatives_allowed_effects.add_type(GenericEffect)
642643

643-
def _generic_abstract_eval(abstract_eval):
644+
def _generic_effectful_abstract_eval(abstract_eval):
644645
def abstract_eval_(*args, **kwargs):
645646
return abstract_eval(*args, **kwargs), generic_effect_set
646647
return abstract_eval_

jax/_src/interpreters/ad.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def linearize_subtrace(_f: Callable, _store: lu.Store, _tag: core.TraceTag,
9393
*primals, **params):
9494
source_info = source_info_util.current()
9595
with core.take_current_trace() as parent_trace:
96-
tangent_trace = pe.DynamicJaxprTrace(debug_info)
96+
tangent_trace = pe.DynamicJaxprTrace(debug_info, auto_dce=True)
9797
tangent_trace.tag = _tag
9898
linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag=_tag)
9999
tracers = [LinearizeTracer(linearize_trace, p,
@@ -167,7 +167,7 @@ def _linearize_jaxpr(
167167
) -> tuple[core.ClosedJaxpr, int, Sequence[bool], Sequence[int | None], core.ClosedJaxpr]:
168168
dbg = jaxpr.jaxpr.debug_info
169169
primal_trace = pe.DynamicJaxprTrace(dbg)
170-
tangent_trace = pe.DynamicJaxprTrace(dbg)
170+
tangent_trace = pe.DynamicJaxprTrace(dbg, auto_dce=True)
171171
lin_trace = LinearizeTrace(primal_trace, tangent_trace)
172172
tangent_trace.tag = lin_trace.tag
173173

@@ -227,7 +227,7 @@ def direct_linearize(traceable: lu.WrappedFun,
227227
primals, kwargs, *, has_aux=False, tag=None):
228228
with core.take_current_trace() as parent_trace:
229229
source_info = source_info_util.current()
230-
tangent_trace = pe.DynamicJaxprTrace(traceable.debug_info)
230+
tangent_trace = pe.DynamicJaxprTrace(traceable.debug_info, auto_dce=True)
231231
tangents = [tangent_trace.new_arg(get_aval(p).to_tangent_aval(), source_info) for p in primals]
232232
tangents = [Zero.from_primal_value(t) if dtype(t) == float0 else t for t in tangents]
233233
linearize_trace = LinearizeTrace(parent_trace, tangent_trace, tag=tag)

jax/_src/interpreters/partial_eval.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1937,7 +1937,7 @@ class DynamicJaxprTrace(core.Trace):
19371937
__slots__ = ("frame", "tag", "parent_trace")
19381938

19391939
def __init__(self, debug_info: core.DebugInfo, parent_trace=None, lower=False,
1940-
auto_dce=True):
1940+
auto_dce=False):
19411941
super().__init__()
19421942
self.requires_low = lower
19431943
self.frame = JaxprStackFrame(debug_info, auto_dce)
@@ -2316,7 +2316,7 @@ def trace_to_jaxpr_dynamic(
23162316
*,
23172317
keep_inputs: list[bool] | None = None,
23182318
lower: bool = False,
2319-
auto_dce: bool = True,
2319+
auto_dce: bool = False,
23202320
) -> tuple[Jaxpr, list[AbstractValue], list[Any]]:
23212321
keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs
23222322
parent_trace = core.trace_ctx.trace

jax/_src/lax/lax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8458,7 +8458,7 @@ def _propagate_mem_kind_copy(in_mem_kind):
84588458
return in_mem_kind
84598459
pxla.memory_kind_propagate_rule[copy_p] = _propagate_mem_kind_copy
84608460

8461-
# the dce_sink_p primitive marks a value as "used" from the perspective of DCE
8461+
# The dce_sink_p primitive marks a value as "used" from the perspective of DCE
84628462
# so the computation producing it won't be eliminated.
84638463
def dce_sink(val):
84648464
tree_util.tree_map(dce_sink_p.bind, val)

jax/_src/pallas/pallas_call.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ def _pallas_call_abstract_eval(
9999
# Report effects that will be introduced when running/lowering
100100
# mosaic_tpu_interpret.mosaic_tpu_interpret.interpret_pallas_call .
101101
effs = mosaic_tpu_interpret.get_interpret_effects()
102+
elif getattr(params.get('compiler_params', None), 'has_side_effects', False):
103+
effs = jax_core.generic_effect_set
102104
else:
103105
effs = jax_core.no_effects
104106

tests/api_test.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5122,18 +5122,45 @@ def g():
51225122
lambda: (c, jnp.sum(d), d))
51235123
self.assertLen(g().consts, 2)
51245124

5125+
# TODO(mattjj,dougalm): this test was flakey on CI; figure out how to enable?
5126+
# @jtu.run_on_devices('cpu')
5127+
# def test_implicit_dce_linearize(self):
5128+
# def foo(x):
5129+
# const = np.zeros((300,))
5130+
# x * const
5131+
# r = weakref.ref(const)
5132+
# del const
5133+
# assert r() is None, "oops, the constant wasn't DCE'd"
5134+
# return x
5135+
# with config.use_direct_linearize(True):
5136+
# _ = jax.grad(foo)(3.)
51255137

5126-
def test_implicit_dce(self):
5127-
@api.jit
5138+
@jtu.run_on_devices('cpu')
5139+
def test_implicit_dce_linearize_jaxpr(self):
51285140
def foo(x):
51295141
const = np.zeros((300,))
5142+
x * const
51305143
r = weakref.ref(const)
5131-
jnp.sin(const) + const
51325144
del const
5133-
assert r() is None, "oops, the constant wasn't DCE'd"
5134-
return x + x
5145+
return x
51355146

5136-
foo(1.0)
5147+
with config.use_direct_linearize(True):
5148+
_, f_vjp = jax.vjp(foo, 3.)
5149+
5150+
self.assertNotIn('mul', str(f_vjp))
5151+
5152+
# TODO(mattjj,dougalm): re-enable when we set auto_dce=True by default
5153+
# @jtu.run_on_devices('cpu')
5154+
# def test_implicit_dce(self):
5155+
# @api.jit
5156+
# def foo(x):
5157+
# const = np.zeros((300,))
5158+
# r = weakref.ref(const)
5159+
# jnp.sin(const) + const
5160+
# del const
5161+
# assert r() is None, "oops, the constant wasn't DCE'd"
5162+
# return x + x
5163+
# foo(1.0)
51375164

51385165
class RematTest(jtu.JaxTestCase):
51395166

@@ -6208,7 +6235,7 @@ def test_vjp_caching(self):
62086235
with jtu.count_pjit_cpp_cache_miss() as count: # noqa: F841
62096236
for _ in range(20):
62106237
f_vjp(1.)[0].block_until_ready()
6211-
self.assertEqual(count(), 1) # backward_pass on bwd
6238+
self.assertLessEqual(count(), 2)
62126239

62136240
def test_vjp_caching_static_argnums(self):
62146241
identity = jax.remat(lambda x, y: jax.jit(lambda x: 2 * x if y else x)(x),
@@ -6217,7 +6244,7 @@ def test_vjp_caching_static_argnums(self):
62176244
with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841
62186245
for _ in range(20):
62196246
f_vjp(1.)[0].block_until_ready()
6220-
self.assertEqual(count(), 1) # backward_pass on bwd
6247+
self.assertLessEqual(count(), 2)
62216248

62226249
def test_fwd_caching(self):
62236250
# see above test also
@@ -7044,7 +7071,7 @@ def body(c, _):
70447071
return out
70457072
jaxpr = api.make_jaxpr(f)([1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]).jaxpr
70467073
self.assertLen(jaxpr.eqns, 1)
7047-
self.assertLen(jaxpr.eqns[0].params['jaxpr'].jaxpr.eqns, 4)
7074+
self.assertLen(jaxpr.eqns[0].params['jaxpr'].jaxpr.eqns, 5)
70487075

70497076
# If we use the value at index 8 only, all the hidden sequence must be kept
70507077
# and no eqns can be pruned.

tests/checkify_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -729,7 +729,7 @@ def sin_bwd(x2, g):
729729
err, y = checkify.checkify(jax.grad(sin),
730730
errors=checkify.float_checks)(jnp.inf)
731731
self.assertIsNotNone(err.get())
732-
self.assertStartsWith(err.get(), "nan generated by primitive: cos")
732+
self.assertStartsWith(err.get(), "nan generated by primitive: sin")
733733

734734
def test_scan_consts(self):
735735
def f(xs):

tests/pallas/mosaic_gpu_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1320,6 +1320,7 @@ def cond(acc):
13201320

13211321
def body(acc):
13221322
del acc # Unused.
1323+
o_ref[...] = o_ref[...] # side-effect to prevent DCE
13231324

13241325
# We deliberately do a cast here to trigger a layout mismatch.
13251326
return plgpu.layout_cast(

tests/pjit_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7854,7 +7854,7 @@ def f(x, y, a, b):
78547854
f_bar(x, y, a, b) # doesn't crash
78557855

78567856
grad_jaxpr = f_bar.trace(x, y, a, b).jaxpr
7857-
reshard_eqn = grad_jaxpr.eqns[2].params['jaxpr'].eqns[0]
7857+
reshard_eqn = grad_jaxpr.eqns[4].params['jaxpr'].eqns[0]
78587858
self.assertEqual(reshard_eqn.params['dst_sharding'].spec.reduced,
78597859
frozenset('y'))
78607860
self.assertEqual(reshard_eqn.params['dst_sharding'].spec.unreduced,

0 commit comments

Comments
 (0)