Skip to content

Commit 921e6fe

Browse files
committed
mark some pallas/mosaic primitives as effectful to avoid DCE
1 parent 31cad9e commit 921e6fe

File tree

3 files changed

+22
-7
lines changed

3 files changed

+22
-7
lines changed

jax/_src/core.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -605,6 +605,10 @@ def def_effectful_abstract_eval(self, effectful_abstract_eval):
605605
self.abstract_eval = effectful_abstract_eval
606606
return effectful_abstract_eval
607607

608+
def def_effectful_abstract_eval2(self, abstract_eval):
609+
self.abstract_eval = _generic_abstract_eval(abstract_eval)
610+
return abstract_eval
611+
608612
def def_bind_with_trace(self, bind_with_trace):
609613
self.bind_with_trace = bind_with_trace
610614
return bind_with_trace
@@ -629,6 +633,17 @@ def abstract_eval_(*args, **kwargs):
629633
return abstract_eval(*args, **kwargs), no_effects
630634
return abstract_eval_
631635

636+
class GenericEffect(Effect):
637+
pass
638+
generic_effect = GenericEffect()
639+
generic_effect_set = {generic_effect}
640+
effects.lowerable_effects.add_type(GenericEffect)
641+
642+
def _generic_abstract_eval(abstract_eval):
643+
def abstract_eval_(*args, **kwargs):
644+
return abstract_eval(*args, **kwargs), generic_effect_set
645+
return abstract_eval_
646+
632647
# -------------------- lifting --------------------
633648

634649
# TODO(mattjj): replace this approach with a primitive-keyed table of rules

jax/_src/pallas/mosaic/primitives.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,7 @@ def do_discharge_src_sem(src_sem=src_sem):
485485
dma_wait_p = jax_core.Primitive('dma_wait')
486486
dma_wait_p.multiple_results = True
487487

488-
@dma_wait_p.def_abstract_eval
488+
@dma_wait_p.def_effectful_abstract_eval2
489489
def _dma_wait_abstract_eval(*args, tree, device_id_type):
490490
del args, tree, device_id_type
491491
return []
@@ -658,7 +658,7 @@ def async_remote_copy(
658658

659659
get_barrier_semaphore_p = jax_core.Primitive('get_barrier_semaphore')
660660

661-
@get_barrier_semaphore_p.def_abstract_eval
661+
@get_barrier_semaphore_p.def_effectful_abstract_eval2
662662
def _get_barrier_semaphore_abstract_eval():
663663
return state.AbstractRef(
664664
jax_core.ShapedArray((), pl_core.BarrierSemaphore()),
@@ -692,7 +692,7 @@ def get_barrier_semaphore():
692692
delay_p.multiple_results = True
693693

694694

695-
@delay_p.def_abstract_eval
695+
@delay_p.def_effectful_abstract_eval2
696696
def _delay_abstract_eval(nanos):
697697
del nanos
698698
return []
@@ -708,7 +708,7 @@ def delay(nanos):
708708
prng_seed_p.multiple_results = True
709709

710710

711-
@prng_seed_p.def_abstract_eval
711+
@prng_seed_p.def_effectful_abstract_eval2
712712
def _prng_seed_abstract_eval(*_):
713713
return []
714714

jax/_src/pallas/primitives.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1083,7 +1083,7 @@ def semaphore_read(sem_or_view):
10831083
flat_args, args_tree = tree_util.tree_flatten(args)
10841084
return semaphore_read_p.bind(*flat_args, args_tree=args_tree)
10851085

1086-
@semaphore_read_p.def_abstract_eval
1086+
@semaphore_read_p.def_effectful_abstract_eval2
10871087
def _semaphore_read_abstract_eval(
10881088
*avals,
10891089
args_tree,
@@ -1155,7 +1155,7 @@ def _semaphore_signal_abstract_eval(
11551155
check_sem_avals(sem_aval, sem_transforms_avals, "signal")
11561156
if value_aval.dtype != jnp.dtype("int32"):
11571157
raise ValueError("Must signal an int32 value.")
1158-
effs = set()
1158+
effs = {jax_core.generic_effect}
11591159
if device_id_avals is not None:
11601160
device_id_flat_avals = tree_util.tree_leaves(device_id_avals)
11611161
for aval in device_id_flat_avals:
@@ -1230,7 +1230,7 @@ def semaphore_wait(sem_or_view, dec: int | jax.Array = 1):
12301230
flat_args, args_tree = tree_util.tree_flatten(args)
12311231
semaphore_wait_p.bind(*flat_args, args_tree=args_tree)
12321232

1233-
@semaphore_wait_p.def_abstract_eval
1233+
@semaphore_wait_p.def_effectful_abstract_eval2
12341234
def _semaphore_wait_abstract_eval(*avals, args_tree):
12351235
sem_aval, sem_transforms_avals, value_aval = tree_util.tree_unflatten(
12361236
args_tree, avals

0 commit comments

Comments
 (0)