Skip to content

Commit 9785678

Browse files
authored
Lint JAX AI Stack Getting Started (#106)
1 parent ca6ed37 commit 9785678

File tree

1 file changed

+31
-31
lines changed

1 file changed

+31
-31
lines changed

docs/getting_started_with_jax_for_AI.md

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,13 @@ JAX focuses on [array-based](https://jax.readthedocs.io/en/latest/key-concepts.h
3434
- [Flax NNX](http://flax.readthedocs.io): A machine learning library designed for defining and building scalable neural networks using JAX.
3535
- [Optax](http://optax.readthedocs.io): A high-performance function optimization library that comes with built-in optimizers and loss functions.
3636

37-
Once you've worked through this content, you may wish to visit the [JAX documentation site](http://jax.readthedocs.io/) for a deeper dive into the core JAX concepts.
37+
After working through this content, you may wish to visit the [JAX documentation site](http://jax.readthedocs.io/) for a deeper dive into the core JAX concepts.
3838

3939
+++ {"id": "z7sAr0sderhh"}
4040

41-
## Example: a simple neural network with flax
41+
## Example: A simple neural network with Flax
4242

43-
We'll start with a very quick example of what it looks like to use JAX with the [flax](https://flax.readthedocs.io) framework to define and train a very simple neural network to recognize hand-written digits.
43+
We'll start with a very quick example of what it looks like to use JAX with the [Flax](https://flax.readthedocs.io) framework to define and train a very simple neural network to recognize hand-written digits.
4444

4545
+++ {"id": "pOlnhK-EioSk"}
4646

@@ -80,8 +80,8 @@ for i, ax in enumerate(axes.flat):
8080

8181
+++ {"id": "Z3l45KgtfUUo"}
8282

83-
Next, split the dataset into a training and testing set, and convert these splits into [`jax.Array`s](https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array) before you feed them into the model.
84-
You’ll use the `jax.numpy` module, which provides a familiar NumPy-style API around JAX operations:
83+
Next, we split the dataset into a training and testing set, and convert these splits into [`jax.Array`s](https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array) before we feed them into the model.
84+
We’ll use the `jax.numpy` module, which provides a familiar NumPy-style API around JAX operations:
8585

8686
```{code-cell}
8787
:id: 6jrYisoPh6TL
@@ -104,7 +104,7 @@ print(f"{images_test.shape=} {label_test.shape=}")
104104

105105
### Defining the Flax model
106106

107-
You can now use [Flax NNX](http://flax.readthedocs.io) to create a simple [feed-forward](https://en.wikipedia.org/wiki/Feedforward_neural_network) neural network - subclassing [`flax.nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module) - with [`flax.nnx.Linear`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/linear.html#flax.nnx.Linear) layers with *scaled exponential linear unit* (SELU) activation function using the built-in [`flax.nnx.selu`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/activations.html#flax.nnx.selu):
107+
We can now use [Flax NNX](http://flax.readthedocs.io) to create a simple [feed-forward](https://en.wikipedia.org/wiki/Feedforward_neural_network) neural network - subclassing [`flax.nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module) - with [`flax.nnx.Linear`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/linear.html#flax.nnx.Linear) layers with *scaled exponential linear unit* (SELU) activation function using the built-in [`flax.nnx.selu`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/activations.html#flax.nnx.selu):
108108

109109
```{code-cell}
110110
:id: U77VMQwRjTfH
@@ -137,10 +137,10 @@ nnx.display(model) # Interactive display if penzai is installed.
137137

138138
### Training the model
139139

140-
With the `SimpleNN` model created and instantiated, you can now choose the loss function and the optimizer with the [Optax](http://optax.readthedocs.io) package, and then define the training step function. Use:
140+
With the `SimpleNN` model created and instantiated, we can now choose the loss function and the optimizer with the [Optax](http://optax.readthedocs.io) package, and then define the training step function. Use:
141141
- [`optax.softmax_cross_entropy_with_integer_labels`](https://optax.readthedocs.io/en/latest/api/losses.html#optax.losses.softmax_cross_entropy_with_integer_labels) as the loss, as the output layer will have nodes corresponding to a handwritten integer label.
142-
- [`optax.sgd`](https://optax.readthedocs.io/en/latest/api/optimizers.html#optax.sgd) for the stochastic gradient descent.
143-
- [`flax.nnx.Optimizer`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/training/optimizer.html) to set the train state.
142+
- [`optax.sgd`](https://optax.readthedocs.io/en/latest/api/optimizers.html#optax.sgd) for the stochastic gradient descent optimizer.
143+
- [`flax.nnx.Optimizer`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/training/optimizer.html) to instantiate the optimizer and set the train state.
144144

145145
```{code-cell}
146146
:id: QwRvFPkYl5b2
@@ -178,9 +178,9 @@ Notice here the use of [`flax.nnx.jit`](https://flax.readthedocs.io/en/latest/ap
178178
- `jax.jit` is a [Just-In-Time compilation transformation](https://jax.readthedocs.io/en/latest/jit-compilation.html#jit-compilation), and will cause the function to be passed to the [XLA](https://openxla.org/xla) compiler for fast repeated execution.
179179
- `jax.grad` is a [gradient transformation](https://jax.readthedocs.io/en/latest/automatic-differentiation.html) that uses JAX's automatic differentiation for fast optimization of large networks.
180180

181-
You will return to these transformations later in the tutorial.
181+
We will return to these transformations later in the tutorial.
182182

183-
Now that you have a training step function, define a training loop to repeatedly perform this training step over the training data, periodically printing the loss against the test set to monitor convergence:
183+
Now that we have a training step function, let's define a training loop to repeatedly perform this training step over the training data, periodically printing the loss against the test set to monitor convergence:
184184

185185
```{code-cell}
186186
:id: l9mukT0eqmsr
@@ -195,7 +195,7 @@ for i in range(301): # 300 training epochs
195195

196196
+++ {"id": "3sjOKxLDv8SS"}
197197

198-
After 300 training epochs, your model should have converged to a target loss of around `0.10`. You can check what this implies for the accuracy of the labels for each image:
198+
After 300 training epochs, our model should have converged to a target loss of around `0.10`. We can check what this implies for the accuracy of the labels for each image:
199199

200200
```{code-cell}
201201
:id: 6OmW0lVlsvJ1
@@ -212,7 +212,7 @@ print(f"{num_matches} labels match out of {num_total}:"
212212
+++ {"id": "vTKF3-CFwY50"}
213213

214214
The simple feed-forward network has achieved approximately 98% accuracy on the test set.
215-
You can do a similar visualization as above to review some examples that the model predicted correctly (in green) and incorrectly (in red):
215+
We can do a similar visualization as above to review some examples that the model predicted correctly (in green) and incorrectly (in red):
216216

217217
```{code-cell}
218218
:id: uinijfm-qXsP
@@ -230,7 +230,7 @@ for i, ax in enumerate(axes.flat):
230230

231231
+++ {"id": "x7IIiVymuTRa"}
232232

233-
In this tutorial, you have just scraped the surface with JAX, Flax NNX, and Optax here. The Flax NNX package includes a number of useful APIs for tracking metrics during training, which are features in the [Flax MNIST tutorial](https://flax.readthedocs.io/en/latest/nnx/mnist_tutorial.html) on the Flax website.
233+
In this tutorial, we have just scraped the surface with JAX, Flax NNX, and Optax here. The Flax NNX package includes a number of useful APIs for tracking metrics during training, which are features in the [Flax MNIST tutorial](https://flax.readthedocs.io/en/latest/nnx/mnist_tutorial.html) on the Flax website.
234234

235235
+++ {"id": "5ZfGvXAiy2yr"}
236236

@@ -239,18 +239,18 @@ In this tutorial, you have just scraped the surface with JAX, Flax NNX, and Opta
239239
The Flax NNX neural network API demonstrated above takes advantage of a number of [key JAX features](https://jax.readthedocs.io/en/latest/key-concepts.html), designed into the library from the ground up. In particular:
240240

241241
- **JAX provides a familiar NumPy-like API for array computing.**
242-
This means that when processing data and outputs, you can reach for APIs like [`jax.numpy.count_nonzero`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.count_nonzero.html), which mirror the familiar APIs of the NumPy package; in this case [`numpy.count_nonzero`](https://numpy.org/doc/stable/reference/generated/numpy.count_nonzero.html).
242+
This means that when processing data and outputs, we can reach for APIs like [`jax.numpy.count_nonzero`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.count_nonzero.html), which mirror the familiar APIs of the NumPy package; in this case [`numpy.count_nonzero`](https://numpy.org/doc/stable/reference/generated/numpy.count_nonzero.html).
243243

244244
- **JAX provides just-in-time (JIT) compilation.**
245-
This means that you can implement your code easily in Python, but count on fast compiled execution on CPU, GPU, and TPU backends via the [XLA](https://openxla.org/xla) compiler by wrapping your code with a simple [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html) [transformation](https://jax.readthedocs.io/en/latest/jit-compilation.html).
245+
This means that we can implement our code easily in Python, but count on fast compiled execution on CPU, GPU, and TPU backends via the [XLA](https://openxla.org/xla) compiler by wrapping the code with a simple [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html) [transformation](https://jax.readthedocs.io/en/latest/jit-compilation.html).
246246

247247
- **JAX provides automatic differentiation (autodiff).**
248248
This means that when fitting models, `optax` and `flax` can compute closed-form gradient functions for fast optimization of models, using the [`jax.grad`](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html) [transformation](https://jax.readthedocs.io/en/latest/automatic-differentiation.html).
249249

250250
- **JAX provides automatic vectorization.**
251-
While you didn't get to use this directly in the code before, but under the hood flax takes advantage of [JAX's vectorized map](https://jax.readthedocs.io/en/latest/automatic-vectorization.html) ([`jax.vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html)) to automatically convert loss and gradient functions to efficient batch-aware functions that are just as fast as hand-written versions. This makes JAX implementations simpler and less error-prone.
251+
While we didn't get to use this directly in the code before, but under the hood flax takes advantage of [JAX's vectorized map](https://jax.readthedocs.io/en/latest/automatic-vectorization.html) ([`jax.vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html)) to automatically convert loss and gradient functions to efficient batch-aware functions that are just as fast as hand-written versions. This makes JAX implementations simpler and less error-prone.
252252

253-
You will learn more about these features through brief examples in the following sections.
253+
We will learn more about these features through brief examples in the following sections.
254254

255255
+++ {"id": "ZjneGfjy2Ef1"}
256256

@@ -259,7 +259,7 @@ You will learn more about these features through brief examples in the following
259259
The foundational array computing package in Python is NumPy, and [JAX provides](https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#jax-vs-numpy) [a matching API](https://jax.readthedocs.io/en/latest/quickstart.html#jax-as-numpy) via the [`jax.numpy`](https://jax.readthedocs.io/en/latest/jax.numpy.html) subpackage.
260260
Additionally, [JAX arrays](https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array) ([`jax.Array`](https://jax.readthedocs.io/en/latest/_autosummary/jax.Array.html#jax.Array)) behave much like NumPy arrays in their attributes, and in terms of [indexing](https://numpy.org/doc/stable/user/basics.indexing.html) and [broadcasting](https://numpy.org/doc/stable/user/basics.broadcasting.html) semantics.
261261

262-
In the previous example, you used Flax's built-in `flax.nnx.selu` implementation. You can also implement SeLU using JAX's NumPy API as follows:
262+
In the previous example, we used Flax's built-in `flax.nnx.selu` implementation. We can also implement SeLU using JAX's NumPy API as follows:
263263

264264
```{code-cell}
265265
:id: 2u2femxe2EzA
@@ -283,9 +283,9 @@ Despite the broad similarities, be aware that JAX does have some well-motivated
283283
### Just-in-time compilation
284284

285285
As mentioned before, JAX is built on the [XLA](https://openxla.org/xla) compiler, and allows sequences of operations to be just-in-time (JIT) compiled using the [`jax.jit` transformation](https://jax.readthedocs.io/en/latest/jit-compilation.html).
286-
In the neural network example above, you used the [similar](https://flax.readthedocs.io/en/latest/guides/transforms.html) [`flax.nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) transform, which has some special handling for Flax NNX objects for speed in neural network training.
286+
In the neural network example above, we used the [similar](https://flax.readthedocs.io/en/latest/guides/transforms.html) [`flax.nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) transform, which has some special handling for Flax NNX objects for speed in neural network training.
287287

288-
Returning to the previously defined `selu` function in JAX, you can create a `jax.jit`-compiled version this way:
288+
Returning to the previously defined `selu` function in JAX, we can create a `jax.jit`-compiled version this way:
289289

290290
```{code-cell}
291291
:id: -Chp8yCjQaFY
@@ -308,7 +308,7 @@ jnp.allclose(selu(x), selu_jit(x)) # results match
308308

309309
+++ {"id": "WWwD0NmzRLP8"}
310310

311-
You can use IPython's `%timeit` magic to observe the speedup (note the use of [`jax.block_until_ready()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.block_until_ready.html#jax.block_until_ready), which you need to use to account for JAX's [asynchronous dispatch](https://jax.readthedocs.io/en/latest/async_dispatch.html)):
311+
We can use IPython's `%timeit` magic to observe the speedup (note the use of [`jax.block_until_ready()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.block_until_ready.html#jax.block_until_ready), which we need to use to account for JAX's [asynchronous dispatch](https://jax.readthedocs.io/en/latest/async_dispatch.html)):
312312

313313
```{code-cell}
314314
:id: SzU_0NU5Jq_W
@@ -333,7 +333,7 @@ JAX's documentation has more discussion of JIT compilation at [Just-in-time comp
333333

334334
### Automatic differentiation (autodiff)
335335

336-
For efficient optimization of neural network models, fast gradient computations are essential. JAX enables this via its [automatic differentiation](https://jax.readthedocs.io/en/latest/automatic-differentiation.html) transformations like [`jax.grad`](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html), which computes a closed-form gradient of a JAX function. In the neural network example, you used the [similar](https://flax.readthedocs.io/en/latest/guides/transforms.html) [`flax.nnx.grad`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.grad) function, which has special handling for [`flax.nnx`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/index.html) objects.
336+
For efficient optimization of neural network models, fast gradient computations are essential. JAX enables this via its [automatic differentiation](https://jax.readthedocs.io/en/latest/automatic-differentiation.html) transformations like [`jax.grad`](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html), which computes a closed-form gradient of a JAX function. In the neural network example, we used the [similar](https://flax.readthedocs.io/en/latest/guides/transforms.html) [`flax.nnx.grad`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.grad) function, which has special handling for [`flax.nnx`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/index.html) objects.
337337

338338
Here's how to compute the gradient of a function with `jax.grad`:
339339

@@ -347,7 +347,7 @@ jax.grad(selu)(x)
347347

348348
+++ {"id": "1P-UEh9VO94k"}
349349

350-
You can briefly check with a finite-difference approximation that this is giving the expected value:
350+
We can briefly check with a finite-difference approximation that this is giving the expected value:
351351

352352
```{code-cell}
353353
:id: 1gOc4FyzPDUC
@@ -365,7 +365,7 @@ Importantly, the automatic differentiation approach is both more accurate and ef
365365

366366
### Automatic vectorization
367367

368-
In the training loop example earlier, you defined the loss function in terms of a single input data vector of shape `n_features` but trained the model by passing batches of data (of shape `[n_samples, n_features]`). Rather than requiring a naive and slow loop over batches in Flax and Optax internals, they instead use JAX's [automatic vectorization](https://jax.readthedocs.io/en/latest/automatic-vectorization.html) via the [`jax.vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html#jax.vmap) transformation to construct a batched version of the kernel automatically.
368+
In the training loop example earlier, we defined the loss function in terms of a single input data vector of shape `n_features` but trained the model by passing batches of data (of shape `[n_samples, n_features]`). Rather than requiring a naive and slow loop over batches in Flax and Optax internals, they instead use JAX's [automatic vectorization](https://jax.readthedocs.io/en/latest/automatic-vectorization.html) via the [`jax.vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html#jax.vmap) transformation to construct a batched version of the kernel automatically.
369369

370370
Consider a simple loss function that looks like this:
371371

@@ -378,7 +378,7 @@ def loss(x: jax.Array, x0: jax.Array):
378378

379379
+++ {"id": "lOg9IWlPddfE"}
380380

381-
You can evaluate it on a single data vector this way:
381+
We can evaluate it on a single data vector this way:
382382

383383
```{code-cell}
384384
:id: sYlEtbxedngb
@@ -391,7 +391,7 @@ loss(x, x0)
391391

392392
+++ {"id": "STit-syzk59F"}
393393

394-
But if you attempt to evaluate it on a batch of vectors, it does not correctly return a batch of 4 losses:
394+
But if we attempt to evaluate it on a batch of vectors, it does not correctly return a batch of 4 losses:
395395

396396
```{code-cell}
397397
:id: 1LFQX3zGlCil
@@ -403,12 +403,12 @@ loss(batched_x, x0) # wrong!
403403

404404
+++ {"id": "Qc3Kwe2HlhpA"}
405405

406-
The problem is that this loss function is not batch-aware. Without automatic vectorization, there are two ways you can address this:
406+
The problem is that this loss function is not batch-aware. Without automatic vectorization, there are two ways we can address this:
407407

408-
1. Re-write your loss function by hand to operate on batched data; however, as functions become more complicated, this becomes difficult and error-prone.
409-
2. Naively loop over unbatched calls to your original function; however, this is easy to code, but can be slow because it doesn't take advantage of vectorized compute.
408+
1. Re-write our loss function by hand to operate on batched data; however, as functions become more complicated, this becomes difficult and error-prone.
409+
2. Naively loop over unbatched calls to our original function. However, this is easy to code, but can be slow because it doesn't take advantage of vectorized compute.
410410

411-
The [`jax.vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html#jax.vmap) [transformation](https://jax.readthedocs.io/en/latest/automatic-vectorization.html) offers a third way: it automatically transforms your original function into a batch-aware version, so you get the speed of option 1 with the ease of option 2:
411+
The [`jax.vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html#jax.vmap) [transformation](https://jax.readthedocs.io/en/latest/automatic-vectorization.html) offers a third way: it automatically transforms our original function into a batch-aware version, so we get the speed of option 1 with the ease of option 2:
412412

413413
```{code-cell}
414414
:id: Y2Sa458OoRVL

0 commit comments

Comments
 (0)