From 3edd6dbe4469344c90404a2d59ac0adbf549c8f8 Mon Sep 17 00:00:00 2001 From: mahdikhashan Date: Sat, 10 May 2025 15:21:17 +0200 Subject: [PATCH 01/38] proposal Signed-off-by: mahdikhashan fix Update docs/proposals/2442-jax-runtime-trainer-v2/README.md Co-authored-by: Shao Wang <2690692950@qq.com> Signed-off-by: Mahdi Khashan <58775404+mahdikhashan@users.noreply.github.com> Update docs/proposals/2442-jax-runtime-trainer-v2/README.md Co-authored-by: Shao Wang <2690692950@qq.com> Signed-off-by: Mahdi Khashan <58775404+mahdikhashan@users.noreply.github.com> Update docs/proposals/2442-jax-runtime-trainer-v2/README.md Co-authored-by: Shao Wang <2690692950@qq.com> Signed-off-by: Mahdi Khashan <58775404+mahdikhashan@users.noreply.github.com> clean table of contents simplify JAX traning workflow improve design details update table of contents fix fix add Signed-off-by: Mahdi Khashan --- .../2442-jax-runtime-trainer-v2/README.md | 339 ++++++++++++++++++ .../drawing.drawio.svg | 4 + 2 files changed, 343 insertions(+) create mode 100644 docs/proposals/2442-jax-runtime-trainer-v2/README.md create mode 100644 docs/proposals/2442-jax-runtime-trainer-v2/drawing.drawio.svg diff --git a/docs/proposals/2442-jax-runtime-trainer-v2/README.md b/docs/proposals/2442-jax-runtime-trainer-v2/README.md new file mode 100644 index 0000000000..903bf8b404 --- /dev/null +++ b/docs/proposals/2442-jax-runtime-trainer-v2/README.md @@ -0,0 +1,339 @@ +# KEP-2442: JAX Runtime for Trainer V2 + +- [Summary](#summary) +- [Motivation](#motivation) + - [Goals](#goals) + - [Non-Goals](#non-goals) +- [Proposal](#proposal) + - [User Stories](#user-stories-optional) + - [Story 1](#story-1) + - [Story 2](#story-2) +- [Design Details](#design-details) + - [Key Concepts in JAX Distributed Training](#key-concepts-in-jax-distributed-training) + - [JAX Training Workflow](#jax-training-workflow-flow) + - [Defining JAX Processes with MLPolicy](#defining-jax-processes-with-mlpolicy) + - [Communication Backend](#communication-backend) + - [OpenMPI](#openmpi) + - [Gloo](#gloo) +- [Test Plan](#test-plan) + - [End-to-End (E2E) Tests](#end-to-end-e2e-tests) + - [Working Examples](#working-examples) + - [Unit and Integration Tests](#unit-and-integration-tests) +- [Implementation History](#implementation-history) + +## Summary + +This document outlines a proposal to support the JAX Runtime in Kubeflow Trainer V2. Built upon the Kubernetes JobSet API, the JAX Runtime enables training and fine-tuning workloads using the JAX framework on Kubernetes. Instead of relying on framework-specific CRDs, Trainer V2 introduces a unified abstraction through TrainingRuntime and TrainJob. The JAX Runtime implements this abstraction to serve as a reusable blueprint for model training tasks, including large language models (LLMs). Thanks to Kubeflow Trainer Pipeline Framework, we can seamlessly integrate the JAX runtime into Kubeflow Trainer V2 as a runtime plugin. + +## Motivation + +JAX is a powerful ML framework created by Google. It is widely used in the machine learning research and ranks as the third most widely used deep learning frameworks. JAX is not only a deep learning framework but suggests its potential in differential programming, large-scale physics simulations and many more. + +These usecases added on top of the new Runtime API for distributed training or calculation of objectives enables new users on top of Kubeflow Trainer, like distributed simulation or training of LLM prototypes developed with JAX, like vast models from Google DeepMind. + +In general the motivation is to enable users to use Single-Program Multi-Data (SPMD) pattern with JAX Framework. + +With this design, Platform Engineers can define standardized training runtimes, while Data Scientists can easily customize them, through a simple SDK interface, without needing to understand Kubernetes internals. + +**Benefits** + +1. Leverage JAX for differential programming and large-scale simulations +2. Enable distributed training or objective computation using the new Runtime API +3. Support prototyping and training of large JAX-based LLMs within Kubeflow Trainer + +### Goals + +- Implement ClusterTrainingRuntime for JAX, supporting multi-controller JAX +- Build the necessary Docker images for JAX worker nodes used by the runtime +- Implement the solution to work on CPU and GPU +- Document user guides for utilizing JAX ClusterTrainingRuntimes +- Test the implementation thoroughly using unit tests and end-to-end (E2E) tests + +### Non-Goals + +- No TPU support (duo to lack of available TPU testing infrastructure) +- No GPU testing, tests will use CPUs + +## Proposal + +### User Stories + +#### Story 1 + +As a Platform Engineer, I want to manage JAX distributed training jobs using the Kubeflow Trainer V2, so then I can provide blueprints for training of machine learning models on a kubernetes cluster to engineering teams. + +#### Story 2 + +As a Data Scientist, I want to use the Trainer V2 SDK to run a distributed training job from notebook, in this way I can incorporate multiple devices for my training task. + +The Python SDK with JAXRuntime may look as follows: + +```python +from kubeflow.trainer import TrainerClient, CustomTrainer + +def jax_train_mnist(): + # TODO: Add training logic using JAX + pass + +# Select the JAX runtime +client = TrainerClient() +jax_runtime = next(r for r in client.list_runtimes() if r.name == "jax-distributed") + +# Launch training job +job_id = client.train( + trainer=CustomTrainer(func=jax_train_mnist, func_args=args, num_nodes=3), + runtime=jax_runtime, +) +``` + +## Design Details + +In order to address this functionality, we propose the following design: + +### Key Concepts in JAX Distributed Training + +To understand the **JAX runtime** in Kubeflow Trainer V2, it's important to clarify the terminology used in JAX for distributed training: + +#### Host + +In JAX, a **host** refers to a physical or virtual machine that participates in distributed training. Each host runs a **single JAX process**, which manages all the local devices (e.g., GPUs or CPUs). JAX automatically detects and utilizes all available devices on a host. + +>In Kubernetes, a host maps to a Node, and in the runtime implementation, one Pod is typically scheduled per host. + +**JAX Process** + +A **JAX process** (or sometimes called a **controller**) is a Python process running the JAX program. There is exactly one JAX process per host. This process is responsible for: + +- Executing the training loop + +- Managing all local devices + +- Communicating with other JAX processes over the network for synchronization + +JAX uses **Single Program Multiple Data (SPMD)** across these processes. + +**Devices** + +Devices refer to individual compute units on a host (like CPU cores, GPUs, or TPUs). JAX handles device detection automatically and runs parallel computations using `jax.pmap`, `jax.shard_map`, or `pjit`. + +Each JAX process has access to **all devices on its host**. There is **no need to spawn multiple processes per GPU**, unlike PyTorch. + +**Controller** + +The **controller** is conceptually the same as the JAX process. In distributed setups, each host runs one controller that communicates with peers (other controllers) to coordinate the training. + +**Pod** + +A **Pod** in Kubernetes runs a JAX process. It is scheduled on a node and may use one or more GPUs depending on the resource specification. + +**Node** + +In Kubernetes, a **Node** is a worker machine in the cluster. When we run multi-node JAX jobs, each node typically runs one pod, which maps to one JAX host. + +### JAX Training Workflow + +This section explains the architecture and flow of executing a distributed JAX training job using Kubeflow, as depicted in the diagram. + +![user-roles](./drawing.drawio.svg) + +#### 1. Platform Engineer Prepares the Training Environment +- A **Platform Engineer** sets up the **Cluster Training Runtime** with details like: + - Container image + - Entrypoint + - Framework (e.g., JAX) + - Resource needs +- This setup can be reused by others to run training jobs. + +#### 2. Training Runtime is Retrieved +- When a user requests a training job, the system **fetches the runtime spec** to know how to run the job. + +#### 3. Data Scientist or ML Researcher Creates the Training Job +- A **Data Scientist** or **ML Researcher** creates a training job using: + - The **Kubeflow Python SDK**, or + - A `kubectl` command. +- They provide the training function (e.g., `jax_train_mnist`), any needed arguments, and settings like how many nodes to use. + +#### 4. JobSet is Created and Submitted +- The training job uses the runtime spec to create a **JobSet**, a group of jobs working together to train the model. + +#### 5. Distributed Jobs Start Running +- The **JobSet** launches multiple **Kubernetes Jobs**. +- Each job runs one instance of the **JAX training process** in its own pod. + +#### 6. Headless Service Connects the Jobs +- A **Headless Service** allows the pods to **communicate directly** for tasks like sharing gradients and coordinating training. + +#### 7. Training Runs Across the Cluster +- Each pod runs the training code using **JAX and Python**. +- The pods work together to complete the distributed training on the available hardware. + + +### Defining Distributed JAX with MLPolicy + +The number of **JAX hosts** is configured using the `numNodes` field in the **MLPolicy** section of the **ClusterTrainingRuntime**. Each host runs a single JAX process inside a Pod. + +### Communication Backend + +#### OpenMPI + +**Pros:** + +* Compatible with existing MPI runtime in Kubeflow Trainer v2, making deployment easier. + +**Cons:** + +* Typically requires more complex environment setup compared to simpler backends like Gloo. + +**ClusterTrainingRuntime Design** + +```yaml +apiVersion: trainer.kubeflow.org/v1alpha1 +kind: ClusterTrainingRuntime +metadata: + name: jax-distributed +spec: + mlPolicy: + numNodes: 1 + mpi: + numProcPerNode: 1 + mpiImplementation: OpenMPI + sshAuthMountPath: /home/mpiuser/.ssh + runLauncherAsNode: true + template: + spec: + network: + publishNotReadyAddresses: true + successPolicy: + operator: All + targetReplicatedJobs: + - launcher + replicatedJobs: + - name: launcher + template: + metadata: + labels: + trainer.kubeflow.org/trainjob-ancestor-step: trainer + spec: + template: + spec: + containers: + - name: node + image: ghcr.io/kubeflow/trainer/jax-runtime + securityContext: + runAsUser: 1000 + command: + - mpirun + - -n + - "1" + - bash + - -c + - | + echo "JAX Distributed Runtime" + + echo "--------------------------------------" + set -e + mpirun --version + python --version + pip list + - name: node + template: + spec: + template: + spec: + containers: + - name: node + image: ghcr.io/kubeflow/trainer/jax-runtime + securityContext: + runAsUser: 1000 + command: + - /usr/sbin/sshd + args: + - -De + - -f + - /home/mpiuser/.sshd_config + readinessProbe: + tcpSocket: + port: 2222 + initialDelaySeconds: 5 +``` + + +### Gloo + +**Pros:** + +* Lightweight and simple to use. + +**Cons:** + +* Significantly slower than OpenMPI (10–20×) for distributed JAX training on CPUs and GPUs. +* Less optimized for multi-node scaling and lacks native support for high-speed interconnects like InfiniBand. + +**ClusterTrainingRuntime Design** + +```yaml +apiVersion: trainer.kubeflow.org/v1alpha1 +kind: ClusterTrainingRuntime +metadata: + name: jax-distributed +spec: + mlPolicy: + numNodes: 4 + template: + spec: + successPolicy: + operator: All + replicatedJobs: + - name: process + template: + spec: + template: + spec: + containers: + - name: node + image: ghcr.io/kubeflow/trainer/jax-runtime + securityContext: + runAsUser: 1000 + command: + - bash + - -c + - | + echo "JAX Distributed Runtime" + echo "--------------------------------------" + set -e + python --version + pip list | grep jax +``` + +## Test Plan + +The testing strategy will focus on validating functionality, usability, and integration of the proposed `TrainingRuntime` mechanism for distributed training workloads. It includes the following components: + +### End-to-End (E2E) Tests + +* **Environment**: Deploy workloads in lightweight local Kubernetes clusters using tools like `kind` or `minikube`. +* **Workloads**: Run simple distributed training examples such as MNIST **JAX**. +* **Validation Goals**: + + * Ensure correct creation of `JobSet` resources. + * Validate successful job execution and error handling paths. + * Confirm compatibility with `TrainingRuntime` configurations. + +### Working Examples + +* Provide clear, runnable examples: + + * **Kubeflow SDK and notebook examples** that demonstrate creating and running training jobs using the new interface. +* These examples will serve as both test cases and documentation to support user onboarding. + +### Unit and Integration Tests + +* For any controller or plugin logic introduced: + + * Write targeted **unit tests** in Go to validate business logic and failure scenarios. + * Use mocks/fakes where needed to simulate cluster conditions and resource state. +* Ensure **controller reconciliation logic** is tested thoroughly. + +## Implementation History + +- 2025-05-28: Initial KEP draft created. diff --git a/docs/proposals/2442-jax-runtime-trainer-v2/drawing.drawio.svg b/docs/proposals/2442-jax-runtime-trainer-v2/drawing.drawio.svg new file mode 100644 index 0000000000..a21453b7ab --- /dev/null +++ b/docs/proposals/2442-jax-runtime-trainer-v2/drawing.drawio.svg @@ -0,0 +1,4 @@ + + + +
Platform Engineer
Kubeflow
Python SDK
TrainJob
image/svg+xml
kubectl
Data Scientist / ML Research Scientist
Create TrainJob
JobSet
image/svg+xml
Cluster Training Runtime
image/svg+xml
Manage
Runtime
Fetch
Spec
Headless Service
JAX Processes
JAX Light StrokeJAX Light StrokeJAX Light Stroke
Cloud GPUs
Cloud GPUs
Cloud GPUs
Cloud GPUs
Cloud GPUs
Cloud GPUs
JAX Light StrokeJAX Light StrokeJAX Light Stroke
Cloud GPUs
Cloud GPUs
Cloud GPUs
Cloud GPUs
Cloud GPUs
Cloud GPUs
From e29d739311be34e0bd1ebb1411d761ddcdc48344 Mon Sep 17 00:00:00 2001 From: Mahdi Khashan <58775404+mahdikhashan@users.noreply.github.com> Date: Thu, 10 Jul 2025 12:26:51 +0200 Subject: [PATCH 02/38] Update docs/proposals/2442-jax-runtime-trainer-v2/README.md Co-authored-by: Shao Wang <2690692950@qq.com> Signed-off-by: Mahdi Khashan <58775404+mahdikhashan@users.noreply.github.com> Signed-off-by: Mahdi Khashan --- docs/proposals/2442-jax-runtime-trainer-v2/README.md | 3 --- 1 file changed, 3 deletions(-) diff --git a/docs/proposals/2442-jax-runtime-trainer-v2/README.md b/docs/proposals/2442-jax-runtime-trainer-v2/README.md index 903bf8b404..8b942e68c0 100644 --- a/docs/proposals/2442-jax-runtime-trainer-v2/README.md +++ b/docs/proposals/2442-jax-runtime-trainer-v2/README.md @@ -118,9 +118,6 @@ Devices refer to individual compute units on a host (like CPU cores, GPUs, or TP Each JAX process has access to **all devices on its host**. There is **no need to spawn multiple processes per GPU**, unlike PyTorch. -**Controller** - -The **controller** is conceptually the same as the JAX process. In distributed setups, each host runs one controller that communicates with peers (other controllers) to coordinate the training. **Pod** From c352faf917f9f1f9f7d6e25519e8ae79d8fa446a Mon Sep 17 00:00:00 2001 From: Mahdi Khashan <58775404+mahdikhashan@users.noreply.github.com> Date: Thu, 10 Jul 2025 12:27:00 +0200 Subject: [PATCH 03/38] Update docs/proposals/2442-jax-runtime-trainer-v2/README.md Co-authored-by: Shao Wang <2690692950@qq.com> Signed-off-by: Mahdi Khashan <58775404+mahdikhashan@users.noreply.github.com> Signed-off-by: Mahdi Khashan --- docs/proposals/2442-jax-runtime-trainer-v2/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/proposals/2442-jax-runtime-trainer-v2/README.md b/docs/proposals/2442-jax-runtime-trainer-v2/README.md index 8b942e68c0..0291285cc7 100644 --- a/docs/proposals/2442-jax-runtime-trainer-v2/README.md +++ b/docs/proposals/2442-jax-runtime-trainer-v2/README.md @@ -100,7 +100,7 @@ In JAX, a **host** refers to a physical or virtual machine that participates in >In Kubernetes, a host maps to a Node, and in the runtime implementation, one Pod is typically scheduled per host. -**JAX Process** +**JAX Process/Controller** A **JAX process** (or sometimes called a **controller**) is a Python process running the JAX program. There is exactly one JAX process per host. This process is responsible for: From 8b9a2227150c20a2dabed9f9ad2ae158045c6712 Mon Sep 17 00:00:00 2001 From: Mahdi Khashan <58775404+mahdikhashan@users.noreply.github.com> Date: Thu, 10 Jul 2025 12:27:13 +0200 Subject: [PATCH 04/38] Update docs/proposals/2442-jax-runtime-trainer-v2/README.md Co-authored-by: Shao Wang <2690692950@qq.com> Signed-off-by: Mahdi Khashan <58775404+mahdikhashan@users.noreply.github.com> Signed-off-by: Mahdi Khashan --- docs/proposals/2442-jax-runtime-trainer-v2/README.md | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/docs/proposals/2442-jax-runtime-trainer-v2/README.md b/docs/proposals/2442-jax-runtime-trainer-v2/README.md index 0291285cc7..f0a3efe2a9 100644 --- a/docs/proposals/2442-jax-runtime-trainer-v2/README.md +++ b/docs/proposals/2442-jax-runtime-trainer-v2/README.md @@ -102,13 +102,7 @@ In JAX, a **host** refers to a physical or virtual machine that participates in **JAX Process/Controller** -A **JAX process** (or sometimes called a **controller**) is a Python process running the JAX program. There is exactly one JAX process per host. This process is responsible for: - -- Executing the training loop - -- Managing all local devices - -- Communicating with other JAX processes over the network for synchronization +A **JAX process** (or sometimes called a **controller**) is a Python process running the JAX program. There is exactly one JAX process per host. This process is responsible for executing the training loop, managing all local devices, and communicating with other JAX processes over the network for synchronization JAX uses **Single Program Multiple Data (SPMD)** across these processes. From 054762d53ffb5b7e8c10ff4a1bf7a5d234c0411b Mon Sep 17 00:00:00 2001 From: Mahdi Khashan <58775404+mahdikhashan@users.noreply.github.com> Date: Mon, 14 Jul 2025 14:21:54 +0200 Subject: [PATCH 05/38] Update docs/proposals/2442-jax-runtime-trainer-v2/README.md Co-authored-by: Shao Wang <2690692950@qq.com> Signed-off-by: Mahdi Khashan <58775404+mahdikhashan@users.noreply.github.com> Signed-off-by: Mahdi Khashan --- docs/proposals/2442-jax-runtime-trainer-v2/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/proposals/2442-jax-runtime-trainer-v2/README.md b/docs/proposals/2442-jax-runtime-trainer-v2/README.md index f0a3efe2a9..614c80ac2f 100644 --- a/docs/proposals/2442-jax-runtime-trainer-v2/README.md +++ b/docs/proposals/2442-jax-runtime-trainer-v2/README.md @@ -249,7 +249,7 @@ spec: ``` -### Gloo +#### Gloo **Pros:** From 5c0348e2ed2f4fc5ad4ba1177354f3555b99dffd Mon Sep 17 00:00:00 2001 From: Mahdi Khashan <58775404+mahdikhashan@users.noreply.github.com> Date: Sun, 20 Jul 2025 10:33:22 +0200 Subject: [PATCH 06/38] Update docs/proposals/2442-jax-runtime-trainer-v2/README.md Co-authored-by: Andrey Velichkevich Signed-off-by: Mahdi Khashan <58775404+mahdikhashan@users.noreply.github.com> Signed-off-by: Mahdi Khashan --- docs/proposals/2442-jax-runtime-trainer-v2/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/proposals/2442-jax-runtime-trainer-v2/README.md b/docs/proposals/2442-jax-runtime-trainer-v2/README.md index 614c80ac2f..f75395251e 100644 --- a/docs/proposals/2442-jax-runtime-trainer-v2/README.md +++ b/docs/proposals/2442-jax-runtime-trainer-v2/README.md @@ -43,7 +43,7 @@ With this design, Platform Engineers can define standardized training runtimes, ### Goals -- Implement ClusterTrainingRuntime for JAX, supporting multi-controller JAX +- Implement ClusterTrainingRuntime for JAX, supporting distributed training with JAX (e.g. multi-controller JAX) - Build the necessary Docker images for JAX worker nodes used by the runtime - Implement the solution to work on CPU and GPU - Document user guides for utilizing JAX ClusterTrainingRuntimes From dade8799da963d4e7af08649a902d4ccdca26224 Mon Sep 17 00:00:00 2001 From: Mahdi Khashan <58775404+mahdikhashan@users.noreply.github.com> Date: Sun, 20 Jul 2025 10:41:14 +0200 Subject: [PATCH 07/38] Update docs/proposals/2442-jax-runtime-trainer-v2/README.md Co-authored-by: Andrey Velichkevich Signed-off-by: Mahdi Khashan <58775404+mahdikhashan@users.noreply.github.com> Signed-off-by: Mahdi Khashan --- docs/proposals/2442-jax-runtime-trainer-v2/README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/proposals/2442-jax-runtime-trainer-v2/README.md b/docs/proposals/2442-jax-runtime-trainer-v2/README.md index f75395251e..09a95896ef 100644 --- a/docs/proposals/2442-jax-runtime-trainer-v2/README.md +++ b/docs/proposals/2442-jax-runtime-trainer-v2/README.md @@ -78,7 +78,9 @@ def jax_train_mnist(): # Select the JAX runtime client = TrainerClient() jax_runtime = next(r for r in client.list_runtimes() if r.name == "jax-distributed") - +args = { + "lr": "0.03" +} # Launch training job job_id = client.train( trainer=CustomTrainer(func=jax_train_mnist, func_args=args, num_nodes=3), From d4fbc1637a4f528a0c55244e023f2e8c8880017c Mon Sep 17 00:00:00 2001 From: Mahdi Khashan <58775404+mahdikhashan@users.noreply.github.com> Date: Sun, 20 Jul 2025 10:41:34 +0200 Subject: [PATCH 08/38] Update docs/proposals/2442-jax-runtime-trainer-v2/README.md Co-authored-by: Andrey Velichkevich Signed-off-by: Mahdi Khashan <58775404+mahdikhashan@users.noreply.github.com> Signed-off-by: Mahdi Khashan --- docs/proposals/2442-jax-runtime-trainer-v2/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/proposals/2442-jax-runtime-trainer-v2/README.md b/docs/proposals/2442-jax-runtime-trainer-v2/README.md index 09a95896ef..345b9bbc18 100644 --- a/docs/proposals/2442-jax-runtime-trainer-v2/README.md +++ b/docs/proposals/2442-jax-runtime-trainer-v2/README.md @@ -71,7 +71,7 @@ The Python SDK with JAXRuntime may look as follows: ```python from kubeflow.trainer import TrainerClient, CustomTrainer -def jax_train_mnist(): +def jax_train_mnist(args): # TODO: Add training logic using JAX pass From 4b373885160bb162e00f1fee615a951188c79dab Mon Sep 17 00:00:00 2001 From: Mahdi Khashan <58775404+mahdikhashan@users.noreply.github.com> Date: Sun, 20 Jul 2025 11:30:00 +0200 Subject: [PATCH 09/38] update drawing Signed-off-by: Mahdi Khashan <58775404+mahdikhashan@users.noreply.github.com> Signed-off-by: Mahdi Khashan --- docs/proposals/2442-jax-runtime-trainer-v2/drawing.drawio.svg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/proposals/2442-jax-runtime-trainer-v2/drawing.drawio.svg b/docs/proposals/2442-jax-runtime-trainer-v2/drawing.drawio.svg index a21453b7ab..7207c6fa9d 100644 --- a/docs/proposals/2442-jax-runtime-trainer-v2/drawing.drawio.svg +++ b/docs/proposals/2442-jax-runtime-trainer-v2/drawing.drawio.svg @@ -1,4 +1,4 @@ -
Platform Engineer
Kubeflow
Python SDK
TrainJob
image/svg+xml
kubectl
Data Scientist / ML Research Scientist
Create TrainJob
JobSet
image/svg+xml
Cluster Training Runtime
image/svg+xml
Manage
Runtime
Fetch
Spec
Headless Service
JAX Processes
JAX Light StrokeJAX Light StrokeJAX Light Stroke
Cloud GPUs
Cloud GPUs
Cloud GPUs
Cloud GPUs
Cloud GPUs
Cloud GPUs
JAX Light StrokeJAX Light StrokeJAX Light Stroke
Cloud GPUs
Cloud GPUs
Cloud GPUs
Cloud GPUs
Cloud GPUs
Cloud GPUs
+
Kubeflow
Python SDK
TrainJob
image/svg+xml
kubectl
Create TrainJob
JobSet
image/svg+xml
Cluster Training Runtime
image/svg+xml
Manage
Runtime
Fetch
Spec
Headless Service
JAX Processes
JAX Light StrokeJAX Light StrokeJAX Light Stroke
Cloud GPUs
Cloud GPUs
Cloud GPUs
Cloud GPUs
Cloud GPUs
Cloud GPUs
JAX Light StrokeJAX Light StrokeJAX Light Stroke
Cloud GPUs
Cloud GPUs
Cloud GPUs
Cloud GPUs
Cloud GPUs
Cloud GPUs
Platform Admins
AI Practitioners
\ No newline at end of file From a51c86aa77be793cc7cb534782779e53bc795899 Mon Sep 17 00:00:00 2001 From: Mahdi Khashan <58775404+mahdikhashan@users.noreply.github.com> Date: Sun, 20 Jul 2025 10:30:14 +0000 Subject: [PATCH 10/38] update non-goals Signed-off-by: Mahdi Khashan --- docs/proposals/2442-jax-runtime-trainer-v2/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/proposals/2442-jax-runtime-trainer-v2/README.md b/docs/proposals/2442-jax-runtime-trainer-v2/README.md index 345b9bbc18..aa8f3ef0c9 100644 --- a/docs/proposals/2442-jax-runtime-trainer-v2/README.md +++ b/docs/proposals/2442-jax-runtime-trainer-v2/README.md @@ -51,8 +51,8 @@ With this design, Platform Engineers can define standardized training runtimes, ### Non-Goals -- No TPU support (duo to lack of available TPU testing infrastructure) -- No GPU testing, tests will use CPUs +- No TPU testing, tests will use CPU +- No GPU testing, tests will use CPU ## Proposal From 0d62ef66df605c1d4a2bd8ce7a7967dbe2f00d50 Mon Sep 17 00:00:00 2001 From: Mahdi Khashan <58775404+mahdikhashan@users.noreply.github.com> Date: Sun, 20 Jul 2025 10:35:44 +0000 Subject: [PATCH 11/38] update goal Signed-off-by: Mahdi Khashan --- docs/proposals/2442-jax-runtime-trainer-v2/README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/proposals/2442-jax-runtime-trainer-v2/README.md b/docs/proposals/2442-jax-runtime-trainer-v2/README.md index aa8f3ef0c9..262b764c81 100644 --- a/docs/proposals/2442-jax-runtime-trainer-v2/README.md +++ b/docs/proposals/2442-jax-runtime-trainer-v2/README.md @@ -46,6 +46,7 @@ With this design, Platform Engineers can define standardized training runtimes, - Implement ClusterTrainingRuntime for JAX, supporting distributed training with JAX (e.g. multi-controller JAX) - Build the necessary Docker images for JAX worker nodes used by the runtime - Implement the solution to work on CPU and GPU +- Integrate with SDK and address any necessary enhancements - Document user guides for utilizing JAX ClusterTrainingRuntimes - Test the implementation thoroughly using unit tests and end-to-end (E2E) tests From d88df352b277fdbfe111d8d418c6258a235e20e2 Mon Sep 17 00:00:00 2001 From: Mahdi Khashan <58775404+mahdikhashan@users.noreply.github.com> Date: Sun, 20 Jul 2025 10:42:12 +0000 Subject: [PATCH 12/38] update user persona names Signed-off-by: Mahdi Khashan --- .../2442-jax-runtime-trainer-v2/README.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/proposals/2442-jax-runtime-trainer-v2/README.md b/docs/proposals/2442-jax-runtime-trainer-v2/README.md index 262b764c81..13044cb07b 100644 --- a/docs/proposals/2442-jax-runtime-trainer-v2/README.md +++ b/docs/proposals/2442-jax-runtime-trainer-v2/README.md @@ -33,7 +33,7 @@ These usecases added on top of the new Runtime API for distributed training or c In general the motivation is to enable users to use Single-Program Multi-Data (SPMD) pattern with JAX Framework. -With this design, Platform Engineers can define standardized training runtimes, while Data Scientists can easily customize them, through a simple SDK interface, without needing to understand Kubernetes internals. +With this design, Platform Admins can define standardized training runtimes, while AI Practitioners can easily customize them, through a simple SDK interface, without needing to understand Kubernetes internals. **Benefits** @@ -61,11 +61,11 @@ With this design, Platform Engineers can define standardized training runtimes, #### Story 1 -As a Platform Engineer, I want to manage JAX distributed training jobs using the Kubeflow Trainer V2, so then I can provide blueprints for training of machine learning models on a kubernetes cluster to engineering teams. +As a Platform Admin, I want to manage JAX distributed training jobs using the Kubeflow Trainer V2, so then I can provide blueprints for training of machine learning models on a kubernetes cluster to engineering teams. #### Story 2 -As a Data Scientist, I want to use the Trainer V2 SDK to run a distributed training job from notebook, in this way I can incorporate multiple devices for my training task. +As an AI Practitioner, I want to use the Trainer V2 SDK to run a distributed training job from notebook, in this way I can incorporate multiple devices for my training task. The Python SDK with JAXRuntime may look as follows: @@ -130,8 +130,8 @@ This section explains the architecture and flow of executing a distributed JAX t ![user-roles](./drawing.drawio.svg) -#### 1. Platform Engineer Prepares the Training Environment -- A **Platform Engineer** sets up the **Cluster Training Runtime** with details like: +#### 1. Platform Admins Prepares the Training Environment +- A **Platform Admins** sets up the **Cluster Training Runtime** with details like: - Container image - Entrypoint - Framework (e.g., JAX) @@ -141,8 +141,8 @@ This section explains the architecture and flow of executing a distributed JAX t #### 2. Training Runtime is Retrieved - When a user requests a training job, the system **fetches the runtime spec** to know how to run the job. -#### 3. Data Scientist or ML Researcher Creates the Training Job -- A **Data Scientist** or **ML Researcher** creates a training job using: +#### 3. AI Practitioner Creates the Training Job +- A **AI Practitioners** creates a training job using: - The **Kubeflow Python SDK**, or - A `kubectl` command. - They provide the training function (e.g., `jax_train_mnist`), any needed arguments, and settings like how many nodes to use. From d3c5c6d64b1e2c3e182439e66fcad3f09241f464 Mon Sep 17 00:00:00 2001 From: Mahdi Khashan <58775404+mahdikhashan@users.noreply.github.com> Date: Wed, 23 Jul 2025 10:28:24 +0000 Subject: [PATCH 13/38] improve summary Signed-off-by: Mahdi Khashan --- .../2442-jax-runtime-trainer-v2/README.md | 66 ++++++++++++++++++- 1 file changed, 63 insertions(+), 3 deletions(-) diff --git a/docs/proposals/2442-jax-runtime-trainer-v2/README.md b/docs/proposals/2442-jax-runtime-trainer-v2/README.md index 13044cb07b..55404de108 100644 --- a/docs/proposals/2442-jax-runtime-trainer-v2/README.md +++ b/docs/proposals/2442-jax-runtime-trainer-v2/README.md @@ -15,6 +15,7 @@ - [Communication Backend](#communication-backend) - [OpenMPI](#openmpi) - [Gloo](#gloo) + - [NCCL](#nccl) - [Test Plan](#test-plan) - [End-to-End (E2E) Tests](#end-to-end-e2e-tests) - [Working Examples](#working-examples) @@ -23,7 +24,8 @@ ## Summary -This document outlines a proposal to support the JAX Runtime in Kubeflow Trainer V2. Built upon the Kubernetes JobSet API, the JAX Runtime enables training and fine-tuning workloads using the JAX framework on Kubernetes. Instead of relying on framework-specific CRDs, Trainer V2 introduces a unified abstraction through TrainingRuntime and TrainJob. The JAX Runtime implements this abstraction to serve as a reusable blueprint for model training tasks, including large language models (LLMs). Thanks to Kubeflow Trainer Pipeline Framework, we can seamlessly integrate the JAX runtime into Kubeflow Trainer V2 as a runtime plugin. +This document outlines a proposal to support the JAX Runtime in Kubeflow Trainer V2. Built upon the Kubernetes JobSet API, the JAX Runtime enables training and fine-tuning workloads using the JAX framework on Kubernetes. Instead of relying on framework-specific CRDs, Trainer V2 introduces a unified abstraction through TrainingRuntime and TrainJob. The JAX Runtime implements this abstraction to serve as a reusable blueprint for model training tasks, including large language models (LLMs). With the Kubeflow Trainer Pipeline Framework, we can easily integrate the JAX runtime into Kubeflow Trainer V2 as a runtime plugin. + ## Motivation @@ -173,6 +175,7 @@ The number of **JAX hosts** is configured using the `numNodes` field in the **ML **Pros:** * Compatible with existing MPI runtime in Kubeflow Trainer v2, making deployment easier. +* Leverage `mpi4jax` for HPC application **Cons:** @@ -257,6 +260,7 @@ spec: **Pros:** * Lightweight and simple to use. +* Compatible with the Trainer v1 **Cons:** @@ -273,10 +277,10 @@ metadata: spec: mlPolicy: numNodes: 4 + jax: + backend: gloo template: spec: - successPolicy: - operator: All replicatedJobs: - name: process template: @@ -299,6 +303,62 @@ spec: pip list | grep jax ``` +#### NCCL + +**Pros** +* Minimal setup; JAX/XLA usually auto-configures it. +* Optimized GPU collectives (all-reduce, etc.) that leverage NVLink/PCIe topology. +* Can use GPUDirect (incl. RDMA) for fast inter-node transfers when fabric supports it. + +**Cons** +* Performance drops on CPU-only or host-staged gradients; MPI often faster there. +* High latency for many small messages; can trail tuned OpenMPI runs (env-dependent, sometimes 10–20× in CPU-centric tests). +* Debug tooling limited; transport or fabric misconfig can silently degrade throughput. + + +**ClusterTrainingRuntime Design** + +```yaml +apiVersion: trainer.kubeflow.org/v1alpha1 +kind: ClusterTrainingRuntime +metadata: + name: jax-distributed +spec: + mlPolicy: + numNodes: 4 + jax: + backend: nccl + envs: + - name: NCCL_DEBUG + value: "WARN" + - name: NCCL_IB_DISABLE + value: "1" + - name: NCCL_SOCKET_IFNAME + value: "eth0" + template: + spec: + replicatedJobs: + - name: process + template: + spec: + template: + spec: + containers: + - name: node + image: ghcr.io/kubeflow/trainer/jax-runtime + securityContext: + runAsUser: 1000 + command: + - bash + - -c + - | + echo "JAX Distributed Runtime with NCCL" + echo "--------------------------------------" + set -e + python --version + pip list | grep jax +``` + ## Test Plan The testing strategy will focus on validating functionality, usability, and integration of the proposed `TrainingRuntime` mechanism for distributed training workloads. It includes the following components: From b5ae74c8532010c3b84ee009ee2b3fd56aa21794 Mon Sep 17 00:00:00 2001 From: Mahdi Khashan <58775404+mahdikhashan@users.noreply.github.com> Date: Wed, 23 Jul 2025 10:34:29 +0000 Subject: [PATCH 14/38] improve motivation Signed-off-by: Mahdi Khashan --- docs/proposals/2442-jax-runtime-trainer-v2/README.md | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/docs/proposals/2442-jax-runtime-trainer-v2/README.md b/docs/proposals/2442-jax-runtime-trainer-v2/README.md index 55404de108..d9a7065edd 100644 --- a/docs/proposals/2442-jax-runtime-trainer-v2/README.md +++ b/docs/proposals/2442-jax-runtime-trainer-v2/README.md @@ -29,13 +29,13 @@ This document outlines a proposal to support the JAX Runtime in Kubeflow Trainer ## Motivation -JAX is a powerful ML framework created by Google. It is widely used in the machine learning research and ranks as the third most widely used deep learning frameworks. JAX is not only a deep learning framework but suggests its potential in differential programming, large-scale physics simulations and many more. +JAX is a high-performance numerical computing framework created by Google. It is widely used in the machine learning research and ranks as the third most widely used deep learning frameworks. JAX also suggests its potential in differential programming, large-scale physics simulations and many more. These usecases added on top of the new Runtime API for distributed training or calculation of objectives enables new users on top of Kubeflow Trainer, like distributed simulation or training of LLM prototypes developed with JAX, like vast models from Google DeepMind. -In general the motivation is to enable users to use Single-Program Multi-Data (SPMD) pattern with JAX Framework. +In general the motivation is to enable users to use Single-Program Multi-Data (SPMD) pattern with JAX Framework, however there are other reasons like ensure backward compatibility with Trainer V1, which previously included JAX support, allowing existing users to transition smoothly while taking advantage of the enhanced Runtime API. -With this design, Platform Admins can define standardized training runtimes, while AI Practitioners can easily customize them, through a simple SDK interface, without needing to understand Kubernetes internals. +Finally with this design, Platform Admins can define standardized training runtimes, while AI Practitioners can easily customize them, through a simple SDK interface, without needing to understand Kubernetes internals. **Benefits** @@ -43,6 +43,8 @@ With this design, Platform Admins can define standardized training runtimes, whi 2. Enable distributed training or objective computation using the new Runtime API 3. Support prototyping and training of large JAX-based LLMs within Kubeflow Trainer +--- + ### Goals - Implement ClusterTrainingRuntime for JAX, supporting distributed training with JAX (e.g. multi-controller JAX) From 71d6a49140af3d5ad844fab806305215503cb102 Mon Sep 17 00:00:00 2001 From: Mahdi Khashan <58775404+mahdikhashan@users.noreply.github.com> Date: Wed, 23 Jul 2025 10:40:34 +0000 Subject: [PATCH 15/38] improve sample code in story 2 Signed-off-by: Mahdi Khashan --- docs/proposals/2442-jax-runtime-trainer-v2/README.md | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/docs/proposals/2442-jax-runtime-trainer-v2/README.md b/docs/proposals/2442-jax-runtime-trainer-v2/README.md index d9a7065edd..5f9250d75f 100644 --- a/docs/proposals/2442-jax-runtime-trainer-v2/README.md +++ b/docs/proposals/2442-jax-runtime-trainer-v2/README.md @@ -76,16 +76,20 @@ The Python SDK with JAXRuntime may look as follows: ```python from kubeflow.trainer import TrainerClient, CustomTrainer +# Add logic using JAX methods def jax_train_mnist(args): - # TODO: Add training logic using JAX - pass + raise NotImplementedError # Select the JAX runtime client = TrainerClient() jax_runtime = next(r for r in client.list_runtimes() if r.name == "jax-distributed") + +# Custom parameters passed as arguments args = { - "lr": "0.03" + "parameter_1": "20" + "parameter_2": "MSE" } + # Launch training job job_id = client.train( trainer=CustomTrainer(func=jax_train_mnist, func_args=args, num_nodes=3), From 3ae485395881f9f7344b98e99286fed89d6224e2 Mon Sep 17 00:00:00 2001 From: Mahdi Khashan <58775404+mahdikhashan@users.noreply.github.com> Date: Wed, 23 Jul 2025 10:41:27 +0000 Subject: [PATCH 16/38] improve code Signed-off-by: Mahdi Khashan --- docs/proposals/2442-jax-runtime-trainer-v2/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/proposals/2442-jax-runtime-trainer-v2/README.md b/docs/proposals/2442-jax-runtime-trainer-v2/README.md index 5f9250d75f..ea139abc05 100644 --- a/docs/proposals/2442-jax-runtime-trainer-v2/README.md +++ b/docs/proposals/2442-jax-runtime-trainer-v2/README.md @@ -86,8 +86,8 @@ jax_runtime = next(r for r in client.list_runtimes() if r.name == "jax-distribut # Custom parameters passed as arguments args = { - "parameter_1": "20" - "parameter_2": "MSE" + "epoch": "20" + "loss": "MSE" } # Launch training job From d1159e069725fac87ccd2ca28dc94ea010444008 Mon Sep 17 00:00:00 2001 From: Mahdi Khashan <58775404+mahdikhashan@users.noreply.github.com> Date: Sat, 6 Sep 2025 09:35:18 +0000 Subject: [PATCH 17/38] add table Signed-off-by: Mahdi Khashan --- .../2442-jax-runtime-trainer-v2/README.md | 32 ++++--------------- 1 file changed, 7 insertions(+), 25 deletions(-) diff --git a/docs/proposals/2442-jax-runtime-trainer-v2/README.md b/docs/proposals/2442-jax-runtime-trainer-v2/README.md index ea139abc05..0618e625cd 100644 --- a/docs/proposals/2442-jax-runtime-trainer-v2/README.md +++ b/docs/proposals/2442-jax-runtime-trainer-v2/README.md @@ -105,32 +105,14 @@ In order to address this functionality, we propose the following design: To understand the **JAX runtime** in Kubeflow Trainer V2, it's important to clarify the terminology used in JAX for distributed training: -#### Host +| Concept | Description | +|---------|-------------| +| **Host** | A physical or virtual machine participating in distributed training. Each host runs a **single JAX process**, which manages all local devices (e.g., GPUs, CPUs). JAX auto-detects and utilizes all available devices. (In Kubernetes, a host maps to a **Node**, and typically one **Pod** is scheduled per host.) | +| **JAX Process / Controller** | A Python process running the JAX program (exactly one per host). Responsible for executing the training loop, managing all local devices, and synchronizing with other JAX processes over the network. Uses **SPMD** across processes. | +| **Devices** | Compute units on a host (CPU cores, GPUs, TPUs). JAX detects devices automatically and runs parallel computations via `jax.pmap`, `jax.shard_map`, or `pjit`. Each JAX process accesses all devices on its host. **No need to spawn multiple processes per GPU** (unlike PyTorch). | +| **Pod** | A Kubernetes Pod runs a single JAX process. Scheduled on a node and may use one or more GPUs depending on resource specifications. | +| **Node** | A Kubernetes Node is a worker machine. In multi-node JAX jobs, each node typically runs one pod, mapping to one JAX host. | -In JAX, a **host** refers to a physical or virtual machine that participates in distributed training. Each host runs a **single JAX process**, which manages all the local devices (e.g., GPUs or CPUs). JAX automatically detects and utilizes all available devices on a host. - ->In Kubernetes, a host maps to a Node, and in the runtime implementation, one Pod is typically scheduled per host. - -**JAX Process/Controller** - -A **JAX process** (or sometimes called a **controller**) is a Python process running the JAX program. There is exactly one JAX process per host. This process is responsible for executing the training loop, managing all local devices, and communicating with other JAX processes over the network for synchronization - -JAX uses **Single Program Multiple Data (SPMD)** across these processes. - -**Devices** - -Devices refer to individual compute units on a host (like CPU cores, GPUs, or TPUs). JAX handles device detection automatically and runs parallel computations using `jax.pmap`, `jax.shard_map`, or `pjit`. - -Each JAX process has access to **all devices on its host**. There is **no need to spawn multiple processes per GPU**, unlike PyTorch. - - -**Pod** - -A **Pod** in Kubernetes runs a JAX process. It is scheduled on a node and may use one or more GPUs depending on the resource specification. - -**Node** - -In Kubernetes, a **Node** is a worker machine in the cluster. When we run multi-node JAX jobs, each node typically runs one pod, which maps to one JAX host. ### JAX Training Workflow From bc7e4bab3a82c5de43fb75c6e8cd4aeabcfc1d37 Mon Sep 17 00:00:00 2001 From: Mahdi Khashan <58775404+mahdikhashan@users.noreply.github.com> Date: Sat, 6 Sep 2025 12:34:43 +0000 Subject: [PATCH 18/38] wip: jax-ml-policy Signed-off-by: Mahdi Khashan --- .../2442-jax-runtime-trainer-v2/README.md | 76 +++++++++++++++++++ 1 file changed, 76 insertions(+) diff --git a/docs/proposals/2442-jax-runtime-trainer-v2/README.md b/docs/proposals/2442-jax-runtime-trainer-v2/README.md index 0618e625cd..866ef94f34 100644 --- a/docs/proposals/2442-jax-runtime-trainer-v2/README.md +++ b/docs/proposals/2442-jax-runtime-trainer-v2/README.md @@ -156,6 +156,82 @@ This section explains the architecture and flow of executing a distributed JAX t The number of **JAX hosts** is configured using the `numNodes` field in the **MLPolicy** section of the **ClusterTrainingRuntime**. Each host runs a single JAX process inside a Pod. +#### JAXMLPolicySource + + JAXMLPolicySource allows detailed configuration of JAX distributed initialization, backend, devices, and precision. + +```golang +type MLPolicySource struct { + [...] + + JAX *JAXMLPolicySource `json:"jax,omitempty"` +} + +type JAXMLPolicySource struct { + + // Backend for JAX distributed communication. + // +kubebuilder:default="nccl" + // +kubebuilder:validation:Enum=nccl;gloo;mpi + TargetBackend *string `json:"targetBackend,omitempty"` + + // Platforms is comma-separated list of platform names + // specifying which platforms jax should initialize + // +kubebuilder:default="gpu,tpu,cpu" + Platforms *string `json:"platform,omitempty"` + + // Whether to disable JAX compilation optimizations. + // +kubebuilder:default=false + DisableJIT *bool `json:"disableJIT,omitempty"` + + // Check for and raise errors on NaNs + // +kubebuilder:default=false + DebugNaNs *bool `json:"debugNaNs,omitempty"` + + // Set default precision for matrix multiplication + // +kubebuilder:validation:Enum=default;high;highest;bfloat16;tensorfloat32;float32; + // ANY_F8_ANY_F8_F32;ANY_F8_ANY_F8_F32_FAST_ACCUM;ANY_F8_ANY_F8_ANY; + // ANY_F8_ANY_F8_ANY_FAST_ACCUM;F16_F16_F16;F16_F16_F32;BF16_BF16_BF16; + // BF16_BF16_F32;BF16_BF16_F32_X3;BF16_BF16_F32_X6;BF16_BF16_F32_X9; + // TF32_TF32_F32;TF32_TF32_F32_X3;F32_F32_F32;F64_F64_F64 + DefaultMatMulPrecision *string `json:"defaultMatmulPrecision,omitempty"` + + // Additional specific configurations. + // +listType=map + // +listMapKey=name + ExtraEnv []corev1.EnvVar `json:"extraEnv,omitempty"` + + // Distributed contains explicit args used when calling jax.distributed.initialize(). + // This should be provided when not relying on automatic cluster detection (Slurm, MPI launcher, Cloud TPU, etc). + Distributed *DistributedConfig `json:"distributed,omitempty"` +} + +type DistributedConfig struct { + // CoordinatorAddress is the address (host:port) of the coordinator process. + // +kubebuilder:validation:Required + CoordinatorAddress string `json:"coordinatorAddress"` + + // NumProcesses is the total number of processes across all hosts. + // +kubebuilder:validation:Minimum=1 + // +kubebuilder:validation:Required + NumProcesses int `json:"numProcesses"` + + // ProcessID is the unique integer id for this process (0..NumProcesses-1). + // +kubebuilder:validation:Minimum=0 + // +kubebuilder:validation:Required + ProcessID int `json:"processID"` + + // LocalDeviceIDs lists local device indexes assigned to this process (e.g. [0,1]). + // +kubebuilder:validation:MinItems=1 + // +kubebuilder:validation:Required + LocalDeviceIDs []int `json:"localDeviceIDs"` + + // ClusterDetectionMethod (optional) — a hint for automatic detection strategies + // (e.g., "slurm", "openmpi", "tpu", "env", "none"). If unset and not using an + // automatic launcher, Distributed must be provided with the required fields above. + ClusterDetectionMethod *string `json:"clusterDetectionMethod,omitempty"` +} +``` + ### Communication Backend #### OpenMPI From 6cc6e4b87be11bfa54fd5ca1c0189bd64e07c8e1 Mon Sep 17 00:00:00 2001 From: Mahdi Khashan <58775404+mahdikhashan@users.noreply.github.com> Date: Fri, 19 Sep 2025 19:17:52 +0000 Subject: [PATCH 19/38] use a table for workflow description Signed-off-by: Mahdi Khashan --- .../2442-jax-runtime-trainer-v2/README.md | 39 +++++-------------- 1 file changed, 9 insertions(+), 30 deletions(-) diff --git a/docs/proposals/2442-jax-runtime-trainer-v2/README.md b/docs/proposals/2442-jax-runtime-trainer-v2/README.md index 866ef94f34..5909d979b3 100644 --- a/docs/proposals/2442-jax-runtime-trainer-v2/README.md +++ b/docs/proposals/2442-jax-runtime-trainer-v2/README.md @@ -120,36 +120,15 @@ This section explains the architecture and flow of executing a distributed JAX t ![user-roles](./drawing.drawio.svg) -#### 1. Platform Admins Prepares the Training Environment -- A **Platform Admins** sets up the **Cluster Training Runtime** with details like: - - Container image - - Entrypoint - - Framework (e.g., JAX) - - Resource needs -- This setup can be reused by others to run training jobs. - -#### 2. Training Runtime is Retrieved -- When a user requests a training job, the system **fetches the runtime spec** to know how to run the job. - -#### 3. AI Practitioner Creates the Training Job -- A **AI Practitioners** creates a training job using: - - The **Kubeflow Python SDK**, or - - A `kubectl` command. -- They provide the training function (e.g., `jax_train_mnist`), any needed arguments, and settings like how many nodes to use. - -#### 4. JobSet is Created and Submitted -- The training job uses the runtime spec to create a **JobSet**, a group of jobs working together to train the model. - -#### 5. Distributed Jobs Start Running -- The **JobSet** launches multiple **Kubernetes Jobs**. -- Each job runs one instance of the **JAX training process** in its own pod. - -#### 6. Headless Service Connects the Jobs -- A **Headless Service** allows the pods to **communicate directly** for tasks like sharing gradients and coordinating training. - -#### 7. Training Runs Across the Cluster -- Each pod runs the training code using **JAX and Python**. -- The pods work together to complete the distributed training on the available hardware. +| **Actor / Component** | **Action** | **Details** | +| ----------------------- | ----------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------- | +| Platform Admin | Prepares the **Cluster Training Runtime** | Defines container image, entrypoint, framework (e.g., JAX), and resource needs. Setup reusable for training jobs. | +| System | Retrieves the **Training Runtime Spec** | Fetched automatically when a user requests a training job to determine execution details. | +| AI Practitioner | Creates the **Training Job** | Uses Kubeflow Python SDK or `kubectl`. Provides training function (e.g., `jax_train_mnist`), arguments, and node configuration. | +| Runtime + Controller | Creates and Submits a **JobSet** | Training job spec is translated into a JobSet (group of coordinated jobs). | +| JobSet Controller / K8s | Launches **Distributed Jobs** | JobSet spawns multiple Kubernetes Jobs, each pod runs a JAX training process instance. | +| Headless Service | Connects Pods for Communication | Enables direct pod-to-pod communication for gradient sharing and coordination. | +| Cluster (Pods + JAX) | Executes **Distributed Training** | Each pod runs JAX+Python code, collaborating to complete training across available hardware. | ### Defining Distributed JAX with MLPolicy From cfddcedefb439e6783eeb4b67a485733b7a16e25 Mon Sep 17 00:00:00 2001 From: Mahdi Khashan <58775404+mahdikhashan@users.noreply.github.com> Date: Fri, 19 Sep 2025 19:27:40 +0000 Subject: [PATCH 20/38] remove communication backend section and add future work Signed-off-by: Mahdi Khashan --- .../2442-jax-runtime-trainer-v2/README.md | 267 +----------------- 1 file changed, 8 insertions(+), 259 deletions(-) diff --git a/docs/proposals/2442-jax-runtime-trainer-v2/README.md b/docs/proposals/2442-jax-runtime-trainer-v2/README.md index 5909d979b3..c4f34f2f91 100644 --- a/docs/proposals/2442-jax-runtime-trainer-v2/README.md +++ b/docs/proposals/2442-jax-runtime-trainer-v2/README.md @@ -12,10 +12,6 @@ - [Key Concepts in JAX Distributed Training](#key-concepts-in-jax-distributed-training) - [JAX Training Workflow](#jax-training-workflow-flow) - [Defining JAX Processes with MLPolicy](#defining-jax-processes-with-mlpolicy) - - [Communication Backend](#communication-backend) - - [OpenMPI](#openmpi) - - [Gloo](#gloo) - - [NCCL](#nccl) - [Test Plan](#test-plan) - [End-to-End (E2E) Tests](#end-to-end-e2e-tests) - [Working Examples](#working-examples) @@ -137,7 +133,7 @@ The number of **JAX hosts** is configured using the `numNodes` field in the **ML #### JAXMLPolicySource - JAXMLPolicySource allows detailed configuration of JAX distributed initialization, backend, devices, and precision. + `JAXMLPolicySource` allows detailed configuration of JAX distributed initialization, backend, devices, and precision. ```golang type MLPolicySource struct { @@ -146,260 +142,7 @@ type MLPolicySource struct { JAX *JAXMLPolicySource `json:"jax,omitempty"` } -type JAXMLPolicySource struct { - - // Backend for JAX distributed communication. - // +kubebuilder:default="nccl" - // +kubebuilder:validation:Enum=nccl;gloo;mpi - TargetBackend *string `json:"targetBackend,omitempty"` - - // Platforms is comma-separated list of platform names - // specifying which platforms jax should initialize - // +kubebuilder:default="gpu,tpu,cpu" - Platforms *string `json:"platform,omitempty"` - - // Whether to disable JAX compilation optimizations. - // +kubebuilder:default=false - DisableJIT *bool `json:"disableJIT,omitempty"` - - // Check for and raise errors on NaNs - // +kubebuilder:default=false - DebugNaNs *bool `json:"debugNaNs,omitempty"` - - // Set default precision for matrix multiplication - // +kubebuilder:validation:Enum=default;high;highest;bfloat16;tensorfloat32;float32; - // ANY_F8_ANY_F8_F32;ANY_F8_ANY_F8_F32_FAST_ACCUM;ANY_F8_ANY_F8_ANY; - // ANY_F8_ANY_F8_ANY_FAST_ACCUM;F16_F16_F16;F16_F16_F32;BF16_BF16_BF16; - // BF16_BF16_F32;BF16_BF16_F32_X3;BF16_BF16_F32_X6;BF16_BF16_F32_X9; - // TF32_TF32_F32;TF32_TF32_F32_X3;F32_F32_F32;F64_F64_F64 - DefaultMatMulPrecision *string `json:"defaultMatmulPrecision,omitempty"` - - // Additional specific configurations. - // +listType=map - // +listMapKey=name - ExtraEnv []corev1.EnvVar `json:"extraEnv,omitempty"` - - // Distributed contains explicit args used when calling jax.distributed.initialize(). - // This should be provided when not relying on automatic cluster detection (Slurm, MPI launcher, Cloud TPU, etc). - Distributed *DistributedConfig `json:"distributed,omitempty"` -} - -type DistributedConfig struct { - // CoordinatorAddress is the address (host:port) of the coordinator process. - // +kubebuilder:validation:Required - CoordinatorAddress string `json:"coordinatorAddress"` - - // NumProcesses is the total number of processes across all hosts. - // +kubebuilder:validation:Minimum=1 - // +kubebuilder:validation:Required - NumProcesses int `json:"numProcesses"` - - // ProcessID is the unique integer id for this process (0..NumProcesses-1). - // +kubebuilder:validation:Minimum=0 - // +kubebuilder:validation:Required - ProcessID int `json:"processID"` - - // LocalDeviceIDs lists local device indexes assigned to this process (e.g. [0,1]). - // +kubebuilder:validation:MinItems=1 - // +kubebuilder:validation:Required - LocalDeviceIDs []int `json:"localDeviceIDs"` - - // ClusterDetectionMethod (optional) — a hint for automatic detection strategies - // (e.g., "slurm", "openmpi", "tpu", "env", "none"). If unset and not using an - // automatic launcher, Distributed must be provided with the required fields above. - ClusterDetectionMethod *string `json:"clusterDetectionMethod,omitempty"` -} -``` - -### Communication Backend - -#### OpenMPI - -**Pros:** - -* Compatible with existing MPI runtime in Kubeflow Trainer v2, making deployment easier. -* Leverage `mpi4jax` for HPC application - -**Cons:** - -* Typically requires more complex environment setup compared to simpler backends like Gloo. - -**ClusterTrainingRuntime Design** - -```yaml -apiVersion: trainer.kubeflow.org/v1alpha1 -kind: ClusterTrainingRuntime -metadata: - name: jax-distributed -spec: - mlPolicy: - numNodes: 1 - mpi: - numProcPerNode: 1 - mpiImplementation: OpenMPI - sshAuthMountPath: /home/mpiuser/.ssh - runLauncherAsNode: true - template: - spec: - network: - publishNotReadyAddresses: true - successPolicy: - operator: All - targetReplicatedJobs: - - launcher - replicatedJobs: - - name: launcher - template: - metadata: - labels: - trainer.kubeflow.org/trainjob-ancestor-step: trainer - spec: - template: - spec: - containers: - - name: node - image: ghcr.io/kubeflow/trainer/jax-runtime - securityContext: - runAsUser: 1000 - command: - - mpirun - - -n - - "1" - - bash - - -c - - | - echo "JAX Distributed Runtime" - - echo "--------------------------------------" - set -e - mpirun --version - python --version - pip list - - name: node - template: - spec: - template: - spec: - containers: - - name: node - image: ghcr.io/kubeflow/trainer/jax-runtime - securityContext: - runAsUser: 1000 - command: - - /usr/sbin/sshd - args: - - -De - - -f - - /home/mpiuser/.sshd_config - readinessProbe: - tcpSocket: - port: 2222 - initialDelaySeconds: 5 -``` - - -#### Gloo - -**Pros:** - -* Lightweight and simple to use. -* Compatible with the Trainer v1 - -**Cons:** - -* Significantly slower than OpenMPI (10–20×) for distributed JAX training on CPUs and GPUs. -* Less optimized for multi-node scaling and lacks native support for high-speed interconnects like InfiniBand. - -**ClusterTrainingRuntime Design** - -```yaml -apiVersion: trainer.kubeflow.org/v1alpha1 -kind: ClusterTrainingRuntime -metadata: - name: jax-distributed -spec: - mlPolicy: - numNodes: 4 - jax: - backend: gloo - template: - spec: - replicatedJobs: - - name: process - template: - spec: - template: - spec: - containers: - - name: node - image: ghcr.io/kubeflow/trainer/jax-runtime - securityContext: - runAsUser: 1000 - command: - - bash - - -c - - | - echo "JAX Distributed Runtime" - echo "--------------------------------------" - set -e - python --version - pip list | grep jax -``` - -#### NCCL - -**Pros** -* Minimal setup; JAX/XLA usually auto-configures it. -* Optimized GPU collectives (all-reduce, etc.) that leverage NVLink/PCIe topology. -* Can use GPUDirect (incl. RDMA) for fast inter-node transfers when fabric supports it. - -**Cons** -* Performance drops on CPU-only or host-staged gradients; MPI often faster there. -* High latency for many small messages; can trail tuned OpenMPI runs (env-dependent, sometimes 10–20× in CPU-centric tests). -* Debug tooling limited; transport or fabric misconfig can silently degrade throughput. - - -**ClusterTrainingRuntime Design** - -```yaml -apiVersion: trainer.kubeflow.org/v1alpha1 -kind: ClusterTrainingRuntime -metadata: - name: jax-distributed -spec: - mlPolicy: - numNodes: 4 - jax: - backend: nccl - envs: - - name: NCCL_DEBUG - value: "WARN" - - name: NCCL_IB_DISABLE - value: "1" - - name: NCCL_SOCKET_IFNAME - value: "eth0" - template: - spec: - replicatedJobs: - - name: process - template: - spec: - template: - spec: - containers: - - name: node - image: ghcr.io/kubeflow/trainer/jax-runtime - securityContext: - runAsUser: 1000 - command: - - bash - - -c - - | - echo "JAX Distributed Runtime with NCCL" - echo "--------------------------------------" - set -e - python --version - pip list | grep jax +type JAXMLPolicySource struct {} ``` ## Test Plan @@ -431,6 +174,12 @@ The testing strategy will focus on validating functionality, usability, and inte * Use mocks/fakes where needed to simulate cluster conditions and resource state. * Ensure **controller reconciliation logic** is tested thoroughly. +--- + +## Future Work + +While it is possible to **configure a specific communication backend** (e.g., NCCL, MPI, Gloo) for the runtime, we have **deferred this decision** to simplify the current implementation. By default, **JAX uses Gloo** as the communication backend. The design ensures the system remains **extensible**, allowing backend selection and integration to be added in response to future feedback without major changes. + ## Implementation History - 2025-05-28: Initial KEP draft created. From c1ed5c8d3cbc6990098b06bfaeaba12993e1ba05 Mon Sep 17 00:00:00 2001 From: Mahdi Khashan <58775404+mahdikhashan@users.noreply.github.com> Date: Fri, 19 Sep 2025 19:29:00 +0000 Subject: [PATCH 21/38] remove extra lines Signed-off-by: Mahdi Khashan --- docs/proposals/2442-jax-runtime-trainer-v2/README.md | 4 ---- 1 file changed, 4 deletions(-) diff --git a/docs/proposals/2442-jax-runtime-trainer-v2/README.md b/docs/proposals/2442-jax-runtime-trainer-v2/README.md index c4f34f2f91..9c1ef2f539 100644 --- a/docs/proposals/2442-jax-runtime-trainer-v2/README.md +++ b/docs/proposals/2442-jax-runtime-trainer-v2/README.md @@ -39,8 +39,6 @@ Finally with this design, Platform Admins can define standardized training runti 2. Enable distributed training or objective computation using the new Runtime API 3. Support prototyping and training of large JAX-based LLMs within Kubeflow Trainer ---- - ### Goals - Implement ClusterTrainingRuntime for JAX, supporting distributed training with JAX (e.g. multi-controller JAX) @@ -174,8 +172,6 @@ The testing strategy will focus on validating functionality, usability, and inte * Use mocks/fakes where needed to simulate cluster conditions and resource state. * Ensure **controller reconciliation logic** is tested thoroughly. ---- - ## Future Work While it is possible to **configure a specific communication backend** (e.g., NCCL, MPI, Gloo) for the runtime, we have **deferred this decision** to simplify the current implementation. By default, **JAX uses Gloo** as the communication backend. The design ensures the system remains **extensible**, allowing backend selection and integration to be added in response to future feedback without major changes. From 35f5995693e9424aadb8ddd49b10adfa5a186afa Mon Sep 17 00:00:00 2001 From: Mahdi Khashan <58775404+mahdikhashan@users.noreply.github.com> Date: Fri, 19 Sep 2025 19:32:07 +0000 Subject: [PATCH 22/38] separate jaxmlpolicy source block Signed-off-by: Mahdi Khashan --- docs/proposals/2442-jax-runtime-trainer-v2/README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/proposals/2442-jax-runtime-trainer-v2/README.md b/docs/proposals/2442-jax-runtime-trainer-v2/README.md index 9c1ef2f539..867633b361 100644 --- a/docs/proposals/2442-jax-runtime-trainer-v2/README.md +++ b/docs/proposals/2442-jax-runtime-trainer-v2/README.md @@ -139,7 +139,9 @@ type MLPolicySource struct { JAX *JAXMLPolicySource `json:"jax,omitempty"` } +``` +```golang type JAXMLPolicySource struct {} ``` From 396e6d0dfb94acd5527c6736167a268b14151296 Mon Sep 17 00:00:00 2001 From: Mahdi Khashan <58775404+mahdikhashan@users.noreply.github.com> Date: Fri, 19 Sep 2025 19:32:59 +0000 Subject: [PATCH 23/38] update the table of contents with future work Signed-off-by: Mahdi Khashan --- docs/proposals/2442-jax-runtime-trainer-v2/README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/proposals/2442-jax-runtime-trainer-v2/README.md b/docs/proposals/2442-jax-runtime-trainer-v2/README.md index 867633b361..99ca736547 100644 --- a/docs/proposals/2442-jax-runtime-trainer-v2/README.md +++ b/docs/proposals/2442-jax-runtime-trainer-v2/README.md @@ -16,6 +16,7 @@ - [End-to-End (E2E) Tests](#end-to-end-e2e-tests) - [Working Examples](#working-examples) - [Unit and Integration Tests](#unit-and-integration-tests) +- [Future Work](#future-work) - [Implementation History](#implementation-history) ## Summary From 92ff200205eb5ec13733b4ca34df285d06cd59c7 Mon Sep 17 00:00:00 2001 From: Mahdi Khashan <58775404+mahdikhashan@users.noreply.github.com> Date: Fri, 19 Sep 2025 19:35:25 +0000 Subject: [PATCH 24/38] simplify the testing plan section Signed-off-by: Mahdi Khashan --- .../2442-jax-runtime-trainer-v2/README.md | 30 ++++--------------- 1 file changed, 5 insertions(+), 25 deletions(-) diff --git a/docs/proposals/2442-jax-runtime-trainer-v2/README.md b/docs/proposals/2442-jax-runtime-trainer-v2/README.md index 99ca736547..c14021d8ba 100644 --- a/docs/proposals/2442-jax-runtime-trainer-v2/README.md +++ b/docs/proposals/2442-jax-runtime-trainer-v2/README.md @@ -148,32 +148,12 @@ type JAXMLPolicySource struct {} ## Test Plan -The testing strategy will focus on validating functionality, usability, and integration of the proposed `TrainingRuntime` mechanism for distributed training workloads. It includes the following components: +The testing strategy focuses on validating functionality and integration of the `TrainingRuntime` mechanism. -### End-to-End (E2E) Tests - -* **Environment**: Deploy workloads in lightweight local Kubernetes clusters using tools like `kind` or `minikube`. -* **Workloads**: Run simple distributed training examples such as MNIST **JAX**. -* **Validation Goals**: - - * Ensure correct creation of `JobSet` resources. - * Validate successful job execution and error handling paths. - * Confirm compatibility with `TrainingRuntime` configurations. - -### Working Examples - -* Provide clear, runnable examples: - - * **Kubeflow SDK and notebook examples** that demonstrate creating and running training jobs using the new interface. -* These examples will serve as both test cases and documentation to support user onboarding. - -### Unit and Integration Tests - -* For any controller or plugin logic introduced: - - * Write targeted **unit tests** in Go to validate business logic and failure scenarios. - * Use mocks/fakes where needed to simulate cluster conditions and resource state. -* Ensure **controller reconciliation logic** is tested thoroughly. +* **Environment**: Run workloads in a lightweight Kubernetes cluster in **CI actions** (e.g., using `kind` or `minikube`). +* **Workloads**: Execute simple distributed training examples such as MNIST **JAX**. +* **Validation Goals**: Ensure correct creation of `JobSet` resources, successful job execution, and compatibility with `TrainingRuntime` configurations. +* **Working Examples**: Provide runnable **notebook examples** demonstrating how to create and run training jobs. These notebooks double as test cases and user documentation. ## Future Work From 178fc66dc1a08a33df9b0aef03776e8cb7db0a3e Mon Sep 17 00:00:00 2001 From: Mahdi Khashan <58775404+mahdikhashan@users.noreply.github.com> Date: Fri, 19 Sep 2025 19:38:07 +0000 Subject: [PATCH 25/38] update the history section Signed-off-by: Mahdi Khashan --- docs/proposals/2442-jax-runtime-trainer-v2/README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/proposals/2442-jax-runtime-trainer-v2/README.md b/docs/proposals/2442-jax-runtime-trainer-v2/README.md index c14021d8ba..6ccdee7072 100644 --- a/docs/proposals/2442-jax-runtime-trainer-v2/README.md +++ b/docs/proposals/2442-jax-runtime-trainer-v2/README.md @@ -162,3 +162,4 @@ While it is possible to **configure a specific communication backend** (e.g., NC ## Implementation History - 2025-05-28: Initial KEP draft created. +- 2025-09-19: Update design detail section and add future work \ No newline at end of file From 97fdf2c8a34cb135d2b5a602748987099e76bc85 Mon Sep 17 00:00:00 2001 From: Mahdi Khashan <58775404+mahdikhashan@users.noreply.github.com> Date: Fri, 19 Sep 2025 19:44:59 +0000 Subject: [PATCH 26/38] improve design detail section Signed-off-by: Mahdi Khashan --- docs/proposals/2442-jax-runtime-trainer-v2/README.md | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/docs/proposals/2442-jax-runtime-trainer-v2/README.md b/docs/proposals/2442-jax-runtime-trainer-v2/README.md index 6ccdee7072..bb18b62d8b 100644 --- a/docs/proposals/2442-jax-runtime-trainer-v2/README.md +++ b/docs/proposals/2442-jax-runtime-trainer-v2/README.md @@ -132,7 +132,7 @@ The number of **JAX hosts** is configured using the `numNodes` field in the **ML #### JAXMLPolicySource - `JAXMLPolicySource` allows detailed configuration of JAX distributed initialization, backend, devices, and precision. +`JAXMLPolicySource` configures JAX distributed initialization, backend, devices, and precision. ```golang type MLPolicySource struct { @@ -146,6 +146,16 @@ type MLPolicySource struct { type JAXMLPolicySource struct {} ``` +This implementation supports `NCCL`, `libtpu`, and `Gloo` (default) backends. The plugin enables JAX distributed training but does not accept user parameters. To run distributed training with the default Gloo backend, the plugin automatically sets required environment variables across pods, including: + +| Env Variable | Purpose | +| --------------------- | ---------------------------------- | +| `COORDINATOR_ADDRESS` | Address of the coordinator process | +| `NUM_PROCESSES` | Total number of JAX processes | +| `PROCESS_ID` | Unique ID of each process | + +The plugin handles distributed initialization internally, allowing users to launch JAX training jobs without manual backend or process configuration while keeping the implementation extendable for future backend options. + ## Test Plan The testing strategy focuses on validating functionality and integration of the `TrainingRuntime` mechanism. From 096468a4e49f8d7a9b9a8a5903d17ea2f85acb00 Mon Sep 17 00:00:00 2001 From: Mahdi Khashan <58775404+mahdikhashan@users.noreply.github.com> Date: Fri, 19 Sep 2025 19:47:02 +0000 Subject: [PATCH 27/38] update future work Signed-off-by: Mahdi Khashan --- docs/proposals/2442-jax-runtime-trainer-v2/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/proposals/2442-jax-runtime-trainer-v2/README.md b/docs/proposals/2442-jax-runtime-trainer-v2/README.md index bb18b62d8b..34717e18c4 100644 --- a/docs/proposals/2442-jax-runtime-trainer-v2/README.md +++ b/docs/proposals/2442-jax-runtime-trainer-v2/README.md @@ -167,7 +167,7 @@ The testing strategy focuses on validating functionality and integration of the ## Future Work -While it is possible to **configure a specific communication backend** (e.g., NCCL, MPI, Gloo) for the runtime, we have **deferred this decision** to simplify the current implementation. By default, **JAX uses Gloo** as the communication backend. The design ensures the system remains **extensible**, allowing backend selection and integration to be added in response to future feedback without major changes. +While it is possible to configure a specific communication backend (e.g., NCCL, MPI, Gloo) for the runtime by a parameter in `JAXMLPolicy`, we have deferred this decision to simplify the current implementation. By default, JAX uses `Gloo` as the communication backend. The design ensures the system remains extensible, allowing backend selection and integration to be added in response to future feedback without major changes. ## Implementation History From f002edc3fd994b9f725ec758195079849ac0e43f Mon Sep 17 00:00:00 2001 From: Mahdi Khashan <58775404+mahdikhashan@users.noreply.github.com> Date: Fri, 19 Sep 2025 19:47:42 +0000 Subject: [PATCH 28/38] update test plan Signed-off-by: Mahdi Khashan --- docs/proposals/2442-jax-runtime-trainer-v2/README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/proposals/2442-jax-runtime-trainer-v2/README.md b/docs/proposals/2442-jax-runtime-trainer-v2/README.md index 34717e18c4..7e6840d4bf 100644 --- a/docs/proposals/2442-jax-runtime-trainer-v2/README.md +++ b/docs/proposals/2442-jax-runtime-trainer-v2/README.md @@ -164,6 +164,7 @@ The testing strategy focuses on validating functionality and integration of the * **Workloads**: Execute simple distributed training examples such as MNIST **JAX**. * **Validation Goals**: Ensure correct creation of `JobSet` resources, successful job execution, and compatibility with `TrainingRuntime` configurations. * **Working Examples**: Provide runnable **notebook examples** demonstrating how to create and run training jobs. These notebooks double as test cases and user documentation. +* **Unit Tests**: Add unit tests for `JAXMLPolicySource` to validate correct backend selection, environment variable setup, and distributed initialization logic. ## Future Work From ae58d22d9f7d91c667a2888b1d33161c1a4bf14e Mon Sep 17 00:00:00 2001 From: Mahdi Khashan <58775404+mahdikhashan@users.noreply.github.com> Date: Fri, 19 Sep 2025 19:49:42 +0000 Subject: [PATCH 29/38] update workflow description Signed-off-by: Mahdi Khashan --- docs/proposals/2442-jax-runtime-trainer-v2/README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/proposals/2442-jax-runtime-trainer-v2/README.md b/docs/proposals/2442-jax-runtime-trainer-v2/README.md index 7e6840d4bf..fdfbf31dc1 100644 --- a/docs/proposals/2442-jax-runtime-trainer-v2/README.md +++ b/docs/proposals/2442-jax-runtime-trainer-v2/README.md @@ -120,10 +120,10 @@ This section explains the architecture and flow of executing a distributed JAX t | Platform Admin | Prepares the **Cluster Training Runtime** | Defines container image, entrypoint, framework (e.g., JAX), and resource needs. Setup reusable for training jobs. | | System | Retrieves the **Training Runtime Spec** | Fetched automatically when a user requests a training job to determine execution details. | | AI Practitioner | Creates the **Training Job** | Uses Kubeflow Python SDK or `kubectl`. Provides training function (e.g., `jax_train_mnist`), arguments, and node configuration. | -| Runtime + Controller | Creates and Submits a **JobSet** | Training job spec is translated into a JobSet (group of coordinated jobs). | -| JobSet Controller / K8s | Launches **Distributed Jobs** | JobSet spawns multiple Kubernetes Jobs, each pod runs a JAX training process instance. | +| Runtime | Creates and Submits a **JobSet** | Training job spec is translated into a JobSet (group of coordinated jobs). | +| JobSet Controller | Launches **Distributed Jobs** | JobSet spawns multiple Kubernetes Jobs, each pod runs a JAX training process instance. | | Headless Service | Connects Pods for Communication | Enables direct pod-to-pod communication for gradient sharing and coordination. | -| Cluster (Pods + JAX) | Executes **Distributed Training** | Each pod runs JAX+Python code, collaborating to complete training across available hardware. | +| Cluster (Pods) | Executes **Distributed Training** | Each pod runs JAX+Python code, collaborating to complete training across available hardware. | ### Defining Distributed JAX with MLPolicy From 55eeba033edcb18bee67fd14eb6833a3715ef5f7 Mon Sep 17 00:00:00 2001 From: Mahdi Khashan <58775404+mahdikhashan@users.noreply.github.com> Date: Fri, 19 Sep 2025 20:02:35 +0000 Subject: [PATCH 30/38] improve design detail section Signed-off-by: Mahdi Khashan --- .../2442-jax-runtime-trainer-v2/README.md | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/docs/proposals/2442-jax-runtime-trainer-v2/README.md b/docs/proposals/2442-jax-runtime-trainer-v2/README.md index fdfbf31dc1..df49f13e46 100644 --- a/docs/proposals/2442-jax-runtime-trainer-v2/README.md +++ b/docs/proposals/2442-jax-runtime-trainer-v2/README.md @@ -117,13 +117,13 @@ This section explains the architecture and flow of executing a distributed JAX t | **Actor / Component** | **Action** | **Details** | | ----------------------- | ----------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------- | -| Platform Admin | Prepares the **Cluster Training Runtime** | Defines container image, entrypoint, framework (e.g., JAX), and resource needs. Setup reusable for training jobs. | -| System | Retrieves the **Training Runtime Spec** | Fetched automatically when a user requests a training job to determine execution details. | -| AI Practitioner | Creates the **Training Job** | Uses Kubeflow Python SDK or `kubectl`. Provides training function (e.g., `jax_train_mnist`), arguments, and node configuration. | -| Runtime | Creates and Submits a **JobSet** | Training job spec is translated into a JobSet (group of coordinated jobs). | -| JobSet Controller | Launches **Distributed Jobs** | JobSet spawns multiple Kubernetes Jobs, each pod runs a JAX training process instance. | +| Platform Admin | Prepares the Cluster Training Runtime | Defines container image, entrypoint, framework (e.g., JAX), and resource needs. Setup reusable for training jobs. | +| System | Retrieves the Training Runtime Spec | Fetched automatically when a user requests a training job to determine execution details. | +| AI Practitioner | Creates the Training Job | Uses Kubeflow Python SDK or `kubectl`. Provides training function (e.g., `jax_train_mnist`), arguments, and node configuration. | +| Runtime | Creates and Submits a JobSet | Training job spec is translated into a JobSet (group of coordinated jobs). | +| JobSet Controller | Launches Distributed Jobs | JobSet spawns multiple Kubernetes Jobs, each pod runs a JAX training process instance. | | Headless Service | Connects Pods for Communication | Enables direct pod-to-pod communication for gradient sharing and coordination. | -| Cluster (Pods) | Executes **Distributed Training** | Each pod runs JAX+Python code, collaborating to complete training across available hardware. | +| Cluster (Pods) | Executes Distributed Training | Each pod runs JAX+Python code, collaborating to complete training across available hardware. | ### Defining Distributed JAX with MLPolicy @@ -146,15 +146,15 @@ type MLPolicySource struct { type JAXMLPolicySource struct {} ``` -This implementation supports `NCCL`, `libtpu`, and `Gloo` (default) backends. The plugin enables JAX distributed training but does not accept user parameters. To run distributed training with the default Gloo backend, the plugin automatically sets required environment variables across pods, including: +#### JAX Distributed System -| Env Variable | Purpose | -| --------------------- | ---------------------------------- | -| `COORDINATOR_ADDRESS` | Address of the coordinator process | -| `NUM_PROCESSES` | Total number of JAX processes | -| `PROCESS_ID` | Unique ID of each process | +The plugin enables JAX distributed training and handles distributed initialization internally, allowing seamless execution of training jobs with multiple backend configurations for multi-GPU and Cloud TPU. -The plugin handles distributed initialization internally, allowing users to launch JAX training jobs without manual backend or process configuration while keeping the implementation extendable for future backend options. +| Backend | Parameters | Notes | +| ------- | ---------------------------------- | ----------------------------------------------------------------------------------------------------------------- | +| NCCL | None | No additional configuration needed. | +| LibTPU | None | No additional configuration needed. | +| Gloo | None | Environment variables (`COORDINATOR_ADDRESS`, `NUM_PROCESSES`, `PROCESS_ID`) are automatically set by the policy. | ## Test Plan @@ -163,7 +163,7 @@ The testing strategy focuses on validating functionality and integration of the * **Environment**: Run workloads in a lightweight Kubernetes cluster in **CI actions** (e.g., using `kind` or `minikube`). * **Workloads**: Execute simple distributed training examples such as MNIST **JAX**. * **Validation Goals**: Ensure correct creation of `JobSet` resources, successful job execution, and compatibility with `TrainingRuntime` configurations. -* **Working Examples**: Provide runnable **notebook examples** demonstrating how to create and run training jobs. These notebooks double as test cases and user documentation. +* **Working Examples**: Provide runnable notebook examples demonstrating how to create and run training jobs. These notebooks double as test cases and user documentation. * **Unit Tests**: Add unit tests for `JAXMLPolicySource` to validate correct backend selection, environment variable setup, and distributed initialization logic. ## Future Work From 1b0dd1fc546b16b4b606fccb43de377ca5c5433d Mon Sep 17 00:00:00 2001 From: Mahdi Khashan <58775404+mahdikhashan@users.noreply.github.com> Date: Sat, 4 Oct 2025 13:26:36 +0000 Subject: [PATCH 31/38] change system to Trainer Controller Manager Signed-off-by: Mahdi Khashan --- docs/proposals/2442-jax-runtime-trainer-v2/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/proposals/2442-jax-runtime-trainer-v2/README.md b/docs/proposals/2442-jax-runtime-trainer-v2/README.md index df49f13e46..9e45b2b6b3 100644 --- a/docs/proposals/2442-jax-runtime-trainer-v2/README.md +++ b/docs/proposals/2442-jax-runtime-trainer-v2/README.md @@ -118,7 +118,7 @@ This section explains the architecture and flow of executing a distributed JAX t | **Actor / Component** | **Action** | **Details** | | ----------------------- | ----------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------- | | Platform Admin | Prepares the Cluster Training Runtime | Defines container image, entrypoint, framework (e.g., JAX), and resource needs. Setup reusable for training jobs. | -| System | Retrieves the Training Runtime Spec | Fetched automatically when a user requests a training job to determine execution details. | +| Trainer Controller Manager | Retrieves the Training Runtime Spec | Fetched automatically when a user requests a training job to determine execution details. | | AI Practitioner | Creates the Training Job | Uses Kubeflow Python SDK or `kubectl`. Provides training function (e.g., `jax_train_mnist`), arguments, and node configuration. | | Runtime | Creates and Submits a JobSet | Training job spec is translated into a JobSet (group of coordinated jobs). | | JobSet Controller | Launches Distributed Jobs | JobSet spawns multiple Kubernetes Jobs, each pod runs a JAX training process instance. | From effee0a8b9b91fe4f2ec54bf859cbb05f4461f77 Mon Sep 17 00:00:00 2001 From: Mahdi Khashan <58775404+mahdikhashan@users.noreply.github.com> Date: Sat, 4 Oct 2025 13:29:39 +0000 Subject: [PATCH 32/38] merge runtime and trainer controller manager Signed-off-by: Mahdi Khashan --- docs/proposals/2442-jax-runtime-trainer-v2/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/proposals/2442-jax-runtime-trainer-v2/README.md b/docs/proposals/2442-jax-runtime-trainer-v2/README.md index 9e45b2b6b3..1da4e5cd88 100644 --- a/docs/proposals/2442-jax-runtime-trainer-v2/README.md +++ b/docs/proposals/2442-jax-runtime-trainer-v2/README.md @@ -115,10 +115,10 @@ This section explains the architecture and flow of executing a distributed JAX t ![user-roles](./drawing.drawio.svg) -| **Actor / Component** | **Action** | **Details** | +| **Component** | **Action** | **Details** | | ----------------------- | ----------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------- | | Platform Admin | Prepares the Cluster Training Runtime | Defines container image, entrypoint, framework (e.g., JAX), and resource needs. Setup reusable for training jobs. | -| Trainer Controller Manager | Retrieves the Training Runtime Spec | Fetched automatically when a user requests a training job to determine execution details. | +| Trainer Controller Manager | Retrieves the Training Runtime Spec, Creates and Submits a JobSet | Fetched automatically when a user requests a training job to determine execution details. Training job spec is translated into a JobSet (group of coordinated jobs). | | AI Practitioner | Creates the Training Job | Uses Kubeflow Python SDK or `kubectl`. Provides training function (e.g., `jax_train_mnist`), arguments, and node configuration. | | Runtime | Creates and Submits a JobSet | Training job spec is translated into a JobSet (group of coordinated jobs). | | JobSet Controller | Launches Distributed Jobs | JobSet spawns multiple Kubernetes Jobs, each pod runs a JAX training process instance. | From 57044b87a4f69501f448d6ed0586c5165bd160b9 Mon Sep 17 00:00:00 2001 From: Mahdi Khashan <58775404+mahdikhashan@users.noreply.github.com> Date: Sat, 4 Oct 2025 13:33:29 +0000 Subject: [PATCH 33/38] change plugin details Signed-off-by: Mahdi Khashan --- docs/proposals/2442-jax-runtime-trainer-v2/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/proposals/2442-jax-runtime-trainer-v2/README.md b/docs/proposals/2442-jax-runtime-trainer-v2/README.md index 1da4e5cd88..3ad4b46248 100644 --- a/docs/proposals/2442-jax-runtime-trainer-v2/README.md +++ b/docs/proposals/2442-jax-runtime-trainer-v2/README.md @@ -132,7 +132,7 @@ The number of **JAX hosts** is configured using the `numNodes` field in the **ML #### JAXMLPolicySource -`JAXMLPolicySource` configures JAX distributed initialization, backend, devices, and precision. +`JAXMLPolicySource` indicates that the JAX plugin should be activated. The extension framework will set the appropriate values for JAX distributed environment, backend, devices, and precision. ```golang type MLPolicySource struct { From f9718fb2e12339b93350337a69ae95ebe3586770 Mon Sep 17 00:00:00 2001 From: Mahdi Khashan <58775404+mahdikhashan@users.noreply.github.com> Date: Sat, 4 Oct 2025 13:34:50 +0000 Subject: [PATCH 34/38] update goals Signed-off-by: Mahdi Khashan --- docs/proposals/2442-jax-runtime-trainer-v2/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/proposals/2442-jax-runtime-trainer-v2/README.md b/docs/proposals/2442-jax-runtime-trainer-v2/README.md index 3ad4b46248..86b61de50a 100644 --- a/docs/proposals/2442-jax-runtime-trainer-v2/README.md +++ b/docs/proposals/2442-jax-runtime-trainer-v2/README.md @@ -44,7 +44,7 @@ Finally with this design, Platform Admins can define standardized training runti - Implement ClusterTrainingRuntime for JAX, supporting distributed training with JAX (e.g. multi-controller JAX) - Build the necessary Docker images for JAX worker nodes used by the runtime -- Implement the solution to work on CPU and GPU +- Implement the solution to work on CPU, GPU and TPU - Integrate with SDK and address any necessary enhancements - Document user guides for utilizing JAX ClusterTrainingRuntimes - Test the implementation thoroughly using unit tests and end-to-end (E2E) tests From aecb6d2f6d4ed7b4855731fc49d4309422ba954e Mon Sep 17 00:00:00 2001 From: Mahdi Khashan <58775404+mahdikhashan@users.noreply.github.com> Date: Sat, 4 Oct 2025 13:37:29 +0000 Subject: [PATCH 35/38] use kwargs Signed-off-by: Mahdi Khashan --- docs/proposals/2442-jax-runtime-trainer-v2/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/proposals/2442-jax-runtime-trainer-v2/README.md b/docs/proposals/2442-jax-runtime-trainer-v2/README.md index 86b61de50a..8eeb570465 100644 --- a/docs/proposals/2442-jax-runtime-trainer-v2/README.md +++ b/docs/proposals/2442-jax-runtime-trainer-v2/README.md @@ -72,7 +72,7 @@ The Python SDK with JAXRuntime may look as follows: from kubeflow.trainer import TrainerClient, CustomTrainer # Add logic using JAX methods -def jax_train_mnist(args): +def jax_train_mnist(epoch = 10, loss: str = None): raise NotImplementedError # Select the JAX runtime From 40eb8860574f07062bf2e709708227e9eedadfaf Mon Sep 17 00:00:00 2001 From: mahdikhashan Date: Sat, 4 Oct 2025 15:59:17 +0200 Subject: [PATCH 36/38] fix ci Signed-off-by: mahdikhashan --- docs/proposals/2442-jax-runtime-trainer-v2/drawing.drawio.svg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/proposals/2442-jax-runtime-trainer-v2/drawing.drawio.svg b/docs/proposals/2442-jax-runtime-trainer-v2/drawing.drawio.svg index 7207c6fa9d..486dce8171 100644 --- a/docs/proposals/2442-jax-runtime-trainer-v2/drawing.drawio.svg +++ b/docs/proposals/2442-jax-runtime-trainer-v2/drawing.drawio.svg @@ -1,4 +1,4 @@ -
Kubeflow
Python SDK
TrainJob
image/svg+xml
kubectl
Create TrainJob
JobSet
image/svg+xml
Cluster Training Runtime
image/svg+xml
Manage
Runtime
Fetch
Spec
Headless Service
JAX Processes
JAX Light StrokeJAX Light StrokeJAX Light Stroke
Cloud GPUs
Cloud GPUs
Cloud GPUs
Cloud GPUs
Cloud GPUs
Cloud GPUs
JAX Light StrokeJAX Light StrokeJAX Light Stroke
Cloud GPUs
Cloud GPUs
Cloud GPUs
Cloud GPUs
Cloud GPUs
Cloud GPUs
Platform Admins
AI Practitioners
\ No newline at end of file +
Kubeflow
Python SDK
TrainJob
image/svg+xml
kubectl
Create TrainJob
JobSet
image/svg+xml
Cluster Training Runtime
image/svg+xml
Manage
Runtime
Fetch
Spec
Headless Service
JAX Processes
JAX Light StrokeJAX Light StrokeJAX Light Stroke
Cloud GPUs
Cloud GPUs
Cloud GPUs
Cloud GPUs
Cloud GPUs
Cloud GPUs
JAX Light StrokeJAX Light StrokeJAX Light Stroke
Cloud GPUs
Cloud GPUs
Cloud GPUs
Cloud GPUs
Cloud GPUs
Cloud GPUs
Platform Admins
AI Practitioners
From c01462224a339c0a14dc43d587fb71aec0490cd5 Mon Sep 17 00:00:00 2001 From: mahdikhashan Date: Sat, 4 Oct 2025 16:22:27 +0200 Subject: [PATCH 37/38] fix ci Signed-off-by: mahdikhashan --- docs/proposals/2442-jax-runtime-trainer-v2/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/proposals/2442-jax-runtime-trainer-v2/README.md b/docs/proposals/2442-jax-runtime-trainer-v2/README.md index 8eeb570465..0921be58c4 100644 --- a/docs/proposals/2442-jax-runtime-trainer-v2/README.md +++ b/docs/proposals/2442-jax-runtime-trainer-v2/README.md @@ -173,4 +173,4 @@ While it is possible to configure a specific communication backend (e.g., NCCL, ## Implementation History - 2025-05-28: Initial KEP draft created. -- 2025-09-19: Update design detail section and add future work \ No newline at end of file +- 2025-09-19: Update design detail section and add future work From 5806896a58f1f356f7e31f661f91dcc61e543a38 Mon Sep 17 00:00:00 2001 From: mahdikhashan Date: Sun, 5 Oct 2025 11:17:08 +0200 Subject: [PATCH 38/38] update non-goals Signed-off-by: mahdikhashan --- docs/proposals/2442-jax-runtime-trainer-v2/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/proposals/2442-jax-runtime-trainer-v2/README.md b/docs/proposals/2442-jax-runtime-trainer-v2/README.md index 0921be58c4..a100913f56 100644 --- a/docs/proposals/2442-jax-runtime-trainer-v2/README.md +++ b/docs/proposals/2442-jax-runtime-trainer-v2/README.md @@ -51,8 +51,8 @@ Finally with this design, Platform Admins can define standardized training runti ### Non-Goals -- No TPU testing, tests will use CPU -- No GPU testing, tests will use CPU +- Set up test in TPU cluster +- Add advanced JAX APIs ## Proposal