From c33a4b6fb6673dd325958d832ea90a9b5a829ee9 Mon Sep 17 00:00:00 2001 From: Antonin Stefanutti Date: Fri, 17 Oct 2025 16:14:40 +0200 Subject: [PATCH 1/2] fix(api): Keep mpiImplementation field a pointer Signed-off-by: Antonin Stefanutti --- .../crds/trainer.kubeflow.org_clustertrainingruntimes.yaml | 1 + .../crds/trainer.kubeflow.org_trainingruntimes.yaml | 1 + .../crds/trainer.kubeflow.org_clustertrainingruntimes.yaml | 1 + .../base/crds/trainer.kubeflow.org_trainingruntimes.yaml | 1 + pkg/apis/trainer/v1alpha1/trainingruntime_types.go | 4 ++-- pkg/apis/trainer/v1alpha1/zz_generated.deepcopy.go | 5 +++++ pkg/runtime/framework/plugins/mpi/mpi.go | 4 ++-- pkg/util/testing/wrapper.go | 2 +- 8 files changed, 14 insertions(+), 5 deletions(-) diff --git a/charts/kubeflow-trainer/crds/trainer.kubeflow.org_clustertrainingruntimes.yaml b/charts/kubeflow-trainer/crds/trainer.kubeflow.org_clustertrainingruntimes.yaml index c7d676f130..e7d2d544c9 100644 --- a/charts/kubeflow-trainer/crds/trainer.kubeflow.org_clustertrainingruntimes.yaml +++ b/charts/kubeflow-trainer/crds/trainer.kubeflow.org_clustertrainingruntimes.yaml @@ -57,6 +57,7 @@ spec: Defaults to OpenMPI. enum: - OpenMPI + - "" type: string numProcPerNode: default: 1 diff --git a/charts/kubeflow-trainer/crds/trainer.kubeflow.org_trainingruntimes.yaml b/charts/kubeflow-trainer/crds/trainer.kubeflow.org_trainingruntimes.yaml index 2f0e5a1529..6e3a42c308 100644 --- a/charts/kubeflow-trainer/crds/trainer.kubeflow.org_trainingruntimes.yaml +++ b/charts/kubeflow-trainer/crds/trainer.kubeflow.org_trainingruntimes.yaml @@ -57,6 +57,7 @@ spec: Defaults to OpenMPI. enum: - OpenMPI + - "" type: string numProcPerNode: default: 1 diff --git a/manifests/base/crds/trainer.kubeflow.org_clustertrainingruntimes.yaml b/manifests/base/crds/trainer.kubeflow.org_clustertrainingruntimes.yaml index c7d676f130..e7d2d544c9 100644 --- a/manifests/base/crds/trainer.kubeflow.org_clustertrainingruntimes.yaml +++ b/manifests/base/crds/trainer.kubeflow.org_clustertrainingruntimes.yaml @@ -57,6 +57,7 @@ spec: Defaults to OpenMPI. enum: - OpenMPI + - "" type: string numProcPerNode: default: 1 diff --git a/manifests/base/crds/trainer.kubeflow.org_trainingruntimes.yaml b/manifests/base/crds/trainer.kubeflow.org_trainingruntimes.yaml index 2f0e5a1529..6e3a42c308 100644 --- a/manifests/base/crds/trainer.kubeflow.org_trainingruntimes.yaml +++ b/manifests/base/crds/trainer.kubeflow.org_trainingruntimes.yaml @@ -57,6 +57,7 @@ spec: Defaults to OpenMPI. enum: - OpenMPI + - "" type: string numProcPerNode: default: 1 diff --git a/pkg/apis/trainer/v1alpha1/trainingruntime_types.go b/pkg/apis/trainer/v1alpha1/trainingruntime_types.go index bfe225f892..221e70d8c5 100644 --- a/pkg/apis/trainer/v1alpha1/trainingruntime_types.go +++ b/pkg/apis/trainer/v1alpha1/trainingruntime_types.go @@ -251,9 +251,9 @@ type MPIMLPolicySource struct { // mpiImplementation is the name of the MPI implementation to create the appropriate hostfile. // Defaults to OpenMPI. // +kubebuilder:default=OpenMPI - // +kubebuilder:validation:Enum=OpenMPI + // +kubebuilder:validation:Enum=OpenMPI;"" // +optional - MPIImplementation MPIImplementation `json:"mpiImplementation,omitempty"` + MPIImplementation *MPIImplementation `json:"mpiImplementation,omitempty"` // sshAuthMountPath is the directory where SSH keys are mounted. // Defaults to /root/.ssh. diff --git a/pkg/apis/trainer/v1alpha1/zz_generated.deepcopy.go b/pkg/apis/trainer/v1alpha1/zz_generated.deepcopy.go index 7ac798629e..93602c41aa 100644 --- a/pkg/apis/trainer/v1alpha1/zz_generated.deepcopy.go +++ b/pkg/apis/trainer/v1alpha1/zz_generated.deepcopy.go @@ -313,6 +313,11 @@ func (in *MPIMLPolicySource) DeepCopyInto(out *MPIMLPolicySource) { *out = new(int32) **out = **in } + if in.MPIImplementation != nil { + in, out := &in.MPIImplementation, &out.MPIImplementation + *out = new(MPIImplementation) + **out = **in + } if in.SSHAuthMountPath != nil { in, out := &in.SSHAuthMountPath, &out.SSHAuthMountPath *out = new(string) diff --git a/pkg/runtime/framework/plugins/mpi/mpi.go b/pkg/runtime/framework/plugins/mpi/mpi.go index d154af5dab..ed32344da3 100644 --- a/pkg/runtime/framework/plugins/mpi/mpi.go +++ b/pkg/runtime/framework/plugins/mpi/mpi.go @@ -190,7 +190,7 @@ func (m *MPI) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) er WithName(constants.MPIHostfileVolumeName). WithMountPath(constants.MPIHostfileDir), ) - switch info.RuntimePolicy.MLPolicySource.MPI.MPIImplementation { + switch *info.RuntimePolicy.MLPolicySource.MPI.MPIImplementation { case trainer.MPIImplementationOpenMPI: apply.UpsertEnvVars( &info.TemplateSpec.PodSets[psIdx].Containers[cIdx].Env, @@ -304,7 +304,7 @@ func (m *MPI) buildHostFileConfigMap(info *runtime.Info, trainJob *trainer.Train if !isNode(runLauncherAsNode, ps) { continue } - switch info.RuntimePolicy.MLPolicySource.MPI.MPIImplementation { + switch *info.RuntimePolicy.MLPolicySource.MPI.MPIImplementation { case trainer.MPIImplementationOpenMPI: for e := range ps.Endpoints { hostFile.WriteString(fmt.Sprintf("%s slots=%d\n", e, slots)) diff --git a/pkg/util/testing/wrapper.go b/pkg/util/testing/wrapper.go index cf31f26daa..b1e2f2c201 100644 --- a/pkg/util/testing/wrapper.go +++ b/pkg/util/testing/wrapper.go @@ -1239,7 +1239,7 @@ func (m *MLPolicySourceWrapper) MPIPolicy(numProcPerNode *int32, MPImplementatio m.MPI = &trainer.MPIMLPolicySource{} } m.MPI.NumProcPerNode = numProcPerNode - m.MPI.MPIImplementation = MPImplementation + m.MPI.MPIImplementation = &MPImplementation m.MPI.SSHAuthMountPath = sshAuthMountPath m.MPI.RunLauncherAsNode = runLauncherAsNode return m From 138ddc6fa059250c9b48c13a5a8e65cf1c60d9d3 Mon Sep 17 00:00:00 2001 From: Antonin Stefanutti Date: Fri, 17 Oct 2025 17:44:40 +0200 Subject: [PATCH 2/2] Fix integration test Signed-off-by: Antonin Stefanutti --- test/integration/webhooks/trainingruntime_webhook_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/integration/webhooks/trainingruntime_webhook_test.go b/test/integration/webhooks/trainingruntime_webhook_test.go index 718671bbcb..f1a39b4f1f 100644 --- a/test/integration/webhooks/trainingruntime_webhook_test.go +++ b/test/integration/webhooks/trainingruntime_webhook_test.go @@ -202,7 +202,7 @@ var _ = ginkgo.Describe("TrainingRuntime marker validations and defaulting", gin WithMLPolicy( testingutil.MakeMLPolicyWrapper(). WithMLPolicySource(*testingutil.MakeMLPolicySourceWrapper(). - MPIPolicy(ptr.To[int32](1), "", ptr.To("/usr/dir"), ptr.To(false)). + MPIPolicy(ptr.To[int32](1), trainer.MPIImplementationOpenMPI, ptr.To("/usr/dir"), ptr.To(false)). Obj(), ). Obj(),