Skip to content

Commit 23b0fa8

Browse files
author
jax authors
committed
Merge pull request #20572 from mattjj:marray-you
PiperOrigin-RevId: 621878367
2 parents 498e81a + 46a5162 commit 23b0fa8

File tree

5 files changed

+49
-16
lines changed

5 files changed

+49
-16
lines changed

jax/_src/core.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1929,6 +1929,14 @@ def mutable_array(init_val):
19291929
return mutable_array_p.bind(init_val)
19301930
mutable_array_p = Primitive('mutable_array')
19311931

1932+
class InternalMutableArray(effects.Effect):
1933+
pass
1934+
1935+
@mutable_array_p.def_effectful_abstract_eval
1936+
def mutable_array_abstract_eval(init_aval):
1937+
from jax._src.state.types import AbstractRef # type: ignore[import]
1938+
return AbstractRef(init_aval), {InternalMutableArray}
1939+
19321940
@mutable_array_p.def_impl
19331941
def _mutable_array_impl(init_val):
19341942
from jax._src.state.types import AbstractRef # type: ignore[import]
@@ -2922,6 +2930,8 @@ def write(v: Var, a: AbstractValue) -> None:
29222930
write(v, v.aval)
29232931

29242932
# Check each eqn.
2933+
sentinel = object()
2934+
in_idx = {v: i for i, v in enumerate(it.chain(jaxpr.constvars, jaxpr.invars))}
29252935
for eqn_idx, eqn in enumerate(jaxpr.eqns):
29262936
prim = eqn.primitive
29272937
try:
@@ -2943,18 +2953,19 @@ def write(v: Var, a: AbstractValue) -> None:
29432953

