Skip to content

Commit 28f84eb

Browse files
author
jax authors
committed
Merge pull request #20044 from mattjj:mutable-arrays
PiperOrigin-RevId: 611866507
2 parents 04f6bfa + 3a403f2 commit 28f84eb

File tree

6 files changed

+86
-75
lines changed

6 files changed

+86
-75
lines changed

jax/_src/core.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1912,6 +1912,30 @@ def __str__(self) -> str:
19121912
AxisSize = Union[int, DArray, Tracer, Var, DBIdx, InDBIdx, OutDBIdx]
19131913

19141914

1915+
class MutableArray:
1916+
_aval: ShapedArray
1917+
_buf: Array
1918+
def __init__(self, aval, buf):
1919+
self._aval = aval
1920+
self._buf = buf
1921+
aval = property(lambda self: self._aval)
1922+
shape = property(lambda self: self._aval.shape)
1923+
dtype = property(lambda self: self._aval.dtype)
1924+
def __getitem__(self, idx): return get_aval(self)._getitem(self, idx)
1925+
def __setitem__(self, idx, x): return get_aval(self)._setitem(self, idx, x)
1926+
pytype_aval_mappings[MutableArray] = lambda x: x._aval
1927+
1928+
def mutable_array(init_val):
1929+
return mutable_array_p.bind(init_val)
1930+
mutable_array_p = Primitive('mutable_array')
1931+
1932+
@mutable_array_p.def_impl
1933+
def _mutable_array_impl(init_val):
1934+
from jax._src.state.types import AbstractRef # type: ignore[import]
1935+
aval = raise_to_shaped(get_aval(init_val))
1936+
return MutableArray(AbstractRef(aval), init_val)
1937+
1938+
19151939
class AbstractToken(AbstractValue):
19161940
def join(self, other):
19171941
if isinstance(other, AbstractToken):

jax/_src/interpreters/partial_eval.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2730,6 +2730,13 @@ def call_padding_rule(prim, in_avals, out_avals, *args, call_jaxpr, **params):
27302730
return prim.bind(*subfuns, *args, **bind_params)
27312731

27322732

2733+
def _error_staging_mutable_array_p(trace, x):
2734+
raise Exception(
2735+
"mutable_array constructor can't be staged out, and in particular can't "
2736+
"be used under a jax.jit or jax.lax.scan")
2737+
custom_staging_rules[core.mutable_array_p] = _error_staging_mutable_array_p
2738+
2739+
27332740
# TODO(mattjj): the following are deprecated; update callers to _nounits version
27342741
# See https://github.com/google/jax/pull/9498
27352742
@lu.transformation

jax/_src/interpreters/pxla.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,10 @@ def _shard_darray(x, sharding):
160160
return shard_arg(x._data, sharding)
161161
shard_arg_handlers[core.DArray] = _shard_darray
162162

163+
def _shard_mutable_array(x, sharding):
164+
return shard_arg(x._buf, sharding)
165+
shard_arg_handlers[core.MutableArray] = _shard_mutable_array
166+
163167
def batched_device_put(aval: core.ShapedArray,
164168
sharding: jax.sharding.Sharding, xs: Sequence[Any],
165169
devices: Sequence[jax.Device], committed: bool = True):
@@ -1778,17 +1782,16 @@ def _dce_jaxpr(closed_jaxpr, global_in_avals, api_name, fun_name,
17781782
@weakref_lru_cache
17791783
def _discharge_refs(
17801784
jaxpr: core.ClosedJaxpr
1781-
) -> tuple[core.ClosedJaxpr, None | Sequence[int | None], None | Sequence[int | None]]:
1785+
) -> tuple[core.ClosedJaxpr, Sequence[int | None], Sequence[int | None]]:
17821786
from jax._src.state.discharge import discharge_state
1783-
out_mut = [None] * len(jaxpr.out_avals) + [
1784-
i for i, a in enumerate(jaxpr.in_avals) if isinstance(a, AbstractRef)]
1785-
count = it.count()
1786-
inout_aliases = tuple(next(count) if isinstance(a, AbstractRef) else None
1787-
for a in jaxpr.in_avals)
1788-
jaxpr = core.ClosedJaxpr(*discharge_state(jaxpr.jaxpr, jaxpr.consts))
1789-
assert len(inout_aliases) == len(jaxpr.in_avals)
1790-
assert len(out_mut) == len(jaxpr.out_avals)
1791-
return jaxpr, inout_aliases, out_mut
1787+
new_jaxpr = core.ClosedJaxpr(*discharge_state(jaxpr.jaxpr, jaxpr.consts))
1788+
count = it.count(len(jaxpr.out_avals)) # new outputs are appended to the end
1789+
inout_map = {i: next(count) for i, a in enumerate(jaxpr.in_avals)
1790+
if isinstance(a, AbstractRef)}
1791+
outin_map = {j: i for i, j in inout_map.items()}
1792+
inout_aliases = tuple(map(inout_map.get, range(len(new_jaxpr.in_avals))))
1793+
out_mut = tuple(map(outin_map.get, range(len(new_jaxpr.out_avals))))
1794+
return new_jaxpr, inout_aliases, out_mut
17921795

