|
| 1 | +--- |
| 2 | +jupytext: |
| 3 | + formats: ipynb,md:myst |
| 4 | + text_representation: |
| 5 | + extension: .md |
| 6 | + format_name: myst |
| 7 | + format_version: 0.13 |
| 8 | + jupytext_version: 1.15.2 |
| 9 | +kernelspec: |
| 10 | + display_name: jax-env |
| 11 | + language: python |
| 12 | + name: python3 |
| 13 | +--- |
| 14 | + |
| 15 | +# Time series classification with JAX |
| 16 | + |
| 17 | +In this tutorial, we're going to perform time series classification with a Convolutional Neural Network. |
| 18 | +We will use the FordA dataset from the [UCR archive](https://www.cs.ucr.edu/%7Eeamonn/time_series_data_2018/), |
| 19 | +which contains measurements of engine noise captured by a motor sensor. |
| 20 | + |
| 21 | +We need to assess if an engine is malfunctioning based on the recorded noises it generates. |
| 22 | +Each sample comprises of noise measurements across time, together with a "yes/no" label, |
| 23 | +so this is a binary classification problem. |
| 24 | + |
| 25 | +Although convolution models are mainly associated with image processing, they are also useful |
| 26 | +for time series data because they can extract temporal structures. |
| 27 | + |
| 28 | ++++ |
| 29 | + |
| 30 | +## Tools overview and setup |
| 31 | + |
| 32 | +Here's a list of key packages that belong to the JAX AI stack required for this tutorial: |
| 33 | + |
| 34 | +- [JAX](https://github.com/jax-ml/jax) for array computations. |
| 35 | +- [Flax](https://github.com/google/flax) for constructing neural networks. |
| 36 | +- [Optax](https://github.com/google-deepmind/optax) for gradient processing and optimization. |
| 37 | +- [Grain](https://github.com/google/grain/) to define data sources. |
| 38 | +- [tqdm](https://tqdm.github.io/) for a progress bar to monitor the training progress. |
| 39 | + |
| 40 | +We'll start by installing and importing these packages. |
| 41 | + |
| 42 | +```{code-cell} ipython3 |
| 43 | +# Required packages |
| 44 | +# !pip install -U jax flax optax |
| 45 | +# !pip install -U grain tqdm requests matplotlib |
| 46 | +``` |
| 47 | + |
| 48 | +```{code-cell} ipython3 |
| 49 | +import jax |
| 50 | +import jax.numpy as jnp |
| 51 | +from flax import nnx |
| 52 | +import optax |
| 53 | +
|
| 54 | +import numpy as np |
| 55 | +import matplotlib.pyplot as plt |
| 56 | +import grain.python as grain |
| 57 | +import tqdm |
| 58 | +``` |
| 59 | + |
| 60 | +## Load the dataset |
| 61 | + |
| 62 | +We load dataset files into NumPy arrays, add singleton dimension to take convolution features |
| 63 | +into account, and change `-1` label to `0` (so that the expected values are `0` and `1`): |
| 64 | + |
| 65 | +```{code-cell} ipython3 |
| 66 | +def prepare_ucr_dataset() -> tuple: |
| 67 | + root_url = "https://raw.githubusercontent.com/hfawaz/cd-diagram/master/FordA/" |
| 68 | +
|
| 69 | + train_data = np.loadtxt(root_url + "FordA_TRAIN.tsv", delimiter="\t") |
| 70 | + x_train, y_train = train_data[:, 1:], train_data[:, 0].astype(int) |
| 71 | +
|
| 72 | + test_data = np.loadtxt(root_url + "FordA_TEST.tsv", delimiter="\t") |
| 73 | + x_test, y_test = test_data[:, 1:], test_data[:, 0].astype(int) |
| 74 | +
|
| 75 | + x_train = x_train.reshape((*x_train.shape, 1)) |
| 76 | + x_test = x_test.reshape((*x_test.shape, 1)) |
| 77 | +
|
| 78 | + rng = np.random.RandomState(113) |
| 79 | + indices = rng.permutation(len(x_train)) |
| 80 | + x_train = x_train[indices] |
| 81 | + y_train = y_train[indices] |
| 82 | +
|
| 83 | + y_train[y_train == -1] = 0 |
| 84 | + y_test[y_test == -1] = 0 |
| 85 | +
|
| 86 | + return (x_train, y_train), (x_test, y_test) |
| 87 | +``` |
| 88 | + |
| 89 | +```{code-cell} ipython3 |
| 90 | +(x_train, y_train), (x_test, y_test) = prepare_ucr_dataset() |
| 91 | +``` |
| 92 | + |
| 93 | +Let's visualize example samples from each class. |
| 94 | + |
| 95 | +```{code-cell} ipython3 |
| 96 | +classes = np.unique(np.concatenate((y_train, y_test), axis=0)) |
| 97 | +for c in classes: |
| 98 | + c_x_train = x_train[y_train == c] |
| 99 | + plt.plot(c_x_train[0], label="class " + str(c)) |
| 100 | +plt.legend() |
| 101 | +plt.show() |
| 102 | +``` |
| 103 | + |
| 104 | +### Create a Data Loader using Grain |
| 105 | + |
| 106 | +For handling input data we're going to use Grain, a pure Python package developed for JAX and |
| 107 | +Flax models. |
| 108 | + |
| 109 | +Grain follows the source-sampler-loader paradigm. Grain supports custom setups where data sources |
| 110 | +might come in different forms, but they all need to implement the `grain.RandomAccessDataSource` |
| 111 | +interface. See [PyGrain Data Sources](https://github.com/google/grain/blob/main/docs/data_sources.md) |
| 112 | +for more details. |
| 113 | + |
| 114 | +Our dataset is comprised of relatively small NumPy arrays so our `DataSource` is uncomplicated: |
| 115 | + |
| 116 | +```{code-cell} ipython3 |
| 117 | +class DataSource(grain.RandomAccessDataSource): |
| 118 | + def __init__(self, x, y): |
| 119 | + self._x = x |
| 120 | + self._y = y |
| 121 | +
|
| 122 | + def __getitem__(self, idx): |
| 123 | + return {"measurement": self._x[idx], "label": self._y[idx]} |
| 124 | +
|
| 125 | + def __len__(self): |
| 126 | + return len(self._x) |
| 127 | +``` |
| 128 | + |
| 129 | +```{code-cell} ipython3 |
| 130 | +train_source = DataSource(x_train, y_train) |
| 131 | +test_source = DataSource(x_test, y_test) |
| 132 | +``` |
| 133 | + |
| 134 | +Samplers determine the order in which records are processed, and we'll use the |
| 135 | +[`IndexSmapler`](https://github.com/google/grain/blob/main/docs/data_loader/samplers.md#index-sampler) |
| 136 | +recommended by Grain. |
| 137 | + |
| 138 | +Finally, we'll create `DataLoader`s that handle orchestration of loading. |
| 139 | +We'll leverage Grain's multiprocessing capabilities to scale processing up to 4 workers. |
| 140 | + |
| 141 | +```{code-cell} ipython3 |
| 142 | +seed = 12 |
| 143 | +train_batch_size = 128 |
| 144 | +test_batch_size = 2 * train_batch_size |
| 145 | +
|
| 146 | +train_sampler = grain.IndexSampler( |
| 147 | + len(train_source), |
| 148 | + shuffle=True, |
| 149 | + seed=seed, |
| 150 | + shard_options=grain.NoSharding(), # No sharding since this is a single-device setup |
| 151 | + num_epochs=1, # Iterate over the dataset for one epoch |
| 152 | +) |
| 153 | +
|
| 154 | +test_sampler = grain.IndexSampler( |
| 155 | + len(test_source), |
| 156 | + shuffle=False, |
| 157 | + seed=seed, |
| 158 | + shard_options=grain.NoSharding(), # No sharding since this is a single-device setup |
| 159 | + num_epochs=1, # Iterate over the dataset for one epoch |
| 160 | +) |
| 161 | +
|
| 162 | +
|
| 163 | +train_loader = grain.DataLoader( |
| 164 | + data_source=train_source, |
| 165 | + sampler=train_sampler, # Sampler to determine how to access the data |
| 166 | + worker_count=4, # Number of child processes launched to parallelize the transformations among |
| 167 | + worker_buffer_size=2, # Count of output batches to produce in advance per worker |
| 168 | + operations=[ |
| 169 | + grain.Batch(train_batch_size, drop_remainder=True), |
| 170 | + ] |
| 171 | +) |
| 172 | +
|
| 173 | +test_loader = grain.DataLoader( |
| 174 | + data_source=test_source, |
| 175 | + sampler=test_sampler, # Sampler to determine how to access the data |
| 176 | + worker_count=4, # Number of child processes launched to parallelize the transformations among |
| 177 | + worker_buffer_size=2, # Count of output batches to produce in advance per worker |
| 178 | + operations=[ |
| 179 | + grain.Batch(test_batch_size), |
| 180 | + ] |
| 181 | +) |
| 182 | +``` |
| 183 | + |
| 184 | +## Define the Model |
| 185 | + |
| 186 | +Let's now construct the Convolutional Neural Network with Flax by subclassing `nnx.Module`. |
| 187 | +You can learn more about the [Flax NNX module system in the Flax documentation](https://flax.readthedocs.io/en/latest/nnx_basics.html#the-flax-nnx-module-system). |
| 188 | + |
| 189 | +Let's have three convolution and dense layers, and use ReLU activation function for middle |
| 190 | +layers and softmax in the final layer for binary classification output. |
| 191 | + |
| 192 | +```{code-cell} ipython3 |
| 193 | +class MyModel(nnx.Module): |
| 194 | + def __init__(self, rngs: nnx.Rngs): |
| 195 | + self.conv_1 = nnx.Conv( |
| 196 | + in_features=1, out_features=64, kernel_size=3, padding="SAME", rngs=rngs |
| 197 | + ) |
| 198 | + self.layer_norm_1 = nnx.LayerNorm(num_features=64, epsilon=0.001, rngs=rngs) |
| 199 | +
|
| 200 | + self.conv_2 = nnx.Conv( |
| 201 | + in_features=64, out_features=64, kernel_size=3, padding="SAME", rngs=rngs |
| 202 | + ) |
| 203 | + self.layer_norm_2 = nnx.LayerNorm(num_features=64, epsilon=0.001, rngs=rngs) |
| 204 | +
|
| 205 | + self.conv_3 = nnx.Conv( |
| 206 | + in_features=64, out_features=64, kernel_size=3, padding="SAME", rngs=rngs |
| 207 | + ) |
| 208 | + self.layer_norm_3 = nnx.LayerNorm(num_features=64, epsilon=0.001, rngs=rngs) |
| 209 | +
|
| 210 | + self.dense_1 = nnx.Linear(in_features=64, out_features=2, rngs=rngs) |
| 211 | +
|
| 212 | + def __call__(self, x: jax.Array): |
| 213 | + x = self.conv_1(x) |
| 214 | + x = self.layer_norm_1(x) |
| 215 | + x = jax.nn.relu(x) |
| 216 | +
|
| 217 | + x = self.conv_2(x) |
| 218 | + x = self.layer_norm_2(x) |
| 219 | + x = jax.nn.relu(x) |
| 220 | +
|
| 221 | + x = self.conv_3(x) |
| 222 | + x = self.layer_norm_3(x) |
| 223 | + x = jax.nn.relu(x) |
| 224 | +
|
| 225 | + x = jnp.mean(x, axis=(1,)) # global average pooling |
| 226 | + x = self.dense_1(x) |
| 227 | + x = jax.nn.softmax(x) |
| 228 | + return x |
| 229 | +``` |
| 230 | + |
| 231 | +```{code-cell} ipython3 |
| 232 | +model = MyModel(rngs=nnx.Rngs(0)) |
| 233 | +nnx.display(model) |
| 234 | +``` |
| 235 | + |
| 236 | +## Train the Model |
| 237 | + |
| 238 | +To train our Flax model we need to construct an `nnx.Optimizer` object with our model and |
| 239 | +a selected optimization algorithm. The optimizer object manages the model’s parameters and |
| 240 | +applies gradients during training. |
| 241 | + |
| 242 | +We're going to use [Adam optimizer](https://optax.readthedocs.io/en/latest/api/optimizers.html#adam), |
| 243 | +a popular choice for Deep Learning models. We'll use it through |
| 244 | +[Optax](https://optax.readthedocs.io/en/latest/index.html), an optimization library developed for JAX. |
| 245 | + |
| 246 | +```{code-cell} ipython3 |
| 247 | +num_epochs = 300 |
| 248 | +learning_rate = 0.0005 |
| 249 | +momentum = 0.9 |
| 250 | +
|
| 251 | +optimizer = nnx.Optimizer(model, optax.adam(learning_rate, momentum)) |
| 252 | +``` |
| 253 | + |
| 254 | +We'll define a loss and logits computation function using Optax's |
| 255 | +[`losses.softmax_cross_entropy_with_integer_labels`](https://optax.readthedocs.io/en/latest/api/losses.html#optax.losses.softmax_cross_entropy_with_integer_labels). |
| 256 | + |
| 257 | +```{code-cell} ipython3 |
| 258 | +def compute_losses_and_logits(model: nnx.Module, batch_tokens: jax.Array, labels: jax.Array): |
| 259 | + logits = model(batch_tokens) |
| 260 | +
|
| 261 | + loss = optax.softmax_cross_entropy_with_integer_labels( |
| 262 | + logits=logits, labels=labels |
| 263 | + ).mean() |
| 264 | + return loss, logits |
| 265 | +``` |
| 266 | + |
| 267 | +We'll now define the training and evaluation step functions. The loss and logits from both |
| 268 | +functions will be used for calculating accuracy metrics. |
| 269 | + |
| 270 | +For training, we'll use `nnx.value_and_grad` to compute the gradients, and then update |
| 271 | +the model’s parameters using our optimizer. |
| 272 | + |
| 273 | +Notice the use of [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit). This sets up the functions for just-in-time (JIT) compilation with [XLA](https://openxla.org/xla) |
| 274 | +for performant execution across different hardware accelerators like GPUs and TPUs. |
| 275 | + |
| 276 | +```{code-cell} ipython3 |
| 277 | +@nnx.jit |
| 278 | +def train_step( |
| 279 | + model: nnx.Module, optimizer: nnx.Optimizer, batch: dict[str, jax.Array] |
| 280 | +): |
| 281 | + batch_tokens = jnp.array(batch["measurement"]) |
| 282 | + labels = jnp.array(batch["label"], dtype=jnp.int32) |
| 283 | +
|
| 284 | + grad_fn = nnx.value_and_grad(compute_losses_and_logits, has_aux=True) |
| 285 | + (loss, logits), grads = grad_fn(model, batch_tokens, labels) |
| 286 | +
|
| 287 | + optimizer.update(grads) # In-place updates. |
| 288 | +
|
| 289 | + return loss |
| 290 | +
|
| 291 | +@nnx.jit |
| 292 | +def eval_step( |
| 293 | + model: nnx.Module, batch: dict[str, jax.Array], eval_metrics: nnx.MultiMetric |
| 294 | +): |
| 295 | + batch_tokens = jnp.array(batch["measurement"]) |
| 296 | + labels = jnp.array(batch["label"], dtype=jnp.int32) |
| 297 | + loss, logits = compute_losses_and_logits(model, batch_tokens, labels) |
| 298 | +
|
| 299 | + eval_metrics.update( |
| 300 | + loss=loss, |
| 301 | + logits=logits, |
| 302 | + labels=labels, |
| 303 | + ) |
| 304 | +``` |
| 305 | + |
| 306 | +```{code-cell} ipython3 |
| 307 | +eval_metrics = nnx.MultiMetric( |
| 308 | + loss=nnx.metrics.Average('loss'), |
| 309 | + accuracy=nnx.metrics.Accuracy(), |
| 310 | +) |
| 311 | +
|
| 312 | +train_metrics_history = { |
| 313 | + "train_loss": [], |
| 314 | +} |
| 315 | +
|
| 316 | +eval_metrics_history = { |
| 317 | + "test_loss": [], |
| 318 | + "test_accuracy": [], |
| 319 | +} |
| 320 | +``` |
| 321 | + |
| 322 | +We can now train the CNN model. We'll evaluate the model’s performance on the test set |
| 323 | +after each epoch, and print the metrics: total loss and accuracy. |
| 324 | + |
| 325 | +```{code-cell} ipython3 |
| 326 | +bar_format = "{desc}[{n_fmt}/{total_fmt}]{postfix} [{elapsed}<{remaining}]" |
| 327 | +train_total_steps = len(x_train) // train_batch_size |
| 328 | +
|
| 329 | +def train_one_epoch(epoch: int): |
| 330 | + model.train() |
| 331 | + with tqdm.tqdm( |
| 332 | + desc=f"[train] epoch: {epoch}/{num_epochs}, ", |
| 333 | + total=train_total_steps, |
| 334 | + bar_format=bar_format, |
| 335 | + miniters=10, |
| 336 | + leave=True, |
| 337 | + ) as pbar: |
| 338 | + for batch in train_loader: |
| 339 | + loss = train_step(model, optimizer, batch) |
| 340 | + train_metrics_history["train_loss"].append(loss.item()) |
| 341 | + pbar.set_postfix({"loss": loss.item()}) |
| 342 | + pbar.update(1) |
| 343 | +
|
| 344 | +def evaluate_model(epoch: int): |
| 345 | + # Compute the metrics on the train and val sets after each training epoch. |
| 346 | + model.eval() |
| 347 | +
|
| 348 | + eval_metrics.reset() # Reset the eval metrics |
| 349 | + for test_batch in test_loader: |
| 350 | + eval_step(model, test_batch, eval_metrics) |
| 351 | +
|
| 352 | + for metric, value in eval_metrics.compute().items(): |
| 353 | + eval_metrics_history[f'test_{metric}'].append(value) |
| 354 | +
|
| 355 | + if epoch % 10 == 0: |
| 356 | + print(f"[test] epoch: {epoch + 1}/{num_epochs}") |
| 357 | + print(f"- total loss: {eval_metrics_history['test_loss'][-1]:0.4f}") |
| 358 | + print(f"- Accuracy: {eval_metrics_history['test_accuracy'][-1]:0.4f}") |
| 359 | +``` |
| 360 | + |
| 361 | +```{code-cell} ipython3 |
| 362 | +%%time |
| 363 | +for epoch in range(num_epochs): |
| 364 | + train_one_epoch(epoch) |
| 365 | + evaluate_model(epoch) |
| 366 | +``` |
| 367 | + |
| 368 | +Finally, let's visualize the loss and accuracy with Matplotlib. |
| 369 | + |
| 370 | +```{code-cell} ipython3 |
| 371 | +plt.plot(train_metrics_history["train_loss"], label="Loss value during the training") |
| 372 | +plt.legend() |
| 373 | +``` |
| 374 | + |
| 375 | +```{code-cell} ipython3 |
| 376 | +fig, axs = plt.subplots(1, 2, figsize=(10, 5)) |
| 377 | +axs[0].set_title("Loss value on test set") |
| 378 | +axs[0].plot(eval_metrics_history["test_loss"]) |
| 379 | +axs[1].set_title("Accuracy on test set") |
| 380 | +axs[1].plot(eval_metrics_history["test_accuracy"]) |
| 381 | +``` |
| 382 | + |
| 383 | +Our model reached almost 90% accuracy on the test set after 300 epochs, but it's worth noting |
| 384 | +that the loss function isn't completely flat yet. We could continue until the curve flattens, |
| 385 | +but we also need to pay attention to validation accuracy so as to spot when the model starts |
| 386 | +overfitting. |
| 387 | + |
| 388 | +For model early stopping and selecting best model, you can check out [Orbax](https://github.com/google/orbax), |
| 389 | +a library which provides checkpointing and persistence utilities. |
0 commit comments