29442954
# Check the computed effect type matches the eqn's annotation, and is
29452955
# included in the jaxpr's annotation.
2956+
if prim is mutable_array_p:
2957+
outvar, = eqn.outvars
2958+
in_idx[outvar] = None # type: ignore
29462959
if eqn.effects != eqn_effects:
29472960
raise JaxprTypeError("Inferred effects do not match equation effects. "
29482961
f"Equation effects: {eqn.effects}. "
29492962
f"Inferred effects: {eqn_effects}")
29502963
for eff in eqn.effects:
29512964
if isinstance(eff, effects.JaxprInputEffect):
29522965
eqn_invar = eqn.invars[eff.input_index]
2953-
all_vars = [*jaxpr.constvars, *jaxpr.invars]
2954-
if eqn_invar not in all_vars:
2966+
if (jaxpr_index := in_idx.get(eqn_invar, sentinel)) is sentinel:
29552967
raise JaxprTypeError(
29562968
"Invalid `JaxprInputEffect`: must correspond to a jaxpr invar")
2957-
jaxpr_index = all_vars.index(eqn_invar)
29582969
jaxpr_effect = eff.replace(input_index=jaxpr_index)
29592970
if jaxpr_effect not in jaxpr.effects:
29602971
raise JaxprTypeError(

jax/_src/interpreters/partial_eval.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1726,9 +1726,13 @@ def get_referent(self):
17261726
api_util._shaped_abstractify_handlers[DynamicJaxprTracer] = op.attrgetter("aval")
17271727

17281728
def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects:
1729+
sentinel = object()
17291730
jaxpr_effects = set()
17301731
all_vars = {v: i for i, v in enumerate(it.chain(constvars, invars))}
17311732
for eqn in eqns:
1733+
if eqn.primitive is core.mutable_array_p:
1734+
outvar, = eqn.outvars
1735+
all_vars[outvar] = None # type: ignore
17321736
for eff in eqn.effects:
17331737
if isinstance(eff, effects.JaxprInputEffect):
17341738
if eff.input_index >= len(eqn.invars):
@@ -1738,7 +1742,7 @@ def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects:
17381742
"\n Jaxpr: "
17391743
f"{core.Jaxpr(constvars, invars, outvars, eqns, set())}")
17401744
invar = eqn.invars[eff.input_index]
1741-
if (input_index := all_vars.get(invar)) is None:
1745+
if (input_index := all_vars.get(invar, sentinel)) is sentinel:
17421746
raise ValueError(
17431747
f"`JaxprInputEffect` {eff} does not have "
17441748
f"corresponding input: {invar}."
@@ -2735,13 +2739,6 @@ def call_padding_rule(prim, in_avals, out_avals, *args, call_jaxpr, **params):
27352739
return prim.bind(*subfuns, *args, **bind_params)
27362740

27372741

2738-
def _error_staging_mutable_array_p(trace, x):
2739-
raise Exception(
2740-
"mutable_array constructor can't be staged out, and in particular can't "
2741-
"be used under a jax.jit or jax.lax.scan")
2742-
custom_staging_rules[core.mutable_array_p] = _error_staging_mutable_array_p
2743-
2744-
27452742
# TODO(mattjj): the following are deprecated; update callers to _nounits version
27462743
# See https://github.com/google/jax/pull/9498
27472744
@lu.transformation

jax/_src/interpreters/pxla.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1817,6 +1817,13 @@ def _move_mutable_consts(
18171817
effects, None)
18181818
return core.ClosedJaxpr(jaxpr, consts), in_mut
18191819

1820+
@weakref_lru_cache
1821+
def _discharge_internal_refs(jaxpr: core.ClosedJaxpr) -> core.ClosedJaxpr:
1822+
from jax._src.state.discharge import discharge_state
1823+
jaxpr_, consts = discharge_state(jaxpr.jaxpr, jaxpr.consts)
1824+
jaxpr_._debug_info = jaxpr.jaxpr.debug_info
1825+
return core.ClosedJaxpr(jaxpr_, consts)
1826+
18201827

18211828
class SemanticallyEqualShardings:
18221829

@@ -2074,6 +2081,8 @@ def lower_sharding_computation(
20742081
global_out_avals = closed_jaxpr.out_avals
20752082
else:
20762083
inout_aliases = mut = None
2084+
if any(isinstance(e, core.InternalMutableArray) for e in closed_jaxpr.effects):
2085+
closed_jaxpr = _discharge_internal_refs(closed_jaxpr)
20772086

20782087
jaxpr = closed_jaxpr.jaxpr
20792088
assert len(out_shardings) == len(out_layouts) == len(global_out_avals), (

jax/_src/state/discharge.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,6 @@ def register(f: DischargeRule):
9696
_discharge_rules[prim] = f
9797
return register
9898

99-
def _has_refs(eqn: core.JaxprEqn):
100-
return any(isinstance(v.aval, AbstractRef) for v in eqn.invars)
101-
10299
def _eval_jaxpr_discharge_state(
103100
jaxpr: core.Jaxpr, should_discharge: Sequence[bool], consts: Sequence[Any],
104101
*args: Any):
@@ -113,8 +110,12 @@ def _eval_jaxpr_discharge_state(
113110
if d and isinstance(v.aval, AbstractRef)}
114111

115112
for eqn in jaxpr.eqns:
116-
if _has_refs(eqn) and any(id(v.aval) in refs_to_discharge
117-
for v in eqn.invars):
113+
if eqn.primitive is core.mutable_array_p:
114+
[invar], [outvar] = eqn.invars, eqn.outvars
115+
init_val = env.read(invar)
116+
env.write(outvar, init_val)
117+
refs_to_discharge.add(id(outvar.aval))
118+
elif any(id(v.aval) in refs_to_discharge for v in eqn.invars):
118119
if eqn.primitive not in _discharge_rules:
119120
raise NotImplementedError("No state discharge rule implemented for "
120121
f"primitive: {eqn.primitive}")

tests/state_test.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1588,6 +1588,21 @@ def f(y_mut, z):
15881588
check_dtypes=False)
15891589
self.assertAllClose(w, 10, check_dtypes=False)
15901590

1591+
@parameterized.parameters([True, False])
1592+
def test_internal_mutarray_basic(self, jit):
1593+
def f():
1594+
x_mut = core.mutable_array(jnp.zeros(3))
1595+
x_mut[0] += 1
1596+
x_mut[0] += 1
1597+
x_mut[2] += 1
1598+
return x_mut[...]
1599+
1600+
if jit:
1601+
f = jax.jit(f)
1602+
1603+
out = f()
1604+
self.assertAllClose(out, jnp.array([2., 0., 1.]), check_dtypes=False)
1605+
15911606

15921607
if CAN_USE_HYPOTHESIS:
15931608

0 commit comments

Comments
 (0)