|
6 | 6 | "id": "PUFGZggH49zp"
|
7 | 7 | },
|
8 | 8 | "source": [
|
9 |
| - "## Introduction to Data Loaders for Multi-Device Training with JAX" |
| 9 | + "# Introduction to Data Loaders for Multi-Device Training with JAX" |
10 | 10 | ]
|
11 | 11 | },
|
12 | 12 | {
|
|
15 | 15 | "id": "3ia4PKEV5Dr8"
|
16 | 16 | },
|
17 | 17 | "source": [
|
| 18 | + "[](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/data_loaders_for_multi_device_setups_with_jax.ipynb)\n", |
| 19 | + "\n", |
18 | 20 | "This tutorial explores various data loading strategies for **JAX** in **multi-device distributed** environments, leveraging [**TPUs**](https://jax.readthedocs.io/en/latest/pallas/tpu/details.html#what-is-a-tpu). While JAX doesn't include a built-in data loader, it seamlessly integrates with popular data loading libraries, including:\n",
|
19 | 21 | "* [**PyTorch DataLoader**](https://github.com/pytorch/data)\n",
|
20 | 22 | "* [**TensorFlow Datasets (TFDS)**](https://github.com/tensorflow/datasets)\n",
|
21 | 23 | "* [**Grain**](https://github.com/google/grain)\n",
|
22 | 24 | "* [**Hugging Face**](https://huggingface.co/docs/datasets/en/use_with_jax#data-loading)\n",
|
23 | 25 | "\n",
|
24 |
| - "You'll see how to use each of these libraries to efficiently load data for a simple image classification task using the MNIST dataset." |
| 26 | + "You'll see how to use each of these libraries to efficiently load data for a simple image classification task using the MNIST dataset.\n", |
| 27 | + "\n", |
| 28 | + "Building on the [Data Loaders on GPU](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_gpu_with_jax.html) tutorial, this guide introduces optimizations for distributed training across multiple GPUs or TPUs. It focuses on data sharding with `Mesh` and `NamedSharding` to efficiently partition and synchronize data across devices. By leveraging multi-device setups, you'll maximize resource utilization for large datasets in distributed environments." |
25 | 29 | ]
|
26 | 30 | },
|
27 | 31 | {
|
|
53 | 57 | "id": "TsFdlkSZKp9S"
|
54 | 58 | },
|
55 | 59 | "source": [
|
56 |
| - "**Checking TPU Availability for JAX**" |
| 60 | + "### Checking TPU Availability for JAX" |
57 | 61 | ]
|
58 | 62 | },
|
59 | 63 | {
|
|
95 | 99 | "id": "qyJ_WTghDnIc"
|
96 | 100 | },
|
97 | 101 | "source": [
|
98 |
| - "**Setting Hyperparameters and Initializing Parameters**\n", |
| 102 | + "### Setting Hyperparameters and Initializing Parameters\n", |
99 | 103 | "\n",
|
100 | 104 | "You'll define hyperparameters for your model and data loading, including layer sizes, learning rate, batch size, and the data directory. You'll also initialize the weights and biases for a fully-connected neural network."
|
101 | 105 | ]
|
|
137 | 141 | "id": "rHLdqeI7D2WZ"
|
138 | 142 | },
|
139 | 143 | "source": [
|
140 |
| - "**Model Prediction with Auto-Batching**\n", |
| 144 | + "### Model Prediction with Auto-Batching\n", |
141 | 145 | "\n",
|
142 | 146 | "In this section, you'll define the `predict` function for your neural network. This function computes the output of the network for a single input image.\n",
|
143 | 147 | "\n",
|
|
206 | 210 | "id": "rLqfeORsERek"
|
207 | 211 | },
|
208 | 212 | "source": [
|
209 |
| - "**Utility and Loss Functions**\n", |
| 213 | + "### Utility and Loss Functions\n", |
210 | 214 | "\n",
|
211 | 215 | "You'll now define utility functions for:\n",
|
212 | 216 | "- One-hot encoding: Converts class indices to binary vectors.\n",
|
|
373 | 377 | "id": "ec-MHhv6hYsK"
|
374 | 378 | },
|
375 | 379 | "source": [
|
376 |
| - "**Load Dataset with Transformations**\n", |
| 380 | + "### Load Dataset with Transformations\n", |
377 | 381 | "\n",
|
378 | 382 | "Standardize the data by flattening the images, casting them to `float32`, and ensuring consistent data types."
|
379 | 383 | ]
|
|
497 | 501 | "id": "kbdsqvPZGrsa"
|
498 | 502 | },
|
499 | 503 | "source": [
|
500 |
| - "**Full Training Dataset for Accuracy Checks**\n", |
| 504 | + "### Full Training Dataset for Accuracy Checks\n", |
501 | 505 | "\n",
|
502 | 506 | "Convert the entire training dataset to JAX arrays."
|
503 | 507 | ]
|
|
520 | 524 | "id": "WXUh0BwvG8Ko"
|
521 | 525 | },
|
522 | 526 | "source": [
|
523 |
| - "**Get Full Test Dataset**\n", |
| 527 | + "### Get Full Test Dataset\n", |
524 | 528 | "\n",
|
525 | 529 | "Load and process the full test dataset."
|
526 | 530 | ]
|
|
569 | 573 | "id": "mfSnfJND6I8G"
|
570 | 574 | },
|
571 | 575 | "source": [
|
572 |
| - "**Training Data Generator**\n", |
| 576 | + "### Training Data Generator\n", |
573 | 577 | "\n",
|
574 | 578 | "Define a generator function using PyTorch's DataLoader for batch training.\n",
|
575 | 579 | "Setting `num_workers > 0` enables multi-process data loading, which can accelerate data loading for larger datasets or intensive preprocessing tasks. Experiment with different values to find the optimal setting for your hardware and workload.\n",
|
|
596 | 600 | "id": "Xzt2x9S1HC3T"
|
597 | 601 | },
|
598 | 602 | "source": [
|
599 |
| - "**Training Loop (PyTorch DataLoader)**\n", |
| 603 | + "### Training Loop (PyTorch DataLoader)\n", |
600 | 604 | "\n",
|
601 | 605 | "The training loop uses the PyTorch DataLoader to iterate through batches and update model parameters."
|
602 | 606 | ]
|
|
801 | 805 | "id": "F6OlzaDqwe4p"
|
802 | 806 | },
|
803 | 807 | "source": [
|
804 |
| - "**Fetch Full Dataset for Evaluation**\n", |
| 808 | + "### Fetch Full Dataset for Evaluation\n", |
805 | 809 | "\n",
|
806 | 810 | "Load the dataset with `tfds.load`, convert it to NumPy arrays, and process it for evaluation."
|
807 | 811 | ]
|
|
908 | 912 | "id": "yy9PunCJdI-G"
|
909 | 913 | },
|
910 | 914 | "source": [
|
911 |
| - "**Define the Training Generator**\n", |
| 915 | + "### Define the Training Generator\n", |
912 | 916 | "\n",
|
913 | 917 | "Create a generator function to yield batches of data for training."
|
914 | 918 | ]
|
|
936 | 940 | "id": "EAWeUdnuFNBY"
|
937 | 941 | },
|
938 | 942 | "source": [
|
939 |
| - "**Training Loop (TFDS)**\n", |
| 943 | + "### Training Loop (TFDS)\n", |
940 | 944 | "\n",
|
941 | 945 | "Use the training generator in a custom training loop."
|
942 | 946 | ]
|
|
1069 | 1073 | "id": "0h6mwVrspPA-"
|
1070 | 1074 | },
|
1071 | 1075 | "source": [
|
1072 |
| - "**Define Dataset Class**\n", |
| 1076 | + "### Define Dataset Class\n", |
1073 | 1077 | "\n",
|
1074 | 1078 | "Create a custom dataset class to load MNIST data for Grain."
|
1075 | 1079 | ]
|
|
1106 | 1110 | "id": "53mf8bWEsyTr"
|
1107 | 1111 | },
|
1108 | 1112 | "source": [
|
1109 |
| - "**Initialize the Dataset**" |
| 1113 | + "### Initialize the Dataset" |
1110 | 1114 | ]
|
1111 | 1115 | },
|
1112 | 1116 | {
|
|
1126 | 1130 | "id": "GqD-ycgBuwv9"
|
1127 | 1131 | },
|
1128 | 1132 | "source": [
|
1129 |
| - "**Get the full train and test dataset**" |
| 1133 | + "### Get the full train and test dataset" |
1130 | 1134 | ]
|
1131 | 1135 | },
|
1132 | 1136 | {
|
|
1178 | 1182 | "id": "1QPbXt7O0JN-"
|
1179 | 1183 | },
|
1180 | 1184 | "source": [
|
1181 |
| - "**Initialize PyGrain DataLoader**" |
| 1185 | + "### Initialize PyGrain DataLoader" |
1182 | 1186 | ]
|
1183 | 1187 | },
|
1184 | 1188 | {
|
|
1207 | 1211 | "id": "GvpJPHAbeuHW"
|
1208 | 1212 | },
|
1209 | 1213 | "source": [
|
1210 |
| - "**Training Loop (Grain)**\n", |
| 1214 | + "### Training Loop (Grain)\n", |
1211 | 1215 | "\n",
|
1212 | 1216 | "Run the training loop using the Grain DataLoader."
|
1213 | 1217 | ]
|
|
1541 | 1545 | "id": "tgI7dIaX7JzM"
|
1542 | 1546 | },
|
1543 | 1547 | "source": [
|
1544 |
| - "**Extract images and labels**\n", |
| 1548 | + "### Extract images and labels\n", |
1545 | 1549 | "\n",
|
1546 | 1550 | "Get image shape and flatten for model input."
|
1547 | 1551 | ]
|
|
1603 | 1607 | "id": "kk_4zJlz7T1E"
|
1604 | 1608 | },
|
1605 | 1609 | "source": [
|
1606 |
| - "**Define Training Generator**\n", |
| 1610 | + "### Define Training Generator\n", |
1607 | 1611 | "\n",
|
1608 | 1612 | "Set up a generator to yield batches of images and labels for training."
|
1609 | 1613 | ]
|
|
1629 | 1633 | "id": "HIsGfkLI7dvZ"
|
1630 | 1634 | },
|
1631 | 1635 | "source": [
|
1632 |
| - "**Training Loop (Hugging Face Datasets)**\n", |
| 1636 | + "### Training Loop (Hugging Face Datasets)\n", |
1633 | 1637 | "\n",
|
1634 | 1638 | "Run the training loop using the Hugging Face training generator."
|
1635 | 1639 | ]
|
|
1670 | 1674 | "id": "_JR0V1Aix9Id"
|
1671 | 1675 | },
|
1672 | 1676 | "source": [
|
1673 |
| - "## **Summary**\n", |
| 1677 | + "## Summary\n", |
1674 | 1678 | "\n",
|
1675 |
| - "This notebook has introduced efficient methods for multi-device distributed data loading on TPUs with JAX. You explored how to leverage popular libraries like PyTorch DataLoader, TensorFlow Datasets, Grain, and Hugging Face Datasets to streamline the data loading process for machine learning tasks. Each method has unique benefits, enabling you to choose the most suitable approach based on your project requirements.\n", |
| 1679 | + "This notebook has introduced efficient methods for multi-device distributed data loading on TPUs with JAX. You explored how to leverage popular libraries like PyTorch DataLoader, TensorFlow Datasets, Grain, and Hugging Face Datasets to streamline the data loading process for machine learning tasks. Each library offers distinct advantages, allowing you to select the best approach for your specific project needs.\n", |
1676 | 1680 | "\n",
|
1677 |
| - "For in-depth strategies on distributed data loading with JAX, including global data pipelines and per-device processing, refer to the [Distributed Data Loading Guide](https://jax.readthedocs.io/en/latest/distributed_data_loading.html)." |
| 1681 | + "For more detailed strategies on distributed data loading with JAX, including global data pipelines and per-device processing, refer to the [Distributed Data Loading Guide](https://jax.readthedocs.io/en/latest/distributed_data_loading.html)." |
1678 | 1682 | ]
|
1679 | 1683 | }
|
1680 | 1684 | ],
|
|
0 commit comments