Skip to content

Commit 31cad9e

Browse files
committed
Add a primitive to foil DCE
1 parent 1a8e8be commit 31cad9e

File tree

3 files changed

+24
-1
lines changed

3 files changed

+24
-1
lines changed

jax/_src/checkify.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -842,7 +842,7 @@ def new_body_f(*c_consts_and_vals):
842842
c_consts, vals = split_list(c_consts_and_vals, [c_consts_num])
843843
out = body_f(*vals)
844844
# This checks if the next cond application will error
845-
_ = cond_f(*c_consts, *out)
845+
lax.dce_sink(cond_f(*c_consts, *out))
846846
return out
847847
new_body_f_ = lu.wrap_init(new_body_f, debug_info=body_jaxpr.jaxpr.debug_info)
848848
c_consts_avals = cond_jaxpr.in_avals[:c_consts_num]

jax/_src/lax/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@
6262
convert_element_type as convert_element_type,
6363
convert_element_type_p as convert_element_type_p,
6464
copy_p as copy_p,
65+
dce_sink_p as dce_sink_p,
6566
cos as cos,
67+
dce_sink as dce_sink,
6668
cos_p as cos_p,
6769
cosh as cosh,
6870
cosh_p as cosh_p,

jax/_src/lax/lax.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from jax._src import core
3737
from jax._src import dispatch
3838
from jax._src import dtypes
39+
from jax._src import effects
3940
from jax._src import linear_util as lu
4041
from jax._src import pjit
4142
from jax._src import pretty_printer as pp
@@ -8457,6 +8458,26 @@ def _propagate_mem_kind_copy(in_mem_kind):
84578458
return in_mem_kind
84588459
pxla.memory_kind_propagate_rule[copy_p] = _propagate_mem_kind_copy
84598460

8461+
# the dce_sink_p primitive marks a value as "used" from the perspective of DCE
8462+
# so the computation producing it won't be eliminated.
8463+
def dce_sink(val):
8464+
tree_util.tree_map(dce_sink_p.bind, val)
8465+
8466+
class NoDCEEffect(effects.Effect):
8467+
pass
8468+
no_dce_effect = NoDCEEffect()
8469+
effects.control_flow_allowed_effects.add_type(NoDCEEffect)
8470+
effects.lowerable_effects.add_type(NoDCEEffect)
8471+
8472+
dce_sink_p = core.Primitive('dce_sink')
8473+
dce_sink_p.def_impl(lambda _: [])
8474+
dce_sink_p.multiple_results = True
8475+
dce_sink_p.def_effectful_abstract_eval(lambda _: ([], {no_dce_effect}))
8476+
mlir.register_lowering(dce_sink_p, lambda ctx, _: [])
8477+
ad.deflinear(dce_sink_p, lambda _: [])
8478+
pe.def_trivial_padding(dce_sink_p)
8479+
batching.defvectorized(dce_sink_p)
8480+
84608481
def rng_bit_generator(key, shape, dtype=np.uint32,
84618482
algorithm=RandomAlgorithm.RNG_DEFAULT,
84628483
*, out_sharding=None):

0 commit comments

Comments
 (0)