diff --git a/lectures/jax_intro.md b/lectures/jax_intro.md index 189d810..d08e4de 100644 --- a/lectures/jax_intro.md +++ b/lectures/jax_intro.md @@ -172,7 +172,9 @@ a, a_new The designers of JAX chose to make arrays immutable because JAX uses a functional programming style. More on this below. -Note that, while mutation is discouraged, it is in fact possible with `at`, as in +However, JAX provides a functionally pure equivalent of in-place array modification +using the [`at` method](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html). + ```{code-cell} ipython3 a = jnp.linspace(0, 1, 3) @@ -183,11 +185,14 @@ id(a) a ``` +Applying `at[0].set(1)` returns a new copy of `a` with the first element set to 1 + ```{code-cell} ipython3 -a.at[0].set(1) +a = a.at[0].set(1) +a ``` -We can check that the array is mutated by verifying its identity is unchanged: +Inspecting the identifier of `a` shows that it has been reassigned ```{code-cell} ipython3 id(a)