From a1ba1d7f836b3d807be07280ba686fb49a5642cc Mon Sep 17 00:00:00 2001 From: Humphrey Yang Date: Wed, 28 May 2025 19:45:03 +1000 Subject: [PATCH 1/4] update jax intro lecture --- lectures/jax_intro.md | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/lectures/jax_intro.md b/lectures/jax_intro.md index 189d810..8ae49d3 100644 --- a/lectures/jax_intro.md +++ b/lectures/jax_intro.md @@ -172,7 +172,9 @@ a, a_new The designers of JAX chose to make arrays immutable because JAX uses a functional programming style. More on this below. -Note that, while mutation is discouraged, it is in fact possible with `at`, as in +However, JAX provides a functionally pure equivalent of in-place array modifications. + +To assign a new value to an element of a JAX array, we can use the `at` method ```{code-cell} ipython3 a = jnp.linspace(0, 1, 3) @@ -183,11 +185,15 @@ id(a) a ``` +We can see that the array `a` is changed by using the +[`at` method](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html), which returns +a new array + ```{code-cell} ipython3 -a.at[0].set(1) +a = a.at[0].set(1) ``` -We can check that the array is mutated by verifying its identity is unchanged: +Inspecting the identifier of `a` shows that it has changed ```{code-cell} ipython3 id(a) From cb78ad42b92cf19fe961e644979cf72c75bbed3d Mon Sep 17 00:00:00 2001 From: Humphrey Yang Date: Wed, 28 May 2025 19:53:29 +1000 Subject: [PATCH 2/4] minor update --- lectures/jax_intro.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/lectures/jax_intro.md b/lectures/jax_intro.md index 8ae49d3..472d120 100644 --- a/lectures/jax_intro.md +++ b/lectures/jax_intro.md @@ -186,8 +186,9 @@ a ``` We can see that the array `a` is changed by using the -[`at` method](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html), which returns -a new array +[`at` method](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html). + +It returns a new copy of `a` with the specified element changed. ```{code-cell} ipython3 a = a.at[0].set(1) From 2fd5aa7afa2eda30433f679896595ac399529cf9 Mon Sep 17 00:00:00 2001 From: Humphrey Yang Date: Wed, 28 May 2025 21:20:24 +1000 Subject: [PATCH 3/4] minor update --- lectures/jax_intro.md | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/lectures/jax_intro.md b/lectures/jax_intro.md index 472d120..2c00375 100644 --- a/lectures/jax_intro.md +++ b/lectures/jax_intro.md @@ -172,9 +172,9 @@ a, a_new The designers of JAX chose to make arrays immutable because JAX uses a functional programming style. More on this below. -However, JAX provides a functionally pure equivalent of in-place array modifications. +However, JAX provides a functionally pure equivalent of in-place array modification +using the [`at` method](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html). -To assign a new value to an element of a JAX array, we can use the `at` method ```{code-cell} ipython3 a = jnp.linspace(0, 1, 3) @@ -185,16 +185,15 @@ id(a) a ``` -We can see that the array `a` is changed by using the -[`at` method](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html). - -It returns a new copy of `a` with the specified element changed. +Applying `at[0].set(1)`, we can see that a new copy of `a` with the first element +set to 1 is returned ```{code-cell} ipython3 a = a.at[0].set(1) +a ``` -Inspecting the identifier of `a` shows that it has changed +Inspecting the identifier of `a` shows that it has been reassigned ```{code-cell} ipython3 id(a) From e47a4665e56147afbe0d4fd8084ccef47bcddce1 Mon Sep 17 00:00:00 2001 From: Humphrey Yang Date: Wed, 28 May 2025 21:46:53 +1000 Subject: [PATCH 4/4] address comment --- lectures/jax_intro.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/lectures/jax_intro.md b/lectures/jax_intro.md index 2c00375..d08e4de 100644 --- a/lectures/jax_intro.md +++ b/lectures/jax_intro.md @@ -185,8 +185,7 @@ id(a) a ``` -Applying `at[0].set(1)`, we can see that a new copy of `a` with the first element -set to 1 is returned +Applying `at[0].set(1)` returns a new copy of `a` with the first element set to 1 ```{code-cell} ipython3 a = a.at[0].set(1)