17931796

17941797
@dataclasses.dataclass(frozen=True)

jax/_src/interpreters/xla.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import functools
2323
from functools import partial
2424
import itertools as it
25-
import operator
2625
from typing import Any, Callable, Protocol, Union
2726

2827
import numpy as np
@@ -166,6 +165,7 @@ def _canonicalize_python_scalar_dtype(typ, x):
166165
(t, partial(_canonicalize_python_scalar_dtype, t)) for t in _scalar_types)
167166
canonicalize_dtype_handlers[core.Token] = identity
168167
canonicalize_dtype_handlers[core.DArray] = identity
168+
canonicalize_dtype_handlers[core.MutableArray] = identity
169169

170170
def abstractify(x) -> Any:
171171
typ = type(x)
@@ -196,7 +196,8 @@ def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray:
196196

197197

198198
pytype_aval_mappings: dict[Any, Callable[[Any], core.AbstractValue]] = {}
199-
pytype_aval_mappings[core.DArray] = operator.attrgetter('_aval')
199+
pytype_aval_mappings[core.DArray] = lambda x: x._aval
200+
pytype_aval_mappings[core.MutableArray] = lambda x: x._aval
200201
pytype_aval_mappings[core.Token] = lambda _: core.abstract_token
201202
pytype_aval_mappings.update((t, _make_shaped_array_for_numpy_scalar)
202203
for t in numpy_scalar_types)

jax/_src/state/primitives.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from jax._src import ad_util
2424
from jax._src import core
25+
from jax._src import dispatch
2526
from jax._src import pretty_printer as pp
2627
from jax._src import tree_util
2728
from jax._src.interpreters import ad
@@ -53,11 +54,7 @@
5354
# `Ref((3,), np.dtype('float32'))` leads to a jaxpr eqn printed like
5455
# a:f32[3] <- x[]
5556
get_p = core.Primitive("get")
56-
57-
def _get_impl(ref: AbstractRef, *args: Any, tree):
58-
del ref, args, tree
59-
raise ValueError("Cannot run stateful primitive.")
60-
get_p.def_impl(_get_impl)
57+
get_p.def_impl(partial(dispatch.apply_primitive, get_p))
6158

6259
Indexer = tuple[Union[int, slice, Array], ...]
6360
# or Ellipsis, but that can't be annotated until Python 3.10? (types.EllipsisType)
@@ -113,11 +110,7 @@ def ref_get(ref_or_view: Any, idx: Indexer | None = None) -> Array:
113110
# are `ShapedArray((), np.dtype('int32'))` leads to a jaxpr eqn printed like
114111
# x:Ref{f32[3]}[i, j] <- a
115112
swap_p = core.Primitive("swap")
116-
117-
def _swap_impl(ref: AbstractRef, value: Array, *idx: Any, tree):
118-
del ref, value, idx, tree
119-
raise ValueError("Cannot run stateful primitive.")
120-
swap_p.def_impl(_swap_impl)
113+
swap_p.def_impl(partial(dispatch.apply_primitive, swap_p))
121114

