Skip to content

Commit 47125fe

Browse files
authored
NLP: JAX time series classification (#105)
1 parent 4ec7510 commit 47125fe

File tree

4 files changed

+1931
-0
lines changed

4 files changed

+1931
-0
lines changed

docs/JAX_time_series_classification.ipynb

Lines changed: 1539 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 389 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,389 @@
1+
---
2+
jupytext:
3+
formats: ipynb,md:myst
4+
text_representation:
5+
extension: .md
6+
format_name: myst
7+
format_version: 0.13
8+
jupytext_version: 1.15.2
9+
kernelspec:
10+
display_name: jax-env
11+
language: python
12+
name: python3
13+
---
14+
15+
# Time series classification with JAX
16+
17+
In this tutorial, we're going to perform time series classification with a Convolutional Neural Network.
18+
We will use the FordA dataset from the [UCR archive](https://www.cs.ucr.edu/%7Eeamonn/time_series_data_2018/),
19+
which contains measurements of engine noise captured by a motor sensor.
20+
21+
We need to assess if an engine is malfunctioning based on the recorded noises it generates.
22+
Each sample comprises of noise measurements across time, together with a "yes/no" label,
23+
so this is a binary classification problem.
24+
25+
Although convolution models are mainly associated with image processing, they are also useful
26+
for time series data because they can extract temporal structures.
27+
28+
+++
29+
30+
## Tools overview and setup
31+
32+
Here's a list of key packages that belong to the JAX AI stack required for this tutorial:
33+
34+
- [JAX](https://github.com/jax-ml/jax) for array computations.
35+
- [Flax](https://github.com/google/flax) for constructing neural networks.
36+
- [Optax](https://github.com/google-deepmind/optax) for gradient processing and optimization.
37+
- [Grain](https://github.com/google/grain/) to define data sources.
38+
- [tqdm](https://tqdm.github.io/) for a progress bar to monitor the training progress.
39+
40+
We'll start by installing and importing these packages.
41+
42+
```{code-cell} ipython3
43+
# Required packages
44+
# !pip install -U jax flax optax
45+
# !pip install -U grain tqdm requests matplotlib
46+
```
47+
48+
```{code-cell} ipython3
49+
import jax
50+
import jax.numpy as jnp
51+
from flax import nnx
52+
import optax
53+
54+
import numpy as np
55+
import matplotlib.pyplot as plt
56+
import grain.python as grain
57+
import tqdm
58+
```
59+
60+
## Load the dataset
61+
62+
We load dataset files into NumPy arrays, add singleton dimension to take convolution features
63+
into account, and change `-1` label to `0` (so that the expected values are `0` and `1`):
64+
65+
```{code-cell} ipython3
66+
def prepare_ucr_dataset() -> tuple:
67+
root_url = "https://raw.githubusercontent.com/hfawaz/cd-diagram/master/FordA/"
68+
69+
train_data = np.loadtxt(root_url + "FordA_TRAIN.tsv", delimiter="\t")
70+
x_train, y_train = train_data[:, 1:], train_data[:, 0].astype(int)
71+
72+
test_data = np.loadtxt(root_url + "FordA_TEST.tsv", delimiter="\t")
73+
x_test, y_test = test_data[:, 1:], test_data[:, 0].astype(int)
74+
75+
x_train = x_train.reshape((*x_train.shape, 1))
76+
x_test = x_test.reshape((*x_test.shape, 1))
77+
78+
rng = np.random.RandomState(113)
79+
indices = rng.permutation(len(x_train))
80+
x_train = x_train[indices]
81+
y_train = y_train[indices]
82+
83+
y_train[y_train == -1] = 0
84+
y_test[y_test == -1] = 0
85+
86+
return (x_train, y_train), (x_test, y_test)
87+
```
88+
89+
```{code-cell} ipython3
90+
(x_train, y_train), (x_test, y_test) = prepare_ucr_dataset()
91+
```
92+
93+
Let's visualize example samples from each class.
94+
95+
```{code-cell} ipython3
96+
classes = np.unique(np.concatenate((y_train, y_test), axis=0))
97+
for c in classes:
98+
c_x_train = x_train[y_train == c]
99+
plt.plot(c_x_train[0], label="class " + str(c))
100+
plt.legend()
101+
plt.show()
102+
```
103+
104+
### Create a Data Loader using Grain
105+
106+
For handling input data we're going to use Grain, a pure Python package developed for JAX and
107+
Flax models.
108+
109+
Grain follows the source-sampler-loader paradigm. Grain supports custom setups where data sources
110+
might come in different forms, but they all need to implement the `grain.RandomAccessDataSource`
111+
interface. See [PyGrain Data Sources](https://github.com/google/grain/blob/main/docs/data_sources.md)
112+
for more details.
113+
114+
Our dataset is comprised of relatively small NumPy arrays so our `DataSource` is uncomplicated:
115+
116+
```{code-cell} ipython3
117+
class DataSource(grain.RandomAccessDataSource):
118+
def __init__(self, x, y):
119+
self._x = x
120+
self._y = y
121+
122+
def __getitem__(self, idx):
123+
return {"measurement": self._x[idx], "label": self._y[idx]}
124+
125+
def __len__(self):
126+
return len(self._x)
127+
```
128+
129+
```{code-cell} ipython3
130+
train_source = DataSource(x_train, y_train)
131+
test_source = DataSource(x_test, y_test)
132+
```
133+
134+
Samplers determine the order in which records are processed, and we'll use the
135+
[`IndexSmapler`](https://github.com/google/grain/blob/main/docs/data_loader/samplers.md#index-sampler)
136+
recommended by Grain.
137+
138+
Finally, we'll create `DataLoader`s that handle orchestration of loading.
139+
We'll leverage Grain's multiprocessing capabilities to scale processing up to 4 workers.
140+
141+
```{code-cell} ipython3
142+
seed = 12
143+
train_batch_size = 128
144+
test_batch_size = 2 * train_batch_size
145+
146+
train_sampler = grain.IndexSampler(
147+
len(train_source),
148+
shuffle=True,
149+
seed=seed,
150+
shard_options=grain.NoSharding(), # No sharding since this is a single-device setup
151+
num_epochs=1, # Iterate over the dataset for one epoch
152+
)
153+
154+
test_sampler = grain.IndexSampler(
155+
len(test_source),
156+
shuffle=False,
157+
seed=seed,
158+
shard_options=grain.NoSharding(), # No sharding since this is a single-device setup
159+
num_epochs=1, # Iterate over the dataset for one epoch
160+
)
161+
162+
163+
train_loader = grain.DataLoader(
164+
data_source=train_source,
165+
sampler=train_sampler, # Sampler to determine how to access the data
166+
worker_count=4, # Number of child processes launched to parallelize the transformations among
167+
worker_buffer_size=2, # Count of output batches to produce in advance per worker
168+
operations=[
169+
grain.Batch(train_batch_size, drop_remainder=True),
170+
]
171+
)
172+
173+
test_loader = grain.DataLoader(
174+
data_source=test_source,
175+
sampler=test_sampler, # Sampler to determine how to access the data
176+
worker_count=4, # Number of child processes launched to parallelize the transformations among
177+
worker_buffer_size=2, # Count of output batches to produce in advance per worker
178+
operations=[
179+
grain.Batch(test_batch_size),
180+
]
181+
)
182+
```
183+
184+
## Define the Model
185+
186+
Let's now construct the Convolutional Neural Network with Flax by subclassing `nnx.Module`.
187+
You can learn more about the [Flax NNX module system in the Flax documentation](https://flax.readthedocs.io/en/latest/nnx_basics.html#the-flax-nnx-module-system).
188+
189+
Let's have three convolution and dense layers, and use ReLU activation function for middle
190+
layers and softmax in the final layer for binary classification output.
191+
192+
```{code-cell} ipython3
193+
class MyModel(nnx.Module):
194+
def __init__(self, rngs: nnx.Rngs):
195+
self.conv_1 = nnx.Conv(
196+
in_features=1, out_features=64, kernel_size=3, padding="SAME", rngs=rngs
197+
)
198+
self.layer_norm_1 = nnx.LayerNorm(num_features=64, epsilon=0.001, rngs=rngs)
199+
200+
self.conv_2 = nnx.Conv(
201+
in_features=64, out_features=64, kernel_size=3, padding="SAME", rngs=rngs
202+
)
203+
self.layer_norm_2 = nnx.LayerNorm(num_features=64, epsilon=0.001, rngs=rngs)
204+
205+
self.conv_3 = nnx.Conv(
206+
in_features=64, out_features=64, kernel_size=3, padding="SAME", rngs=rngs
207+
)
208+
self.layer_norm_3 = nnx.LayerNorm(num_features=64, epsilon=0.001, rngs=rngs)
209+
210+
self.dense_1 = nnx.Linear(in_features=64, out_features=2, rngs=rngs)
211+
212+
def __call__(self, x: jax.Array):
213+
x = self.conv_1(x)
214+
x = self.layer_norm_1(x)
215+
x = jax.nn.relu(x)
216+
217+
x = self.conv_2(x)
218+
x = self.layer_norm_2(x)
219+
x = jax.nn.relu(x)
220+
221+
x = self.conv_3(x)
222+
x = self.layer_norm_3(x)
223+
x = jax.nn.relu(x)
224+
225+
x = jnp.mean(x, axis=(1,)) # global average pooling
226+
x = self.dense_1(x)
227+
x = jax.nn.softmax(x)
228+
return x
229+
```
230+
231+
```{code-cell} ipython3
232+
model = MyModel(rngs=nnx.Rngs(0))
233+
nnx.display(model)
234+
```
235+
236+
## Train the Model
237+
238+
To train our Flax model we need to construct an `nnx.Optimizer` object with our model and
239+
a selected optimization algorithm. The optimizer object manages the model’s parameters and
240+
applies gradients during training.
241+
242+
We're going to use [Adam optimizer](https://optax.readthedocs.io/en/latest/api/optimizers.html#adam),
243+
a popular choice for Deep Learning models. We'll use it through
244+
[Optax](https://optax.readthedocs.io/en/latest/index.html), an optimization library developed for JAX.
245+
246+
```{code-cell} ipython3
247+
num_epochs = 300
248+
learning_rate = 0.0005
249+
momentum = 0.9
250+
251+
optimizer = nnx.Optimizer(model, optax.adam(learning_rate, momentum))
252+
```
253+
254+
We'll define a loss and logits computation function using Optax's
255+
[`losses.softmax_cross_entropy_with_integer_labels`](https://optax.readthedocs.io/en/latest/api/losses.html#optax.losses.softmax_cross_entropy_with_integer_labels).
256+
257+
```{code-cell} ipython3
258+
def compute_losses_and_logits(model: nnx.Module, batch_tokens: jax.Array, labels: jax.Array):
259+
logits = model(batch_tokens)
260+
261+
loss = optax.softmax_cross_entropy_with_integer_labels(
262+
logits=logits, labels=labels
263+
).mean()
264+
return loss, logits
265+
```
266+
267+
We'll now define the training and evaluation step functions. The loss and logits from both
268+
functions will be used for calculating accuracy metrics.
269+
270+
For training, we'll use `nnx.value_and_grad` to compute the gradients, and then update
271+
the model’s parameters using our optimizer.
272+
273+
Notice the use of [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit). This sets up the functions for just-in-time (JIT) compilation with [XLA](https://openxla.org/xla)
274+
for performant execution across different hardware accelerators like GPUs and TPUs.
275+
276+
```{code-cell} ipython3
277+
@nnx.jit
278+
def train_step(
279+
model: nnx.Module, optimizer: nnx.Optimizer, batch: dict[str, jax.Array]
280+
):
281+
batch_tokens = jnp.array(batch["measurement"])
282+
labels = jnp.array(batch["label"], dtype=jnp.int32)
283+
284+
grad_fn = nnx.value_and_grad(compute_losses_and_logits, has_aux=True)
285+
(loss, logits), grads = grad_fn(model, batch_tokens, labels)
286+
287+
optimizer.update(grads) # In-place updates.
288+
289+
return loss
290+
291+
@nnx.jit
292+
def eval_step(
293+
model: nnx.Module, batch: dict[str, jax.Array], eval_metrics: nnx.MultiMetric
294+
):
295+
batch_tokens = jnp.array(batch["measurement"])
296+
labels = jnp.array(batch["label"], dtype=jnp.int32)
297+
loss, logits = compute_losses_and_logits(model, batch_tokens, labels)
298+
299+
eval_metrics.update(
300+
loss=loss,
301+
logits=logits,
302+
labels=labels,
303+
)
304+
```
305+
306+
```{code-cell} ipython3
307+
eval_metrics = nnx.MultiMetric(
308+
loss=nnx.metrics.Average('loss'),
309+
accuracy=nnx.metrics.Accuracy(),
310+
)
311+
312+
train_metrics_history = {
313+
"train_loss": [],
314+
}
315+
316+
eval_metrics_history = {
317+
"test_loss": [],
318+
"test_accuracy": [],
319+
}
320+
```
321+
322+
We can now train the CNN model. We'll evaluate the model’s performance on the test set
323+
after each epoch, and print the metrics: total loss and accuracy.
324+
325+
```{code-cell} ipython3
326+
bar_format = "{desc}[{n_fmt}/{total_fmt}]{postfix} [{elapsed}<{remaining}]"
327+
train_total_steps = len(x_train) // train_batch_size
328+
329+
def train_one_epoch(epoch: int):
330+
model.train()
331+
with tqdm.tqdm(
332+
desc=f"[train] epoch: {epoch}/{num_epochs}, ",
333+
total=train_total_steps,
334+
bar_format=bar_format,
335+
miniters=10,
336+
leave=True,
337+
) as pbar:
338+
for batch in train_loader:
339+
loss = train_step(model, optimizer, batch)
340+
train_metrics_history["train_loss"].append(loss.item())
341+
pbar.set_postfix({"loss": loss.item()})
342+
pbar.update(1)
343+
344+
def evaluate_model(epoch: int):
345+
# Compute the metrics on the train and val sets after each training epoch.
346+
model.eval()
347+
348+
eval_metrics.reset() # Reset the eval metrics
349+
for test_batch in test_loader:
350+
eval_step(model, test_batch, eval_metrics)
351+
352+
for metric, value in eval_metrics.compute().items():
353+
eval_metrics_history[f'test_{metric}'].append(value)
354+
355+
if epoch % 10 == 0:
356+
print(f"[test] epoch: {epoch + 1}/{num_epochs}")
357+
print(f"- total loss: {eval_metrics_history['test_loss'][-1]:0.4f}")
358+
print(f"- Accuracy: {eval_metrics_history['test_accuracy'][-1]:0.4f}")
359+
```
360+
361+
```{code-cell} ipython3
362+
%%time
363+
for epoch in range(num_epochs):
364+
train_one_epoch(epoch)
365+
evaluate_model(epoch)
366+
```
367+
368+
Finally, let's visualize the loss and accuracy with Matplotlib.
369+
370+
```{code-cell} ipython3
371+
plt.plot(train_metrics_history["train_loss"], label="Loss value during the training")
372+
plt.legend()
373+
```
374+
375+
```{code-cell} ipython3
376+
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
377+
axs[0].set_title("Loss value on test set")
378+
axs[0].plot(eval_metrics_history["test_loss"])
379+
axs[1].set_title("Accuracy on test set")
380+
axs[1].plot(eval_metrics_history["test_accuracy"])
381+
```
382+
383+
Our model reached almost 90% accuracy on the test set after 300 epochs, but it's worth noting
384+
that the loss function isn't completely flat yet. We could continue until the curve flattens,
385+
but we also need to pay attention to validation accuracy so as to spot when the model starts
386+
overfitting.
387+
388+
For model early stopping and selecting best model, you can check out [Orbax](https://github.com/google/orbax),
389+
a library which provides checkpointing and persistence utilities.

0 commit comments

Comments
 (0)