Skip to content
Merged
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
3edd6db
proposal
mahdikhashan May 10, 2025
e29d739
Update docs/proposals/2442-jax-runtime-trainer-v2/README.md
mahdikhashan Jul 10, 2025
c352faf
Update docs/proposals/2442-jax-runtime-trainer-v2/README.md
mahdikhashan Jul 10, 2025
8b9a222
Update docs/proposals/2442-jax-runtime-trainer-v2/README.md
mahdikhashan Jul 10, 2025
054762d
Update docs/proposals/2442-jax-runtime-trainer-v2/README.md
mahdikhashan Jul 14, 2025
5c0348e
Update docs/proposals/2442-jax-runtime-trainer-v2/README.md
mahdikhashan Jul 20, 2025
dade879
Update docs/proposals/2442-jax-runtime-trainer-v2/README.md
mahdikhashan Jul 20, 2025
d4fbc16
Update docs/proposals/2442-jax-runtime-trainer-v2/README.md
mahdikhashan Jul 20, 2025
4b37388
update drawing
mahdikhashan Jul 20, 2025
a51c86a
update non-goals
mahdikhashan Jul 20, 2025
0d62ef6
update goal
mahdikhashan Jul 20, 2025
d88df35
update user persona names
mahdikhashan Jul 20, 2025
d3c5c6d
improve summary
mahdikhashan Jul 23, 2025
b5ae74c
improve motivation
mahdikhashan Jul 23, 2025
71d6a49
improve sample code in story 2
mahdikhashan Jul 23, 2025
3ae4853
improve code
mahdikhashan Jul 23, 2025
d1159e0
add table
mahdikhashan Sep 6, 2025
bc7e4ba
wip: jax-ml-policy
mahdikhashan Sep 6, 2025
6cc6e4b
use a table for workflow description
mahdikhashan Sep 19, 2025
cfddced
remove communication backend section and add future work
mahdikhashan Sep 19, 2025
c1ed5c8
remove extra lines
mahdikhashan Sep 19, 2025
35f5995
separate jaxmlpolicy source block
mahdikhashan Sep 19, 2025
396e6d0
update the table of contents with future work
mahdikhashan Sep 19, 2025
92ff200
simplify the testing plan section
mahdikhashan Sep 19, 2025
178fc66
update the history section
mahdikhashan Sep 19, 2025
97fdf2c
improve design detail section
mahdikhashan Sep 19, 2025
096468a
update future work
mahdikhashan Sep 19, 2025
f002edc
update test plan
mahdikhashan Sep 19, 2025
ae58d22
update workflow description
mahdikhashan Sep 19, 2025
55eeba0
improve design detail section
mahdikhashan Sep 19, 2025
1b0dd1f
change system to Trainer Controller Manager
mahdikhashan Oct 4, 2025
effee0a
merge runtime and trainer controller manager
mahdikhashan Oct 4, 2025
57044b8
change plugin details
mahdikhashan Oct 4, 2025
f9718fb
update goals
mahdikhashan Oct 4, 2025
aecb6d2
use kwargs
mahdikhashan Oct 4, 2025
40eb886
fix ci
mahdikhashan Oct 4, 2025
c014622
fix ci
mahdikhashan Oct 4, 2025
5806896
update non-goals
mahdikhashan Oct 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 176 additions & 0 deletions docs/proposals/2442-jax-runtime-trainer-v2/README.md
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also need test plans section.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right. i'll add it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
# KEP-2442: JAX Runtime for Trainer V2