122115
def ref_swap(ref_or_view: AbstractRef | RefView, idx: Indexer | None, value: Array,
123116
_function_name: str = "ref_swap") -> Array:
@@ -143,11 +136,7 @@ def ref_set(ref_or_view: AbstractRef | RefView, idx: Indexer | None, value: Arra
143136
# ```
144137
addupdate_p = core.Primitive('addupdate')
145138
addupdate_p.multiple_results = True
146-
147-
def _addupdate_impl(ref: AbstractRef, value: Array, *args: Any, tree):
148-
del ref, value, args, tree
149-
raise ValueError("Can't evaluate `addupdate` outside a stateful context.")
150-
addupdate_p.def_impl(_addupdate_impl)
139+
addupdate_p.def_impl(partial(dispatch.apply_primitive, addupdate_p))
151140

152141
def ref_addupdate(ref_or_view: AbstractRef, idx: Indexer | None, x: Array) -> None:
153142
"""Mutates a ref with an additive update i.e. `ref[idx] += x`."""

tests/state_test.py

Lines changed: 35 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -55,18 +55,6 @@
5555

5656
class StatePrimitivesTest(jtu.JaxTestCase):
5757

58-
def test_cant_eval_get_primitive(self):
59-
with self.assertRaises(ValueError):
60-
get_p.bind(jnp.ones(5), tree=None)
61-
62-
def test_cant_eval_swap_primitive(self):
63-
with self.assertRaises(ValueError):
64-
swap_p.bind(jnp.ones(5), jnp.zeros(5), tree=None)
65-
66-
def test_cant_eval_addupdate_primitive(self):
67-
with self.assertRaises(ValueError):
68-
addupdate_p.bind(jnp.ones(5), jnp.zeros(5), tree=None)
69-
7058
def test_get_abstract_aval_must_take_in_refs(self):
7159
ref_aval = core.ShapedArray((), jnp.float32)
7260
def f(x_ref):
@@ -1508,55 +1496,54 @@ def _body(ref):
15081496
jtu.check_grads(f, (0.5,), order=3)
15091497

15101498

1511-
class MutableArray:
1512-
_aval: core.ShapedArray
1513-
_buf: jax.Array
1514-
def __init__(self, aval, buf):
1515-
self._aval = aval
1516-
self._buf = buf
1517-
aval = property(lambda self: self._aval)
1518-
shape = property(lambda self: self._aval.shape)
1519-
dtype = property(lambda self: self._aval.dtype)
1520-
1521-
def mutable_array(init_val):
1522-
return mutable_array_p.bind(init_val)
1523-
mutable_array_p = core.Primitive('mutable_array')
1524-
1525-
@mutable_array_p.def_impl
1526-
def _mutable_array_impl(init_val):
1527-
aval = core.raise_to_shaped(core.get_aval(init_val))
1528-
return MutableArray(AbstractRef(aval), init_val)
1529-
1530-
def _error_on_staging(trace, x):
1531-
raise Exception
1532-
pe.custom_staging_rules[mutable_array_p] = _error_on_staging
1533-
1534-
from jax._src.interpreters import xla
1535-
from jax._src.interpreters import pxla
1536-
xla.canonicalize_dtype_handlers[MutableArray] = lambda x: x
1537-
xla.pytype_aval_mappings[MutableArray] = lambda x: x._aval
1538-
pxla.shard_arg_handlers[MutableArray] = lambda x, s: pxla.shard_arg(x._buf, s)
1539-
core.pytype_aval_mappings[MutableArray] = lambda x: x._aval
1540-
15411499
class MutableArrayTest(jtu.JaxTestCase):
15421500

1543-
def test_basic(self):
1544-
read = jax.jit(lambda x_ref: x_ref[...])
1545-
1546-
@jax.jit
1501+
@parameterized.parameters([True, False])
1502+
def test_basic(self, jit):
15471503
def f(x_mut):
15481504
x_mut[...] += 1.
15491505
x_mut[0] += 1
15501506
x_mut[1] += 5
15511507

1552-
x_mut = mutable_array(jnp.zeros(3))
1508+
if jit:
1509+
f = jax.jit(f)
1510+
1511+
x_mut = core.mutable_array(jnp.zeros(3))
15531512
f(x_mut)
15541513

1555-
self.assertAllClose(read(x_mut), jnp.array([2., 6., 1.]), check_dtypes=False)
1514+
self.assertAllClose(x_mut[...], jnp.array([2., 6., 1.]),
1515+
check_dtypes=False)
15561516

15571517
jaxpr = jax.make_jaxpr(f)(x_mut)
15581518
self.assertTrue(any(isinstance(e, RefEffect) for e in jaxpr.effects))
15591519

1520+
def test_staging_error(self):
1521+
x = jnp.zeros(3)
1522+
with self.assertRaises(Exception):
1523+
jax.jit(core.mutable_array)(x)
1524+
1525+
@parameterized.parameters([True, False])
1526+
def test_multiple_inputs_and_outputs(self, jit):
1527+
def f(x_mut, y, z_mut, w):
1528+
x_mut[...] += 1
1529+
z_mut[...] += 1
1530+
return x_mut[...] + y + z_mut[...] + w, y + w
1531+
1532+
if jit:
1533+
f = jax.jit(f)
1534+
1535+
x_mut = core.mutable_array(jnp.zeros((1, 3)))
1536+
y = jnp.ones((2, 3))
1537+
z_mut = core.mutable_array(jnp.zeros((2, 3)))
1538+
w = jnp.ones((2, 1))
1539+
1540+
out1, out2 = f(x_mut, y, z_mut, w)
1541+
1542+
self.assertAllClose(x_mut[...], jnp.ones((1, 3)), check_dtypes=False)
1543+
self.assertAllClose(z_mut[...], jnp.ones((2, 3)), check_dtypes=False)
1544+
self.assertAllClose(out1, 4 * jnp.ones((2, 3)), check_dtypes=False)
1545+
self.assertAllClose(out2, y + w, check_dtypes=False)
1546+
15601547

15611548
if CAN_USE_HYPOTHESIS:
15621549

0 commit comments

Comments
 (0)