From 619aa20d040d589750b2a672a95373fe58c79c8f Mon Sep 17 00:00:00 2001 From: Sandipan Panda Date: Sun, 8 Dec 2024 21:52:47 +0530 Subject: [PATCH 01/16] Add MNIST example with SPMD for JAX Illustrate how to use JAX's `pmap` to express and execute single-program multiple-data (SPMD) programs for data parallelism along a batch dimension Signed-off-by: Sandipan Panda --- .github/workflows/publish-example-images.yaml | 4 + examples/jax/jax-dist-spmd-mnist/Dockerfile | 27 +++ examples/jax/jax-dist-spmd-mnist/README.md | 30 ++++ examples/jax/jax-dist-spmd-mnist/datasets.py | 97 +++++++++++ .../jaxjob_dist_spmd_mnist_gloo.yaml | 19 +++ .../spmd_mnist_classifier_fromscratch.py | 155 ++++++++++++++++++ 6 files changed, 332 insertions(+) create mode 100644 examples/jax/jax-dist-spmd-mnist/Dockerfile create mode 100644 examples/jax/jax-dist-spmd-mnist/README.md create mode 100644 examples/jax/jax-dist-spmd-mnist/datasets.py create mode 100644 examples/jax/jax-dist-spmd-mnist/jaxjob_dist_spmd_mnist_gloo.yaml create mode 100644 examples/jax/jax-dist-spmd-mnist/spmd_mnist_classifier_fromscratch.py diff --git a/.github/workflows/publish-example-images.yaml b/.github/workflows/publish-example-images.yaml index 74dc242551..5df38f1f37 100644 --- a/.github/workflows/publish-example-images.yaml +++ b/.github/workflows/publish-example-images.yaml @@ -74,3 +74,7 @@ jobs: platforms: linux/amd64 dockerfile: examples/pytorch/deepspeed-demo/Dockerfile context: examples/pytorch/deepspeed-demo + - component-name: jaxjob-mnist + platforms: linux/amd64,linux/arm64 + dockerfile: examples/jax/jax-dist-spmd-mnist/Dockerfile + context: examples/jax/jax-dist-spmd-mnist/ diff --git a/examples/jax/jax-dist-spmd-mnist/Dockerfile b/examples/jax/jax-dist-spmd-mnist/Dockerfile new file mode 100644 index 0000000000..805f222a35 --- /dev/null +++ b/examples/jax/jax-dist-spmd-mnist/Dockerfile @@ -0,0 +1,27 @@ +FROM python:3.12 + +RUN pip install --upgrade pip +RUN pip install --upgrade jax absl-py + +RUN apt-get update && apt-get install -y \ + build-essential \ + cmake \ + git \ + libgoogle-glog-dev \ + libgflags-dev \ + libprotobuf-dev \ + protobuf-compiler \ + && rm -rf /var/lib/apt/lists/* + +RUN git clone https://github.com/facebookincubator/gloo.git \ + && cd gloo \ + && git checkout 43b7acbf372cdce14075f3526e39153b7e433b53 \ + && mkdir build \ + && cd build \ + && cmake ../ \ + && make \ + && make install + +WORKDIR /app + +ADD datasets.py spmd_mnist_classifier_fromscratch.py /app diff --git a/examples/jax/jax-dist-spmd-mnist/README.md b/examples/jax/jax-dist-spmd-mnist/README.md new file mode 100644 index 0000000000..3f44afb615 --- /dev/null +++ b/examples/jax/jax-dist-spmd-mnist/README.md @@ -0,0 +1,30 @@ +## An MNIST example with single-program multiple-data (SPMD) data parallelism. + +The aim here is to illustrate how to use JAX's [`pmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html) to express and execute +[SPMD](https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD) programs for data parallelism along a batch dimension, while also +minimizing dependencies by avoiding the use of higher-level layers and +optimizers libraries. + +Adapted from https://github.com/jax-ml/jax/blob/main/examples/spmd_mnist_classifier_fromscratch.py. + +```bash +$ kubectl apply -f examples/jax/jax-dist-spmd-mnist/jaxjob_dist_spmd_mnist_gloo.yaml +``` + +--- + +```bash +$ kubectl get pods -n kubeflow -l training.kubeflow.org/job-name=jaxjob-mnist +``` +--- +```bash +$ PODNAME=$(kubectl get pods -l training.kubeflow.org/job-name=jaxjob-mnist,training.kubeflow.org/replica-type=worker,training.kubeflow.org/replica-index=0 -o +name -n kubeflow) +$ kubectl logs -f ${PODNAME} -n kubeflow +``` + +--- + +```bash +$ kubectl get -o yaml jaxjobs jaxjob-mnist -n kubeflow +``` diff --git a/examples/jax/jax-dist-spmd-mnist/datasets.py b/examples/jax/jax-dist-spmd-mnist/datasets.py new file mode 100644 index 0000000000..60fb8ce25b --- /dev/null +++ b/examples/jax/jax-dist-spmd-mnist/datasets.py @@ -0,0 +1,97 @@ +# Copyright 2018 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Datasets used in examples.""" + + +import array +import gzip +import os +import struct +import urllib.request +from os import path + +import numpy as np + +_DATA = "/tmp/jax_example_data/" + + +def _download(url, filename): + """Download a url to a file in the JAX data temp directory.""" + if not path.exists(_DATA): + os.makedirs(_DATA) + out_file = path.join(_DATA, filename) + if not path.isfile(out_file): + urllib.request.urlretrieve(url, out_file) + print(f"downloaded {url} to {_DATA}") + + +def _partial_flatten(x): + """Flatten all but the first dimension of an ndarray.""" + return np.reshape(x, (x.shape[0], -1)) + + +def _one_hot(x, k, dtype=np.float32): + """Create a one-hot encoding of x of size k.""" + return np.array(x[:, None] == np.arange(k), dtype) + + +def mnist_raw(): + """Download and parse the raw MNIST dataset.""" + # CVDF mirror of http://yann.lecun.com/exdb/mnist/ + base_url = "https://storage.googleapis.com/cvdf-datasets/mnist/" + + def parse_labels(filename): + with gzip.open(filename, "rb") as fh: + _ = struct.unpack(">II", fh.read(8)) + return np.array(array.array("B", fh.read()), dtype=np.uint8) + + def parse_images(filename): + with gzip.open(filename, "rb") as fh: + _, num_data, rows, cols = struct.unpack(">IIII", fh.read(16)) + return np.array(array.array("B", fh.read()), dtype=np.uint8).reshape( + num_data, rows, cols + ) + + for filename in [ + "train-images-idx3-ubyte.gz", + "train-labels-idx1-ubyte.gz", + "t10k-images-idx3-ubyte.gz", + "t10k-labels-idx1-ubyte.gz", + ]: + _download(base_url + filename, filename) + + train_images = parse_images(path.join(_DATA, "train-images-idx3-ubyte.gz")) + train_labels = parse_labels(path.join(_DATA, "train-labels-idx1-ubyte.gz")) + test_images = parse_images(path.join(_DATA, "t10k-images-idx3-ubyte.gz")) + test_labels = parse_labels(path.join(_DATA, "t10k-labels-idx1-ubyte.gz")) + + return train_images, train_labels, test_images, test_labels + + +def mnist(permute_train=False): + """Download, parse and process MNIST data to unit scale and one-hot labels.""" + train_images, train_labels, test_images, test_labels = mnist_raw() + + train_images = _partial_flatten(train_images) / np.float32(255.0) + test_images = _partial_flatten(test_images) / np.float32(255.0) + train_labels = _one_hot(train_labels, 10) + test_labels = _one_hot(test_labels, 10) + + if permute_train: + perm = np.random.RandomState(0).permutation(train_images.shape[0]) + train_images = train_images[perm] + train_labels = train_labels[perm] + + return train_images, train_labels, test_images, test_labels diff --git a/examples/jax/jax-dist-spmd-mnist/jaxjob_dist_spmd_mnist_gloo.yaml b/examples/jax/jax-dist-spmd-mnist/jaxjob_dist_spmd_mnist_gloo.yaml new file mode 100644 index 0000000000..912ebde719 --- /dev/null +++ b/examples/jax/jax-dist-spmd-mnist/jaxjob_dist_spmd_mnist_gloo.yaml @@ -0,0 +1,19 @@ +apiVersion: "kubeflow.org/v1" +kind: JAXJob +metadata: + name: jaxjob-mnist + namespace: kubeflow +spec: + jaxReplicaSpecs: + Worker: + replicas: 2 + restartPolicy: OnFailure + template: + spec: + containers: + - name: jax + image: docker.io/sandipanify/jaxjob-spmd-mnist:latest + command: + - "python3" + - "spmd_mnist_classifier_fromscratch.py" + imagePullPolicy: Always diff --git a/examples/jax/jax-dist-spmd-mnist/spmd_mnist_classifier_fromscratch.py b/examples/jax/jax-dist-spmd-mnist/spmd_mnist_classifier_fromscratch.py new file mode 100644 index 0000000000..5982963ba7 --- /dev/null +++ b/examples/jax/jax-dist-spmd-mnist/spmd_mnist_classifier_fromscratch.py @@ -0,0 +1,155 @@ +"""An MNIST example with single-program multiple-data (SPMD) data parallelism. + +The aim here is to illustrate how to use JAX's `pmap` to express and execute +SPMD programs for data parallelism along a batch dimension, while also +minimizing dependencies by avoiding the use of higher-level layers and +optimizers libraries. +""" + +import multiprocessing +import os +import time +from functools import partial + +import numpy as np +import numpy.random as npr + +# JAX will treat your CPU as a single device by default, regardless of the number +# of cores available. Unfortunately, this means that using `pmap` is not possible out +# of the box – we’ll first need to instruct JAX to split the CPU into multiple devices. +# This variable has to be set before JAX or any library that imports it is imported + +os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count={}".format( + multiprocessing.cpu_count() +) + +import datasets # noqa +import jax # noqa +import jax.numpy as jnp # noqa +from jax import grad, jit, lax, pmap # noqa +from jax.scipy.special import logsumexp # noqa +from jax.tree_util import tree_map # noqa + +jax.config.update("jax_cpu_collectives_implementation", "gloo") + +process_id = int(os.getenv("PROCESS_ID")) +num_processes = int(os.getenv("NUM_PROCESSES")) +coordinator_address = os.getenv("COORDINATOR_ADDRESS") +coordinator_port = int(os.getenv("COORDINATOR_PORT")) +coordinator_address = f"{coordinator_address}:{coordinator_port}" + +jax.distributed.initialize( + coordinator_address=coordinator_address, + num_processes=num_processes, + process_id=process_id, +) + + +def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)): + return [ + (scale * rng.randn(m, n), scale * rng.randn(n)) + for m, n, in zip(layer_sizes[:-1], layer_sizes[1:]) + ] + + +def predict(params, inputs): + activations = inputs + for w, b in params[:-1]: + outputs = jnp.dot(activations, w) + b + activations = jnp.tanh(outputs) + + final_w, final_b = params[-1] + logits = jnp.dot(activations, final_w) + final_b + return logits - logsumexp(logits, axis=1, keepdims=True) + + +def loss(params, batch): + inputs, targets = batch + preds = predict(params, inputs) + return -jnp.mean(jnp.sum(preds * targets, axis=1)) + + +@jit +def accuracy(params, batch): + inputs, targets = batch + target_class = jnp.argmax(targets, axis=1) + predicted_class = jnp.argmax(predict(params, inputs), axis=1) + return jnp.mean(predicted_class == target_class) + + +if __name__ == "__main__": + layer_sizes = [784, 1024, 1024, 10] + param_scale = 0.1 + step_size = 0.001 + num_epochs = 10 + batch_size = 128 + + train_images, train_labels, test_images, test_labels = datasets.mnist() + num_train = train_images.shape[0] + num_complete_batches, leftover = divmod(num_train, batch_size) + num_batches = num_complete_batches + bool(leftover) + + # For this manual SPMD example, we get the number of devices (e.g. CPU, + # GPUs or TPU cores) that we're using, and use it to reshape data minibatches. + num_devices = jax.local_device_count() + + def data_stream(): + rng = npr.RandomState(0) + while True: + perm = rng.permutation(num_train) + for i in range(num_batches): + batch_idx = perm[i * batch_size : (i + 1) * batch_size] # noqa + images, labels = train_images[batch_idx], train_labels[batch_idx] + # For this SPMD example, we reshape the data batch dimension into two + # batch dimensions, one of which is mapped over parallel devices. + batch_size_per_device, ragged = divmod(images.shape[0], num_devices) + if ragged: + msg = "batch size must be divisible by device count, got {} and {}." + raise ValueError(msg.format(batch_size, num_devices)) + shape_prefix = (num_devices, batch_size_per_device) + images = images.reshape(shape_prefix + images.shape[1:]) + labels = labels.reshape(shape_prefix + labels.shape[1:]) + yield images, labels + + batches = data_stream() + + @partial(pmap, axis_name="batch") + def spmd_update(params, batch): + grads = grad(loss)(params, batch) + # We compute the total gradients, summing across the device-mapped axis, + # using the `lax.psum` SPMD primitive, which does a fast all-reduce-sum. + grads = [(lax.psum(dw, "batch"), lax.psum(db, "batch")) for dw, db in grads] + return [ + (w - step_size * dw, b - step_size * db) + for (w, b), (dw, db) in zip(params, grads) + ] + + # We replicate the parameters so that the constituent arrays have a leading + # dimension of size equal to the number of devices we're pmapping over. + init_params = init_random_params(param_scale, layer_sizes) + + def replicate_array(x): + return np.broadcast_to(x, (num_devices,) + x.shape) + + replicated_params = tree_map(replicate_array, init_params) + + print(f"JAX global devices:{jax.devices()}") + print(f"JAX local devices:{jax.local_devices()}") + + print(f"JAX device count:{jax.device_count()}") + print(f"JAX local device count:{jax.local_device_count()}") + + for epoch in range(num_epochs): + start_time = time.time() + for _ in range(num_batches): + replicated_params = spmd_update(replicated_params, next(batches)) + epoch_time = time.time() - start_time + + # We evaluate using the jitted `accuracy` function (not using pmap) by + # grabbing just one of the replicated parameter values. + params = tree_map(lambda x: x[0], replicated_params) + train_acc = accuracy(params, (train_images, train_labels)) + test_acc = accuracy(params, (test_images, test_labels)) + print(f"Epoch {epoch} in {epoch_time:0.2f} sec") + print(f"Training set accuracy {train_acc}") + print(f"Test set accuracy {test_acc}") From ca78118cc41a5be176b1e3169ca300d6d071d3e5 Mon Sep 17 00:00:00 2001 From: Sandipan Panda Date: Mon, 9 Dec 2024 21:53:43 +0530 Subject: [PATCH 02/16] Update CONTRIBUTING.md Use -- server-side to install the latest local changes of Training Operator control plane Signed-off-by: Sandipan Panda --- CONTRIBUTING.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index eca6af84d7..a7bd8ef76e 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -66,7 +66,7 @@ Note, that for the example job below, the PyTorchJob uses the `kubeflow` namespa From here we can apply the manifests to the cluster. ```sh -kubectl apply -k "github.com/kubeflow/training-operator/manifests/overlays/standalone" +kubectl apply --server-side -k "github.com/kubeflow/training-operator/manifests/overlays/standalone" ``` Then we can patch it with the latest operator image. From 7f523e449dc4bb76eae61536d0a1ed5c978f2293 Mon Sep 17 00:00:00 2001 From: Sandipan Panda Date: Thu, 19 Dec 2024 00:36:04 +0530 Subject: [PATCH 03/16] Add JAXJob output Signed-off-by: Sandipan Panda --- .github/workflows/publish-example-images.yaml | 2 +- examples/jax/jax-dist-spmd-mnist/Dockerfile | 4 +- examples/jax/jax-dist-spmd-mnist/README.md | 103 ++++++++++++++++++ .../jaxjob_dist_spmd_mnist_gloo.yaml | 3 - .../spmd_mnist_classifier_fromscratch.py | 21 +++- sdk/python/test/e2e/test_e2e_jaxjob.py | 3 +- 6 files changed, 126 insertions(+), 10 deletions(-) diff --git a/.github/workflows/publish-example-images.yaml b/.github/workflows/publish-example-images.yaml index 5df38f1f37..5012714b57 100644 --- a/.github/workflows/publish-example-images.yaml +++ b/.github/workflows/publish-example-images.yaml @@ -74,7 +74,7 @@ jobs: platforms: linux/amd64 dockerfile: examples/pytorch/deepspeed-demo/Dockerfile context: examples/pytorch/deepspeed-demo - - component-name: jaxjob-mnist + - component-name: jaxjob-dist-spmd-mnist platforms: linux/amd64,linux/arm64 dockerfile: examples/jax/jax-dist-spmd-mnist/Dockerfile context: examples/jax/jax-dist-spmd-mnist/ diff --git a/examples/jax/jax-dist-spmd-mnist/Dockerfile b/examples/jax/jax-dist-spmd-mnist/Dockerfile index 805f222a35..92b406f117 100644 --- a/examples/jax/jax-dist-spmd-mnist/Dockerfile +++ b/examples/jax/jax-dist-spmd-mnist/Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.12 +FROM python:3.13 RUN pip install --upgrade pip RUN pip install --upgrade jax absl-py @@ -25,3 +25,5 @@ RUN git clone https://github.com/facebookincubator/gloo.git \ WORKDIR /app ADD datasets.py spmd_mnist_classifier_fromscratch.py /app + +ENTRYPOINT ["python3", "spmd_mnist_classifier_fromscratch.py"] diff --git a/examples/jax/jax-dist-spmd-mnist/README.md b/examples/jax/jax-dist-spmd-mnist/README.md index 3f44afb615..6194ea9eda 100644 --- a/examples/jax/jax-dist-spmd-mnist/README.md +++ b/examples/jax/jax-dist-spmd-mnist/README.md @@ -16,6 +16,13 @@ $ kubectl apply -f examples/jax/jax-dist-spmd-mnist/jaxjob_dist_spmd_mnist_gloo. ```bash $ kubectl get pods -n kubeflow -l training.kubeflow.org/job-name=jaxjob-mnist ``` + +``` +NAME READY STATUS RESTARTS AGE +jaxjob-mnist-worker-0 0/1 Completed 0 108m +jaxjob-mnist-worker-1 0/1 Completed 0 108m +``` + --- ```bash $ PODNAME=$(kubectl get pods -l training.kubeflow.org/job-name=jaxjob-mnist,training.kubeflow.org/replica-type=worker,training.kubeflow.org/replica-index=0 -o @@ -23,8 +30,104 @@ name -n kubeflow) $ kubectl logs -f ${PODNAME} -n kubeflow ``` +``` +downloaded https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz to /tmp/jax_example_data/ +downloaded https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz to /tmp/jax_example_data/ +downloaded https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz to /tmp/jax_example_data/ +downloaded https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz to /tmp/jax_example_data/ +JAX global devices:[CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7), CpuDevice(id=131072), CpuDevice(id=131073), CpuDevice(id=131074), CpuDevice(id=131075), CpuDevice(id=131076), CpuDevice(id=131077), CpuDevice(id=131078), CpuDevice(id=131079)] +JAX local devices:[CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)] +JAX device count:16 +JAX local device count:8 +JAX process count:2 +Epoch 0 in 1809.25 sec +Training set accuracy 0.09871666878461838 +Test set accuracy 0.09799999743700027 +Epoch 1 in 0.51 sec +Training set accuracy 0.09871666878461838 +Test set accuracy 0.09799999743700027 +Epoch 2 in 0.69 sec +Training set accuracy 0.09871666878461838 +Test set accuracy 0.09799999743700027 +Epoch 3 in 0.81 sec +Training set accuracy 0.09871666878461838 +Test set accuracy 0.09799999743700027 +Epoch 4 in 0.91 sec +Training set accuracy 0.09871666878461838 +Test set accuracy 0.09799999743700027 +Epoch 5 in 0.97 sec +Training set accuracy 0.09871666878461838 +Test set accuracy 0.09799999743700027 +Epoch 6 in 1.12 sec +Training set accuracy 0.09035000205039978 +Test set accuracy 0.08919999748468399 +Epoch 7 in 1.11 sec +Training set accuracy 0.09871666878461838 +Test set accuracy 0.09799999743700027 +Epoch 8 in 1.21 sec +Training set accuracy 0.09871666878461838 +Test set accuracy 0.09799999743700027 +Epoch 9 in 1.29 sec +Training set accuracy 0.09871666878461838 +Test set accuracy 0.09799999743700027 + +``` + --- ```bash $ kubectl get -o yaml jaxjobs jaxjob-mnist -n kubeflow ``` + +``` +apiVersion: kubeflow.org/v1 +kind: JAXJob +metadata: + annotations: + kubectl.kubernetes.io/last-applied-configuration: | + {"apiVersion":"kubeflow.org/v1","kind":"JAXJob","metadata":{"annotations":{},"name":"jaxjob-mnist","namespace":"kubeflow"},"spec":{"jaxReplicaSpecs":{"Worker":{"replicas":2,"restartPolicy":"OnFailure","template":{"spec":{"containers":[{"image":"docker.io/sandipanify/jaxjob-spmd-mnist:latest","imagePullPolicy":"Always","name":"jax"}]}}}}}} + creationTimestamp: "2024-12-18T16:47:28Z" + generation: 1 + name: jaxjob-mnist + namespace: kubeflow + resourceVersion: "3620" + uid: 15f1db77-3326-405d-95e6-3d9a0d581611 +spec: + jaxReplicaSpecs: + Worker: + replicas: 2 + restartPolicy: OnFailure + template: + spec: + containers: + - image: docker.io/sandipanify/jaxjob-spmd-mnist:latest + imagePullPolicy: Always + name: jax +status: + completionTime: "2024-12-18T17:22:11Z" + conditions: + - lastTransitionTime: "2024-12-18T16:47:28Z" + lastUpdateTime: "2024-12-18T16:47:28Z" + message: JAXJob jaxjob-mnist is created. + reason: JAXJobCreated + status: "True" + type: Created + - lastTransitionTime: "2024-12-18T16:50:57Z" + lastUpdateTime: "2024-12-18T16:50:57Z" + message: JAXJob kubeflow/jaxjob-mnist is running. + reason: JAXJobRunning + status: "False" + type: Running + - lastTransitionTime: "2024-12-18T17:22:11Z" + lastUpdateTime: "2024-12-18T17:22:11Z" + message: JAXJob kubeflow/jaxjob-mnist successfully completed. + reason: JAXJobSucceeded + status: "True" + type: Succeeded + replicaStatuses: + Worker: + selector: training.kubeflow.org/job-name=jaxjob-mnist,training.kubeflow.org/operator-name=jaxjob-controller,training.kubeflow.org/replica-type=worker + succeeded: 2 + startTime: "2024-12-18T16:47:28Z" + +``` diff --git a/examples/jax/jax-dist-spmd-mnist/jaxjob_dist_spmd_mnist_gloo.yaml b/examples/jax/jax-dist-spmd-mnist/jaxjob_dist_spmd_mnist_gloo.yaml index 912ebde719..50bd66f583 100644 --- a/examples/jax/jax-dist-spmd-mnist/jaxjob_dist_spmd_mnist_gloo.yaml +++ b/examples/jax/jax-dist-spmd-mnist/jaxjob_dist_spmd_mnist_gloo.yaml @@ -13,7 +13,4 @@ spec: containers: - name: jax image: docker.io/sandipanify/jaxjob-spmd-mnist:latest - command: - - "python3" - - "spmd_mnist_classifier_fromscratch.py" imagePullPolicy: Always diff --git a/examples/jax/jax-dist-spmd-mnist/spmd_mnist_classifier_fromscratch.py b/examples/jax/jax-dist-spmd-mnist/spmd_mnist_classifier_fromscratch.py index 5982963ba7..41f55d745c 100644 --- a/examples/jax/jax-dist-spmd-mnist/spmd_mnist_classifier_fromscratch.py +++ b/examples/jax/jax-dist-spmd-mnist/spmd_mnist_classifier_fromscratch.py @@ -1,3 +1,17 @@ +# Copyright 2024 kubeflow.org. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """An MNIST example with single-program multiple-data (SPMD) data parallelism. The aim here is to illustrate how to use JAX's `pmap` to express and execute @@ -34,9 +48,9 @@ process_id = int(os.getenv("PROCESS_ID")) num_processes = int(os.getenv("NUM_PROCESSES")) -coordinator_address = os.getenv("COORDINATOR_ADDRESS") -coordinator_port = int(os.getenv("COORDINATOR_PORT")) -coordinator_address = f"{coordinator_address}:{coordinator_port}" +coordinator_address = ( + f"{os.getenv('COORDINATOR_ADDRESS')}:{int(os.getenv('COORDINATOR_PORT'))}" +) jax.distributed.initialize( coordinator_address=coordinator_address, @@ -138,6 +152,7 @@ def replicate_array(x): print(f"JAX device count:{jax.device_count()}") print(f"JAX local device count:{jax.local_device_count()}") + print(f"JAX process count:{jax.process_count()}") for epoch in range(num_epochs): start_time = time.time() diff --git a/sdk/python/test/e2e/test_e2e_jaxjob.py b/sdk/python/test/e2e/test_e2e_jaxjob.py index 6223c8a988..98cc5fff49 100644 --- a/sdk/python/test/e2e/test_e2e_jaxjob.py +++ b/sdk/python/test/e2e/test_e2e_jaxjob.py @@ -155,7 +155,6 @@ def generate_jaxjob( def generate_container() -> V1Container: return V1Container( name=CONTAINER_NAME, - image="docker.io/kubeflow/jaxjob-simple:latest", - command=["python", "train.py"], + image="docker.io/kubeflow/jaxjob-dist-spmd-mnist:latest", resources=V1ResourceRequirements(limits={"memory": "2Gi", "cpu": "0.8"}), ) From 1ed9b7d2c9af367bff839a476ff2c18f42afea25 Mon Sep 17 00:00:00 2001 From: Sandipan Panda Date: Thu, 19 Dec 2024 23:11:37 +0530 Subject: [PATCH 04/16] Update JAXJob CI images Signed-off-by: Sandipan Panda --- .../jax/jax-dist-spmd-mnist/jaxjob_dist_spmd_mnist_gloo.yaml | 2 +- pkg/controller.v1/jax/envvar_test.go | 2 +- pkg/webhooks/jax/jaxjob_webhook_test.go | 2 +- sdk/python/kubeflow/training/constants/constants.py | 2 +- sdk/python/test/e2e/test_e2e_jaxjob.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/jax/jax-dist-spmd-mnist/jaxjob_dist_spmd_mnist_gloo.yaml b/examples/jax/jax-dist-spmd-mnist/jaxjob_dist_spmd_mnist_gloo.yaml index 50bd66f583..e124b2efef 100644 --- a/examples/jax/jax-dist-spmd-mnist/jaxjob_dist_spmd_mnist_gloo.yaml +++ b/examples/jax/jax-dist-spmd-mnist/jaxjob_dist_spmd_mnist_gloo.yaml @@ -12,5 +12,5 @@ spec: spec: containers: - name: jax - image: docker.io/sandipanify/jaxjob-spmd-mnist:latest + image: docker.io/kubeflow/jaxjob-dist-spmd-mnist:latest imagePullPolicy: Always diff --git a/pkg/controller.v1/jax/envvar_test.go b/pkg/controller.v1/jax/envvar_test.go index 9920e89bbb..3b0f0b5691 100644 --- a/pkg/controller.v1/jax/envvar_test.go +++ b/pkg/controller.v1/jax/envvar_test.go @@ -30,7 +30,7 @@ func TestSetPodEnv(t *testing.T) { Spec: corev1.PodSpec{ Containers: []corev1.Container{{ Name: "jax", - Image: "docker.io/kubeflow/jaxjob-simple:latest", + Image: "docker.io/kubeflow/jaxjob-dist-spmd-mnist:latest", Ports: []corev1.ContainerPort{{ Name: kubeflowv1.JAXJobDefaultPortName, ContainerPort: validPort, diff --git a/pkg/webhooks/jax/jaxjob_webhook_test.go b/pkg/webhooks/jax/jaxjob_webhook_test.go index bfbc0eb29c..a6463fb3aa 100644 --- a/pkg/webhooks/jax/jaxjob_webhook_test.go +++ b/pkg/webhooks/jax/jaxjob_webhook_test.go @@ -156,7 +156,7 @@ func TestValidateV1JAXJob(t *testing.T) { Containers: []corev1.Container{ { Name: "", - Image: "gcr.io/kubeflow-ci/jaxjob-simple_test:1.0", + Image: "gcr.io/kubeflow-ci/jaxjob-dist-spmd-mnist_test:1.0", }, }, }, diff --git a/sdk/python/kubeflow/training/constants/constants.py b/sdk/python/kubeflow/training/constants/constants.py index dba4d49681..2a5415ea26 100644 --- a/sdk/python/kubeflow/training/constants/constants.py +++ b/sdk/python/kubeflow/training/constants/constants.py @@ -153,7 +153,7 @@ JAXJOB_PLURAL = "jaxjobs" JAXJOB_CONTAINER = "jax" JAXJOB_REPLICA_TYPES = REPLICA_TYPE_WORKER.lower() -JAXJOB_BASE_IMAGE = "docker.io/kubeflow/jaxjob-simple:latest" +JAXJOB_BASE_IMAGE = "docker.io/kubeflow/jaxjob-dist-spmd-mnist:latest" # Dictionary to get plural, model, and container for each Job kind. JOB_PARAMETERS = { diff --git a/sdk/python/test/e2e/test_e2e_jaxjob.py b/sdk/python/test/e2e/test_e2e_jaxjob.py index 98cc5fff49..2012e7bdb2 100644 --- a/sdk/python/test/e2e/test_e2e_jaxjob.py +++ b/sdk/python/test/e2e/test_e2e_jaxjob.py @@ -88,7 +88,7 @@ def test_sdk_e2e_with_gang_scheduling(job_namespace): logging.info(TRAINING_CLIENT.list_jobs(job_namespace)) try: - utils.verify_job_e2e(TRAINING_CLIENT, JOB_NAME, job_namespace, wait_timeout=900) + utils.verify_job_e2e(TRAINING_CLIENT, JOB_NAME, job_namespace, wait_timeout=6000) except Exception as e: utils.print_job_results(TRAINING_CLIENT, JOB_NAME, job_namespace) TRAINING_CLIENT.delete_job(JOB_NAME, job_namespace) From 904dbd69bb3cf0c4e4dabb53cd37dc64cb4770c3 Mon Sep 17 00:00:00 2001 From: Sandipan Panda Date: Fri, 27 Dec 2024 23:33:08 +0530 Subject: [PATCH 05/16] Adjust jaxjob spmd example batch size Signed-off-by: Sandipan Panda --- examples/jax/jax-dist-spmd-mnist/README.md | 3 +-- .../spmd_mnist_classifier_fromscratch.py | 9 ++++----- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/examples/jax/jax-dist-spmd-mnist/README.md b/examples/jax/jax-dist-spmd-mnist/README.md index 6194ea9eda..d57a4d80fc 100644 --- a/examples/jax/jax-dist-spmd-mnist/README.md +++ b/examples/jax/jax-dist-spmd-mnist/README.md @@ -25,8 +25,7 @@ jaxjob-mnist-worker-1 0/1 Completed 0 108m --- ```bash -$ PODNAME=$(kubectl get pods -l training.kubeflow.org/job-name=jaxjob-mnist,training.kubeflow.org/replica-type=worker,training.kubeflow.org/replica-index=0 -o -name -n kubeflow) +$ PODNAME=$(kubectl get pods -l training.kubeflow.org/job-name=jaxjob-simple,training.kubeflow.org/replica-type=worker,training.kubeflow.org/replica-index=0 -o name -n kubeflow) $ kubectl logs -f ${PODNAME} -n kubeflow ``` diff --git a/examples/jax/jax-dist-spmd-mnist/spmd_mnist_classifier_fromscratch.py b/examples/jax/jax-dist-spmd-mnist/spmd_mnist_classifier_fromscratch.py index 41f55d745c..d3e5f2cad1 100644 --- a/examples/jax/jax-dist-spmd-mnist/spmd_mnist_classifier_fromscratch.py +++ b/examples/jax/jax-dist-spmd-mnist/spmd_mnist_classifier_fromscratch.py @@ -96,17 +96,16 @@ def accuracy(params, batch): param_scale = 0.1 step_size = 0.001 num_epochs = 10 - batch_size = 128 + # For this manual SPMD example, we get the number of devices (e.g. CPU, + # GPUs or TPU cores) that we're using, and use it to reshape data minibatches. + num_devices = jax.local_device_count() + batch_size = num_devices * 5 train_images, train_labels, test_images, test_labels = datasets.mnist() num_train = train_images.shape[0] num_complete_batches, leftover = divmod(num_train, batch_size) num_batches = num_complete_batches + bool(leftover) - # For this manual SPMD example, we get the number of devices (e.g. CPU, - # GPUs or TPU cores) that we're using, and use it to reshape data minibatches. - num_devices = jax.local_device_count() - def data_stream(): rng = npr.RandomState(0) while True: From ef997a17f0dc2363807a36dd4dc81767a1d5fa5f Mon Sep 17 00:00:00 2001 From: sailesh duddupudi Date: Thu, 9 Jan 2025 18:58:11 +0000 Subject: [PATCH 06/16] Add JAX Example Docker Image Build in CI Signed-off-by: sailesh duddupudi --- .github/workflows/integration-tests.yaml | 13 ++++++++++++ scripts/gha/build-jax-mnist-image.sh | 25 ++++++++++++++++++++++++ sdk/python/test/e2e/test_e2e_jaxjob.py | 2 +- 3 files changed, 39 insertions(+), 1 deletion(-) create mode 100644 scripts/gha/build-jax-mnist-image.sh diff --git a/.github/workflows/integration-tests.yaml b/.github/workflows/integration-tests.yaml index d6fdd6389a..17ada298aa 100644 --- a/.github/workflows/integration-tests.yaml +++ b/.github/workflows/integration-tests.yaml @@ -65,12 +65,25 @@ jobs: python-version: ${{ matrix.python-version }} gang-scheduler-name: ${{ matrix.gang-scheduler-name }} + - name: Build JAX Job Example Image + run: | + ./scripts/gha/build-e2e-test-images.sh + env: + JAX_JOB_CI_IMAGE: kubeflow/jaxjob-dist-spmd-mnist:test + + - name: Load JAX Job Example Image + run: | + kind load docker-image ${{ env.JAX_JOB_CI_IMAGE }} --name ${{ env.KIND_CLUSTER }} + env: + JAX_JOB_CI_IMAGE: kubeflow/jaxjob-dist-spmd-mnist:test + - name: Run tests run: | pip install pytest python3 -m pip install -e sdk/python; pytest -s sdk/python/test/e2e --log-cli-level=debug --namespace=default env: GANG_SCHEDULER_NAME: ${{ matrix.gang-scheduler-name }} + JAX_JOB_IMAGE: kubeflow/jaxjob-dist-spmd-mnist:test - name: Collect volcano logs if: ${{ failure() && matrix.gang-scheduler-name == 'volcano' }} diff --git a/scripts/gha/build-jax-mnist-image.sh b/scripts/gha/build-jax-mnist-image.sh new file mode 100644 index 0000000000..b9a30fa18f --- /dev/null +++ b/scripts/gha/build-jax-mnist-image.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +# Copyright 2025 The Kubeflow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# The script is used to build images needed to run JAX Job E2E test. + + +set -o errexit +set -o nounset +set -o pipefail + +# Build Image for MNIST example with SPMD for JAX +docker build examples/jax/jax-dist-spmd-mnist -t ${JAX_JOB_CI_IMAGE} -f examples/jax/jax-dist-spmd-mnist/Dockerfile diff --git a/sdk/python/test/e2e/test_e2e_jaxjob.py b/sdk/python/test/e2e/test_e2e_jaxjob.py index 2012e7bdb2..eba4fa9842 100644 --- a/sdk/python/test/e2e/test_e2e_jaxjob.py +++ b/sdk/python/test/e2e/test_e2e_jaxjob.py @@ -155,6 +155,6 @@ def generate_jaxjob( def generate_container() -> V1Container: return V1Container( name=CONTAINER_NAME, - image="docker.io/kubeflow/jaxjob-dist-spmd-mnist:latest", + image=os.getenv("JAX_JOB_IMAGE", "docker.io/kubeflow/jaxjob-dist-spmd-mnist:latest"), resources=V1ResourceRequirements(limits={"memory": "2Gi", "cpu": "0.8"}), ) From 9a88072c48fd1e8fb24be04ee5d1ec412ac08c12 Mon Sep 17 00:00:00 2001 From: sailesh duddupudi Date: Thu, 9 Jan 2025 19:36:19 +0000 Subject: [PATCH 07/16] Fix script name typo Signed-off-by: sailesh duddupudi --- .github/workflows/integration-tests.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/integration-tests.yaml b/.github/workflows/integration-tests.yaml index 17ada298aa..1b9f23c13e 100644 --- a/.github/workflows/integration-tests.yaml +++ b/.github/workflows/integration-tests.yaml @@ -67,7 +67,7 @@ jobs: - name: Build JAX Job Example Image run: | - ./scripts/gha/build-e2e-test-images.sh + ./scripts/gha/build-jax-mnist-image.sh env: JAX_JOB_CI_IMAGE: kubeflow/jaxjob-dist-spmd-mnist:test From a2ca3a3e12ce6ec8f2a48ef3cc9846f813a6a0f8 Mon Sep 17 00:00:00 2001 From: sailesh duddupudi Date: Thu, 9 Jan 2025 19:53:03 +0000 Subject: [PATCH 08/16] Update script permissions Signed-off-by: sailesh duddupudi --- scripts/gha/build-jax-mnist-image.sh | 0 1 file changed, 0 insertions(+), 0 deletions(-) mode change 100644 => 100755 scripts/gha/build-jax-mnist-image.sh diff --git a/scripts/gha/build-jax-mnist-image.sh b/scripts/gha/build-jax-mnist-image.sh old mode 100644 new mode 100755 From 15e6205a2e321a9362e11f99229615daa74443a9 Mon Sep 17 00:00:00 2001 From: sailesh duddupudi Date: Thu, 9 Jan 2025 19:59:34 +0000 Subject: [PATCH 09/16] Add KIND_CLUSTER env var Signed-off-by: sailesh duddupudi --- .github/workflows/integration-tests.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/integration-tests.yaml b/.github/workflows/integration-tests.yaml index 1b9f23c13e..a450a76b16 100644 --- a/.github/workflows/integration-tests.yaml +++ b/.github/workflows/integration-tests.yaml @@ -75,6 +75,7 @@ jobs: run: | kind load docker-image ${{ env.JAX_JOB_CI_IMAGE }} --name ${{ env.KIND_CLUSTER }} env: + KIND_CLUSTER: training-operator-cluster JAX_JOB_CI_IMAGE: kubeflow/jaxjob-dist-spmd-mnist:test - name: Run tests From 07688ec1f7e72f53956a8f7d9d0fc3c602afe98d Mon Sep 17 00:00:00 2001 From: sailesh duddupudi Date: Thu, 9 Jan 2025 20:54:36 +0000 Subject: [PATCH 10/16] Increase timeouts Signed-off-by: sailesh duddupudi --- sdk/python/test/e2e/test_e2e_jaxjob.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sdk/python/test/e2e/test_e2e_jaxjob.py b/sdk/python/test/e2e/test_e2e_jaxjob.py index eba4fa9842..c60a87554d 100644 --- a/sdk/python/test/e2e/test_e2e_jaxjob.py +++ b/sdk/python/test/e2e/test_e2e_jaxjob.py @@ -88,7 +88,7 @@ def test_sdk_e2e_with_gang_scheduling(job_namespace): logging.info(TRAINING_CLIENT.list_jobs(job_namespace)) try: - utils.verify_job_e2e(TRAINING_CLIENT, JOB_NAME, job_namespace, wait_timeout=6000) + utils.verify_job_e2e(TRAINING_CLIENT, JOB_NAME, job_namespace, wait_timeout=3000) except Exception as e: utils.print_job_results(TRAINING_CLIENT, JOB_NAME, job_namespace) TRAINING_CLIENT.delete_job(JOB_NAME, job_namespace) @@ -123,7 +123,7 @@ def test_sdk_e2e(job_namespace): logging.info(TRAINING_CLIENT.list_jobs(job_namespace)) try: - utils.verify_job_e2e(TRAINING_CLIENT, JOB_NAME, job_namespace, wait_timeout=900) + utils.verify_job_e2e(TRAINING_CLIENT, JOB_NAME, job_namespace, wait_timeout=3000) except Exception as e: utils.print_job_results(TRAINING_CLIENT, JOB_NAME, job_namespace) TRAINING_CLIENT.delete_job(JOB_NAME, job_namespace) From b876b7b121a0a4c5460b3039e8937684e07f828f Mon Sep 17 00:00:00 2001 From: sailesh duddupudi Date: Fri, 10 Jan 2025 16:59:49 +0000 Subject: [PATCH 11/16] Test higher resources Signed-off-by: sailesh duddupudi --- sdk/python/test/e2e/test_e2e_jaxjob.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/python/test/e2e/test_e2e_jaxjob.py b/sdk/python/test/e2e/test_e2e_jaxjob.py index c60a87554d..82585c234a 100644 --- a/sdk/python/test/e2e/test_e2e_jaxjob.py +++ b/sdk/python/test/e2e/test_e2e_jaxjob.py @@ -156,5 +156,5 @@ def generate_container() -> V1Container: return V1Container( name=CONTAINER_NAME, image=os.getenv("JAX_JOB_IMAGE", "docker.io/kubeflow/jaxjob-dist-spmd-mnist:latest"), - resources=V1ResourceRequirements(limits={"memory": "2Gi", "cpu": "0.8"}), + resources=V1ResourceRequirements(limits={"memory": "4Gi", "cpu": "1.6"}), ) From 063adfa3312ac54191ae517678178721cf41a671 Mon Sep 17 00:00:00 2001 From: sailesh duddupudi Date: Fri, 10 Jan 2025 18:50:00 +0000 Subject: [PATCH 12/16] Increase Timeout Signed-off-by: sailesh duddupudi --- examples/jax/jax-dist-spmd-mnist/Dockerfile | 4 ++-- sdk/python/test/e2e/test_e2e_jaxjob.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/jax/jax-dist-spmd-mnist/Dockerfile b/examples/jax/jax-dist-spmd-mnist/Dockerfile index 92b406f117..1538d26507 100644 --- a/examples/jax/jax-dist-spmd-mnist/Dockerfile +++ b/examples/jax/jax-dist-spmd-mnist/Dockerfile @@ -1,7 +1,7 @@ FROM python:3.13 RUN pip install --upgrade pip -RUN pip install --upgrade jax absl-py +RUN pip install --upgrade jax[k8s] absl-py RUN apt-get update && apt-get install -y \ build-essential \ @@ -24,6 +24,6 @@ RUN git clone https://github.com/facebookincubator/gloo.git \ WORKDIR /app -ADD datasets.py spmd_mnist_classifier_fromscratch.py /app +ADD datasets.py spmd_mnist_classifier_fromscratch.py /app/ ENTRYPOINT ["python3", "spmd_mnist_classifier_fromscratch.py"] diff --git a/sdk/python/test/e2e/test_e2e_jaxjob.py b/sdk/python/test/e2e/test_e2e_jaxjob.py index 82585c234a..97ad517887 100644 --- a/sdk/python/test/e2e/test_e2e_jaxjob.py +++ b/sdk/python/test/e2e/test_e2e_jaxjob.py @@ -88,7 +88,7 @@ def test_sdk_e2e_with_gang_scheduling(job_namespace): logging.info(TRAINING_CLIENT.list_jobs(job_namespace)) try: - utils.verify_job_e2e(TRAINING_CLIENT, JOB_NAME, job_namespace, wait_timeout=3000) + utils.verify_job_e2e(TRAINING_CLIENT, JOB_NAME, job_namespace, wait_timeout=9000) except Exception as e: utils.print_job_results(TRAINING_CLIENT, JOB_NAME, job_namespace) TRAINING_CLIENT.delete_job(JOB_NAME, job_namespace) @@ -123,7 +123,7 @@ def test_sdk_e2e(job_namespace): logging.info(TRAINING_CLIENT.list_jobs(job_namespace)) try: - utils.verify_job_e2e(TRAINING_CLIENT, JOB_NAME, job_namespace, wait_timeout=3000) + utils.verify_job_e2e(TRAINING_CLIENT, JOB_NAME, job_namespace, wait_timeout=9000) except Exception as e: utils.print_job_results(TRAINING_CLIENT, JOB_NAME, job_namespace) TRAINING_CLIENT.delete_job(JOB_NAME, job_namespace) From 76f2e80ce8483069b03b257cdc61ee49a3a0fe28 Mon Sep 17 00:00:00 2001 From: sailesh duddupudi Date: Fri, 10 Jan 2025 22:32:58 +0000 Subject: [PATCH 13/16] remove resource reqs Signed-off-by: sailesh duddupudi --- sdk/python/test/e2e/test_e2e_jaxjob.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/python/test/e2e/test_e2e_jaxjob.py b/sdk/python/test/e2e/test_e2e_jaxjob.py index 97ad517887..ed17442477 100644 --- a/sdk/python/test/e2e/test_e2e_jaxjob.py +++ b/sdk/python/test/e2e/test_e2e_jaxjob.py @@ -156,5 +156,5 @@ def generate_container() -> V1Container: return V1Container( name=CONTAINER_NAME, image=os.getenv("JAX_JOB_IMAGE", "docker.io/kubeflow/jaxjob-dist-spmd-mnist:latest"), - resources=V1ResourceRequirements(limits={"memory": "4Gi", "cpu": "1.6"}), + # resources=V1ResourceRequirements(limits={"memory": "4Gi", "cpu": "1.6"}), ) From 3353f7a589c7bc9a5387e81e96428f0eb6537f8b Mon Sep 17 00:00:00 2001 From: sailesh duddupudi Date: Thu, 16 Jan 2025 07:20:48 +0000 Subject: [PATCH 14/16] test low batch size Signed-off-by: sailesh duddupudi --- .../jax-dist-spmd-mnist/spmd_mnist_classifier_fromscratch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/jax/jax-dist-spmd-mnist/spmd_mnist_classifier_fromscratch.py b/examples/jax/jax-dist-spmd-mnist/spmd_mnist_classifier_fromscratch.py index d3e5f2cad1..6a8a89db3d 100644 --- a/examples/jax/jax-dist-spmd-mnist/spmd_mnist_classifier_fromscratch.py +++ b/examples/jax/jax-dist-spmd-mnist/spmd_mnist_classifier_fromscratch.py @@ -99,7 +99,8 @@ def accuracy(params, batch): # For this manual SPMD example, we get the number of devices (e.g. CPU, # GPUs or TPU cores) that we're using, and use it to reshape data minibatches. num_devices = jax.local_device_count() - batch_size = num_devices * 5 + # batch_size = num_devices * 5 + batch_size = 5 # testing train_images, train_labels, test_images, test_labels = datasets.mnist() num_train = train_images.shape[0] From c80db23dd4df8cf7cc0a3eb15e8ffc3e58060ae3 Mon Sep 17 00:00:00 2001 From: sailesh duddupudi Date: Thu, 16 Jan 2025 10:06:02 +0000 Subject: [PATCH 15/16] test small batch size Signed-off-by: sailesh duddupudi --- .../jax-dist-spmd-mnist/spmd_mnist_classifier_fromscratch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/jax/jax-dist-spmd-mnist/spmd_mnist_classifier_fromscratch.py b/examples/jax/jax-dist-spmd-mnist/spmd_mnist_classifier_fromscratch.py index 6a8a89db3d..262b7ad07b 100644 --- a/examples/jax/jax-dist-spmd-mnist/spmd_mnist_classifier_fromscratch.py +++ b/examples/jax/jax-dist-spmd-mnist/spmd_mnist_classifier_fromscratch.py @@ -99,8 +99,7 @@ def accuracy(params, batch): # For this manual SPMD example, we get the number of devices (e.g. CPU, # GPUs or TPU cores) that we're using, and use it to reshape data minibatches. num_devices = jax.local_device_count() - # batch_size = num_devices * 5 - batch_size = 5 # testing + batch_size = num_devices * 5 train_images, train_labels, test_images, test_labels = datasets.mnist() num_train = train_images.shape[0] @@ -156,6 +155,7 @@ def replicate_array(x): for epoch in range(num_epochs): start_time = time.time() + num_batches = 5 for _ in range(num_batches): replicated_params = spmd_update(replicated_params, next(batches)) epoch_time = time.time() - start_time From 343090cac1d54c6fe0bc5eb2c62149ac80e7adcb Mon Sep 17 00:00:00 2001 From: sailesh duddupudi Date: Thu, 16 Jan 2025 18:39:01 +0000 Subject: [PATCH 16/16] Hardcode number of batches Signed-off-by: sailesh duddupudi --- .../spmd_mnist_classifier_fromscratch.py | 5 +++-- sdk/python/test/e2e/test_e2e_jaxjob.py | 6 +++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/examples/jax/jax-dist-spmd-mnist/spmd_mnist_classifier_fromscratch.py b/examples/jax/jax-dist-spmd-mnist/spmd_mnist_classifier_fromscratch.py index 262b7ad07b..ca0e9f5165 100644 --- a/examples/jax/jax-dist-spmd-mnist/spmd_mnist_classifier_fromscratch.py +++ b/examples/jax/jax-dist-spmd-mnist/spmd_mnist_classifier_fromscratch.py @@ -104,7 +104,9 @@ def accuracy(params, batch): train_images, train_labels, test_images, test_labels = datasets.mnist() num_train = train_images.shape[0] num_complete_batches, leftover = divmod(num_train, batch_size) - num_batches = num_complete_batches + bool(leftover) + + # Increasing number of batches requires more resources. + num_batches = 10 def data_stream(): rng = npr.RandomState(0) @@ -155,7 +157,6 @@ def replicate_array(x): for epoch in range(num_epochs): start_time = time.time() - num_batches = 5 for _ in range(num_batches): replicated_params = spmd_update(replicated_params, next(batches)) epoch_time = time.time() - start_time diff --git a/sdk/python/test/e2e/test_e2e_jaxjob.py b/sdk/python/test/e2e/test_e2e_jaxjob.py index ed17442477..7471f67338 100644 --- a/sdk/python/test/e2e/test_e2e_jaxjob.py +++ b/sdk/python/test/e2e/test_e2e_jaxjob.py @@ -88,7 +88,7 @@ def test_sdk_e2e_with_gang_scheduling(job_namespace): logging.info(TRAINING_CLIENT.list_jobs(job_namespace)) try: - utils.verify_job_e2e(TRAINING_CLIENT, JOB_NAME, job_namespace, wait_timeout=9000) + utils.verify_job_e2e(TRAINING_CLIENT, JOB_NAME, job_namespace, wait_timeout=900) except Exception as e: utils.print_job_results(TRAINING_CLIENT, JOB_NAME, job_namespace) TRAINING_CLIENT.delete_job(JOB_NAME, job_namespace) @@ -123,7 +123,7 @@ def test_sdk_e2e(job_namespace): logging.info(TRAINING_CLIENT.list_jobs(job_namespace)) try: - utils.verify_job_e2e(TRAINING_CLIENT, JOB_NAME, job_namespace, wait_timeout=9000) + utils.verify_job_e2e(TRAINING_CLIENT, JOB_NAME, job_namespace, wait_timeout=900) except Exception as e: utils.print_job_results(TRAINING_CLIENT, JOB_NAME, job_namespace) TRAINING_CLIENT.delete_job(JOB_NAME, job_namespace) @@ -156,5 +156,5 @@ def generate_container() -> V1Container: return V1Container( name=CONTAINER_NAME, image=os.getenv("JAX_JOB_IMAGE", "docker.io/kubeflow/jaxjob-dist-spmd-mnist:latest"), - # resources=V1ResourceRequirements(limits={"memory": "4Gi", "cpu": "1.6"}), + resources=V1ResourceRequirements(limits={"memory": "3Gi", "cpu": "1.2"}), )