diff --git a/docs/source/conf.py b/docs/source/conf.py index aad2b2f..45b5040 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -67,10 +67,12 @@ 'JAX_transformer_text_classification.md', 'data_loaders_on_cpu_with_jax.md', 'data_loaders_on_gpu_with_jax.md', + 'data_loaders_for_multi_device_setups_with_jax.md', ] suppress_warnings = [ 'misc.highlighting_failure', # Suppress warning in exception in digits_vae + 'mystnb.unknown_mime_type', # Suppress warning for unknown mime type (e.g. colab-display-data+json) ] # -- Options for myst ---------------------------------------------- @@ -104,4 +106,5 @@ 'JAX_transformer_text_classification.ipynb', 'data_loaders_on_cpu_with_jax.ipynb', 'data_loaders_on_gpu_with_jax.ipynb', + 'data_loaders_for_multi_device_setups_with_jax.ipynb', ] diff --git a/docs/source/data_loaders_on_cpu_with_jax.ipynb b/docs/source/data_loaders_for_multi_device_setups_with_jax.ipynb similarity index 71% rename from docs/source/data_loaders_on_cpu_with_jax.ipynb rename to docs/source/data_loaders_for_multi_device_setups_with_jax.ipynb index 34a8445..6c4a7e0 100644 --- a/docs/source/data_loaders_on_cpu_with_jax.ipynb +++ b/docs/source/data_loaders_for_multi_device_setups_with_jax.ipynb @@ -6,7 +6,7 @@ "id": "PUFGZggH49zp" }, "source": [ - "# Introduction to Data Loaders on CPU with JAX" + "# Introduction to Data Loaders for Multi-Device Training with JAX" ] }, { @@ -15,45 +15,21 @@ "id": "3ia4PKEV5Dr8" }, "source": [ - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/data_loaders_on_cpu_with_jax.ipynb)\n", + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/data_loaders_for_multi_device_setups_with_jax.ipynb)\n", "\n", - "This tutorial explores different data loading strategies for using **JAX** on a single [**CPU**](https://jax.readthedocs.io/en/latest/glossary.html#term-CPU). While JAX doesn't include a built-in data loader, it seamlessly integrates with popular data loading libraries, including:\n", + "This tutorial explores various data loading strategies for **JAX** in **multi-device distributed** environments, leveraging [**TPUs**](https://jax.readthedocs.io/en/latest/pallas/tpu/details.html#what-is-a-tpu). While JAX doesn't include a built-in data loader, it seamlessly integrates with popular data loading libraries, including:\n", + "* [**PyTorch DataLoader**](https://github.com/pytorch/data)\n", + "* [**TensorFlow Datasets (TFDS)**](https://github.com/tensorflow/datasets)\n", + "* [**Grain**](https://github.com/google/grain)\n", + "* [**Hugging Face**](https://huggingface.co/docs/datasets/en/use_with_jax#data-loading)\n", "\n", - "- [**PyTorch DataLoader**](https://github.com/pytorch/data)\n", - "- [**TensorFlow Datasets (TFDS)**](https://github.com/tensorflow/datasets)\n", - "- [**Grain**](https://github.com/google/grain)\n", - "- [**Hugging Face**](https://huggingface.co/docs/datasets/en/use_with_jax#data-loading)\n", + "You'll learn how to use each of these libraries to efficiently load data for an image classification task using the MNIST dataset.\n", "\n", - "In this tutorial, you'll learn how to efficiently load data using these libraries for a simple image classification task based on the MNIST dataset.\n", + "Building on the [Data Loaders on GPU](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_gpu_with_jax.html) tutorial, this guide covers advanced strategies for multi-device setups, such as data sharding with `Mesh` and `NamedSharding` to partition and synchronize data across devices. These techniques allow you to partition and synchronize data across multiple devices, balancing the complexities of distributed systems while optimizing resource usage for large-scale datasets.\n", "\n", - "Compared to GPU or multi-device setups, CPU-based data loading is straightforward as it avoids challenges like GPU memory management and data synchronization across devices. This makes it ideal for smaller-scale tasks or scenarios where data resides exclusively on the CPU.\n", + "If you're looking for CPU-specific data loading advice, see [Data Loaders on CPU](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_cpu_with_jax.html).\n", "\n", - "If you're looking for GPU-specific data loading advice, see [Data Loaders on GPU](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_gpu_with_jax.html).\n", - "\n", - "If you're looking for a multi-device data loading strategy, see [Data Loaders on Multi-Device Setups](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_for_multi_device_setups_with_jax.html)." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "pEsb135zE-Jo" - }, - "source": [ - "## Setting JAX to Use CPU Only\n", - "\n", - "First, you'll restrict JAX to use only the CPU, even if a GPU is available. This ensures consistency and allows you to focus on CPU-based data loading." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "id": "vqP6xyObC0_9" - }, - "outputs": [], - "source": [ - "import os\n", - "os.environ['JAX_PLATFORM_NAME'] = 'cpu'" + "If you're looking for GPU-specific data loading advice, see [Data Loaders on GPU](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_gpu_with_jax.html)." ] }, { @@ -67,7 +43,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": { "id": "tDJNQ6V-Dg5g" }, @@ -75,7 +51,8 @@ "source": [ "import jax\n", "import jax.numpy as jnp\n", - "from jax import random, grad, jit, vmap" + "from jax import grad, jit, vmap, random, device_put\n", + "from jax.sharding import Mesh, PartitionSpec, NamedSharding" ] }, { @@ -84,27 +61,34 @@ "id": "TsFdlkSZKp9S" }, "source": [ - "### CPU Setup Verification" + "## Checking TPU Availability for JAX" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "N3sqvaF3KJw1", - "outputId": "449c83d9-d050-4b15-9a8d-f71e340501f2" + "outputId": "ee3286d0-b75f-46c5-8548-b57e3d895dd7" }, "outputs": [ { "data": { "text/plain": [ - "[CpuDevice(id=0)]" + "[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),\n", + " TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),\n", + " TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),\n", + " TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),\n", + " TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),\n", + " TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),\n", + " TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),\n", + " TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]" ] }, - "execution_count": 3, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } @@ -126,7 +110,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": { "id": "qLNOSloFDka_" }, @@ -144,11 +128,11 @@ " return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]\n", "\n", "layer_sizes = [784, 512, 512, 10] # Layers of the network\n", - "step_size = 0.01 # Learning rate for optimization\n", + "step_size = 0.01 # Learning rate\n", "num_epochs = 8 # Number of training epochs\n", "batch_size = 128 # Batch size for training\n", "n_targets = 10 # Number of classes (digits 0-9)\n", - "num_pixels = 28 * 28 # Input size (MNIST images are 28x28 pixels)\n", + "num_pixels = 28 * 28 # Each MNIST image is 28x28 pixels\n", "data_dir = '/tmp/mnist_dataset' # Directory for storing the dataset\n", "\n", "# Initialize network parameters using the defined layer sizes and a random seed\n", @@ -158,7 +142,7 @@ { "cell_type": "markdown", "metadata": { - "id": "6Ci_CqW7q6XM" + "id": "rHLdqeI7D2WZ" }, "source": [ "## Model Prediction with Auto-Batching\n", @@ -170,7 +154,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": { "id": "bKIYPSkvD1QV" }, @@ -182,7 +166,7 @@ " return jnp.maximum(0, x)\n", "\n", "def predict(params, image):\n", - " # per-example prediction\n", + " # per-example predictions\n", " activations = image\n", " for w, b in params[:-1]:\n", " outputs = jnp.dot(w, activations) + b\n", @@ -199,21 +183,49 @@ { "cell_type": "markdown", "metadata": { - "id": "niTSr34_sDZi" + "id": "AMWmxjVEpH2D" + }, + "source": [ + "## Multi-device setup using a Mesh of devices" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "id": "4Jc5YLFnpE-_" + }, + "outputs": [], + "source": [ + "# Get the number of available devices (GPUs/TPUs) for sharding\n", + "num_devices = len(jax.devices())\n", + "\n", + "# Multi-device setup using a Mesh of devices\n", + "devices = jax.devices()\n", + "mesh = Mesh(devices, ('device',))\n", + "\n", + "# Define the sharding specification - split the data along the first axis (batch)\n", + "sharding_spec = PartitionSpec('device')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rLqfeORsERek" }, "source": [ "## Utility and Loss Functions\n", "\n", "You'll now define utility functions for:\n", - "\n", "- One-hot encoding: Converts class indices to binary vectors.\n", "- Accuracy calculation: Measures the performance of the model on the dataset.\n", "- Loss computation: Calculates the difference between predictions and targets.\n", "\n", "To optimize performance:\n", - "\n", "- [`grad`](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html#jax.grad) is used to compute gradients of the loss function with respect to network parameters.\n", - "- [`jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit) compiles the update function, enabling faster execution by leveraging JAX's [XLA](https://openxla.org/xla) compilation." + "- [`jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit) compiles the update function, enabling faster execution by leveraging JAX's [XLA](https://openxla.org/xla) compilation.\n", + "\n", + "- [`device_put`](https://jax.readthedocs.io/en/latest/_autosummary/jax.device_put.html) to distribute the dataset across TPU cores." ] }, { @@ -255,15 +267,16 @@ " return x, y\n", "\n", "def train_model(num_epochs, params, training_generator, data_loader_type='streamed'):\n", - " \"\"\"Train the model for a given number of epochs.\"\"\"\n", + " \"\"\"Train the model for a given number of epochs and device_put for TPU transfer.\"\"\"\n", " for epoch in range(num_epochs):\n", " start_time = time.time()\n", " for x, y in training_generator() if data_loader_type == 'streamed' else training_generator:\n", " x, y = reshape_and_one_hot(x, y)\n", + " x, y = device_put(x, NamedSharding(mesh, sharding_spec)), device_put(y, NamedSharding(mesh, sharding_spec))\n", " params = update(params, x, y)\n", "\n", " print(f\"Epoch {epoch + 1} in {time.time() - start_time:.2f} sec: \"\n", - " f\"Train Accuracy: {accuracy(params, train_images, train_labels):.4f}, \"\n", + " f\"Train Accuracy: {accuracy(params, train_images, train_labels):.4f},\"\n", " f\"Test Accuracy: {accuracy(params, test_images, test_labels):.4f}\")" ] }, @@ -285,16 +298,16 @@ "colab": { "base_uri": "https://localhost:8080/" }, - "id": "jmsfrWrHxIhC", - "outputId": "33dfeada-a763-4d26-f778-a27966e34d55" + "id": "33Wyf77WzNjA", + "outputId": "a2378431-79f2-4dc4-aa1a-d98704657d26" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.5.1+cu121)\n", - "Requirement already satisfied: torchvision in /usr/local/lib/python3.10/dist-packages (0.20.1+cu121)\n", + "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.5.1+cpu)\n", + "Requirement already satisfied: torchvision in /usr/local/lib/python3.10/dist-packages (0.20.1+cpu)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch) (3.16.1)\n", "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch) (4.12.2)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.4.2)\n", @@ -335,14 +348,27 @@ "outputs": [], "source": [ "def numpy_collate(batch):\n", - " \"\"\"Convert a batch of PyTorch data to NumPy arrays.\"\"\"\n", + " \"\"\"Collate function to convert a batch of PyTorch data into NumPy arrays.\"\"\"\n", " return tree_map(np.asarray, data.default_collate(batch))\n", "\n", "class NumpyLoader(data.DataLoader):\n", " \"\"\"Custom DataLoader to return NumPy arrays from a PyTorch Dataset.\"\"\"\n", - " def __init__(self, dataset, batch_size=1, shuffle=False, **kwargs):\n", - " super().__init__(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=numpy_collate, **kwargs)\n", - "\n", + " def __init__(self, dataset, batch_size=1,\n", + " shuffle=False, sampler=None,\n", + " batch_sampler=None, num_workers=0,\n", + " pin_memory=False, drop_last=False,\n", + " timeout=0, worker_init_fn=None):\n", + " super(self.__class__, self).__init__(dataset,\n", + " batch_size=batch_size,\n", + " shuffle=shuffle,\n", + " sampler=sampler,\n", + " batch_sampler=batch_sampler,\n", + " num_workers=num_workers,\n", + " collate_fn=numpy_collate,\n", + " pin_memory=pin_memory,\n", + " drop_last=drop_last,\n", + " timeout=timeout,\n", + " worker_init_fn=worker_init_fn)\n", "class FlattenAndCast(object):\n", " \"\"\"Transform class to flatten and cast images to float32.\"\"\"\n", " def __call__(self, pic):\n", @@ -352,7 +378,7 @@ { "cell_type": "markdown", "metadata": { - "id": "mfSnfJND6I8G" + "id": "ec-MHhv6hYsK" }, "source": [ "### Load Dataset with Transformations\n", @@ -367,8 +393,8 @@ "colab": { "base_uri": "https://localhost:8080/" }, - "id": "Kxbl6bcx6crv", - "outputId": "372bbf4c-3ad5-4fd8-cc5d-27b50f5e4f38" + "id": "nSviwX9ohhUh", + "outputId": "0bb3bc04-11ac-4fb6-8854-76a3f5e725a5" }, "outputs": [ { @@ -387,7 +413,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 9.91M/9.91M [00:00<00:00, 49.4MB/s]\n" + "100%|██████████| 9.91M/9.91M [00:00<00:00, 36.1MB/s]\n" ] }, { @@ -408,7 +434,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 28.9k/28.9k [00:00<00:00, 2.09MB/s]" + "100%|██████████| 28.9k/28.9k [00:00<00:00, 1.13MB/s]\n" ] }, { @@ -417,20 +443,7 @@ "text": [ "Extracting /tmp/mnist_dataset/MNIST/raw/train-labels-idx1-ubyte.gz to /tmp/mnist_dataset/MNIST/raw\n", "\n", - "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ + "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n", "Failed to download (trying next):\n", "HTTP Error 403: Forbidden\n", "\n", @@ -442,7 +455,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 1.65M/1.65M [00:00<00:00, 13.3MB/s]\n" + "100%|██████████| 1.65M/1.65M [00:00<00:00, 10.1MB/s]\n" ] }, { @@ -463,7 +476,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 4.54k/4.54k [00:00<00:00, 8.81MB/s]\n" + "100%|██████████| 4.54k/4.54k [00:00<00:00, 6.34MB/s]" ] }, { @@ -473,6 +486,13 @@ "Extracting /tmp/mnist_dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz to /tmp/mnist_dataset/MNIST/raw\n", "\n" ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] } ], "source": [ @@ -534,7 +554,7 @@ "base_uri": "https://localhost:8080/" }, "id": "Oz-UVnCxG5E8", - "outputId": "abbaa26d-491a-4e63-e8c9-d3c571f53a28" + "outputId": "0f44cb63-b12c-47a7-8bd5-ed773e2b2ec5" }, "outputs": [ { @@ -554,21 +574,23 @@ { "cell_type": "markdown", "metadata": { - "id": "m3zfxqnMiCbm" + "id": "mfSnfJND6I8G" }, "source": [ "### Training Data Generator\n", "\n", - "Define a generator function using PyTorch's DataLoader for batch training. Setting `num_workers > 0` enables multi-process data loading, which can accelerate data loading for larger datasets or intensive preprocessing tasks. Experiment with different values to find the optimal setting for your hardware and workload.\n", + "Define a generator function using PyTorch's DataLoader for batch training.\n", + "Setting `num_workers > 0` enables multi-process data loading, which can accelerate data loading for larger datasets or intensive preprocessing tasks. Experiment with different values to find the optimal setting for your hardware and workload.\n", "\n", - "Note: When setting `num_workers > 0`, you may see the following `RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.` This warning can be safely ignored since data loaders do not use JAX within the forked processes." + "Note: When setting `num_workers > 0`, you may see the following `RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.`\n", + "This warning can be safely ignored since data loaders do not use JAX within the forked processes." ] }, { "cell_type": "code", "execution_count": 14, "metadata": { - "id": "B-fES82EiL6Z" + "id": "Kxbl6bcx6crv" }, "outputs": [], "source": [ @@ -594,22 +616,22 @@ "colab": { "base_uri": "https://localhost:8080/" }, - "id": "vtUjHsh-rJs8", - "outputId": "4766333e-4366-493b-995a-102778d1345a" + "id": "MUrJxpjvUyOm", + "outputId": "629a19b1-acba-418a-f04b-3b78d7909de1" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Epoch 1 in 28.93 sec: Train Accuracy: 0.9158, Test Accuracy: 0.9196\n", - "Epoch 2 in 8.33 sec: Train Accuracy: 0.9372, Test Accuracy: 0.9384\n", - "Epoch 3 in 6.99 sec: Train Accuracy: 0.9492, Test Accuracy: 0.9468\n", - "Epoch 4 in 7.01 sec: Train Accuracy: 0.9569, Test Accuracy: 0.9532\n", - "Epoch 5 in 8.17 sec: Train Accuracy: 0.9630, Test Accuracy: 0.9579\n", - "Epoch 6 in 8.27 sec: Train Accuracy: 0.9674, Test Accuracy: 0.9615\n", - "Epoch 7 in 8.32 sec: Train Accuracy: 0.9708, Test Accuracy: 0.9650\n", - "Epoch 8 in 8.07 sec: Train Accuracy: 0.9737, Test Accuracy: 0.9671\n" + "Epoch 1 in 5.65 sec: Train Accuracy: 0.9159,Test Accuracy: 0.9197\n", + "Epoch 2 in 4.26 sec: Train Accuracy: 0.9371,Test Accuracy: 0.9383\n", + "Epoch 3 in 4.39 sec: Train Accuracy: 0.9493,Test Accuracy: 0.9468\n", + "Epoch 4 in 4.16 sec: Train Accuracy: 0.9568,Test Accuracy: 0.9536\n", + "Epoch 5 in 4.04 sec: Train Accuracy: 0.9632,Test Accuracy: 0.9576\n", + "Epoch 6 in 4.06 sec: Train Accuracy: 0.9674,Test Accuracy: 0.9617\n", + "Epoch 7 in 4.06 sec: Train Accuracy: 0.9708,Test Accuracy: 0.9649\n", + "Epoch 8 in 4.07 sec: Train Accuracy: 0.9737,Test Accuracy: 0.9672\n" ] } ], @@ -620,7 +642,7 @@ { "cell_type": "markdown", "metadata": { - "id": "Nm45ZTo6yrf5" + "id": "ACy1PoSVa3zH" }, "source": [ "## Loading Data with TensorFlow Datasets (TFDS)\n", @@ -628,25 +650,163 @@ "This section demonstrates how to load the MNIST dataset using TFDS, fetch the full dataset for evaluation, and define a training generator for batch processing. GPU usage is explicitly disabled for TensorFlow." ] }, + { + "cell_type": "markdown", + "metadata": { + "id": "tcJRzpyOveWK" + }, + "source": [ + "Ensure you have the latest versions of both TensorFlow and TensorFlow Datasets" + ] + }, { "cell_type": "code", "execution_count": 16, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "id": "_f55HPGAZu6P", + "outputId": "838c8f76-aa07-49d5-986d-3c88ed516b22" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: tensorflow in /usr/local/lib/python3.10/dist-packages (2.15.0)\n", + "Collecting tensorflow\n", + " Downloading tensorflow-2.18.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.1 kB)\n", + "Requirement already satisfied: tensorflow-datasets in /usr/local/lib/python3.10/dist-packages (4.9.7)\n", + "Requirement already satisfied: absl-py>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (1.4.0)\n", + "Requirement already satisfied: astunparse>=1.6.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (1.6.3)\n", + "Requirement already satisfied: flatbuffers>=24.3.25 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (24.3.25)\n", + "Requirement already satisfied: gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (0.6.0)\n", + "Requirement already satisfied: google-pasta>=0.1.1 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (0.2.0)\n", + "Requirement already satisfied: libclang>=13.0.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (18.1.1)\n", + "Requirement already satisfied: opt-einsum>=2.3.2 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (3.4.0)\n", + "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from tensorflow) (24.2)\n", + "Requirement already satisfied: protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.3 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (4.25.5)\n", + "Requirement already satisfied: requests<3,>=2.21.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (2.32.3)\n", + "Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from tensorflow) (75.1.0)\n", + "Requirement already satisfied: six>=1.12.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (1.16.0)\n", + "Requirement already satisfied: termcolor>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (2.5.0)\n", + "Requirement already satisfied: typing-extensions>=3.6.6 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (4.12.2)\n", + "Requirement already satisfied: wrapt>=1.11.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (1.14.1)\n", + "Requirement already satisfied: grpcio<2.0,>=1.24.3 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (1.68.0)\n", + "Collecting tensorboard<2.19,>=2.18 (from tensorflow)\n", + " Downloading tensorboard-2.18.0-py3-none-any.whl.metadata (1.6 kB)\n", + "Collecting keras>=3.5.0 (from tensorflow)\n", + " Downloading keras-3.6.0-py3-none-any.whl.metadata (5.8 kB)\n", + "Requirement already satisfied: numpy<2.1.0,>=1.26.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (1.26.4)\n", + "Requirement already satisfied: h5py>=3.11.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (3.12.1)\n", + "Collecting ml-dtypes<0.5.0,>=0.4.0 (from tensorflow)\n", + " Downloading ml_dtypes-0.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (20 kB)\n", + "Requirement already satisfied: tensorflow-io-gcs-filesystem>=0.23.1 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (0.37.1)\n", + "Requirement already satisfied: click in /usr/local/lib/python3.10/dist-packages (from tensorflow-datasets) (8.1.7)\n", + "Requirement already satisfied: dm-tree in /usr/local/lib/python3.10/dist-packages (from tensorflow-datasets) (0.1.8)\n", + "Requirement already satisfied: immutabledict in /usr/local/lib/python3.10/dist-packages (from tensorflow-datasets) (4.2.1)\n", + "Requirement already satisfied: promise in /usr/local/lib/python3.10/dist-packages (from tensorflow-datasets) (2.3)\n", + "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from tensorflow-datasets) (5.9.5)\n", + "Requirement already satisfied: pyarrow in /usr/local/lib/python3.10/dist-packages (from tensorflow-datasets) (18.0.0)\n", + "Requirement already satisfied: simple-parsing in /usr/local/lib/python3.10/dist-packages (from tensorflow-datasets) (0.1.6)\n", + "Requirement already satisfied: tensorflow-metadata in /usr/local/lib/python3.10/dist-packages (from tensorflow-datasets) (1.13.1)\n", + "Requirement already satisfied: toml in /usr/local/lib/python3.10/dist-packages (from tensorflow-datasets) (0.10.2)\n", + "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from tensorflow-datasets) (4.66.6)\n", + "Requirement already satisfied: array-record>=0.5.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow-datasets) (0.5.1)\n", + "Requirement already satisfied: etils>=1.6.0 in /usr/local/lib/python3.10/dist-packages (from etils[edc,enp,epath,epy,etree]>=1.6.0; python_version < \"3.11\"->tensorflow-datasets) (1.10.0)\n", + "Requirement already satisfied: wheel<1.0,>=0.23.0 in /usr/local/lib/python3.10/dist-packages (from astunparse>=1.6.0->tensorflow) (0.45.0)\n", + "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from etils[edc,enp,epath,epy,etree]>=1.6.0; python_version < \"3.11\"->tensorflow-datasets) (2024.10.0)\n", + "Requirement already satisfied: importlib_resources in /usr/local/lib/python3.10/dist-packages (from etils[edc,enp,epath,epy,etree]>=1.6.0; python_version < \"3.11\"->tensorflow-datasets) (6.4.5)\n", + "Requirement already satisfied: zipp in /usr/local/lib/python3.10/dist-packages (from etils[edc,enp,epath,epy,etree]>=1.6.0; python_version < \"3.11\"->tensorflow-datasets) (3.21.0)\n", + "Requirement already satisfied: rich in /usr/local/lib/python3.10/dist-packages (from keras>=3.5.0->tensorflow) (13.9.4)\n", + "Collecting namex (from keras>=3.5.0->tensorflow)\n", + " Downloading namex-0.0.8-py3-none-any.whl.metadata (246 bytes)\n", + "Collecting optree (from keras>=3.5.0->tensorflow)\n", + " Downloading optree-0.13.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (47 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m47.8/47.8 kB\u001b[0m \u001b[31m1.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.21.0->tensorflow) (3.4.0)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.21.0->tensorflow) (3.10)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.21.0->tensorflow) (2.2.3)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.21.0->tensorflow) (2024.8.30)\n", + "Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.10/dist-packages (from tensorboard<2.19,>=2.18->tensorflow) (3.7)\n", + "Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /usr/local/lib/python3.10/dist-packages (from tensorboard<2.19,>=2.18->tensorflow) (0.7.2)\n", + "Requirement already satisfied: werkzeug>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from tensorboard<2.19,>=2.18->tensorflow) (3.1.3)\n", + "Requirement already satisfied: docstring-parser<1.0,>=0.15 in /usr/local/lib/python3.10/dist-packages (from simple-parsing->tensorflow-datasets) (0.16)\n", + "Requirement already satisfied: googleapis-common-protos<2,>=1.52.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow-metadata->tensorflow-datasets) (1.66.0)\n", + "Requirement already satisfied: MarkupSafe>=2.1.1 in /usr/local/lib/python3.10/dist-packages (from werkzeug>=1.0.1->tensorboard<2.19,>=2.18->tensorflow) (3.0.2)\n", + "Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich->keras>=3.5.0->tensorflow) (3.0.0)\n", + "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich->keras>=3.5.0->tensorflow) (2.18.0)\n", + "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich->keras>=3.5.0->tensorflow) (0.1.2)\n", + "Downloading tensorflow-2.18.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (615.3 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m615.3/615.3 MB\u001b[0m \u001b[31m626.4 kB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading keras-3.6.0-py3-none-any.whl (1.2 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.2/1.2 MB\u001b[0m \u001b[31m49.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading ml_dtypes-0.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.2 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.2/2.2 MB\u001b[0m \u001b[31m77.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading tensorboard-2.18.0-py3-none-any.whl (5.5 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.5/5.5 MB\u001b[0m \u001b[31m70.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading namex-0.0.8-py3-none-any.whl (5.8 kB)\n", + "Downloading optree-0.13.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (381 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m381.3/381.3 kB\u001b[0m \u001b[31m27.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hInstalling collected packages: namex, optree, ml-dtypes, tensorboard, keras, tensorflow\n", + " Attempting uninstall: ml-dtypes\n", + " Found existing installation: ml-dtypes 0.2.0\n", + " Uninstalling ml-dtypes-0.2.0:\n", + " Successfully uninstalled ml-dtypes-0.2.0\n", + " Attempting uninstall: tensorboard\n", + " Found existing installation: tensorboard 2.15.2\n", + " Uninstalling tensorboard-2.15.2:\n", + " Successfully uninstalled tensorboard-2.15.2\n", + " Attempting uninstall: keras\n", + " Found existing installation: keras 2.15.0\n", + " Uninstalling keras-2.15.0:\n", + " Successfully uninstalled keras-2.15.0\n", + " Attempting uninstall: tensorflow\n", + " Found existing installation: tensorflow 2.15.0\n", + " Uninstalling tensorflow-2.15.0:\n", + " Successfully uninstalled tensorflow-2.15.0\n", + "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", + "tensorflow-text 2.15.0 requires tensorflow<2.16,>=2.15.0; platform_machine != \"arm64\" or platform_system != \"Darwin\", but you have tensorflow 2.18.0 which is incompatible.\n", + "tf-keras 2.15.1 requires tensorflow<2.16,>=2.15, but you have tensorflow 2.18.0 which is incompatible.\u001b[0m\u001b[31m\n", + "\u001b[0mSuccessfully installed keras-3.6.0 ml-dtypes-0.4.1 namex-0.0.8 optree-0.13.1 tensorboard-2.18.0 tensorflow-2.18.0\n" + ] + }, + { + "data": { + "application/vnd.colab-display-data+json": { + "id": "62e7ae5195964acea7f16ab1423ff920", + "pip_warning": { + "packages": [ + "ml_dtypes" + ] + } + } + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "!pip install --upgrade tensorflow tensorflow-datasets" + ] + }, + { + "cell_type": "code", + "execution_count": 17, "metadata": { "id": "sGaQAk1DHMUx" }, "outputs": [], "source": [ - "import tensorflow_datasets as tfds\n", - "import tensorflow as tf\n", - "\n", - "# Ensuring CPU-Only Execution, disable any GPU usage(if applicable) for TF\n", - "tf.config.set_visible_devices([], device_type='GPU')" + "import tensorflow_datasets as tfds" ] }, { "cell_type": "markdown", "metadata": { - "id": "3xdQY7H6wr3n" + "id": "F6OlzaDqwe4p" }, "source": [ "### Fetch Full Dataset for Evaluation\n", @@ -656,27 +816,27 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 18, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 104, "referenced_widgets": [ - "b8cdabf5c05848f38f03850cab08b56f", - "a8b76d5f93004c089676e5a2a9b3336c", - "119ac8428f9441e7a25eb0afef2fbb2a", - "76a9815e5c2b4764a13409cebaf66821", - "45ce8dd5c4b949afa957ec8ffb926060", - "05b7145fd62d4581b2123c7680f11cdd", - "b96267f014814ec5b96ad7e6165104b1", - "bce34bdbfbd64f1f8353a4e8515cee0b", - "93b8206f8c5841a692cdce985ae301d8", - "c95f592620c64da595cc787567b2c4db", - "8a97071f862c4ec3b4b4140d2e34eda2" + "43d95e3e6b704cb5ae941541862e35fe", + "fca543b71352477db00545b3990d44fa", + "d3c971a3507249c9a22cad026e46d739", + "6da776e94f7740b9aae06f298c1e03cd", + "b4aec5e3895e4a19912c74777e9ea835", + "ef4dc5b756d74129bd2d643d99a1ab2e", + "30243b81748e497eb526b25404e95826", + "3bb9b93e595d4a0ca973ded476c0a5d0", + "b770951ecace4b02ad1575fe9eb9e640", + "79009c4ea2bf46b1a3a2c6558fa6ec2f", + "5cb081d3a038482583350d018a768bd4" ] }, "id": "1hOamw_7C8Pb", - "outputId": "ca166490-22db-4732-b29f-866b7593e489" + "outputId": "0e3805dc-1bfd-4222-9052-0b2111ea3091" }, "outputs": [ { @@ -689,7 +849,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "b8cdabf5c05848f38f03850cab08b56f", + "model_id": "43d95e3e6b704cb5ae941541862e35fe", "version_major": 2, "version_minor": 0 }, @@ -727,13 +887,13 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 19, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Td3PiLdmEf7z", - "outputId": "96403b0f-6079-43ce-df16-d4583f09906b" + "outputId": "464da4f6-f028-4667-889d-a812382739b0" }, "outputs": [ { @@ -753,7 +913,7 @@ { "cell_type": "markdown", "metadata": { - "id": "UWRSaalfdyDX" + "id": "yy9PunCJdI-G" }, "source": [ "### Define the Training Generator\n", @@ -763,7 +923,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 20, "metadata": { "id": "vX59u8CqEf4J" }, @@ -791,27 +951,27 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 21, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, - "id": "h2sO13XDGvq1", - "outputId": "a150246e-ceb5-46ac-db71-2a8177a9d04d" + "id": "AsFKboVRaV6r", + "outputId": "9cb33f79-1b17-439d-88d3-61cd984124f6" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Epoch 1 in 8.46 sec: Train Accuracy: 0.9252, Test Accuracy: 0.9270\n", - "Epoch 2 in 7.79 sec: Train Accuracy: 0.9429, Test Accuracy: 0.9412\n", - "Epoch 3 in 9.84 sec: Train Accuracy: 0.9533, Test Accuracy: 0.9514\n", - "Epoch 4 in 9.47 sec: Train Accuracy: 0.9602, Test Accuracy: 0.9551\n", - "Epoch 5 in 9.32 sec: Train Accuracy: 0.9652, Test Accuracy: 0.9602\n", - "Epoch 6 in 9.30 sec: Train Accuracy: 0.9692, Test Accuracy: 0.9630\n", - "Epoch 7 in 9.24 sec: Train Accuracy: 0.9726, Test Accuracy: 0.9655\n", - "Epoch 8 in 8.00 sec: Train Accuracy: 0.9755, Test Accuracy: 0.9667\n" + "Epoch 1 in 4.96 sec: Train Accuracy: 0.9254,Test Accuracy: 0.9271\n", + "Epoch 2 in 3.22 sec: Train Accuracy: 0.9428,Test Accuracy: 0.9418\n", + "Epoch 3 in 3.23 sec: Train Accuracy: 0.9532,Test Accuracy: 0.9517\n", + "Epoch 4 in 3.26 sec: Train Accuracy: 0.9600,Test Accuracy: 0.9557\n", + "Epoch 5 in 3.28 sec: Train Accuracy: 0.9651,Test Accuracy: 0.9605\n", + "Epoch 6 in 3.11 sec: Train Accuracy: 0.9691,Test Accuracy: 0.9628\n", + "Epoch 7 in 3.25 sec: Train Accuracy: 0.9726,Test Accuracy: 0.9648\n", + "Epoch 8 in 3.15 sec: Train Accuracy: 0.9754,Test Accuracy: 0.9665\n" ] } ], @@ -841,13 +1001,13 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 22, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "L78o7eeyGvn5", - "outputId": "76d16565-0d9e-4f5f-c6b1-4cf4a683d0e7" + "outputId": "8f32bb0f-9a73-48a9-dbcd-4eb93ba3f606" }, "outputs": [ { @@ -863,18 +1023,25 @@ "Requirement already satisfied: etils[epath,epy] in /usr/local/lib/python3.10/dist-packages (from grain) (1.10.0)\n", "Collecting jaxtyping (from grain)\n", " Downloading jaxtyping-0.2.36-py3-none-any.whl.metadata (6.5 kB)\n", - "Requirement already satisfied: more-itertools>=9.1.0 in /usr/local/lib/python3.10/dist-packages (from grain) (10.5.0)\n", + "Collecting more-itertools>=9.1.0 (from grain)\n", + " Downloading more_itertools-10.5.0-py3-none-any.whl.metadata (36 kB)\n", "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from grain) (1.26.4)\n", "Requirement already satisfied: typing_extensions in /usr/local/lib/python3.10/dist-packages (from etils[epath,epy]->grain) (4.12.2)\n", "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from etils[epath,epy]->grain) (2024.10.0)\n", "Requirement already satisfied: importlib_resources in /usr/local/lib/python3.10/dist-packages (from etils[epath,epy]->grain) (6.4.5)\n", "Requirement already satisfied: zipp in /usr/local/lib/python3.10/dist-packages (from etils[epath,epy]->grain) (3.21.0)\n", "Downloading grain-0.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (418 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m419.0/419.0 kB\u001b[0m \u001b[31m7.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m419.0/419.0 kB\u001b[0m \u001b[31m7.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading more_itertools-10.5.0-py3-none-any.whl (60 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m61.0/61.0 kB\u001b[0m \u001b[31m3.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading jaxtyping-0.2.36-py3-none-any.whl (55 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m55.8/55.8 kB\u001b[0m \u001b[31m4.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hInstalling collected packages: jaxtyping, grain\n", - "Successfully installed grain-0.2.2 jaxtyping-0.2.36\n" + "\u001b[?25hInstalling collected packages: more-itertools, jaxtyping, grain\n", + " Attempting uninstall: more-itertools\n", + " Found existing installation: more-itertools 8.10.0\n", + " Uninstalling more-itertools-8.10.0:\n", + " Successfully uninstalled more-itertools-8.10.0\n", + "Successfully installed grain-0.2.2 jaxtyping-0.2.36 more-itertools-10.5.0\n" ] } ], @@ -893,7 +1060,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 23, "metadata": { "id": "mS62eVL9Ifmz" }, @@ -917,7 +1084,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 24, "metadata": { "id": "bnrhac5Hh7y1" }, @@ -930,6 +1097,7 @@ " self.load_data()\n", "\n", " def load_data(self):\n", + " # Load the MNIST dataset using PyGrain\n", " self.dataset = MNIST(self.data_dir, download=True, train=self.train)\n", "\n", " def __len__(self):\n", @@ -951,7 +1119,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 25, "metadata": { "id": "pN3oF7-ostGE" }, @@ -971,31 +1139,31 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 26, "metadata": { "id": "f1VnTuX3u_kL" }, "outputs": [], "source": [ - "# Convert training data to JAX arrays and encode labels as one-hot vectors\n", "train_images = jnp.array([mnist_dataset[i][0] for i in range(len(mnist_dataset))], dtype=jnp.float32)\n", "train_labels = one_hot(np.array([mnist_dataset[i][1] for i in range(len(mnist_dataset))]), n_targets)\n", "\n", - "# Load test dataset and process it\n", "mnist_dataset_test = MNIST(data_dir, download=True, train=False)\n", + "\n", + "# Convert test images to JAX arrays and encode test labels as one-hot vectors\n", "test_images = jnp.array([np.ravel(np.array(mnist_dataset_test[i][0], dtype=np.float32)) for i in range(len(mnist_dataset_test))], dtype=jnp.float32)\n", "test_labels = one_hot(np.array([mnist_dataset_test[i][1] for i in range(len(mnist_dataset_test))]), n_targets)" ] }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 27, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "a2NHlp9klrQL", - "outputId": "14be58c0-851e-4a44-dfcc-d02f0718dab5" + "outputId": "cc9e0958-8484-4669-a2d1-abac36a3097f" }, "outputs": [ { @@ -1015,17 +1183,15 @@ { "cell_type": "markdown", "metadata": { - "id": "fETnWRo2crhf" + "id": "1QPbXt7O0JN-" }, "source": [ - "### Initialize PyGrain DataLoader\n", - "\n", - "Set up a PyGrain DataLoader for sequential batch sampling." + "### Initialize PyGrain DataLoader" ] }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 28, "metadata": { "id": "9RuFTcsCs2Ac" }, @@ -1033,10 +1199,9 @@ "source": [ "sampler = pygrain.SequentialSampler(\n", " num_records=len(mnist_dataset),\n", - " shard_options=pygrain.NoSharding()) # Single-device, no sharding\n", + " shard_options=pygrain.ShardByJaxProcess()) # Shard across TPU cores\n", "\n", "def pygrain_training_generator():\n", - " \"\"\"Grain DataLoader generator for training.\"\"\"\n", " return pygrain.DataLoader(\n", " data_source=mnist_dataset,\n", " sampler=sampler,\n", @@ -1057,27 +1222,27 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 29, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "cjxJRtiTadEI", - "outputId": "3f624366-b683-4d20-9d0a-777d345b0e21" + "outputId": "a620e9f7-7a01-4ba8-fe16-6f988401c7c1" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Epoch 1 in 15.39 sec: Train Accuracy: 0.9158, Test Accuracy: 0.9196\n", - "Epoch 2 in 15.27 sec: Train Accuracy: 0.9372, Test Accuracy: 0.9384\n", - "Epoch 3 in 12.61 sec: Train Accuracy: 0.9492, Test Accuracy: 0.9468\n", - "Epoch 4 in 12.62 sec: Train Accuracy: 0.9569, Test Accuracy: 0.9532\n", - "Epoch 5 in 12.39 sec: Train Accuracy: 0.9630, Test Accuracy: 0.9579\n", - "Epoch 6 in 12.19 sec: Train Accuracy: 0.9674, Test Accuracy: 0.9615\n", - "Epoch 7 in 12.56 sec: Train Accuracy: 0.9708, Test Accuracy: 0.9650\n", - "Epoch 8 in 13.04 sec: Train Accuracy: 0.9737, Test Accuracy: 0.9671\n" + "Epoch 1 in 8.05 sec: Train Accuracy: 0.9159,Test Accuracy: 0.9197\n", + "Epoch 2 in 8.14 sec: Train Accuracy: 0.9371,Test Accuracy: 0.9383\n", + "Epoch 3 in 8.99 sec: Train Accuracy: 0.9493,Test Accuracy: 0.9468\n", + "Epoch 4 in 9.00 sec: Train Accuracy: 0.9568,Test Accuracy: 0.9536\n", + "Epoch 5 in 8.40 sec: Train Accuracy: 0.9632,Test Accuracy: 0.9576\n", + "Epoch 6 in 8.28 sec: Train Accuracy: 0.9674,Test Accuracy: 0.9617\n", + "Epoch 7 in 8.20 sec: Train Accuracy: 0.9708,Test Accuracy: 0.9649\n", + "Epoch 8 in 8.24 sec: Train Accuracy: 0.9737,Test Accuracy: 0.9672\n" ] } ], @@ -1107,13 +1272,13 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 30, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "19ipxPhI6oSN", - "outputId": "684e445f-d23e-4924-9e76-2c2c9359f0be" + "outputId": "e0d52dfb-6c60-4539-a043-574d2533a744" }, "outputs": [ { @@ -1124,7 +1289,7 @@ " Downloading datasets-3.1.0-py3-none-any.whl.metadata (20 kB)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from datasets) (3.16.1)\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (1.26.4)\n", - "Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (17.0.0)\n", + "Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (18.0.0)\n", "Collecting dill<0.3.9,>=0.3.0 (from datasets)\n", " Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)\n", "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (2.2.2)\n", @@ -1136,45 +1301,65 @@ " Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)\n", "Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)\n", " Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)\n", - "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.11.2)\n", + "Collecting aiohttp (from datasets)\n", + " Downloading aiohttp-3.11.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.7 kB)\n", "Requirement already satisfied: huggingface-hub>=0.23.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.26.2)\n", "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets) (24.2)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (6.0.2)\n", - "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (2.4.3)\n", - "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n", + "Collecting aiohappyeyeballs>=2.3.0 (from aiohttp->datasets)\n", + " Downloading aiohappyeyeballs-2.4.3-py3-none-any.whl.metadata (6.1 kB)\n", + "Collecting aiosignal>=1.1.2 (from aiohttp->datasets)\n", + " Downloading aiosignal-1.3.1-py3-none-any.whl.metadata (4.0 kB)\n", "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (24.2.0)\n", - "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.5.0)\n", - "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.1.0)\n", - "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (0.2.0)\n", - "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.17.2)\n", - "Requirement already satisfied: async-timeout<6.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n", + "Collecting frozenlist>=1.1.1 (from aiohttp->datasets)\n", + " Downloading frozenlist-1.5.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (13 kB)\n", + "Collecting multidict<7.0,>=4.5 (from aiohttp->datasets)\n", + " Downloading multidict-6.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.0 kB)\n", + "Collecting propcache>=0.2.0 (from aiohttp->datasets)\n", + " Downloading propcache-0.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.7 kB)\n", + "Collecting yarl<2.0,>=1.17.0 (from aiohttp->datasets)\n", + " Downloading yarl-1.17.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (66 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m66.6/66.6 kB\u001b[0m \u001b[31m1.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting async-timeout<6.0,>=4.0 (from aiohttp->datasets)\n", + " Downloading async_timeout-5.0.1-py3-none-any.whl.metadata (5.1 kB)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.23.0->datasets) (4.12.2)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (3.4.0)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (3.10)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (2.2.3)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (2024.8.30)\n", - "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.9.0.post0)\n", "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.2)\n", "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.2)\n", "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n", "Downloading datasets-3.1.0-py3-none-any.whl (480 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m480.6/480.6 kB\u001b[0m \u001b[31m8.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m480.6/480.6 kB\u001b[0m \u001b[31m7.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m116.3/116.3 kB\u001b[0m \u001b[31m9.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m116.3/116.3 kB\u001b[0m \u001b[31m10.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading fsspec-2024.9.0-py3-none-any.whl (179 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m179.3/179.3 kB\u001b[0m \u001b[31m13.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m179.3/179.3 kB\u001b[0m \u001b[31m15.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading aiohttp-3.11.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.6 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.6/1.6 MB\u001b[0m \u001b[31m30.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading multiprocess-0.70.16-py310-none-any.whl (134 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m9.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m9.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m15.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hInstalling collected packages: xxhash, fsspec, dill, multiprocess, datasets\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m15.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading aiohappyeyeballs-2.4.3-py3-none-any.whl (14 kB)\n", + "Downloading aiosignal-1.3.1-py3-none-any.whl (7.6 kB)\n", + "Downloading async_timeout-5.0.1-py3-none-any.whl (6.2 kB)\n", + "Downloading frozenlist-1.5.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (241 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m241.9/241.9 kB\u001b[0m \u001b[31m18.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading multidict-6.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (124 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m124.6/124.6 kB\u001b[0m \u001b[31m10.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading propcache-0.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (208 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m208.9/208.9 kB\u001b[0m \u001b[31m15.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading yarl-1.17.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (319 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m319.2/319.2 kB\u001b[0m \u001b[31m23.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hInstalling collected packages: xxhash, propcache, multidict, fsspec, frozenlist, dill, async-timeout, aiohappyeyeballs, yarl, multiprocess, aiosignal, aiohttp, datasets\n", " Attempting uninstall: fsspec\n", " Found existing installation: fsspec 2024.10.0\n", " Uninstalling fsspec-2024.10.0:\n", " Successfully uninstalled fsspec-2024.10.0\n", - "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", - "gcsfs 2024.10.0 requires fsspec==2024.10.0, but you have fsspec 2024.9.0 which is incompatible.\u001b[0m\u001b[31m\n", - "\u001b[0mSuccessfully installed datasets-3.1.0 dill-0.3.8 fsspec-2024.9.0 multiprocess-0.70.16 xxhash-3.5.0\n" + "Successfully installed aiohappyeyeballs-2.4.3 aiohttp-3.11.6 aiosignal-1.3.1 async-timeout-5.0.1 datasets-3.1.0 dill-0.3.8 frozenlist-1.5.0 fsspec-2024.9.0 multidict-6.1.0 multiprocess-0.70.16 propcache-0.2.0 xxhash-3.5.0 yarl-1.17.2\n" ] } ], @@ -1182,18 +1367,9 @@ "!pip install datasets" ] }, - { - "cell_type": "markdown", - "metadata": { - "id": "be0h_dZv0593" - }, - "source": [ - "Import Library" - ] - }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 31, "metadata": { "id": "8v1N59p76zn0" }, @@ -1208,78 +1384,76 @@ "id": "8Gaj11tO7C86" }, "source": [ - "### Load and Format MNIST Dataset\n", - "\n", "Load the MNIST dataset from Hugging Face and format it as `numpy` arrays for quick access or `jax` to get JAX arrays." ] }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 32, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 301, "referenced_widgets": [ - "32f6132a31aa4c508d3c3c5ef70348bb", - "d7c2ffa6b143463c91cbf8befca6ca01", - "fd964ecd3926419d92927c67f955d5d0", - "60feca3fde7c4447ad8393b0542eb999", - "3354a0baeca94d18bc6b2a8b8b465b58", - "a0d0d052772b46deac7657ad052991a4", - "fb34783b9cba462e9b690e0979c4b07a", - "8d8170c1ed99490589969cd753c40748", - "f1ecb6db00a54e088f1e09164222d637", - "3cf5dd8d29aa4619b39dc2542df7e42e", - "2e5d42ca710441b389895f2d3b611d0a", - "5d8202da24244dc896e9a8cba6a4ed4f", - "a6d64c953631412b8bd8f0ba53ae4d32", - "69240c5cbfbb4e91961f5b49812a26f0", - "865f38532b784a7c971f5d33b87b443e", - "ceb1c004191947cdaa10af9b9c03c80d", - "64c6041037914779b5e8e9cf5a80ad04", - "562fa6a0e7b846a180ac4b423c5511c5", - "b3b922288f9c4df2a4088279ff6d1531", - "75a1a8ffda554318890cf74c345ed9a9", - "3bae06cacf394a5998c2326199da94f5", - "ff6428a3daa5496c81d5e664aba01f97", - "1ba3f86870724f55b94a35cb6b4173af", - "b3e163fd8b8a4f289d5a25611cb66d23", - "abd2daba215e4f7c9ddabde04d6eb382", - "e22ee019049144d5aba573cdf4dbe4fc", - "6ac765dac67841a69218140785f024c6", - "7b057411a54e434fb74804b90daa8d44", - "563f71b3c67d47c3ab1100f5dc1b98f3", - "d81a657361ab4bba8bcc0cf309d2ff64", - "20316312ab88471ba90cbb954be3e964", - "698fda742f834473a23fb7e5e4cf239c", - "289b52c5a38146b8b467a5f4678f6271", - "d07c2f37cf914894b1551a8104e6cb70", - "5b55c73d551d483baaa6a1411c2597b1", - "2308f77723f54ac898588f48d1853b65", - "54d2589714d04b2e928b816258cb0df4", - "f84b795348c04c7a950165301a643671", - "bc853a4a8d3c4dbda23d183f0a3b4f27", - "1012ddc0343842d8b913a7d85df8ab8f", - "771a73a8f5084a57afc5654d72e022f0", - "311a43449f074841b6df4130b0871ac9", - "cd4d29cb01134469b52d6936c35eb943", - "013cf89ee6174d29bb3f4fdff7b36049", - "9237d877d84e4b3ab69698ecf56915bb", - "337ef4d37e6b4ff6bf6e8bd4ca93383f", - "b4096d3837b84ccdb8f1186435c87281", - "7259d3b7e11b4736b4d2aa8e9c55e994", - "1ad1f8e99a864fc4a2bc532d9a4ff110", - "b2b50451eabd40978ef46db5e7dd08c4", - "2dad5c5541e243128e23c3dd3e420ac2", - "a3de458b61e5493081d6bb9cf7e923db", - "37760f8a7b164e6f9c1a23d621e9fe6b", - "745a2aedcfab491fb9cffba19958b0c5", - "2f6c670640d048d2af453638cfde3a1e" + "86617153e14143c6900da3535b74ef07", + "8de57c9ecba14aa5b1f642af5c7e9094", + "515fe154b1b74ed981e877aef503aa99", + "4e291a8b028847328ea1d9a650c20beb", + "87a0c8cdc0ad423daba7082b985cbd2b", + "4764b5b806b94734b760cf6cc2fc224d", + "5307bf3142804235bb688694c517d80c", + "6a2fd6755667443abe7710ad607a79cc", + "91bc1755904e40db8d758db4d09754e3", + "69c38d75960542fb83fa087cae761957", + "dc31cb349c9b4c3580b2b77cbad1325c", + "d451224a0ce540648b0c28d433d85803", + "52f2f12dcffe4507ab92286fd3810db6", + "6ab919475c80413e94afa66304b05338", + "305d05093c6e411cb438a0bbf122d574", + "aa11f21e68994a8d9ddead215f2f4920", + "59a7233abf61461b8b3feeb31b2f544f", + "9d909399be9a4fa48bc3d781905c7f5a", + "5b6172eb4e0541a3b07d4f82de77a303", + "bc3bec617b0040f487f80134537a3068", + "9fe417f8159244f8ac808f2844922cf3", + "c4748e35e8574bb286a527295df98c8e", + "f50572e8058c4864bb8143c364d191f9", + "436955f611674e27b4ddf3e040cc5ce9", + "048231bf788c447091b8ef0174101f42", + "97009f7e20d84c7c9d89f7497efc494c", + "84e2844437884f6c89683e6545a2262e", + "df3019cc6aa44a4cbcb62096444769a7", + "ce17fe81850c49cd924297d21ecda621", + "422117e32e0b4a95bed7925c99fd9f78", + "56ab1fa0212a43a4a70838e440be0e9c", + "1c5483472cea483bbf2a8fe2a9182ce0", + "00034cb6a66143d8a87922befb1da7a6", + "368b51d79aed4184854f155e2951da81", + "eb9de18be48d4a0db1034a38a0287ea6", + "dbec1d9b196849a5ad79a5f083dbe64e", + "66db6915d27b4fb49e1b44f70cb61654", + "80f3e3a30dc24d3fa54bb72dc1c60182", + "c320096ba1e74c7bbbd9509cc11c22e9", + "a664dd9c446040e8b175bb91d1c051db", + "66c7826ff9b4455db9f7e9717a432f73", + "74ec8cec0f3c4c04b76f5fb87ea2d9bb", + "ea4537aef1e247378de1935ad50ef76c", + "a9cffb2f5e194dfaba516bb4c8c47e3f", + "4f17b7ab6ae94ce3b122561bcd8d4427", + "3c0bdc06fe07412bacc00daa6f1eec34", + "1ba273ced1484bcf9855366ff0dc3645", + "7413d8bab616446ba6b820a3f874f6a0", + "53c160c26c634b53a914be18ed91016c", + "ebc4ad2fae264e72a5307a0481a97ab3", + "83ab5e7617fb45898c259bc20f71e958", + "21f1138e807e4946953e3074d72d9a27", + "86d7357878634706b9e214103efa262a", + "3713a0e1880a43bc8b23225dbb8b4c45", + "f9f85ce1cbf34a7da27804ce7cc6444e" ] }, "id": "a22kTvgk6_fJ", - "outputId": "35fc38b9-a6ab-4b02-ffa4-ab27fac69df4" + "outputId": "53e1d208-5360-479b-c097-0c03c7fac3e8" }, "outputs": [ { @@ -1297,7 +1471,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "32f6132a31aa4c508d3c3c5ef70348bb", + "model_id": "86617153e14143c6900da3535b74ef07", "version_major": 2, "version_minor": 0 }, @@ -1311,7 +1485,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "5d8202da24244dc896e9a8cba6a4ed4f", + "model_id": "d451224a0ce540648b0c28d433d85803", "version_major": 2, "version_minor": 0 }, @@ -1325,7 +1499,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "1ba3f86870724f55b94a35cb6b4173af", + "model_id": "f50572e8058c4864bb8143c364d191f9", "version_major": 2, "version_minor": 0 }, @@ -1339,7 +1513,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "d07c2f37cf914894b1551a8104e6cb70", + "model_id": "368b51d79aed4184854f155e2951da81", "version_major": 2, "version_minor": 0 }, @@ -1353,7 +1527,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "9237d877d84e4b3ab69698ecf56915bb", + "model_id": "4f17b7ab6ae94ce3b122561bcd8d4427", "version_major": 2, "version_minor": 0 }, @@ -1366,29 +1540,55 @@ } ], "source": [ - "mnist_dataset = load_dataset(\"mnist\").with_format(\"numpy\")" + "mnist_dataset = load_dataset(\"mnist\", cache_dir=data_dir).with_format(\"numpy\")" ] }, { "cell_type": "markdown", "metadata": { - "id": "IFjTyGxY19b0" + "id": "tgI7dIaX7JzM" }, "source": [ "### Extract images and labels\n", "\n", - "Get image shape and flatten for model input" + "Get image shape and flatten for model input." ] }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 33, + "metadata": { + "id": "NHrKatD_7HbH" + }, + "outputs": [], + "source": [ + "train_images = mnist_dataset[\"train\"][\"image\"]\n", + "train_labels = mnist_dataset[\"train\"][\"label\"]\n", + "test_images = mnist_dataset[\"test\"][\"image\"]\n", + "test_labels = mnist_dataset[\"test\"][\"label\"]\n", + "\n", + "# Extract image shape\n", + "image_shape = train_images.shape[1:]\n", + "num_features = image_shape[0] * image_shape[1]\n", + "\n", + "# Flatten the images\n", + "train_images = train_images.reshape(-1, num_features)\n", + "test_images = test_images.reshape(-1, num_features)\n", + "\n", + "# One-hot encode the labels\n", + "train_labels = one_hot(train_labels, n_targets)\n", + "test_labels = one_hot(test_labels, n_targets)" + ] + }, + { + "cell_type": "code", + "execution_count": 34, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, - "id": "NHrKatD_7HbH", - "outputId": "deec1739-2fc0-4e71-8567-f2e0c9db198b" + "id": "dITh435Z7Nwb", + "outputId": "cd77ebf6-7d44-420f-f8d8-4357f915c956" }, "outputs": [ { @@ -1401,21 +1601,6 @@ } ], "source": [ - "train_images = mnist_dataset[\"train\"][\"image\"]\n", - "train_labels = mnist_dataset[\"train\"][\"label\"]\n", - "test_images = mnist_dataset[\"test\"][\"image\"]\n", - "test_labels = mnist_dataset[\"test\"][\"label\"]\n", - "\n", - "# Flatten images and one-hot encode labels\n", - "image_shape = train_images.shape[1:]\n", - "num_features = image_shape[0] * image_shape[1]\n", - "\n", - "train_images = train_images.reshape(-1, num_features)\n", - "test_images = test_images.reshape(-1, num_features)\n", - "\n", - "train_labels = one_hot(train_labels, n_targets)\n", - "test_labels = one_hot(test_labels, n_targets)\n", - "\n", "print('Train:', train_images.shape, train_labels.shape)\n", "print('Test:', test_images.shape, test_labels.shape)" ] @@ -1433,7 +1618,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 35, "metadata": { "id": "-zLJhogj7RL-" }, @@ -1459,27 +1644,27 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 36, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, - "id": "RhloYGsw6nPf", - "outputId": "d49c1cd2-a546-46a6-84fb-d9507c38f4ca" + "id": "Ui6aLiZP7aLe", + "outputId": "48347baf-30f2-443d-b3bf-b12100d96b8f" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Epoch 1 in 9.77 sec: Train Accuracy: 0.9158, Test Accuracy: 0.9196\n", - "Epoch 2 in 9.94 sec: Train Accuracy: 0.9372, Test Accuracy: 0.9384\n", - "Epoch 3 in 9.44 sec: Train Accuracy: 0.9492, Test Accuracy: 0.9468\n", - "Epoch 4 in 9.48 sec: Train Accuracy: 0.9569, Test Accuracy: 0.9532\n", - "Epoch 5 in 9.41 sec: Train Accuracy: 0.9630, Test Accuracy: 0.9579\n", - "Epoch 6 in 9.98 sec: Train Accuracy: 0.9674, Test Accuracy: 0.9615\n", - "Epoch 7 in 12.19 sec: Train Accuracy: 0.9708, Test Accuracy: 0.9650\n", - "Epoch 8 in 10.91 sec: Train Accuracy: 0.9737, Test Accuracy: 0.9671\n" + "Epoch 1 in 6.24 sec: Train Accuracy: 0.9159,Test Accuracy: 0.9197\n", + "Epoch 2 in 5.76 sec: Train Accuracy: 0.9371,Test Accuracy: 0.9383\n", + "Epoch 3 in 5.70 sec: Train Accuracy: 0.9493,Test Accuracy: 0.9468\n", + "Epoch 4 in 6.36 sec: Train Accuracy: 0.9568,Test Accuracy: 0.9536\n", + "Epoch 5 in 5.89 sec: Train Accuracy: 0.9632,Test Accuracy: 0.9576\n", + "Epoch 6 in 5.78 sec: Train Accuracy: 0.9674,Test Accuracy: 0.9617\n", + "Epoch 7 in 5.74 sec: Train Accuracy: 0.9708,Test Accuracy: 0.9649\n", + "Epoch 8 in 6.21 sec: Train Accuracy: 0.9737,Test Accuracy: 0.9672\n" ] } ], @@ -1490,18 +1675,22 @@ { "cell_type": "markdown", "metadata": { - "id": "qXylIOwidWI3" + "id": "_JR0V1Aix9Id" }, "source": [ "## Summary\n", "\n", - "This notebook has introduced efficient strategies for data loading on a CPU with JAX, demonstrating how to integrate popular libraries like PyTorch DataLoader, TensorFlow Datasets, Grain, and Hugging Face Datasets. Each library offers distinct advantages, enabling you to streamline the data loading process for machine learning tasks. By understanding the strengths of these methods, you can select the approach that best suits your project's specific requirements." + "This notebook introduced efficient methods for multi-device distributed data loading on TPUs with JAX. You explored how to leverage popular libraries like PyTorch DataLoader, TensorFlow Datasets, Grain, and Hugging Face Datasets to optimize the data loading process for machine learning tasks. Each library offers unique advantages, enabling you to choose the best approach based on your project’s requirements.\n", + "\n", + "For more in-depth strategies on distributed data loading with JAX, including global data pipelines and per-device processing, refer to the [Distributed Data Loading Guide](https://jax.readthedocs.io/en/latest/distributed_data_loading.html)." ] } ], "metadata": { + "accelerator": "TPU", "colab": { - "name": "data_loaders_on_cpu_with_jax.ipynb", + "gpuType": "V28", + "name": "data_loaders_for_multi_device_setups_with_jax.ipynb", "provenance": [] }, "jupytext": { @@ -1516,7 +1705,7 @@ }, "widgets": { "application/vnd.jupyter.widget-state+json": { - "013cf89ee6174d29bb3f4fdff7b36049": { + "00034cb6a66143d8a87922befb1da7a6": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", @@ -1531,7 +1720,55 @@ "description_width": "" } }, - "05b7145fd62d4581b2123c7680f11cdd": { + "048231bf788c447091b8ef0174101f42": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_422117e32e0b4a95bed7925c99fd9f78", + "max": 2595890, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_56ab1fa0212a43a4a70838e440be0e9c", + "value": 2595890 + } + }, + "1ba273ced1484bcf9855366ff0dc3645": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_21f1138e807e4946953e3074d72d9a27", + "max": 10000, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_86d7357878634706b9e214103efa262a", + "value": 10000 + } + }, + "1c5483472cea483bbf2a8fe2a9182ce0": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", @@ -1583,53 +1820,14 @@ "width": null } }, - "1012ddc0343842d8b913a7d85df8ab8f": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", + "21f1138e807e4946953e3074d72d9a27": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "119ac8428f9441e7a25eb0afef2fbb2a": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_bce34bdbfbd64f1f8353a4e8515cee0b", - "max": 5, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_93b8206f8c5841a692cdce985ae301d8", - "value": 5 - } - }, - "1ad1f8e99a864fc4a2bc532d9a4ff110": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", @@ -1674,145 +1872,43 @@ "width": null } }, - "1ba3f86870724f55b94a35cb6b4173af": { + "30243b81748e497eb526b25404e95826": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_b3e163fd8b8a4f289d5a25611cb66d23", - "IPY_MODEL_abd2daba215e4f7c9ddabde04d6eb382", - "IPY_MODEL_e22ee019049144d5aba573cdf4dbe4fc" - ], - "layout": "IPY_MODEL_6ac765dac67841a69218140785f024c6" - } - }, - "20316312ab88471ba90cbb954be3e964": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ProgressStyleModel", + "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", + "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", - "bar_color": null, "description_width": "" } }, - "2308f77723f54ac898588f48d1853b65": { + "305d05093c6e411cb438a0bbf122d574": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", - "model_name": "FloatProgressModel", + "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", + "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", + "_view_name": "HTMLView", "description": "", "description_tooltip": null, - "layout": "IPY_MODEL_771a73a8f5084a57afc5654d72e022f0", - "max": 60000, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_311a43449f074841b6df4130b0871ac9", - "value": 60000 - } - }, - "289b52c5a38146b8b467a5f4678f6271": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "2dad5c5541e243128e23c3dd3e420ac2": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "2e5d42ca710441b389895f2d3b611d0a": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "2f6c670640d048d2af453638cfde3a1e": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "311a43449f074841b6df4130b0871ac9": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" + "layout": "IPY_MODEL_9fe417f8159244f8ac808f2844922cf3", + "placeholder": "​", + "style": "IPY_MODEL_c4748e35e8574bb286a527295df98c8e", + "value": " 15.6M/15.6M [00:00<00:00, 43.3MB/s]" } }, - "32f6132a31aa4c508d3c3c5ef70348bb": { + "368b51d79aed4184854f155e2951da81": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HBoxModel", @@ -1827,14 +1923,14 @@ "_view_name": "HBoxView", "box_style": "", "children": [ - "IPY_MODEL_d7c2ffa6b143463c91cbf8befca6ca01", - "IPY_MODEL_fd964ecd3926419d92927c67f955d5d0", - "IPY_MODEL_60feca3fde7c4447ad8393b0542eb999" + "IPY_MODEL_eb9de18be48d4a0db1034a38a0287ea6", + "IPY_MODEL_dbec1d9b196849a5ad79a5f083dbe64e", + "IPY_MODEL_66db6915d27b4fb49e1b44f70cb61654" ], - "layout": "IPY_MODEL_3354a0baeca94d18bc6b2a8b8b465b58" + "layout": "IPY_MODEL_80f3e3a30dc24d3fa54bb72dc1c60182" } }, - "3354a0baeca94d18bc6b2a8b8b465b58": { + "3713a0e1880a43bc8b23225dbb8b4c45": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", @@ -1886,44 +1982,7 @@ "width": null } }, - "337ef4d37e6b4ff6bf6e8bd4ca93383f": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_b2b50451eabd40978ef46db5e7dd08c4", - "placeholder": "​", - "style": "IPY_MODEL_2dad5c5541e243128e23c3dd3e420ac2", - "value": "Generating test split: 100%" - } - }, - "37760f8a7b164e6f9c1a23d621e9fe6b": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "3bae06cacf394a5998c2326199da94f5": { + "3bb9b93e595d4a0ca973ded476c0a5d0": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", @@ -1975,7 +2034,28 @@ "width": null } }, - "3cf5dd8d29aa4619b39dc2542df7e42e": { + "3c0bdc06fe07412bacc00daa6f1eec34": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_ebc4ad2fae264e72a5307a0481a97ab3", + "placeholder": "​", + "style": "IPY_MODEL_83ab5e7617fb45898c259bc20f71e958", + "value": "Generating test split: 100%" + } + }, + "422117e32e0b4a95bed7925c99fd9f78": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", @@ -2027,7 +2107,50 @@ "width": null } }, - "45ce8dd5c4b949afa957ec8ffb926060": { + "436955f611674e27b4ddf3e040cc5ce9": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_df3019cc6aa44a4cbcb62096444769a7", + "placeholder": "​", + "style": "IPY_MODEL_ce17fe81850c49cd924297d21ecda621", + "value": "test-00000-of-00001.parquet: 100%" + } + }, + "43d95e3e6b704cb5ae941541862e35fe": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_fca543b71352477db00545b3990d44fa", + "IPY_MODEL_d3c971a3507249c9a22cad026e46d739", + "IPY_MODEL_6da776e94f7740b9aae06f298c1e03cd" + ], + "layout": "IPY_MODEL_b4aec5e3895e4a19912c74777e9ea835" + } + }, + "4764b5b806b94734b760cf6cc2fc224d": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", @@ -2079,7 +2202,7 @@ "width": null } }, - "54d2589714d04b2e928b816258cb0df4": { + "4e291a8b028847328ea1d9a650c20beb": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", @@ -2094,43 +2217,59 @@ "_view_name": "HTMLView", "description": "", "description_tooltip": null, - "layout": "IPY_MODEL_cd4d29cb01134469b52d6936c35eb943", + "layout": "IPY_MODEL_69c38d75960542fb83fa087cae761957", "placeholder": "​", - "style": "IPY_MODEL_013cf89ee6174d29bb3f4fdff7b36049", - "value": " 60000/60000 [00:01<00:00, 78483.65 examples/s]" + "style": "IPY_MODEL_dc31cb349c9b4c3580b2b77cbad1325c", + "value": " 6.97k/6.97k [00:00<00:00, 617kB/s]" } }, - "562fa6a0e7b846a180ac4b423c5511c5": { + "4f17b7ab6ae94ce3b122561bcd8d4427": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", + "model_name": "HBoxModel", "state": { + "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", + "_model_name": "HBoxModel", "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_3c0bdc06fe07412bacc00daa6f1eec34", + "IPY_MODEL_1ba273ced1484bcf9855366ff0dc3645", + "IPY_MODEL_7413d8bab616446ba6b820a3f874f6a0" + ], + "layout": "IPY_MODEL_53c160c26c634b53a914be18ed91016c" } }, - "563f71b3c67d47c3ab1100f5dc1b98f3": { + "515fe154b1b74ed981e877aef503aa99": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", + "model_name": "FloatProgressModel", "state": { + "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", + "_model_name": "FloatProgressModel", "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_6a2fd6755667443abe7710ad607a79cc", + "max": 6971, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_91bc1755904e40db8d758db4d09754e3", + "value": 6971 } }, - "5b55c73d551d483baaa6a1411c2597b1": { + "52f2f12dcffe4507ab92286fd3810db6": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", @@ -2145,56 +2284,28 @@ "_view_name": "HTMLView", "description": "", "description_tooltip": null, - "layout": "IPY_MODEL_bc853a4a8d3c4dbda23d183f0a3b4f27", + "layout": "IPY_MODEL_59a7233abf61461b8b3feeb31b2f544f", "placeholder": "​", - "style": "IPY_MODEL_1012ddc0343842d8b913a7d85df8ab8f", - "value": "Generating train split: 100%" - } - }, - "5d8202da24244dc896e9a8cba6a4ed4f": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_a6d64c953631412b8bd8f0ba53ae4d32", - "IPY_MODEL_69240c5cbfbb4e91961f5b49812a26f0", - "IPY_MODEL_865f38532b784a7c971f5d33b87b443e" - ], - "layout": "IPY_MODEL_ceb1c004191947cdaa10af9b9c03c80d" + "style": "IPY_MODEL_9d909399be9a4fa48bc3d781905c7f5a", + "value": "train-00000-of-00001.parquet: 100%" } }, - "60feca3fde7c4447ad8393b0542eb999": { + "5307bf3142804235bb688694c517d80c": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", - "model_name": "HTMLModel", + "model_name": "DescriptionStyleModel", "state": { - "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", + "_model_name": "DescriptionStyleModel", "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_3cf5dd8d29aa4619b39dc2542df7e42e", - "placeholder": "​", - "style": "IPY_MODEL_2e5d42ca710441b389895f2d3b611d0a", - "value": " 6.97k/6.97k [00:00<00:00, 101kB/s]" + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" } }, - "64c6041037914779b5e8e9cf5a80ad04": { + "53c160c26c634b53a914be18ed91016c": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", @@ -2246,31 +2357,23 @@ "width": null } }, - "69240c5cbfbb4e91961f5b49812a26f0": { + "56ab1fa0212a43a4a70838e440be0e9c": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", - "model_name": "FloatProgressModel", + "model_name": "ProgressStyleModel", "state": { - "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", + "_model_name": "ProgressStyleModel", "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_b3b922288f9c4df2a4088279ff6d1531", - "max": 15561616, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_75a1a8ffda554318890cf74c345ed9a9", - "value": 15561616 + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" } }, - "698fda742f834473a23fb7e5e4cf239c": { + "59a7233abf61461b8b3feeb31b2f544f": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", @@ -2322,7 +2425,7 @@ "width": null } }, - "6ac765dac67841a69218140785f024c6": { + "5b6172eb4e0541a3b07d4f82de77a303": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", @@ -2374,28 +2477,22 @@ "width": null } }, - "7259d3b7e11b4736b4d2aa8e9c55e994": { + "5cb081d3a038482583350d018a768bd4": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", - "model_name": "HTMLModel", + "model_name": "DescriptionStyleModel", "state": { - "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", + "_model_name": "DescriptionStyleModel", "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_745a2aedcfab491fb9cffba19958b0c5", - "placeholder": "​", - "style": "IPY_MODEL_2f6c670640d048d2af453638cfde3a1e", - "value": " 10000/10000 [00:00<00:00, 13598.62 examples/s]" + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" } }, - "745a2aedcfab491fb9cffba19958b0c5": { + "66c7826ff9b4455db9f7e9717a432f73": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", @@ -2447,23 +2544,7 @@ "width": null } }, - "75a1a8ffda554318890cf74c345ed9a9": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "76a9815e5c2b4764a13409cebaf66821": { + "66db6915d27b4fb49e1b44f70cb61654": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", @@ -2478,13 +2559,13 @@ "_view_name": "HTMLView", "description": "", "description_tooltip": null, - "layout": "IPY_MODEL_c95f592620c64da595cc787567b2c4db", + "layout": "IPY_MODEL_ea4537aef1e247378de1935ad50ef76c", "placeholder": "​", - "style": "IPY_MODEL_8a97071f862c4ec3b4b4140d2e34eda2", - "value": " 5/5 [00:00<00:00, 12.97 file/s]" + "style": "IPY_MODEL_a9cffb2f5e194dfaba516bb4c8c47e3f", + "value": " 60000/60000 [00:00<00:00, 199154.39 examples/s]" } }, - "771a73a8f5084a57afc5654d72e022f0": { + "69c38d75960542fb83fa087cae761957": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", @@ -2536,7 +2617,7 @@ "width": null } }, - "7b057411a54e434fb74804b90daa8d44": { + "6a2fd6755667443abe7710ad607a79cc": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", @@ -2588,117 +2669,73 @@ "width": null } }, - "865f38532b784a7c971f5d33b87b443e": { + "6ab919475c80413e94afa66304b05338": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", - "model_name": "HTMLModel", + "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", + "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", - "_view_name": "HTMLView", + "_view_name": "ProgressView", + "bar_style": "success", "description": "", "description_tooltip": null, - "layout": "IPY_MODEL_3bae06cacf394a5998c2326199da94f5", - "placeholder": "​", - "style": "IPY_MODEL_ff6428a3daa5496c81d5e664aba01f97", - "value": " 15.6M/15.6M [00:00<00:00, 22.6MB/s]" + "layout": "IPY_MODEL_5b6172eb4e0541a3b07d4f82de77a303", + "max": 15561616, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_bc3bec617b0040f487f80134537a3068", + "value": 15561616 } }, - "8a97071f862c4ec3b4b4140d2e34eda2": { + "6da776e94f7740b9aae06f298c1e03cd": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", + "model_name": "HTMLModel", "state": { + "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "8d8170c1ed99490589969cd753c40748": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", + "_model_name": "HTMLModel", "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_79009c4ea2bf46b1a3a2c6558fa6ec2f", + "placeholder": "​", + "style": "IPY_MODEL_5cb081d3a038482583350d018a768bd4", + "value": " 5/5 [00:00<00:00, 22.68 file/s]" } }, - "9237d877d84e4b3ab69698ecf56915bb": { + "7413d8bab616446ba6b820a3f874f6a0": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", - "model_name": "HBoxModel", + "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", + "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_337ef4d37e6b4ff6bf6e8bd4ca93383f", - "IPY_MODEL_b4096d3837b84ccdb8f1186435c87281", - "IPY_MODEL_7259d3b7e11b4736b4d2aa8e9c55e994" - ], - "layout": "IPY_MODEL_1ad1f8e99a864fc4a2bc532d9a4ff110" + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_3713a0e1880a43bc8b23225dbb8b4c45", + "placeholder": "​", + "style": "IPY_MODEL_f9f85ce1cbf34a7da27804ce7cc6444e", + "value": " 10000/10000 [00:00<00:00, 169495.59 examples/s]" } }, - "93b8206f8c5841a692cdce985ae301d8": { + "74ec8cec0f3c4c04b76f5fb87ea2d9bb": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "ProgressStyleModel", @@ -2714,7 +2751,7 @@ "description_width": "" } }, - "a0d0d052772b46deac7657ad052991a4": { + "79009c4ea2bf46b1a3a2c6558fa6ec2f": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", @@ -2766,7 +2803,7 @@ "width": null } }, - "a3de458b61e5493081d6bb9cf7e923db": { + "80f3e3a30dc24d3fa54bb72dc1c60182": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", @@ -2818,73 +2855,22 @@ "width": null } }, - "a6d64c953631412b8bd8f0ba53ae4d32": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_64c6041037914779b5e8e9cf5a80ad04", - "placeholder": "​", - "style": "IPY_MODEL_562fa6a0e7b846a180ac4b423c5511c5", - "value": "train-00000-of-00001.parquet: 100%" - } - }, - "a8b76d5f93004c089676e5a2a9b3336c": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_05b7145fd62d4581b2123c7680f11cdd", - "placeholder": "​", - "style": "IPY_MODEL_b96267f014814ec5b96ad7e6165104b1", - "value": "Dl Completed...: 100%" - } - }, - "abd2daba215e4f7c9ddabde04d6eb382": { + "83ab5e7617fb45898c259bc20f71e958": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", - "model_name": "FloatProgressModel", + "model_name": "DescriptionStyleModel", "state": { - "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", + "_model_name": "DescriptionStyleModel", "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_d81a657361ab4bba8bcc0cf309d2ff64", - "max": 2595890, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_20316312ab88471ba90cbb954be3e964", - "value": 2595890 + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" } }, - "b2b50451eabd40978ef46db5e7dd08c4": { + "84e2844437884f6c89683e6545a2262e": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", @@ -2936,7 +2922,45 @@ "width": null } }, - "b3b922288f9c4df2a4088279ff6d1531": { + "86617153e14143c6900da3535b74ef07": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_8de57c9ecba14aa5b1f642af5c7e9094", + "IPY_MODEL_515fe154b1b74ed981e877aef503aa99", + "IPY_MODEL_4e291a8b028847328ea1d9a650c20beb" + ], + "layout": "IPY_MODEL_87a0c8cdc0ad423daba7082b985cbd2b" + } + }, + "86d7357878634706b9e214103efa262a": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "87a0c8cdc0ad423daba7082b985cbd2b": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", @@ -2988,7 +3012,7 @@ "width": null } }, - "b3e163fd8b8a4f289d5a25611cb66d23": { + "8de57c9ecba14aa5b1f642af5c7e9094": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", @@ -3003,59 +3027,132 @@ "_view_name": "HTMLView", "description": "", "description_tooltip": null, - "layout": "IPY_MODEL_7b057411a54e434fb74804b90daa8d44", + "layout": "IPY_MODEL_4764b5b806b94734b760cf6cc2fc224d", "placeholder": "​", - "style": "IPY_MODEL_563f71b3c67d47c3ab1100f5dc1b98f3", - "value": "test-00000-of-00001.parquet: 100%" + "style": "IPY_MODEL_5307bf3142804235bb688694c517d80c", + "value": "README.md: 100%" } }, - "b4096d3837b84ccdb8f1186435c87281": { + "91bc1755904e40db8d758db4d09754e3": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", - "model_name": "FloatProgressModel", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "97009f7e20d84c7c9d89f7497efc494c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", + "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", + "_view_name": "HTMLView", "description": "", "description_tooltip": null, - "layout": "IPY_MODEL_a3de458b61e5493081d6bb9cf7e923db", - "max": 10000, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_37760f8a7b164e6f9c1a23d621e9fe6b", - "value": 10000 + "layout": "IPY_MODEL_1c5483472cea483bbf2a8fe2a9182ce0", + "placeholder": "​", + "style": "IPY_MODEL_00034cb6a66143d8a87922befb1da7a6", + "value": " 2.60M/2.60M [00:00<00:00, 33.4MB/s]" + } + }, + "9d909399be9a4fa48bc3d781905c7f5a": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "9fe417f8159244f8ac808f2844922cf3": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null } }, - "b8cdabf5c05848f38f03850cab08b56f": { + "a664dd9c446040e8b175bb91d1c051db": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", - "model_name": "HBoxModel", + "model_name": "DescriptionStyleModel", "state": { - "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", + "_model_name": "DescriptionStyleModel", "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_a8b76d5f93004c089676e5a2a9b3336c", - "IPY_MODEL_119ac8428f9441e7a25eb0afef2fbb2a", - "IPY_MODEL_76a9815e5c2b4764a13409cebaf66821" - ], - "layout": "IPY_MODEL_45ce8dd5c4b949afa957ec8ffb926060" + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" } }, - "b96267f014814ec5b96ad7e6165104b1": { + "a9cffb2f5e194dfaba516bb4c8c47e3f": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", @@ -3070,7 +3167,7 @@ "description_width": "" } }, - "bc853a4a8d3c4dbda23d183f0a3b4f27": { + "aa11f21e68994a8d9ddead215f2f4920": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", @@ -3122,7 +3219,7 @@ "width": null } }, - "bce34bdbfbd64f1f8353a4e8515cee0b": { + "b4aec5e3895e4a19912c74777e9ea835": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", @@ -3174,7 +3271,39 @@ "width": null } }, - "c95f592620c64da595cc787567b2c4db": { + "b770951ecace4b02ad1575fe9eb9e640": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "bc3bec617b0040f487f80134537a3068": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "c320096ba1e74c7bbbd9509cc11c22e9": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", @@ -3226,7 +3355,122 @@ "width": null } }, - "cd4d29cb01134469b52d6936c35eb943": { + "c4748e35e8574bb286a527295df98c8e": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "ce17fe81850c49cd924297d21ecda621": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "d3c971a3507249c9a22cad026e46d739": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_3bb9b93e595d4a0ca973ded476c0a5d0", + "max": 5, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_b770951ecace4b02ad1575fe9eb9e640", + "value": 5 + } + }, + "d451224a0ce540648b0c28d433d85803": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_52f2f12dcffe4507ab92286fd3810db6", + "IPY_MODEL_6ab919475c80413e94afa66304b05338", + "IPY_MODEL_305d05093c6e411cb438a0bbf122d574" + ], + "layout": "IPY_MODEL_aa11f21e68994a8d9ddead215f2f4920" + } + }, + "dbec1d9b196849a5ad79a5f083dbe64e": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_66c7826ff9b4455db9f7e9717a432f73", + "max": 60000, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_74ec8cec0f3c4c04b76f5fb87ea2d9bb", + "value": 60000 + } + }, + "dc31cb349c9b4c3580b2b77cbad1325c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "df3019cc6aa44a4cbcb62096444769a7": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", @@ -3278,7 +3522,7 @@ "width": null } }, - "ceb1c004191947cdaa10af9b9c03c80d": { + "ea4537aef1e247378de1935ad50ef76c": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", @@ -3330,29 +3574,7 @@ "width": null } }, - "d07c2f37cf914894b1551a8104e6cb70": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_5b55c73d551d483baaa6a1411c2597b1", - "IPY_MODEL_2308f77723f54ac898588f48d1853b65", - "IPY_MODEL_54d2589714d04b2e928b816258cb0df4" - ], - "layout": "IPY_MODEL_f84b795348c04c7a950165301a643671" - } - }, - "d7c2ffa6b143463c91cbf8befca6ca01": { + "eb9de18be48d4a0db1034a38a0287ea6": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", @@ -3367,13 +3589,13 @@ "_view_name": "HTMLView", "description": "", "description_tooltip": null, - "layout": "IPY_MODEL_a0d0d052772b46deac7657ad052991a4", + "layout": "IPY_MODEL_c320096ba1e74c7bbbd9509cc11c22e9", "placeholder": "​", - "style": "IPY_MODEL_fb34783b9cba462e9b690e0979c4b07a", - "value": "README.md: 100%" + "style": "IPY_MODEL_a664dd9c446040e8b175bb91d1c051db", + "value": "Generating train split: 100%" } }, - "d81a657361ab4bba8bcc0cf309d2ff64": { + "ebc4ad2fae264e72a5307a0481a97ab3": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", @@ -3425,44 +3647,7 @@ "width": null } }, - "e22ee019049144d5aba573cdf4dbe4fc": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_698fda742f834473a23fb7e5e4cf239c", - "placeholder": "​", - "style": "IPY_MODEL_289b52c5a38146b8b467a5f4678f6271", - "value": " 2.60M/2.60M [00:00<00:00, 14.2MB/s]" - } - }, - "f1ecb6db00a54e088f1e09164222d637": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "f84b795348c04c7a950165301a643671": { + "ef4dc5b756d74129bd2d643d99a1ab2e": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", @@ -3514,7 +3699,29 @@ "width": null } }, - "fb34783b9cba462e9b690e0979c4b07a": { + "f50572e8058c4864bb8143c364d191f9": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_436955f611674e27b4ddf3e040cc5ce9", + "IPY_MODEL_048231bf788c447091b8ef0174101f42", + "IPY_MODEL_97009f7e20d84c7c9d89f7497efc494c" + ], + "layout": "IPY_MODEL_84e2844437884f6c89683e6545a2262e" + } + }, + "f9f85ce1cbf34a7da27804ce7cc6444e": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", @@ -3529,43 +3736,25 @@ "description_width": "" } }, - "fd964ecd3926419d92927c67f955d5d0": { + "fca543b71352477db00545b3990d44fa": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", - "model_name": "FloatProgressModel", + "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", + "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", + "_view_name": "HTMLView", "description": "", "description_tooltip": null, - "layout": "IPY_MODEL_8d8170c1ed99490589969cd753c40748", - "max": 6971, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_f1ecb6db00a54e088f1e09164222d637", - "value": 6971 - } - }, - "ff6428a3daa5496c81d5e664aba01f97": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" + "layout": "IPY_MODEL_ef4dc5b756d74129bd2d643d99a1ab2e", + "placeholder": "​", + "style": "IPY_MODEL_30243b81748e497eb526b25404e95826", + "value": "Dl Completed...: 100%" } } } diff --git a/docs/source/data_loaders_on_cpu_with_jax.md b/docs/source/data_loaders_for_multi_device_setups_with_jax.md similarity index 64% rename from docs/source/data_loaders_on_cpu_with_jax.md rename to docs/source/data_loaders_for_multi_device_setups_with_jax.md index d26c687..4494b37 100644 --- a/docs/source/data_loaders_on_cpu_with_jax.md +++ b/docs/source/data_loaders_for_multi_device_setups_with_jax.md @@ -13,40 +13,26 @@ kernelspec: +++ {"id": "PUFGZggH49zp"} -# Introduction to Data Loaders on CPU with JAX +# Introduction to Data Loaders for Multi-Device Training with JAX +++ {"id": "3ia4PKEV5Dr8"} -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/data_loaders_on_cpu_with_jax.ipynb) +[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/data_loaders_for_multi_device_setups_with_jax.ipynb) -This tutorial explores different data loading strategies for using **JAX** on a single [**CPU**](https://jax.readthedocs.io/en/latest/glossary.html#term-CPU). While JAX doesn't include a built-in data loader, it seamlessly integrates with popular data loading libraries, including: +This tutorial explores various data loading strategies for **JAX** in **multi-device distributed** environments, leveraging [**TPUs**](https://jax.readthedocs.io/en/latest/pallas/tpu/details.html#what-is-a-tpu). While JAX doesn't include a built-in data loader, it seamlessly integrates with popular data loading libraries, including: +* [**PyTorch DataLoader**](https://github.com/pytorch/data) +* [**TensorFlow Datasets (TFDS)**](https://github.com/tensorflow/datasets) +* [**Grain**](https://github.com/google/grain) +* [**Hugging Face**](https://huggingface.co/docs/datasets/en/use_with_jax#data-loading) -- [**PyTorch DataLoader**](https://github.com/pytorch/data) -- [**TensorFlow Datasets (TFDS)**](https://github.com/tensorflow/datasets) -- [**Grain**](https://github.com/google/grain) -- [**Hugging Face**](https://huggingface.co/docs/datasets/en/use_with_jax#data-loading) +You'll learn how to use each of these libraries to efficiently load data for an image classification task using the MNIST dataset. -In this tutorial, you'll learn how to efficiently load data using these libraries for a simple image classification task based on the MNIST dataset. +Building on the [Data Loaders on GPU](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_gpu_with_jax.html) tutorial, this guide covers advanced strategies for multi-device setups, such as data sharding with `Mesh` and `NamedSharding` to partition and synchronize data across devices. These techniques allow you to partition and synchronize data across multiple devices, balancing the complexities of distributed systems while optimizing resource usage for large-scale datasets. -Compared to GPU or multi-device setups, CPU-based data loading is straightforward as it avoids challenges like GPU memory management and data synchronization across devices. This makes it ideal for smaller-scale tasks or scenarios where data resides exclusively on the CPU. +If you're looking for CPU-specific data loading advice, see [Data Loaders on CPU](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_cpu_with_jax.html). If you're looking for GPU-specific data loading advice, see [Data Loaders on GPU](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_gpu_with_jax.html). -If you're looking for a multi-device data loading strategy, see [Data Loaders on Multi-Device Setups](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_for_multi_device_setups_with_jax.html). - -+++ {"id": "pEsb135zE-Jo"} - -## Setting JAX to Use CPU Only - -First, you'll restrict JAX to use only the CPU, even if a GPU is available. This ensures consistency and allows you to focus on CPU-based data loading. - -```{code-cell} -:id: vqP6xyObC0_9 - -import os -os.environ['JAX_PLATFORM_NAME'] = 'cpu' -``` - +++ {"id": "-rsMgVtO6asW"} Import JAX API @@ -56,19 +42,20 @@ Import JAX API import jax import jax.numpy as jnp -from jax import random, grad, jit, vmap +from jax import grad, jit, vmap, random, device_put +from jax.sharding import Mesh, PartitionSpec, NamedSharding ``` +++ {"id": "TsFdlkSZKp9S"} -### CPU Setup Verification +## Checking TPU Availability for JAX ```{code-cell} --- colab: base_uri: https://localhost:8080/ id: N3sqvaF3KJw1 -outputId: 449c83d9-d050-4b15-9a8d-f71e340501f2 +outputId: ee3286d0-b75f-46c5-8548-b57e3d895dd7 --- jax.devices() ``` @@ -94,18 +81,18 @@ def init_network_params(sizes, key): return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)] layer_sizes = [784, 512, 512, 10] # Layers of the network -step_size = 0.01 # Learning rate for optimization +step_size = 0.01 # Learning rate num_epochs = 8 # Number of training epochs batch_size = 128 # Batch size for training n_targets = 10 # Number of classes (digits 0-9) -num_pixels = 28 * 28 # Input size (MNIST images are 28x28 pixels) +num_pixels = 28 * 28 # Each MNIST image is 28x28 pixels data_dir = '/tmp/mnist_dataset' # Directory for storing the dataset # Initialize network parameters using the defined layer sizes and a random seed params = init_network_params(layer_sizes, random.PRNGKey(0)) ``` -+++ {"id": "6Ci_CqW7q6XM"} ++++ {"id": "rHLdqeI7D2WZ"} ## Model Prediction with Auto-Batching @@ -122,7 +109,7 @@ def relu(x): return jnp.maximum(0, x) def predict(params, image): - # per-example prediction + # per-example predictions activations = image for w, b in params[:-1]: outputs = jnp.dot(w, activations) + b @@ -136,21 +123,39 @@ def predict(params, image): batched_predict = vmap(predict, in_axes=(None, 0)) ``` -+++ {"id": "niTSr34_sDZi"} ++++ {"id": "AMWmxjVEpH2D"} + +## Multi-device setup using a Mesh of devices + +```{code-cell} +:id: 4Jc5YLFnpE-_ + +# Get the number of available devices (GPUs/TPUs) for sharding +num_devices = len(jax.devices()) + +# Multi-device setup using a Mesh of devices +devices = jax.devices() +mesh = Mesh(devices, ('device',)) + +# Define the sharding specification - split the data along the first axis (batch) +sharding_spec = PartitionSpec('device') +``` + ++++ {"id": "rLqfeORsERek"} ## Utility and Loss Functions You'll now define utility functions for: - - One-hot encoding: Converts class indices to binary vectors. - Accuracy calculation: Measures the performance of the model on the dataset. - Loss computation: Calculates the difference between predictions and targets. To optimize performance: - - [`grad`](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html#jax.grad) is used to compute gradients of the loss function with respect to network parameters. - [`jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit) compiles the update function, enabling faster execution by leveraging JAX's [XLA](https://openxla.org/xla) compilation. +- [`device_put`](https://jax.readthedocs.io/en/latest/_autosummary/jax.device_put.html) to distribute the dataset across TPU cores. + ```{code-cell} :id: sA0a06raEQfS @@ -185,15 +190,16 @@ def reshape_and_one_hot(x, y): return x, y def train_model(num_epochs, params, training_generator, data_loader_type='streamed'): - """Train the model for a given number of epochs.""" + """Train the model for a given number of epochs and device_put for TPU transfer.""" for epoch in range(num_epochs): start_time = time.time() for x, y in training_generator() if data_loader_type == 'streamed' else training_generator: x, y = reshape_and_one_hot(x, y) + x, y = device_put(x, NamedSharding(mesh, sharding_spec)), device_put(y, NamedSharding(mesh, sharding_spec)) params = update(params, x, y) print(f"Epoch {epoch + 1} in {time.time() - start_time:.2f} sec: " - f"Train Accuracy: {accuracy(params, train_images, train_labels):.4f}, " + f"Train Accuracy: {accuracy(params, train_images, train_labels):.4f}," f"Test Accuracy: {accuracy(params, test_images, test_labels):.4f}") ``` @@ -207,8 +213,8 @@ This section shows how to load the MNIST dataset using PyTorch's DataLoader, con --- colab: base_uri: https://localhost:8080/ -id: jmsfrWrHxIhC -outputId: 33dfeada-a763-4d26-f778-a27966e34d55 +id: 33Wyf77WzNjA +outputId: a2378431-79f2-4dc4-aa1a-d98704657d26 --- !pip install torch torchvision ``` @@ -226,21 +232,34 @@ from torchvision.datasets import MNIST :id: 6f6qU8PCc143 def numpy_collate(batch): - """Convert a batch of PyTorch data to NumPy arrays.""" + """Collate function to convert a batch of PyTorch data into NumPy arrays.""" return tree_map(np.asarray, data.default_collate(batch)) class NumpyLoader(data.DataLoader): """Custom DataLoader to return NumPy arrays from a PyTorch Dataset.""" - def __init__(self, dataset, batch_size=1, shuffle=False, **kwargs): - super().__init__(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=numpy_collate, **kwargs) - + def __init__(self, dataset, batch_size=1, + shuffle=False, sampler=None, + batch_sampler=None, num_workers=0, + pin_memory=False, drop_last=False, + timeout=0, worker_init_fn=None): + super(self.__class__, self).__init__(dataset, + batch_size=batch_size, + shuffle=shuffle, + sampler=sampler, + batch_sampler=batch_sampler, + num_workers=num_workers, + collate_fn=numpy_collate, + pin_memory=pin_memory, + drop_last=drop_last, + timeout=timeout, + worker_init_fn=worker_init_fn) class FlattenAndCast(object): """Transform class to flatten and cast images to float32.""" def __call__(self, pic): return np.ravel(np.array(pic, dtype=jnp.float32)) ``` -+++ {"id": "mfSnfJND6I8G"} ++++ {"id": "ec-MHhv6hYsK"} ### Load Dataset with Transformations @@ -250,8 +269,8 @@ Standardize the data by flattening the images, casting them to `float32`, and en --- colab: base_uri: https://localhost:8080/ -id: Kxbl6bcx6crv -outputId: 372bbf4c-3ad5-4fd8-cc5d-27b50f5e4f38 +id: nSviwX9ohhUh +outputId: 0bb3bc04-11ac-4fb6-8854-76a3f5e725a5 --- mnist_dataset = MNIST(data_dir, download=True, transform=FlattenAndCast()) ``` @@ -288,22 +307,24 @@ test_labels = one_hot(np.array(mnist_dataset_test.targets), n_targets) colab: base_uri: https://localhost:8080/ id: Oz-UVnCxG5E8 -outputId: abbaa26d-491a-4e63-e8c9-d3c571f53a28 +outputId: 0f44cb63-b12c-47a7-8bd5-ed773e2b2ec5 --- print('Train:', train_images.shape, train_labels.shape) print('Test:', test_images.shape, test_labels.shape) ``` -+++ {"id": "m3zfxqnMiCbm"} ++++ {"id": "mfSnfJND6I8G"} ### Training Data Generator -Define a generator function using PyTorch's DataLoader for batch training. Setting `num_workers > 0` enables multi-process data loading, which can accelerate data loading for larger datasets or intensive preprocessing tasks. Experiment with different values to find the optimal setting for your hardware and workload. +Define a generator function using PyTorch's DataLoader for batch training. +Setting `num_workers > 0` enables multi-process data loading, which can accelerate data loading for larger datasets or intensive preprocessing tasks. Experiment with different values to find the optimal setting for your hardware and workload. -Note: When setting `num_workers > 0`, you may see the following `RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.` This warning can be safely ignored since data loaders do not use JAX within the forked processes. +Note: When setting `num_workers > 0`, you may see the following `RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.` +This warning can be safely ignored since data loaders do not use JAX within the forked processes. ```{code-cell} -:id: B-fES82EiL6Z +:id: Kxbl6bcx6crv def pytorch_training_generator(mnist_dataset): return NumpyLoader(mnist_dataset, batch_size=batch_size, num_workers=0) @@ -319,29 +340,40 @@ The training loop uses the PyTorch DataLoader to iterate through batches and upd --- colab: base_uri: https://localhost:8080/ -id: vtUjHsh-rJs8 -outputId: 4766333e-4366-493b-995a-102778d1345a +id: MUrJxpjvUyOm +outputId: 629a19b1-acba-418a-f04b-3b78d7909de1 --- train_model(num_epochs, params, pytorch_training_generator(mnist_dataset), data_loader_type='iterable') ``` -+++ {"id": "Nm45ZTo6yrf5"} ++++ {"id": "ACy1PoSVa3zH"} ## Loading Data with TensorFlow Datasets (TFDS) This section demonstrates how to load the MNIST dataset using TFDS, fetch the full dataset for evaluation, and define a training generator for batch processing. GPU usage is explicitly disabled for TensorFlow. ++++ {"id": "tcJRzpyOveWK"} + +Ensure you have the latest versions of both TensorFlow and TensorFlow Datasets + +```{code-cell} +--- +colab: + base_uri: https://localhost:8080/ + height: 1000 +id: _f55HPGAZu6P +outputId: 838c8f76-aa07-49d5-986d-3c88ed516b22 +--- +!pip install --upgrade tensorflow tensorflow-datasets +``` + ```{code-cell} :id: sGaQAk1DHMUx import tensorflow_datasets as tfds -import tensorflow as tf - -# Ensuring CPU-Only Execution, disable any GPU usage(if applicable) for TF -tf.config.set_visible_devices([], device_type='GPU') ``` -+++ {"id": "3xdQY7H6wr3n"} ++++ {"id": "F6OlzaDqwe4p"} ### Fetch Full Dataset for Evaluation @@ -352,12 +384,12 @@ Load the dataset with `tfds.load`, convert it to NumPy arrays, and process it fo colab: base_uri: https://localhost:8080/ height: 104 - referenced_widgets: [b8cdabf5c05848f38f03850cab08b56f, a8b76d5f93004c089676e5a2a9b3336c, - 119ac8428f9441e7a25eb0afef2fbb2a, 76a9815e5c2b4764a13409cebaf66821, 45ce8dd5c4b949afa957ec8ffb926060, - 05b7145fd62d4581b2123c7680f11cdd, b96267f014814ec5b96ad7e6165104b1, bce34bdbfbd64f1f8353a4e8515cee0b, - 93b8206f8c5841a692cdce985ae301d8, c95f592620c64da595cc787567b2c4db, 8a97071f862c4ec3b4b4140d2e34eda2] + referenced_widgets: [43d95e3e6b704cb5ae941541862e35fe, fca543b71352477db00545b3990d44fa, + d3c971a3507249c9a22cad026e46d739, 6da776e94f7740b9aae06f298c1e03cd, b4aec5e3895e4a19912c74777e9ea835, + ef4dc5b756d74129bd2d643d99a1ab2e, 30243b81748e497eb526b25404e95826, 3bb9b93e595d4a0ca973ded476c0a5d0, + b770951ecace4b02ad1575fe9eb9e640, 79009c4ea2bf46b1a3a2c6558fa6ec2f, 5cb081d3a038482583350d018a768bd4] id: 1hOamw_7C8Pb -outputId: ca166490-22db-4732-b29f-866b7593e489 +outputId: 0e3805dc-1bfd-4222-9052-0b2111ea3091 --- # tfds.load returns tf.Tensors (or tf.data.Datasets if batch_size != -1) mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True) @@ -380,13 +412,13 @@ test_labels = one_hot(test_labels, n_targets) colab: base_uri: https://localhost:8080/ id: Td3PiLdmEf7z -outputId: 96403b0f-6079-43ce-df16-d4583f09906b +outputId: 464da4f6-f028-4667-889d-a812382739b0 --- print('Train:', train_images.shape, train_labels.shape) print('Test:', test_images.shape, test_labels.shape) ``` -+++ {"id": "UWRSaalfdyDX"} ++++ {"id": "yy9PunCJdI-G"} ### Define the Training Generator @@ -414,8 +446,8 @@ Use the training generator in a custom training loop. --- colab: base_uri: https://localhost:8080/ -id: h2sO13XDGvq1 -outputId: a150246e-ceb5-46ac-db71-2a8177a9d04d +id: AsFKboVRaV6r +outputId: 9cb33f79-1b17-439d-88d3-61cd984124f6 --- train_model(num_epochs, params, training_generator) ``` @@ -435,7 +467,7 @@ Install Grain colab: base_uri: https://localhost:8080/ id: L78o7eeyGvn5 -outputId: 76d16565-0d9e-4f5f-c6b1-4cf4a683d0e7 +outputId: 8f32bb0f-9a73-48a9-dbcd-4eb93ba3f606 --- !pip install grain ``` @@ -468,6 +500,7 @@ class Dataset: self.load_data() def load_data(self): + # Load the MNIST dataset using PyGrain self.dataset = MNIST(self.data_dir, download=True, train=self.train) def __len__(self): @@ -495,12 +528,12 @@ mnist_dataset = Dataset(data_dir) ```{code-cell} :id: f1VnTuX3u_kL -# Convert training data to JAX arrays and encode labels as one-hot vectors train_images = jnp.array([mnist_dataset[i][0] for i in range(len(mnist_dataset))], dtype=jnp.float32) train_labels = one_hot(np.array([mnist_dataset[i][1] for i in range(len(mnist_dataset))]), n_targets) -# Load test dataset and process it mnist_dataset_test = MNIST(data_dir, download=True, train=False) + +# Convert test images to JAX arrays and encode test labels as one-hot vectors test_images = jnp.array([np.ravel(np.array(mnist_dataset_test[i][0], dtype=np.float32)) for i in range(len(mnist_dataset_test))], dtype=jnp.float32) test_labels = one_hot(np.array([mnist_dataset_test[i][1] for i in range(len(mnist_dataset_test))]), n_targets) ``` @@ -510,27 +543,24 @@ test_labels = one_hot(np.array([mnist_dataset_test[i][1] for i in range(len(mnis colab: base_uri: https://localhost:8080/ id: a2NHlp9klrQL -outputId: 14be58c0-851e-4a44-dfcc-d02f0718dab5 +outputId: cc9e0958-8484-4669-a2d1-abac36a3097f --- print("Train:", train_images.shape, train_labels.shape) print("Test:", test_images.shape, test_labels.shape) ``` -+++ {"id": "fETnWRo2crhf"} ++++ {"id": "1QPbXt7O0JN-"} ### Initialize PyGrain DataLoader -Set up a PyGrain DataLoader for sequential batch sampling. - ```{code-cell} :id: 9RuFTcsCs2Ac sampler = pygrain.SequentialSampler( num_records=len(mnist_dataset), - shard_options=pygrain.NoSharding()) # Single-device, no sharding + shard_options=pygrain.ShardByJaxProcess()) # Shard across TPU cores def pygrain_training_generator(): - """Grain DataLoader generator for training.""" return pygrain.DataLoader( data_source=mnist_dataset, sampler=sampler, @@ -549,7 +579,7 @@ Run the training loop using the Grain DataLoader. colab: base_uri: https://localhost:8080/ id: cjxJRtiTadEI -outputId: 3f624366-b683-4d20-9d0a-777d345b0e21 +outputId: a620e9f7-7a01-4ba8-fe16-6f988401c7c1 --- train_model(num_epochs, params, pygrain_training_generator) ``` @@ -569,15 +599,11 @@ Install the Hugging Face `datasets` library. colab: base_uri: https://localhost:8080/ id: 19ipxPhI6oSN -outputId: 684e445f-d23e-4924-9e76-2c2c9359f0be +outputId: e0d52dfb-6c60-4539-a043-574d2533a744 --- !pip install datasets ``` -+++ {"id": "be0h_dZv0593"} - -Import Library - ```{code-cell} :id: 8v1N59p76zn0 @@ -586,8 +612,6 @@ from datasets import load_dataset +++ {"id": "8Gaj11tO7C86"} -### Load and Format MNIST Dataset - Load the MNIST dataset from Hugging Face and format it as `numpy` arrays for quick access or `jax` to get JAX arrays. ```{code-cell} @@ -595,59 +619,65 @@ Load the MNIST dataset from Hugging Face and format it as `numpy` arrays for qui colab: base_uri: https://localhost:8080/ height: 301 - referenced_widgets: [32f6132a31aa4c508d3c3c5ef70348bb, d7c2ffa6b143463c91cbf8befca6ca01, - fd964ecd3926419d92927c67f955d5d0, 60feca3fde7c4447ad8393b0542eb999, 3354a0baeca94d18bc6b2a8b8b465b58, - a0d0d052772b46deac7657ad052991a4, fb34783b9cba462e9b690e0979c4b07a, 8d8170c1ed99490589969cd753c40748, - f1ecb6db00a54e088f1e09164222d637, 3cf5dd8d29aa4619b39dc2542df7e42e, 2e5d42ca710441b389895f2d3b611d0a, - 5d8202da24244dc896e9a8cba6a4ed4f, a6d64c953631412b8bd8f0ba53ae4d32, 69240c5cbfbb4e91961f5b49812a26f0, - 865f38532b784a7c971f5d33b87b443e, ceb1c004191947cdaa10af9b9c03c80d, 64c6041037914779b5e8e9cf5a80ad04, - 562fa6a0e7b846a180ac4b423c5511c5, b3b922288f9c4df2a4088279ff6d1531, 75a1a8ffda554318890cf74c345ed9a9, - 3bae06cacf394a5998c2326199da94f5, ff6428a3daa5496c81d5e664aba01f97, 1ba3f86870724f55b94a35cb6b4173af, - b3e163fd8b8a4f289d5a25611cb66d23, abd2daba215e4f7c9ddabde04d6eb382, e22ee019049144d5aba573cdf4dbe4fc, - 6ac765dac67841a69218140785f024c6, 7b057411a54e434fb74804b90daa8d44, 563f71b3c67d47c3ab1100f5dc1b98f3, - d81a657361ab4bba8bcc0cf309d2ff64, 20316312ab88471ba90cbb954be3e964, 698fda742f834473a23fb7e5e4cf239c, - 289b52c5a38146b8b467a5f4678f6271, d07c2f37cf914894b1551a8104e6cb70, 5b55c73d551d483baaa6a1411c2597b1, - 2308f77723f54ac898588f48d1853b65, 54d2589714d04b2e928b816258cb0df4, f84b795348c04c7a950165301a643671, - bc853a4a8d3c4dbda23d183f0a3b4f27, 1012ddc0343842d8b913a7d85df8ab8f, 771a73a8f5084a57afc5654d72e022f0, - 311a43449f074841b6df4130b0871ac9, cd4d29cb01134469b52d6936c35eb943, 013cf89ee6174d29bb3f4fdff7b36049, - 9237d877d84e4b3ab69698ecf56915bb, 337ef4d37e6b4ff6bf6e8bd4ca93383f, b4096d3837b84ccdb8f1186435c87281, - 7259d3b7e11b4736b4d2aa8e9c55e994, 1ad1f8e99a864fc4a2bc532d9a4ff110, b2b50451eabd40978ef46db5e7dd08c4, - 2dad5c5541e243128e23c3dd3e420ac2, a3de458b61e5493081d6bb9cf7e923db, 37760f8a7b164e6f9c1a23d621e9fe6b, - 745a2aedcfab491fb9cffba19958b0c5, 2f6c670640d048d2af453638cfde3a1e] + referenced_widgets: [86617153e14143c6900da3535b74ef07, 8de57c9ecba14aa5b1f642af5c7e9094, + 515fe154b1b74ed981e877aef503aa99, 4e291a8b028847328ea1d9a650c20beb, 87a0c8cdc0ad423daba7082b985cbd2b, + 4764b5b806b94734b760cf6cc2fc224d, 5307bf3142804235bb688694c517d80c, 6a2fd6755667443abe7710ad607a79cc, + 91bc1755904e40db8d758db4d09754e3, 69c38d75960542fb83fa087cae761957, dc31cb349c9b4c3580b2b77cbad1325c, + d451224a0ce540648b0c28d433d85803, 52f2f12dcffe4507ab92286fd3810db6, 6ab919475c80413e94afa66304b05338, + 305d05093c6e411cb438a0bbf122d574, aa11f21e68994a8d9ddead215f2f4920, 59a7233abf61461b8b3feeb31b2f544f, + 9d909399be9a4fa48bc3d781905c7f5a, 5b6172eb4e0541a3b07d4f82de77a303, bc3bec617b0040f487f80134537a3068, + 9fe417f8159244f8ac808f2844922cf3, c4748e35e8574bb286a527295df98c8e, f50572e8058c4864bb8143c364d191f9, + 436955f611674e27b4ddf3e040cc5ce9, 048231bf788c447091b8ef0174101f42, 97009f7e20d84c7c9d89f7497efc494c, + 84e2844437884f6c89683e6545a2262e, df3019cc6aa44a4cbcb62096444769a7, ce17fe81850c49cd924297d21ecda621, + 422117e32e0b4a95bed7925c99fd9f78, 56ab1fa0212a43a4a70838e440be0e9c, 1c5483472cea483bbf2a8fe2a9182ce0, + 00034cb6a66143d8a87922befb1da7a6, 368b51d79aed4184854f155e2951da81, eb9de18be48d4a0db1034a38a0287ea6, + dbec1d9b196849a5ad79a5f083dbe64e, 66db6915d27b4fb49e1b44f70cb61654, 80f3e3a30dc24d3fa54bb72dc1c60182, + c320096ba1e74c7bbbd9509cc11c22e9, a664dd9c446040e8b175bb91d1c051db, 66c7826ff9b4455db9f7e9717a432f73, + 74ec8cec0f3c4c04b76f5fb87ea2d9bb, ea4537aef1e247378de1935ad50ef76c, a9cffb2f5e194dfaba516bb4c8c47e3f, + 4f17b7ab6ae94ce3b122561bcd8d4427, 3c0bdc06fe07412bacc00daa6f1eec34, 1ba273ced1484bcf9855366ff0dc3645, + 7413d8bab616446ba6b820a3f874f6a0, 53c160c26c634b53a914be18ed91016c, ebc4ad2fae264e72a5307a0481a97ab3, + 83ab5e7617fb45898c259bc20f71e958, 21f1138e807e4946953e3074d72d9a27, 86d7357878634706b9e214103efa262a, + 3713a0e1880a43bc8b23225dbb8b4c45, f9f85ce1cbf34a7da27804ce7cc6444e] id: a22kTvgk6_fJ -outputId: 35fc38b9-a6ab-4b02-ffa4-ab27fac69df4 +outputId: 53e1d208-5360-479b-c097-0c03c7fac3e8 --- -mnist_dataset = load_dataset("mnist").with_format("numpy") +mnist_dataset = load_dataset("mnist", cache_dir=data_dir).with_format("numpy") ``` -+++ {"id": "IFjTyGxY19b0"} ++++ {"id": "tgI7dIaX7JzM"} ### Extract images and labels -Get image shape and flatten for model input +Get image shape and flatten for model input. ```{code-cell} ---- -colab: - base_uri: https://localhost:8080/ -id: NHrKatD_7HbH -outputId: deec1739-2fc0-4e71-8567-f2e0c9db198b ---- +:id: NHrKatD_7HbH + train_images = mnist_dataset["train"]["image"] train_labels = mnist_dataset["train"]["label"] test_images = mnist_dataset["test"]["image"] test_labels = mnist_dataset["test"]["label"] -# Flatten images and one-hot encode labels +# Extract image shape image_shape = train_images.shape[1:] num_features = image_shape[0] * image_shape[1] +# Flatten the images train_images = train_images.reshape(-1, num_features) test_images = test_images.reshape(-1, num_features) +# One-hot encode the labels train_labels = one_hot(train_labels, n_targets) test_labels = one_hot(test_labels, n_targets) +``` +```{code-cell} +--- +colab: + base_uri: https://localhost:8080/ +id: dITh435Z7Nwb +outputId: cd77ebf6-7d44-420f-f8d8-4357f915c956 +--- print('Train:', train_images.shape, train_labels.shape) print('Test:', test_images.shape, test_labels.shape) ``` @@ -678,14 +708,16 @@ Run the training loop using the Hugging Face training generator. --- colab: base_uri: https://localhost:8080/ -id: RhloYGsw6nPf -outputId: d49c1cd2-a546-46a6-84fb-d9507c38f4ca +id: Ui6aLiZP7aLe +outputId: 48347baf-30f2-443d-b3bf-b12100d96b8f --- train_model(num_epochs, params, hf_training_generator) ``` -+++ {"id": "qXylIOwidWI3"} ++++ {"id": "_JR0V1Aix9Id"} ## Summary -This notebook has introduced efficient strategies for data loading on a CPU with JAX, demonstrating how to integrate popular libraries like PyTorch DataLoader, TensorFlow Datasets, Grain, and Hugging Face Datasets. Each library offers distinct advantages, enabling you to streamline the data loading process for machine learning tasks. By understanding the strengths of these methods, you can select the approach that best suits your project's specific requirements. +This notebook introduced efficient methods for multi-device distributed data loading on TPUs with JAX. You explored how to leverage popular libraries like PyTorch DataLoader, TensorFlow Datasets, Grain, and Hugging Face Datasets to optimize the data loading process for machine learning tasks. Each library offers unique advantages, enabling you to choose the best approach based on your project’s requirements. + +For more in-depth strategies on distributed data loading with JAX, including global data pipelines and per-device processing, refer to the [Distributed Data Loading Guide](https://jax.readthedocs.io/en/latest/distributed_data_loading.html). diff --git a/docs/source/tutorials.md b/docs/source/tutorials.md index 23cdd58..071343b 100644 --- a/docs/source/tutorials.md +++ b/docs/source/tutorials.md @@ -25,6 +25,7 @@ JAX_time_series_classification JAX_transformer_text_classification data_loaders_on_cpu_with_jax data_loaders_on_gpu_with_jax +data_loaders_for_multi_device_setups_with_jax ``` Once you've gone through this content, you can refer to package-specific