Skip to content

Commit 60d3378

Browse files
committed
mark some pallas/mosaic primitives as effectful to avoid DCE
1 parent b7ce110 commit 60d3378

File tree

3 files changed

+22
-7
lines changed

3 files changed

+22
-7
lines changed

jax/_src/core.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -605,6 +605,10 @@ def def_effectful_abstract_eval(self, effectful_abstract_eval):
605605
self.abstract_eval = effectful_abstract_eval
606606
return effectful_abstract_eval
607607

608+
def def_effectful_abstract_eval2(self, abstract_eval):
609+
self.abstract_eval = _generic_abstract_eval(abstract_eval)
610+
return abstract_eval
611+
608612
def def_bind_with_trace(self, bind_with_trace):
609613
self.bind_with_trace = bind_with_trace
610614
return bind_with_trace
@@ -629,6 +633,17 @@ def abstract_eval_(*args, **kwargs):
629633
return abstract_eval(*args, **kwargs), no_effects
630634
return abstract_eval_
631635

636+
class GenericEffect(Effect):
637+
pass
638+
generic_effect = GenericEffect()
639+
generic_effect_set = {generic_effect}
640+
effects.lowerable_effects.add_type(GenericEffect)
641+
642+
def _generic_abstract_eval(abstract_eval):
643+
def abstract_eval_(*args, **kwargs):
644+
return abstract_eval(*args, **kwargs), generic_effect_set
645+
return abstract_eval_
646+
632647
# -------------------- lifting --------------------
633648

634649
# TODO(mattjj): replace this approach with a primitive-keyed table of rules

jax/_src/pallas/mosaic/primitives.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,7 @@ def do_discharge_src_sem(src_sem=src_sem):
485485
dma_wait_p = jax_core.Primitive('dma_wait')
486486
dma_wait_p.multiple_results = True
487487

488-
@dma_wait_p.def_abstract_eval
488+
@dma_wait_p.def_effectful_abstract_eval2
489489
def _dma_wait_abstract_eval(*args, tree, device_id_type):
490490
del args, tree, device_id_type
491491
return []
@@ -641,7 +641,7 @@ def async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id,
641641

642642
get_barrier_semaphore_p = jax_core.Primitive('get_barrier_semaphore')
643643

644-
@get_barrier_semaphore_p.def_abstract_eval
644+
@get_barrier_semaphore_p.def_effectful_abstract_eval2
645645
def _get_barrier_semaphore_abstract_eval():
646646
return state.AbstractRef(
647647
jax_core.ShapedArray((), pl_core.BarrierSemaphore()),
@@ -675,7 +675,7 @@ def get_barrier_semaphore():
675675
delay_p.multiple_results = True
676676

677677

678-
@delay_p.def_abstract_eval
678+
@delay_p.def_effectful_abstract_eval2
679679
def _delay_abstract_eval(nanos):
680680
del nanos
681681
return []
@@ -691,7 +691,7 @@ def delay(nanos):
691691
prng_seed_p.multiple_results = True
692692

693693

694-
@prng_seed_p.def_abstract_eval
694+
@prng_seed_p.def_effectful_abstract_eval2
695695
def _prng_seed_abstract_eval(*_):
696696
return []
697697

jax/_src/pallas/primitives.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1083,7 +1083,7 @@ def semaphore_read(sem_or_view):
10831083
flat_args, args_tree = tree_util.tree_flatten(args)
10841084
return semaphore_read_p.bind(*flat_args, args_tree=args_tree)
10851085

1086-
@semaphore_read_p.def_abstract_eval
1086+
@semaphore_read_p.def_effectful_abstract_eval2
10871087
def _semaphore_read_abstract_eval(
10881088
*avals,
10891089
args_tree,
@@ -1128,7 +1128,7 @@ def semaphore_signal(
11281128
)
11291129

11301130

1131-
@semaphore_signal_p.def_abstract_eval
1131+
@semaphore_signal_p.def_effectful_abstract_eval2
11321132
def _semaphore_signal_abstract_eval(
11331133
*avals,
11341134
args_tree,
@@ -1218,7 +1218,7 @@ def semaphore_wait(sem_or_view, dec: int | jax.Array = 1):
12181218
flat_args, args_tree = tree_util.tree_flatten(args)
12191219
semaphore_wait_p.bind(*flat_args, args_tree=args_tree)
12201220

1221-
@semaphore_wait_p.def_abstract_eval
1221+
@semaphore_wait_p.def_effectful_abstract_eval2
12221222
def _semaphore_wait_abstract_eval(*avals, args_tree):
12231223
sem_aval, sem_transforms_avals, value_aval = tree_util.tree_unflatten(
12241224
args_tree, avals

0 commit comments

Comments
 (0)