You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/getting_started_with_jax_for_AI.md
+31-31Lines changed: 31 additions & 31 deletions
Original file line number
Diff line number
Diff line change
@@ -34,13 +34,13 @@ JAX focuses on [array-based](https://jax.readthedocs.io/en/latest/key-concepts.h
34
34
-[Flax NNX](http://flax.readthedocs.io): A machine learning library designed for defining and building scalable neural networks using JAX.
35
35
-[Optax](http://optax.readthedocs.io): A high-performance function optimization library that comes with built-in optimizers and loss functions.
36
36
37
-
Once you've worked through this content, you may wish to visit the [JAX documentation site](http://jax.readthedocs.io/) for a deeper dive into the core JAX concepts.
37
+
After working through this content, you may wish to visit the [JAX documentation site](http://jax.readthedocs.io/) for a deeper dive into the core JAX concepts.
38
38
39
39
+++ {"id": "z7sAr0sderhh"}
40
40
41
-
## Example: a simple neural network with flax
41
+
## Example: A simple neural network with Flax
42
42
43
-
We'll start with a very quick example of what it looks like to use JAX with the [flax](https://flax.readthedocs.io) framework to define and train a very simple neural network to recognize hand-written digits.
43
+
We'll start with a very quick example of what it looks like to use JAX with the [Flax](https://flax.readthedocs.io) framework to define and train a very simple neural network to recognize hand-written digits.
44
44
45
45
+++ {"id": "pOlnhK-EioSk"}
46
46
@@ -80,8 +80,8 @@ for i, ax in enumerate(axes.flat):
80
80
81
81
+++ {"id": "Z3l45KgtfUUo"}
82
82
83
-
Next, split the dataset into a training and testing set, and convert these splits into [`jax.Array`s](https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array) before you feed them into the model.
84
-
You’ll use the `jax.numpy` module, which provides a familiar NumPy-style API around JAX operations:
83
+
Next, we split the dataset into a training and testing set, and convert these splits into [`jax.Array`s](https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array) before we feed them into the model.
84
+
We’ll use the `jax.numpy` module, which provides a familiar NumPy-style API around JAX operations:
You can now use [Flax NNX](http://flax.readthedocs.io) to create a simple [feed-forward](https://en.wikipedia.org/wiki/Feedforward_neural_network) neural network - subclassing [`flax.nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module) - with [`flax.nnx.Linear`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/linear.html#flax.nnx.Linear) layers with *scaled exponential linear unit* (SELU) activation function using the built-in [`flax.nnx.selu`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/activations.html#flax.nnx.selu):
107
+
We can now use [Flax NNX](http://flax.readthedocs.io) to create a simple [feed-forward](https://en.wikipedia.org/wiki/Feedforward_neural_network) neural network - subclassing [`flax.nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module) - with [`flax.nnx.Linear`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/linear.html#flax.nnx.Linear) layers with *scaled exponential linear unit* (SELU) activation function using the built-in [`flax.nnx.selu`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/activations.html#flax.nnx.selu):
108
108
109
109
```{code-cell}
110
110
:id: U77VMQwRjTfH
@@ -137,10 +137,10 @@ nnx.display(model) # Interactive display if penzai is installed.
137
137
138
138
### Training the model
139
139
140
-
With the `SimpleNN` model created and instantiated, you can now choose the loss function and the optimizer with the [Optax](http://optax.readthedocs.io) package, and then define the training step function. Use:
140
+
With the `SimpleNN` model created and instantiated, we can now choose the loss function and the optimizer with the [Optax](http://optax.readthedocs.io) package, and then define the training step function. Use:
141
141
-[`optax.softmax_cross_entropy_with_integer_labels`](https://optax.readthedocs.io/en/latest/api/losses.html#optax.losses.softmax_cross_entropy_with_integer_labels) as the loss, as the output layer will have nodes corresponding to a handwritten integer label.
142
-
-[`optax.sgd`](https://optax.readthedocs.io/en/latest/api/optimizers.html#optax.sgd) for the stochastic gradient descent.
143
-
-[`flax.nnx.Optimizer`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/training/optimizer.html) to set the train state.
142
+
-[`optax.sgd`](https://optax.readthedocs.io/en/latest/api/optimizers.html#optax.sgd) for the stochastic gradient descent optimizer.
143
+
-[`flax.nnx.Optimizer`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/training/optimizer.html) to instantiate the optimizer and set the train state.
144
144
145
145
```{code-cell}
146
146
:id: QwRvFPkYl5b2
@@ -178,9 +178,9 @@ Notice here the use of [`flax.nnx.jit`](https://flax.readthedocs.io/en/latest/ap
178
178
-`jax.jit` is a [Just-In-Time compilation transformation](https://jax.readthedocs.io/en/latest/jit-compilation.html#jit-compilation), and will cause the function to be passed to the [XLA](https://openxla.org/xla) compiler for fast repeated execution.
179
179
-`jax.grad` is a [gradient transformation](https://jax.readthedocs.io/en/latest/automatic-differentiation.html) that uses JAX's automatic differentiation for fast optimization of large networks.
180
180
181
-
You will return to these transformations later in the tutorial.
181
+
We will return to these transformations later in the tutorial.
182
182
183
-
Now that you have a training step function, define a training loop to repeatedly perform this training step over the training data, periodically printing the loss against the test set to monitor convergence:
183
+
Now that we have a training step function, let's define a training loop to repeatedly perform this training step over the training data, periodically printing the loss against the test set to monitor convergence:
184
184
185
185
```{code-cell}
186
186
:id: l9mukT0eqmsr
@@ -195,7 +195,7 @@ for i in range(301): # 300 training epochs
195
195
196
196
+++ {"id": "3sjOKxLDv8SS"}
197
197
198
-
After 300 training epochs, your model should have converged to a target loss of around `0.10`. You can check what this implies for the accuracy of the labels for each image:
198
+
After 300 training epochs, our model should have converged to a target loss of around `0.10`. We can check what this implies for the accuracy of the labels for each image:
199
199
200
200
```{code-cell}
201
201
:id: 6OmW0lVlsvJ1
@@ -212,7 +212,7 @@ print(f"{num_matches} labels match out of {num_total}:"
212
212
+++ {"id": "vTKF3-CFwY50"}
213
213
214
214
The simple feed-forward network has achieved approximately 98% accuracy on the test set.
215
-
You can do a similar visualization as above to review some examples that the model predicted correctly (in green) and incorrectly (in red):
215
+
We can do a similar visualization as above to review some examples that the model predicted correctly (in green) and incorrectly (in red):
216
216
217
217
```{code-cell}
218
218
:id: uinijfm-qXsP
@@ -230,7 +230,7 @@ for i, ax in enumerate(axes.flat):
230
230
231
231
+++ {"id": "x7IIiVymuTRa"}
232
232
233
-
In this tutorial, you have just scraped the surface with JAX, Flax NNX, and Optax here. The Flax NNX package includes a number of useful APIs for tracking metrics during training, which are features in the [Flax MNIST tutorial](https://flax.readthedocs.io/en/latest/nnx/mnist_tutorial.html) on the Flax website.
233
+
In this tutorial, we have just scraped the surface with JAX, Flax NNX, and Optax here. The Flax NNX package includes a number of useful APIs for tracking metrics during training, which are features in the [Flax MNIST tutorial](https://flax.readthedocs.io/en/latest/nnx/mnist_tutorial.html) on the Flax website.
234
234
235
235
+++ {"id": "5ZfGvXAiy2yr"}
236
236
@@ -239,18 +239,18 @@ In this tutorial, you have just scraped the surface with JAX, Flax NNX, and Opta
239
239
The Flax NNX neural network API demonstrated above takes advantage of a number of [key JAX features](https://jax.readthedocs.io/en/latest/key-concepts.html), designed into the library from the ground up. In particular:
240
240
241
241
-**JAX provides a familiar NumPy-like API for array computing.**
242
-
This means that when processing data and outputs, you can reach for APIs like [`jax.numpy.count_nonzero`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.count_nonzero.html), which mirror the familiar APIs of the NumPy package; in this case [`numpy.count_nonzero`](https://numpy.org/doc/stable/reference/generated/numpy.count_nonzero.html).
242
+
This means that when processing data and outputs, we can reach for APIs like [`jax.numpy.count_nonzero`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.count_nonzero.html), which mirror the familiar APIs of the NumPy package; in this case [`numpy.count_nonzero`](https://numpy.org/doc/stable/reference/generated/numpy.count_nonzero.html).
243
243
244
244
-**JAX provides just-in-time (JIT) compilation.**
245
-
This means that you can implement your code easily in Python, but count on fast compiled execution on CPU, GPU, and TPU backends via the [XLA](https://openxla.org/xla) compiler by wrapping your code with a simple [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html)[transformation](https://jax.readthedocs.io/en/latest/jit-compilation.html).
245
+
This means that we can implement our code easily in Python, but count on fast compiled execution on CPU, GPU, and TPU backends via the [XLA](https://openxla.org/xla) compiler by wrapping the code with a simple [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html)[transformation](https://jax.readthedocs.io/en/latest/jit-compilation.html).
This means that when fitting models, `optax` and `flax` can compute closed-form gradient functions for fast optimization of models, using the [`jax.grad`](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html)[transformation](https://jax.readthedocs.io/en/latest/automatic-differentiation.html).
249
249
250
250
-**JAX provides automatic vectorization.**
251
-
While you didn't get to use this directly in the code before, but under the hood flax takes advantage of [JAX's vectorized map](https://jax.readthedocs.io/en/latest/automatic-vectorization.html) ([`jax.vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html)) to automatically convert loss and gradient functions to efficient batch-aware functions that are just as fast as hand-written versions. This makes JAX implementations simpler and less error-prone.
251
+
While we didn't get to use this directly in the code before, but under the hood flax takes advantage of [JAX's vectorized map](https://jax.readthedocs.io/en/latest/automatic-vectorization.html) ([`jax.vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html)) to automatically convert loss and gradient functions to efficient batch-aware functions that are just as fast as hand-written versions. This makes JAX implementations simpler and less error-prone.
252
252
253
-
You will learn more about these features through brief examples in the following sections.
253
+
We will learn more about these features through brief examples in the following sections.
254
254
255
255
+++ {"id": "ZjneGfjy2Ef1"}
256
256
@@ -259,7 +259,7 @@ You will learn more about these features through brief examples in the following
259
259
The foundational array computing package in Python is NumPy, and [JAX provides](https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#jax-vs-numpy)[a matching API](https://jax.readthedocs.io/en/latest/quickstart.html#jax-as-numpy) via the [`jax.numpy`](https://jax.readthedocs.io/en/latest/jax.numpy.html) subpackage.
260
260
Additionally, [JAX arrays](https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array) ([`jax.Array`](https://jax.readthedocs.io/en/latest/_autosummary/jax.Array.html#jax.Array)) behave much like NumPy arrays in their attributes, and in terms of [indexing](https://numpy.org/doc/stable/user/basics.indexing.html) and [broadcasting](https://numpy.org/doc/stable/user/basics.broadcasting.html) semantics.
261
261
262
-
In the previous example, you used Flax's built-in `flax.nnx.selu` implementation. You can also implement SeLU using JAX's NumPy API as follows:
262
+
In the previous example, we used Flax's built-in `flax.nnx.selu` implementation. We can also implement SeLU using JAX's NumPy API as follows:
263
263
264
264
```{code-cell}
265
265
:id: 2u2femxe2EzA
@@ -283,9 +283,9 @@ Despite the broad similarities, be aware that JAX does have some well-motivated
283
283
### Just-in-time compilation
284
284
285
285
As mentioned before, JAX is built on the [XLA](https://openxla.org/xla) compiler, and allows sequences of operations to be just-in-time (JIT) compiled using the [`jax.jit` transformation](https://jax.readthedocs.io/en/latest/jit-compilation.html).
286
-
In the neural network example above, you used the [similar](https://flax.readthedocs.io/en/latest/guides/transforms.html)[`flax.nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) transform, which has some special handling for Flax NNX objects for speed in neural network training.
286
+
In the neural network example above, we used the [similar](https://flax.readthedocs.io/en/latest/guides/transforms.html)[`flax.nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) transform, which has some special handling for Flax NNX objects for speed in neural network training.
287
287
288
-
Returning to the previously defined `selu` function in JAX, you can create a `jax.jit`-compiled version this way:
288
+
Returning to the previously defined `selu` function in JAX, we can create a `jax.jit`-compiled version this way:
289
289
290
290
```{code-cell}
291
291
:id: -Chp8yCjQaFY
@@ -308,7 +308,7 @@ jnp.allclose(selu(x), selu_jit(x)) # results match
308
308
309
309
+++ {"id": "WWwD0NmzRLP8"}
310
310
311
-
You can use IPython's `%timeit` magic to observe the speedup (note the use of [`jax.block_until_ready()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.block_until_ready.html#jax.block_until_ready), which you need to use to account for JAX's [asynchronous dispatch](https://jax.readthedocs.io/en/latest/async_dispatch.html)):
311
+
We can use IPython's `%timeit` magic to observe the speedup (note the use of [`jax.block_until_ready()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.block_until_ready.html#jax.block_until_ready), which we need to use to account for JAX's [asynchronous dispatch](https://jax.readthedocs.io/en/latest/async_dispatch.html)):
312
312
313
313
```{code-cell}
314
314
:id: SzU_0NU5Jq_W
@@ -333,7 +333,7 @@ JAX's documentation has more discussion of JIT compilation at [Just-in-time comp
333
333
334
334
### Automatic differentiation (autodiff)
335
335
336
-
For efficient optimization of neural network models, fast gradient computations are essential. JAX enables this via its [automatic differentiation](https://jax.readthedocs.io/en/latest/automatic-differentiation.html) transformations like [`jax.grad`](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html), which computes a closed-form gradient of a JAX function. In the neural network example, you used the [similar](https://flax.readthedocs.io/en/latest/guides/transforms.html)[`flax.nnx.grad`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.grad) function, which has special handling for [`flax.nnx`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/index.html) objects.
336
+
For efficient optimization of neural network models, fast gradient computations are essential. JAX enables this via its [automatic differentiation](https://jax.readthedocs.io/en/latest/automatic-differentiation.html) transformations like [`jax.grad`](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html), which computes a closed-form gradient of a JAX function. In the neural network example, we used the [similar](https://flax.readthedocs.io/en/latest/guides/transforms.html)[`flax.nnx.grad`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.grad) function, which has special handling for [`flax.nnx`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/index.html) objects.
337
337
338
338
Here's how to compute the gradient of a function with `jax.grad`:
339
339
@@ -347,7 +347,7 @@ jax.grad(selu)(x)
347
347
348
348
+++ {"id": "1P-UEh9VO94k"}
349
349
350
-
You can briefly check with a finite-difference approximation that this is giving the expected value:
350
+
We can briefly check with a finite-difference approximation that this is giving the expected value:
351
351
352
352
```{code-cell}
353
353
:id: 1gOc4FyzPDUC
@@ -365,7 +365,7 @@ Importantly, the automatic differentiation approach is both more accurate and ef
365
365
366
366
### Automatic vectorization
367
367
368
-
In the training loop example earlier, you defined the loss function in terms of a single input data vector of shape `n_features` but trained the model by passing batches of data (of shape `[n_samples, n_features]`). Rather than requiring a naive and slow loop over batches in Flax and Optax internals, they instead use JAX's [automatic vectorization](https://jax.readthedocs.io/en/latest/automatic-vectorization.html) via the [`jax.vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html#jax.vmap) transformation to construct a batched version of the kernel automatically.
368
+
In the training loop example earlier, we defined the loss function in terms of a single input data vector of shape `n_features` but trained the model by passing batches of data (of shape `[n_samples, n_features]`). Rather than requiring a naive and slow loop over batches in Flax and Optax internals, they instead use JAX's [automatic vectorization](https://jax.readthedocs.io/en/latest/automatic-vectorization.html) via the [`jax.vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html#jax.vmap) transformation to construct a batched version of the kernel automatically.
369
369
370
370
Consider a simple loss function that looks like this:
The problem is that this loss function is not batch-aware. Without automatic vectorization, there are two ways you can address this:
406
+
The problem is that this loss function is not batch-aware. Without automatic vectorization, there are two ways we can address this:
407
407
408
-
1. Re-write your loss function by hand to operate on batched data; however, as functions become more complicated, this becomes difficult and error-prone.
409
-
2. Naively loop over unbatched calls to your original function; however, this is easy to code, but can be slow because it doesn't take advantage of vectorized compute.
408
+
1. Re-write our loss function by hand to operate on batched data; however, as functions become more complicated, this becomes difficult and error-prone.
409
+
2. Naively loop over unbatched calls to our original function. However, this is easy to code, but can be slow because it doesn't take advantage of vectorized compute.
410
410
411
-
The [`jax.vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html#jax.vmap)[transformation](https://jax.readthedocs.io/en/latest/automatic-vectorization.html) offers a third way: it automatically transforms your original function into a batch-aware version, so you get the speed of option 1 with the ease of option 2:
411
+
The [`jax.vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html#jax.vmap)[transformation](https://jax.readthedocs.io/en/latest/automatic-vectorization.html) offers a third way: it automatically transforms our original function into a batch-aware version, so we get the speed of option 1 with the ease of option 2:
0 commit comments