Skip to content

Commit 1998729

Browse files
authored
Adding tensorboard tracking/visualization and nnx display example (#118)
* adding tensorboard tracking/visualization and nnx display example * Address Victor's initial feedback for information and links * Addressing feedback - correcting links and final code cell
1 parent 6364815 commit 1998729

10 files changed

+721
-0
lines changed

docs/JAX_visualizing_models_metrics.ipynb

Lines changed: 472 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
---
2+
jupytext:
3+
formats: ipynb,md:myst
4+
text_representation:
5+
extension: .md
6+
format_name: myst
7+
format_version: 0.13
8+
jupytext_version: 1.15.2
9+
kernelspec:
10+
display_name: Python 3 (ipykernel)
11+
language: python
12+
name: python3
13+
---
14+
15+
# JAX and Tensorboard / NNX Display
16+
17+
+++
18+
19+
To keep things straightforward and familiar, we reuse the model and data from '[Getting started with JAX for AI](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html)' - if you haven't read that yet and want the primer, start there before returning.
20+
21+
All of the modeling and training code is the same here. What we have added are the tensorboard connections and the discussion around them.
22+
23+
```{code-cell} ipython3
24+
import tensorflow as tf
25+
import io
26+
from datetime import datetime
27+
```
28+
29+
```{code-cell} ipython3
30+
:id: hKhPLnNxfOHU
31+
:outputId: ac3508f0-ccc6-409b-c719-99a4b8f94bd6
32+
33+
from sklearn.datasets import load_digits
34+
digits = load_digits()
35+
```
36+
37+
Here we set the location of the tensorflow writer - the organization is somewhat arbitrary, though keeping a folder for each training run can make later navigation more straightforward.
38+
39+
```{code-cell} ipython3
40+
file_path = "runs/test/" + datetime.now().strftime("%Y%m%d-%H%M%S")
41+
test_summary_writer = tf.summary.create_file_writer(file_path)
42+
```
43+
44+
Pulled from the official tensorboard examples, this convert function makes it simple to drop matplotlib figures directly into tensorboard
45+
46+
```{code-cell} ipython3
47+
def plot_to_image(figure):
48+
"""Sourced from https://www.tensorflow.org/tensorboard/image_summaries
49+
Converts the matplotlib plot specified by 'figure' to a PNG image and
50+
returns it. The supplied figure is closed and inaccessible after this call."""
51+
# Save the plot to a PNG in memory.
52+
buf = io.BytesIO()
53+
plt.savefig(buf, format='png')
54+
# Closing the figure prevents it from being displayed directly inside
55+
# the notebook.
56+
plt.close(figure)
57+
buf.seek(0)
58+
# Convert PNG buffer to TF image
59+
image = tf.image.decode_png(buf.getvalue(), channels=4)
60+
# Add the batch dimension
61+
image = tf.expand_dims(image, 0)
62+
return image
63+
```
64+
65+
Whereas previously the example displays the training data snapshot in the notebook, here we stash it in the tensorboard images. If a given training is to be repeated many, many times it can save space to stash the training data information as its own run and skip this step for each subsequent training, provided the input is static. Note that this pattern uses the writer in a `with` context manager. We are able to step into and out of this type of context through the run without losing the same file/folder experiment.
66+
67+
```{code-cell} ipython3
68+
:id: Y8cMntSdfyyT
69+
:outputId: 9343a558-cd8c-473c-c109-aa8015c7ae7e
70+
71+
import matplotlib.pyplot as plt
72+
73+
fig, axes = plt.subplots(10, 10, figsize=(6, 6),
74+
subplot_kw={'xticks':[], 'yticks':[]},
75+
gridspec_kw=dict(hspace=0.1, wspace=0.1))
76+
77+
for i, ax in enumerate(axes.flat):
78+
ax.imshow(digits.images[i], cmap='binary', interpolation='gaussian')
79+
ax.text(0.05, 0.05, str(digits.target[i]), transform=ax.transAxes, color='green')
80+
with test_summary_writer.as_default():
81+
tf.summary.image("Training Data", plot_to_image(fig), step=0)
82+
```
83+
84+
After running all above and launching `tensorboard --logdir runs/test` from the same folder, you should see the following in the supplied URL:
85+
86+
![image.png](./_static/training_data_example.png)
87+
88+
```{code-cell} ipython3
89+
:id: 6jrYisoPh6TL
90+
91+
from sklearn.model_selection import train_test_split
92+
splits = train_test_split(digits.images, digits.target, random_state=0)
93+
```
94+
95+
```{code-cell} ipython3
96+
:id: oMRcwKd4hqOo
97+
:outputId: 0ad36290-397b-431d-eba2-ef114daf5ea6
98+
99+
import jax.numpy as jnp
100+
images_train, images_test, label_train, label_test = map(jnp.asarray, splits)
101+
print(f"{images_train.shape=} {label_train.shape=}")
102+
print(f"{images_test.shape=} {label_test.shape=}")
103+
```
104+
105+
```{code-cell} ipython3
106+
:id: U77VMQwRjTfH
107+
:outputId: 345fed7a-4455-4036-85ed-57e673a4de01
108+
109+
from flax import nnx
110+
111+
class SimpleNN(nnx.Module):
112+
113+
def __init__(self, n_features: int = 64, n_hidden: int = 100, n_targets: int = 10,
114+
*, rngs: nnx.Rngs):
115+
self.n_features = n_features
116+
self.layer1 = nnx.Linear(n_features, n_hidden, rngs=rngs)
117+
self.layer2 = nnx.Linear(n_hidden, n_hidden, rngs=rngs)
118+
self.layer3 = nnx.Linear(n_hidden, n_targets, rngs=rngs)
119+
120+
def __call__(self, x):
121+
x = x.reshape(x.shape[0], self.n_features) # Flatten images.
122+
x = nnx.selu(self.layer1(x))
123+
x = nnx.selu(self.layer2(x))
124+
x = self.layer3(x)
125+
return x
126+
127+
model = SimpleNN(rngs=nnx.Rngs(0))
128+
129+
nnx.display(model) # Interactive display if penzai is installed.
130+
```
131+
132+
We've now created the basic model - the above cell will render an interactive view of the model. Which, when fully expanded, should look something like this:
133+
134+
![image.png](./_static/nnx_display_example.png)
135+
136+
+++
137+
138+
In order to track loss across our training run, we've collected the loss function call inside the training step:
139+
140+
```{code-cell} ipython3
141+
:id: QwRvFPkYl5b2
142+
143+
import jax
144+
import optax
145+
146+
optimizer = nnx.Optimizer(model, optax.sgd(learning_rate=0.05))
147+
148+
def loss_fun(
149+
model: nnx.Module,
150+
data: jax.Array,
151+
labels: jax.Array):
152+
logits = model(data)
153+
loss = optax.softmax_cross_entropy_with_integer_labels(
154+
logits=logits, labels=labels
155+
).mean()
156+
return loss, logits
157+
158+
@nnx.jit # JIT-compile the function
159+
def train_step(
160+
model: nnx.Module,
161+
optimizer: nnx.Optimizer,
162+
data: jax.Array,
163+
labels: jax.Array):
164+
loss_gradient = nnx.grad(loss_fun, has_aux=True) # gradient transform!
165+
grads, logits = loss_gradient(model, data, labels)
166+
optimizer.update(grads) # inplace update
167+
168+
# Calculate loss
169+
loss, _ = loss_fun(model, images_test, label_test)
170+
return loss
171+
```
172+
173+
Now, we've collected the metrics that were previously computed once at the end of training and called them throughout the `for` loop, as you would in an eval stage.
174+
175+
With the summary_writer context in place, we write out the `Loss` scalar every epoch, test the model accuracy every 10, and stash a accuracy test sheet every 500. Any custom metric can be added this way, through the tf.summary API.
176+
177+
```{code-cell} ipython3
178+
:id: l9mukT0eqmsr
179+
:outputId: c6c7b2d6-8706-4bc3-d5a6-0396d7cfbf56
180+
181+
max_epoch = 3000
182+
with test_summary_writer.as_default():
183+
for i in range(max_epoch):
184+
loss = train_step(model, optimizer, images_train, label_train)
185+
## Store the training loss per epoch
186+
tf.summary.scalar('Loss', loss.item(), step=i+1) #.item() because the loss coming out of train_step() is a tensor
187+
if ((i+1)%10 == 0) or i == 0:
188+
label_pred = model(images_test).argmax(axis=1)
189+
num_matches = jnp.count_nonzero(label_pred == label_test)
190+
num_total = len(label_test)
191+
accuracy = num_matches / num_total
192+
## store the evaluated Accuracy every 10 epochs
193+
tf.summary.scalar('Accuracy', accuracy.item(), step=i+1)
194+
if ((i+1)%500 == 0) or i == 0:
195+
fig, axes = None, None
196+
fig, axes = plt.subplots(10, 10, figsize=(6, 6),
197+
subplot_kw={'xticks':[], 'yticks':[]},
198+
gridspec_kw=dict(hspace=0.1, wspace=0.1))
199+
200+
label_pred = model(images_test).argmax(axis=1)
201+
202+
for j, ax in enumerate(axes.flat):
203+
ax.imshow(images_test[j], cmap='binary', interpolation='gaussian')
204+
color = 'green' if label_pred[j] == label_test[j] else 'red'
205+
ax.text(0.05, 0.05, str(label_pred[j]), transform=ax.transAxes, color=color)
206+
## store the Accuracy test sheet every 500 epochs - be sure to give each a different name, or they will overwrite the previous output.
207+
tf.summary.image(f"Step {i+1} Accuracy Testsheet", plot_to_image(fig), step=i+1)
208+
```
209+
210+
During the training has run, and after, the added `Loss` and `Accuracy` scalars are available in the tensorboard UI under the run folder we've dynamically created by the datetime.
211+
212+
The output there should look something like the following:
213+
214+
![image.png](./_static/loss_acc_example.png)
215+
216+
+++
217+
218+
Since we've stored the example test sheet every 500 epochs, it's easy to go back and step through the progress. With each training step using all of the training data the steps and epochs are essentially the same here.
219+
220+
At step 1, we see poor accuracy, as you would expect
221+
222+
![image.png](./_static/testsheet_start_example.png)
223+
224+
By 500, the model is essentially done, but we see the bottom row `7` get lost and recovered at higher epochs as we go far into an overfitting regime. This kind of stored data can be very useful when the training routines become automated and a human is potentially only looking when something has gone wrong.
225+
226+
![image.png](./_static/testsheets_500_3000.png)
227+
228+
+++
229+
230+
Finally, it can be useful to use nnx.display's ability to visualize networks and model output. Here we feed the top 35 test images into the model and display the final output vector for each - in the top plot, each row is an individual image prediction result: each column corresponds to a class, in this case the digits (0-9). Since we're calling the highest value in a given row the class prediction (`.argmax(axis=1)`), the final image predictions (bottom plot) simply match the largest value in each row in the upper plot.
231+
232+
```{code-cell} ipython3
233+
nnx.display(model(images_test[:35])), nnx.display(model(images_test[:35]).argmax(axis=1))
234+
```
235+
236+
The above cell output will give you an interactive plot that looks like this image below, where here we've 'clicked' in the bottom plot for entry `7` and hover over the corresponding value in the top plot.
237+
238+
![image.png](./_static/model_display_example.png)
239+
240+
+++
241+
242+
## Extra Resources
243+
244+
For further information about `TensorBoard` see [https://www.tensorflow.org/tensorboard/get_started](https://www.tensorflow.org/tensorboard/get_started)
245+
246+
For more about `nnx.display()`, which calls Treescope under the hood, see [https://treescope.readthedocs.io/en/stable/](https://treescope.readthedocs.io/en/stable/)

docs/_static/loss_acc_example.png

79 KB
Loading
49.4 KB
Loading

docs/_static/nnx_display_example.png

183 KB
Loading
221 KB
Loading

docs/_static/testsheets_500_3000.png

355 KB
Loading
135 KB
Loading

docs/conf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
'JAX_examples_image_segmentation.md',
5858
'JAX_Vision_transformer.md',
5959
'JAX_machine_translation.md',
60+
'JAX_visualizing_models_metrics.md',
6061
'JAX_image_captioning.md',
6162
'JAX_time_series_classification.md',
6263
'JAX_transformer_text_classification.md',
@@ -91,6 +92,7 @@
9192
'JAX_examples_image_segmentation.ipynb',
9293
'JAX_Vision_transformer.ipynb',
9394
'JAX_machine_translation.ipynb',
95+
'JAX_visualizing_models_metrics.ipynb',
9496
'JAX_image_captioning.ipynb',
9597
'JAX_time_series_classification.ipynb',
9698
'JAX_transformer_text_classification.ipynb',

docs/tutorials.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ JAX_basic_text_classification
1717
JAX_examples_image_segmentation
1818
JAX_Vision_transformer
1919
JAX_machine_translation
20+
JAX_visualizing_models_metrics
2021
JAX_image_captioning
2122
JAX_time_series_classification
2223
JAX_transformer_text_classification

0 commit comments

Comments
 (0)