- [Summary](#summary)
- [Motivation](#motivation)
- [Goals](#goals)
- [Non-Goals](#non-goals)
- [Proposal](#proposal)
- [User Stories](#user-stories-optional)
- [Story 1](#story-1)
- [Story 2](#story-2)
- [Design Details](#design-details)
- [Key Concepts in JAX Distributed Training](#key-concepts-in-jax-distributed-training)
- [JAX Training Workflow](#jax-training-workflow-flow)
- [Defining JAX Processes with MLPolicy](#defining-jax-processes-with-mlpolicy)
- [Test Plan](#test-plan)
- [End-to-End (E2E) Tests](#end-to-end-e2e-tests)
- [Working Examples](#working-examples)
- [Unit and Integration Tests](#unit-and-integration-tests)
- [Future Work](#future-work)
- [Implementation History](#implementation-history)

## Summary

This document outlines a proposal to support the JAX Runtime in Kubeflow Trainer V2. Built upon the Kubernetes JobSet API, the JAX Runtime enables training and fine-tuning workloads using the JAX framework on Kubernetes. Instead of relying on framework-specific CRDs, Trainer V2 introduces a unified abstraction through TrainingRuntime and TrainJob. The JAX Runtime implements this abstraction to serve as a reusable blueprint for model training tasks, including large language models (LLMs). With the Kubeflow Trainer Pipeline Framework, we can easily integrate the JAX runtime into Kubeflow Trainer V2 as a runtime plugin.


## Motivation

JAX is a high-performance numerical computing framework created by Google. It is widely used in the machine learning research and ranks as the third most widely used deep learning frameworks. JAX also suggests its potential in differential programming, large-scale physics simulations and many more.

These usecases added on top of the new Runtime API for distributed training or calculation of objectives enables new users on top of Kubeflow Trainer, like distributed simulation or training of LLM prototypes developed with JAX, like vast models from Google DeepMind.

In general the motivation is to enable users to use Single-Program Multi-Data (SPMD) pattern with JAX Framework, however there are other reasons like ensure backward compatibility with Trainer V1, which previously included JAX support, allowing existing users to transition smoothly while taking advantage of the enhanced Runtime API.

Finally with this design, Platform Admins can define standardized training runtimes, while AI Practitioners can easily customize them, through a simple SDK interface, without needing to understand Kubernetes internals.

**Benefits**

1. Leverage JAX for differential programming and large-scale simulations
2. Enable distributed training or objective computation using the new Runtime API
3. Support prototyping and training of large JAX-based LLMs within Kubeflow Trainer

### Goals

- Implement ClusterTrainingRuntime for JAX, supporting distributed training with JAX (e.g. multi-controller JAX)
- Build the necessary Docker images for JAX worker nodes used by the runtime
- Implement the solution to work on CPU, GPU and TPU
- Integrate with SDK and address any necessary enhancements
- Document user guides for utilizing JAX ClusterTrainingRuntimes
- Test the implementation thoroughly using unit tests and end-to-end (E2E) tests

### Non-Goals

- No TPU testing, tests will use CPU
- No GPU testing, tests will use CPU
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now, we've had access to GPU test cluster. So it might not be accurate now. How about saying non-goals:

  1. Set up test in TPU cluster
  2. Add advanced JAX APIs

WDYT? @andreyvelich @mahdikhashan

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.


## Proposal

### User Stories

#### Story 1

As a Platform Admin, I want to manage JAX distributed training jobs using the Kubeflow Trainer V2, so then I can provide blueprints for training of machine learning models on a kubernetes cluster to engineering teams.

#### Story 2

As an AI Practitioner, I want to use the Trainer V2 SDK to run a distributed training job from notebook, in this way I can incorporate multiple devices for my training task.

The Python SDK with JAXRuntime may look as follows:

```python
from kubeflow.trainer import TrainerClient, CustomTrainer

# Add logic using JAX methods
def jax_train_mnist(epoch = 10, loss: str = None):
raise NotImplementedError

# Select the JAX runtime
client = TrainerClient()
jax_runtime = next(r for r in client.list_runtimes() if r.name == "jax-distributed")

# Custom parameters passed as arguments
args = {
"epoch": "20"
"loss": "MSE"
}

# Launch training job
job_id = client.train(
trainer=CustomTrainer(func=jax_train_mnist, func_args=args, num_nodes=3),
runtime=jax_runtime,
)
```

## Design Details
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember that we need to specify lots of env variables for distributed JAX training. Where would you upload these env variables to training pods?

https://github.com/kubeflow/trainer/blob/d24e7a46b20b268af66d59e099e41ec2f4032378/docs/proposals/2145-jax-integration/README.md#design-details

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all of the envs are optional, however, i'll dig deeper and update you.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jax have 3 ways to define configurations, in general, i'm in favor of preventing any desicion paralysis for end users in this scope. however, I'd suggest to leave this to users, they can either pass it from python code (Runtime configuration (in your Python code), check link for details) or use the pod-level env in replicatedJobs.

ref: https://docs.jax.dev/en/latest/config_options.html

WDYT @Electronic-Waste

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use the pod-level env in replicatedJobs

I meant for this. Shouldn't we provide some default envs in the runtime? And also we can mutate some of these envs in the runtime plugin like torch plugin.

In torch plugin, a similar env to JAX_COORDINATOR_ADDRESS is PET_MASTER_ADDR:

if !slices.Equal(trainJob.Spec.Trainer.Command, constants.TorchTuneEntrypoint) {
// Add PET_MASTER_ADDR and PET_MASTER_PORT envs for torchrun.
apply.UpsertEnvVar(&trainerContainer.Env,
*corev1ac.EnvVar().
WithName(constants.TorchEnvMasterAddr).
WithValue(fmt.Sprintf("%s-%s-0-0.%s", trainJob.Name, constants.Node, trainJob.Name)),
*corev1ac.EnvVar().
WithName(constants.TorchEnvMasterPort).
WithValue(fmt.Sprintf("%d", constants.ContainerTrainerPort)),
)

WDYT?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mahdikhashan can you also share how you are going to extend the plugins ?
@Doris-xm Has done good job by explaining it here for Volcano KEP: https://github.com/kubeflow/trainer/blob/dad72fd21caa1cda07f1d402f6c37c792b9004b0/docs/proposals/2437-volcano-scheduler/README.md#volcano-runtime-plugin

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, sure. I'll add the details for multiple backend support @Electronic-Waste suggested, then I'll consider details for in a plugin. hope it is fine with you.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also add some descriptions about the image you describe in: #2643 (comment)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. you are right. i'll add more details. thanks for heads up.


In order to address this functionality, we propose the following design:

### Key Concepts in JAX Distributed Training

To understand the **JAX runtime** in Kubeflow Trainer V2, it's important to clarify the terminology used in JAX for distributed training:

| Concept | Description |
|---------|-------------|
| **Host** | A physical or virtual machine participating in distributed training. Each host runs a **single JAX process**, which manages all local devices (e.g., GPUs, CPUs). JAX auto-detects and utilizes all available devices. (In Kubernetes, a host maps to a **Node**, and typically one **Pod** is scheduled per host.) |
| **JAX Process / Controller** | A Python process running the JAX program (exactly one per host). Responsible for executing the training loop, managing all local devices, and synchronizing with other JAX processes over the network. Uses **SPMD** across processes. |
| **Devices** | Compute units on a host (CPU cores, GPUs, TPUs). JAX detects devices automatically and runs parallel computations via `jax.pmap`, `jax.shard_map`, or `pjit`. Each JAX process accesses all devices on its host. **No need to spawn multiple processes per GPU** (unlike PyTorch). |
| **Pod** | A Kubernetes Pod runs a single JAX process. Scheduled on a node and may use one or more GPUs depending on resource specifications. |
| **Node** | A Kubernetes Node is a worker machine. In multi-node JAX jobs, each node typically runs one pod, mapping to one JAX host. |


### JAX Training Workflow

This section explains the architecture and flow of executing a distributed JAX training job using Kubeflow, as depicted in the diagram.

![user-roles](./drawing.drawio.svg)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you can align with the personas names as here, that would be awesome!
https://www.kubeflow.org/docs/components/trainer/overview/#who-is-this-for

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure. consider it done.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!


| **Component** | **Action** | **Details** |
| ----------------------- | ----------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------- |
| Platform Admin | Prepares the Cluster Training Runtime | Defines container image, entrypoint, framework (e.g., JAX), and resource needs. Setup reusable for training jobs. |
| Trainer Controller Manager | Retrieves the Training Runtime Spec, Creates and Submits a JobSet | Fetched automatically when a user requests a training job to determine execution details. Training job spec is translated into a JobSet (group of coordinated jobs). |
| AI Practitioner | Creates the Training Job | Uses Kubeflow Python SDK or `kubectl`. Provides training function (e.g., `jax_train_mnist`), arguments, and node configuration. |
| Runtime | Creates and Submits a JobSet | Training job spec is translated into a JobSet (group of coordinated jobs). |
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above. "Runtime" is too vague. And I think the actor/component for "system" and "runtime` are same. Can you combine them? Maybe "Trainer Controller Manager" is better.

WDYT? @andreyvelich @mahdikhashan

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, it is Trainer Controller Manager which combines specs from TrainJob and Runtime to create JobSet.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

| JobSet Controller | Launches Distributed Jobs | JobSet spawns multiple Kubernetes Jobs, each pod runs a JAX training process instance. |
| Headless Service | Connects Pods for Communication | Enables direct pod-to-pod communication for gradient sharing and coordination. |
| Cluster (Pods) | Executes Distributed Training | Each pod runs JAX+Python code, collaborating to complete training across available hardware. |


### Defining Distributed JAX with MLPolicy

The number of **JAX hosts** is configured using the `numNodes` field in the **MLPolicy** section of the **ClusterTrainingRuntime**. Each host runs a single JAX process inside a Pod.
Comment on lines +129 to +131
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you specify the API design details? Like:

```golang
// Only one of its members may be specified.
type PodGroupPolicySource struct {
// Coscheduling plugin from the Kubernetes scheduler-plugins for gang-scheduling.
Coscheduling *CoschedulingPodGroupPolicySource `json:"coscheduling,omitempty"`
// Volcano plugin from the Volcano scheduler for gang-scheduling and advanced queue-based scheduling.
Volcano *VolcanoPodGroupPolicySource `json:"volcano,omitempty"`
}
// VolcanoPodPolicySource configures scheduling behavior for Volcano.
type VolcanoPodPolicySource struct {
// NetworkTopology defines the NetworkTopology config, this field works in conjunction with network topology feature and hyperNode CRD.
NetworkTopology *volcanov1beta1.NetworkTopologySpec `json:"networkTopology,omitempty"`
}
```

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. consider it done.


#### JAXMLPolicySource

`JAXMLPolicySource` indicates that the JAX plugin should be activated. The extension framework will set the appropriate values for JAX distributed environment, backend, devices, and precision.

```golang
type MLPolicySource struct {
[...]

JAX *JAXMLPolicySource `json:"jax,omitempty"`
}
```

```golang
type JAXMLPolicySource struct {}
```

#### JAX Distributed System

The plugin enables JAX distributed training and handles distributed initialization internally, allowing seamless execution of training jobs with multiple backend configurations for multi-GPU and Cloud TPU.

| Backend | Parameters | Notes |
| ------- | ---------------------------------- | ----------------------------------------------------------------------------------------------------------------- |
| NCCL | None | No additional configuration needed. |
| LibTPU | None | No additional configuration needed. |
| Gloo | None | Environment variables (`COORDINATOR_ADDRESS`, `NUM_PROCESSES`, `PROCESS_ID`) are automatically set by the policy. |

## Test Plan

The testing strategy focuses on validating functionality and integration of the `TrainingRuntime` mechanism.

* **Environment**: Run workloads in a lightweight Kubernetes cluster in **CI actions** (e.g., using `kind` or `minikube`).
* **Workloads**: Execute simple distributed training examples such as MNIST **JAX**.
* **Validation Goals**: Ensure correct creation of `JobSet` resources, successful job execution, and compatibility with `TrainingRuntime` configurations.
* **Working Examples**: Provide runnable notebook examples demonstrating how to create and run training jobs. These notebooks double as test cases and user documentation.
* **Unit Tests**: Add unit tests for `JAXMLPolicySource` to validate correct backend selection, environment variable setup, and distributed initialization logic.

## Future Work

While it is possible to configure a specific communication backend (e.g., NCCL, MPI, Gloo) for the runtime by a parameter in `JAXMLPolicy`, we have deferred this decision to simplify the current implementation. By default, JAX uses `Gloo` as the communication backend. The design ensures the system remains extensible, allowing backend selection and integration to be added in response to future feedback without major changes.

## Implementation History

- 2025-05-28: Initial KEP draft created.
- 2025-09-19: Update design detail section and add future work
4 changes: 4 additions & 0 deletions docs/proposals/2442-jax-runtime-trainer-v2/drawing.drawio.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading