From 223f1b5c8da36e4977be10ef58bb4f154a8d071a Mon Sep 17 00:00:00 2001 From: Antonin Stefanutti Date: Thu, 28 Aug 2025 13:34:16 +0200 Subject: [PATCH 1/8] feat(api): Sync TrainJob JobsStatus from JobSet ReplicatedJobsStatus Signed-off-by: Antonin Stefanutti --- pkg/controller/trainjob_controller.go | 16 ++ pkg/runtime/core/clustertrainingruntime.go | 4 + pkg/runtime/core/trainingruntime.go | 4 + pkg/runtime/framework/core/framework.go | 19 ++- pkg/runtime/framework/core/framework_test.go | 159 +++++++++++++++++- pkg/runtime/framework/interface.go | 5 + .../framework/plugins/jobset/jobset.go | 20 +++ pkg/runtime/interface.go | 1 + pkg/util/testing/wrapper.go | 5 + 9 files changed, 231 insertions(+), 2 deletions(-) diff --git a/pkg/controller/trainjob_controller.go b/pkg/controller/trainjob_controller.go index 78cb8c28df..7061d29d0e 100644 --- a/pkg/controller/trainjob_controller.go +++ b/pkg/controller/trainjob_controller.go @@ -142,6 +142,10 @@ func (r *TrainJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c err = errors.Join(err, terminalCondErr) } + if jobsStatusErr := setJobsStatus(ctx, runtime, &trainJob); jobsStatusErr != nil { + err = errors.Join(err, jobsStatusErr) + } + if !equality.Semantic.DeepEqual(&trainJob.Status, originStatus) { return ctrl.Result{}, errors.Join(err, r.client.Status().Update(ctx, &trainJob)) } @@ -267,6 +271,18 @@ func setTerminalCondition(ctx context.Context, runtime jobruntimes.Runtime, trai return nil } +func setJobsStatus(ctx context.Context, runtime jobruntimes.Runtime, trainJob *trainer.TrainJob) error { + statuses, err := runtime.JobsStatus(ctx, trainJob) + if err != nil { + return err + } + if statuses == nil { + return nil + } + trainJob.Status.JobsStatus = statuses + return nil +} + func isTrainJobFinished(trainJob *trainer.TrainJob) bool { return meta.IsStatusConditionTrue(trainJob.Status.Conditions, trainer.TrainJobComplete) || meta.IsStatusConditionTrue(trainJob.Status.Conditions, trainer.TrainJobFailed) diff --git a/pkg/runtime/core/clustertrainingruntime.go b/pkg/runtime/core/clustertrainingruntime.go index 16214203e5..201ab72f82 100644 --- a/pkg/runtime/core/clustertrainingruntime.go +++ b/pkg/runtime/core/clustertrainingruntime.go @@ -75,6 +75,10 @@ func (r *ClusterTrainingRuntime) TerminalCondition(ctx context.Context, trainJob return r.TrainingRuntime.TerminalCondition(ctx, trainJob) } +func (r *ClusterTrainingRuntime) JobsStatus(ctx context.Context, trainJob *trainer.TrainJob) ([]trainer.JobStatus, error) { + return r.TrainingRuntime.JobsStatus(ctx, trainJob) +} + func (r *ClusterTrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBuilder { return nil } diff --git a/pkg/runtime/core/trainingruntime.go b/pkg/runtime/core/trainingruntime.go index c25de2936d..c90e732a66 100644 --- a/pkg/runtime/core/trainingruntime.go +++ b/pkg/runtime/core/trainingruntime.go @@ -250,6 +250,10 @@ func (r *TrainingRuntime) TerminalCondition(ctx context.Context, trainJob *train return r.framework.RunTerminalConditionPlugins(ctx, trainJob) } +func (r *TrainingRuntime) JobsStatus(ctx context.Context, trainJob *trainer.TrainJob) ([]trainer.JobStatus, error) { + return r.framework.RunJobsStatusPlugins(ctx, trainJob) +} + func (r *TrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBuilder { var builders []runtime.ReconcilerBuilder for _, ex := range r.framework.WatchExtensionPlugins() { diff --git a/pkg/runtime/framework/core/framework.go b/pkg/runtime/framework/core/framework.go index a7567615a5..1712986404 100644 --- a/pkg/runtime/framework/core/framework.go +++ b/pkg/runtime/framework/core/framework.go @@ -32,7 +32,10 @@ import ( index "github.com/kubeflow/trainer/v2/pkg/runtime/indexer" ) -var errorTooManyTerminalConditionPlugin = errors.New("too many TerminalCondition plugins are registered") +var ( + errorTooManyTerminalConditionPlugin = errors.New("too many TerminalCondition plugins are registered") + errorTooManyJobsStatusPlugin = errors.New("too many JobsStatus plugins are registered") +) type Framework struct { registry fwkplugins.Registry @@ -44,6 +47,7 @@ type Framework struct { podNetworkPlugins []framework.PodNetworkPlugin componentBuilderPlugins []framework.ComponentBuilderPlugin terminalConditionPlugins []framework.TerminalConditionPlugin + jobsStatusPlugins []framework.JobsStatusPlugin } func New(ctx context.Context, c client.Client, r fwkplugins.Registry, indexer client.FieldIndexer) (*Framework, error) { @@ -82,6 +86,9 @@ func New(ctx context.Context, c client.Client, r fwkplugins.Registry, indexer cl if p, ok := plugin.(framework.TerminalConditionPlugin); ok { f.terminalConditionPlugins = append(f.terminalConditionPlugins, p) } + if p, ok := plugin.(framework.JobsStatusPlugin); ok { + f.jobsStatusPlugins = append(f.jobsStatusPlugins, p) + } } f.plugins = plugins return f, nil @@ -152,6 +159,16 @@ func (f *Framework) RunTerminalConditionPlugins(ctx context.Context, trainJob *t return nil, nil } +func (f *Framework) RunJobsStatusPlugins(ctx context.Context, trainJob *trainer.TrainJob) ([]trainer.JobStatus, error) { + if len(f.jobsStatusPlugins) > 1 { + return nil, errorTooManyJobsStatusPlugin + } + if len(f.jobsStatusPlugins) != 0 { + return f.jobsStatusPlugins[0].JobsStatus(ctx, trainJob) + } + return nil, nil +} + func (f *Framework) WatchExtensionPlugins() []framework.WatchExtensionPlugin { return f.watchExtensionPlugins } diff --git a/pkg/runtime/framework/core/framework_test.go b/pkg/runtime/framework/core/framework_test.go index 04262483f8..1884a8ad9e 100644 --- a/pkg/runtime/framework/core/framework_test.go +++ b/pkg/runtime/framework/core/framework_test.go @@ -114,6 +114,9 @@ func TestNew(t *testing.T) { terminalConditionPlugins: []framework.TerminalConditionPlugin{ &jobset.JobSet{}, }, + jobsStatusPlugins: []framework.JobsStatusPlugin{ + &jobset.JobSet{}, + }, }, }, "indexer key for trainingRuntime and runtimeClass is an empty": { @@ -136,7 +139,7 @@ func TestNew(t *testing.T) { cmpopts.IgnoreUnexported(coscheduling.CoScheduling{}, volcano.Volcano{}, mpi.MPI{}, plainml.PlainML{}, torch.Torch{}, jobset.JobSet{}), cmpopts.IgnoreFields(coscheduling.CoScheduling{}, "client"), cmpopts.IgnoreFields(volcano.Volcano{}, "client"), - cmpopts.IgnoreFields(jobset.JobSet{}, "client"), + cmpopts.IgnoreFields(jobset.JobSet{}, "client", "restMapper", "scheme", "logger"), cmpopts.IgnoreTypes(apiruntime.Scheme{}, meta.DefaultRESTMapper{}, fwkplugins.Registry{}), cmpopts.SortMaps(func(a, b string) bool { return a < b }), cmpopts.SortSlices(func(a, b framework.Plugin) bool { return a.Name() < b.Name() }), @@ -1549,6 +1552,160 @@ func TestTerminalConditionPlugins(t *testing.T) { } } +type fakeJobsStatusPlugin struct{} + +var _ framework.JobsStatusPlugin = (*fakeJobsStatusPlugin)(nil) + +func newFakeJobsStatusPlugin(context.Context, client.Client, client.FieldIndexer) (framework.Plugin, error) { + return &fakeJobsStatusPlugin{}, nil +} + +const fakeJobsStatusPluginName = "fake-jobs-status" + +func (f fakeJobsStatusPlugin) Name() string { return fakeJobsStatusPluginName } +func (f fakeJobsStatusPlugin) JobsStatus(context.Context, *trainer.TrainJob) ([]trainer.JobStatus, error) { + return []trainer.JobStatus{ + {Name: "fake-job", Ready: 1, Succeeded: 0, Failed: 0, Active: 1, Suspended: 0}, + }, nil +} + +func TestJobsStatusPlugins(t *testing.T) { + cases := map[string]struct { + registry fwkplugins.Registry + trainJob *trainer.TrainJob + jobSet *jobsetv1alpha2.JobSet + wantStatuses []trainer.JobStatus + wantError error + }{ + "JobSet with empty replicated jobs status": { + registry: fwkplugins.NewRegistry(), + trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "testing"). + Obj(), + jobSet: testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "testing"). + Obj(), + wantStatuses: nil, + }, + "succeeded to obtain JobsStatus from JobSet with multiple replicated jobs": { + registry: fwkplugins.NewRegistry(), + trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "testing"). + Obj(), + jobSet: testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "testing"). + ReplicatedJobsStatuses([]jobsetv1alpha2.ReplicatedJobStatus{ + { + Name: constants.DatasetInitializer, + Ready: 1, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.ModelInitializer, + Ready: 1, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.Node, + Ready: 2, + Succeeded: 0, + Failed: 0, + Active: 2, + Suspended: 0, + }, + }). + Obj(), + wantStatuses: []trainer.JobStatus{ + { + Name: constants.DatasetInitializer, + Ready: 1, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.ModelInitializer, + Ready: 1, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.Node, + Ready: 2, + Succeeded: 0, + Failed: 0, + Active: 2, + Suspended: 0, + }, + }, + }, + "succeeded to obtain JobsStatus from JobSet with failed job": { + registry: fwkplugins.NewRegistry(), + trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "testing"). + Obj(), + jobSet: testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "testing"). + ReplicatedJobsStatuses([]jobsetv1alpha2.ReplicatedJobStatus{ + { + Name: constants.Node, + Ready: 0, + Succeeded: 0, + Failed: 1, + Active: 0, + Suspended: 0, + }, + }). + Obj(), + wantStatuses: []trainer.JobStatus{ + { + Name: constants.Node, + Ready: 0, + Succeeded: 0, + Failed: 1, + Active: 0, + Suspended: 0, + }, + }, + }, + "failed to obtain JobsStatus due to multiple JobsStatusPlugins": { + registry: fwkplugins.Registry{ + jobset.Name: jobset.New, + fakeJobsStatusPluginName: newFakeJobsStatusPlugin, + }, + wantError: errorTooManyJobsStatusPlugin, + }, + } + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + clientBuilder := testingutil.NewClientBuilder() + if tc.jobSet != nil { + clientBuilder = clientBuilder.WithObjects(tc.jobSet) + } + c := clientBuilder.Build() + + fwk, err := New(ctx, c, tc.registry, testingutil.AsIndex(clientBuilder)) + if err != nil { + t.Fatal(err) + } + + gotStatuses, gotErr := fwk.RunJobsStatusPlugins(ctx, tc.trainJob) + if diff := cmp.Diff(tc.wantError, gotErr, cmpopts.EquateErrors()); len(diff) != 0 { + t.Errorf("Unexpected error (-want,+got):\n%s", diff) + } + + if diff := cmp.Diff(tc.wantStatuses, gotStatuses); len(diff) != 0 { + t.Errorf("Unexpected jobs status (-want,+got):\n%s", diff) + } + }) + } +} + func TestPodNetworkPlugins(t *testing.T) { cases := map[string]struct { registry fwkplugins.Registry diff --git a/pkg/runtime/framework/interface.go b/pkg/runtime/framework/interface.go index 676548ebe7..b6ab61712a 100644 --- a/pkg/runtime/framework/interface.go +++ b/pkg/runtime/framework/interface.go @@ -65,3 +65,8 @@ type TerminalConditionPlugin interface { Plugin TerminalCondition(ctx context.Context, trainJob *trainer.TrainJob) (*metav1.Condition, error) } + +type JobsStatusPlugin interface { + Plugin + JobsStatus(ctx context.Context, trainJob *trainer.TrainJob) ([]trainer.JobStatus, error) +} diff --git a/pkg/runtime/framework/plugins/jobset/jobset.go b/pkg/runtime/framework/plugins/jobset/jobset.go index e829188fd8..d737efc603 100644 --- a/pkg/runtime/framework/plugins/jobset/jobset.go +++ b/pkg/runtime/framework/plugins/jobset/jobset.go @@ -63,6 +63,7 @@ var _ framework.WatchExtensionPlugin = (*JobSet)(nil) var _ framework.PodNetworkPlugin = (*JobSet)(nil) var _ framework.ComponentBuilderPlugin = (*JobSet)(nil) var _ framework.TerminalConditionPlugin = (*JobSet)(nil) +var _ framework.JobsStatusPlugin = (*JobSet)(nil) var _ framework.CustomValidationPlugin = (*JobSet)(nil) const Name = constants.JobSetKind @@ -304,3 +305,22 @@ func (j *JobSet) TerminalCondition(ctx context.Context, trainJob *trainer.TrainJ } return nil, nil } + +func (j *JobSet) JobsStatus(ctx context.Context, trainJob *trainer.TrainJob) ([]trainer.JobStatus, error) { + jobSet := &jobsetv1alpha2.JobSet{} + if err := j.client.Get(ctx, client.ObjectKeyFromObject(trainJob), jobSet); err != nil { + return nil, err + } + var statuses []trainer.JobStatus + for _, status := range jobSet.Status.ReplicatedJobsStatus { + statuses = append(statuses, trainer.JobStatus{ + Name: status.Name, + Ready: status.Ready, + Succeeded: status.Succeeded, + Failed: status.Failed, + Active: status.Active, + Suspended: status.Suspended, + }) + } + return statuses, nil +} diff --git a/pkg/runtime/interface.go b/pkg/runtime/interface.go index 6f49bace31..6150a8c821 100644 --- a/pkg/runtime/interface.go +++ b/pkg/runtime/interface.go @@ -39,6 +39,7 @@ type Runtime interface { NewObjects(ctx context.Context, trainJob *trainer.TrainJob) ([]any, error) RuntimeInfo(trainJob *trainer.TrainJob, runtimeTemplateSpec any, mlPolicy *trainer.MLPolicy, podGroupPolicy *trainer.PodGroupPolicy) (*Info, error) TerminalCondition(ctx context.Context, trainJob *trainer.TrainJob) (*metav1.Condition, error) + JobsStatus(ctx context.Context, trainJob *trainer.TrainJob) ([]trainer.JobStatus, error) EventHandlerRegistrars() []ReconcilerBuilder ValidateObjects(ctx context.Context, old, new *trainer.TrainJob) (admission.Warnings, field.ErrorList) } diff --git a/pkg/util/testing/wrapper.go b/pkg/util/testing/wrapper.go index d497ecc37d..c4c3a552f9 100644 --- a/pkg/util/testing/wrapper.go +++ b/pkg/util/testing/wrapper.go @@ -503,6 +503,11 @@ func (j *JobSetWrapper) DependsOn(rJobName string, dependsOn ...jobsetv1alpha2.D return j } +func (j *JobSetWrapper) ReplicatedJobsStatuses(statuses []jobsetv1alpha2.ReplicatedJobStatus) *JobSetWrapper { + j.Status.ReplicatedJobsStatus = statuses + return j +} + func (j *JobSetWrapper) Obj() *jobsetv1alpha2.JobSet { return &j.JobSet } From 783630d1d424b2775d28a39e916b23ce156531f0 Mon Sep 17 00:00:00 2001 From: Antonin Stefanutti Date: Mon, 1 Sep 2025 10:29:12 +0200 Subject: [PATCH 2/8] Add integration tests Signed-off-by: Antonin Stefanutti --- .../controller/trainjob_controller_test.go | 406 +++++++++++++++++- 1 file changed, 392 insertions(+), 14 deletions(-) diff --git a/test/integration/controller/trainjob_controller_test.go b/test/integration/controller/trainjob_controller_test.go index 9be21beb47..dda0f3774f 100644 --- a/test/integration/controller/trainjob_controller_test.go +++ b/test/integration/controller/trainjob_controller_test.go @@ -412,7 +412,7 @@ var _ = ginkgo.Describe("TrainJob controller", ginkgo.Ordered, func() { g.Expect(k8sClient.Get(ctx, trainJobKey, &schedulerpluginsv1alpha1.PodGroup{})).Should(gomega.Succeed()) }, util.Timeout, util.Interval).Should(gomega.Succeed()) - ginkgo.By("Checking if TrainJob has Suspended and Created conditions") + ginkgo.By("Checking if TrainJob has Suspended=True condition") gomega.Eventually(func(g gomega.Gomega) { gotTrainJob := &trainer.TrainJob{} g.Expect(k8sClient.Get(ctx, trainJobKey, gotTrainJob)).Should(gomega.Succeed()) @@ -426,7 +426,7 @@ var _ = ginkgo.Describe("TrainJob controller", ginkgo.Ordered, func() { }, util.IgnoreConditions)) }, util.Timeout, util.Interval).Should(gomega.Succeed()) - ginkgo.By("Checking if the TrainJob has Resumed and Created conditions after unsuspended") + ginkgo.By("Checking if the TrainJob has Suspended=False [Resumed] condition after unsuspended") gomega.Eventually(func(g gomega.Gomega) { gotTrainJob := &trainer.TrainJob{} g.Expect(k8sClient.Get(ctx, trainJobKey, gotTrainJob)).Should(gomega.Succeed()) @@ -443,7 +443,7 @@ var _ = ginkgo.Describe("TrainJob controller", ginkgo.Ordered, func() { }, util.IgnoreConditions)) }, util.Timeout, util.Interval).Should(gomega.Succeed()) - ginkgo.By("Updating the JobSet condition with Completed") + ginkgo.By("Updating the JobSet conditions and ReplicatedJobsStatus with successful completion") gomega.Eventually(func(g gomega.Gomega) { jobSet := &jobsetv1alpha2.JobSet{} g.Expect(k8sClient.Get(ctx, trainJobKey, jobSet)).Should(gomega.Succeed()) @@ -453,10 +453,37 @@ var _ = ginkgo.Describe("TrainJob controller", ginkgo.Ordered, func() { Message: jobsetconsts.AllJobsCompletedMessage, Status: metav1.ConditionTrue, }) + jobSet.Status.ReplicatedJobsStatus = []jobsetv1alpha2.ReplicatedJobStatus{ + { + Name: constants.DatasetInitializer, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.ModelInitializer, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.Node, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + } + g.Expect(k8sClient.Status().Update(ctx, jobSet)).Should(gomega.Succeed()) }, util.Timeout, util.Interval).Should(gomega.Succeed()) - ginkgo.By("Checking if the TranJob has Resumed, Created, and Completed conditions") + ginkgo.By("Checking if the TranJob has Suspended and Complete conditions as well as Succeeded JobsStatus") gomega.Eventually(func(g gomega.Gomega) { gotTrainJob := &trainer.TrainJob{} g.Expect(k8sClient.Get(ctx, trainJobKey, gotTrainJob)).Should(gomega.Succeed()) @@ -474,6 +501,32 @@ var _ = ginkgo.Describe("TrainJob controller", ginkgo.Ordered, func() { Message: jobsetconsts.AllJobsCompletedMessage, }, }, util.IgnoreConditions)) + g.Expect(gotTrainJob.Status.JobsStatus).Should(gomega.BeComparableTo([]trainer.JobStatus{ + { + Name: constants.DatasetInitializer, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.ModelInitializer, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.Node, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + })) }, util.Timeout, util.Interval).Should(gomega.Succeed()) }) @@ -499,7 +552,7 @@ var _ = ginkgo.Describe("TrainJob controller", ginkgo.Ordered, func() { g.Expect(k8sClient.Update(ctx, gotTrainJob)).Should(gomega.Succeed()) }, util.Timeout, util.Interval).Should(gomega.Succeed()) - ginkgo.By("Waiting for TrainJob Created=True and Suspended=False condition") + ginkgo.By("Waiting for TrainJob Suspended=False condition") gomega.Eventually(func(g gomega.Gomega) { gotTrainJob := &trainer.TrainJob{} g.Expect(k8sClient.Get(ctx, trainJobKey, gotTrainJob)).Should(gomega.Succeed()) @@ -513,7 +566,7 @@ var _ = ginkgo.Describe("TrainJob controller", ginkgo.Ordered, func() { }, util.IgnoreConditions)) }, util.Timeout, util.Interval).Should(gomega.Succeed()) - ginkgo.By("Updating the JobSet condition with Failed") + ginkgo.By("Updating the JobSet conditions and ReplicatedJobsStatus with failed jobs") gomega.Eventually(func(g gomega.Gomega) { jobSet := &jobsetv1alpha2.JobSet{} g.Expect(k8sClient.Get(ctx, trainJobKey, jobSet)).Should(gomega.Succeed()) @@ -523,10 +576,36 @@ var _ = ginkgo.Describe("TrainJob controller", ginkgo.Ordered, func() { Message: jobsetconsts.FailedJobsMessage, Status: metav1.ConditionTrue, }) + jobSet.Status.ReplicatedJobsStatus = []jobsetv1alpha2.ReplicatedJobStatus{ + { + Name: constants.DatasetInitializer, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.ModelInitializer, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.Node, + Ready: 0, + Succeeded: 0, + Failed: 1, + Active: 0, + Suspended: 0, + }, + } g.Expect(k8sClient.Status().Update(ctx, jobSet)).Should(gomega.Succeed()) }, util.Timeout, util.Interval).Should(gomega.Succeed()) - ginkgo.By("Checking if the TranJob has Resumed, Created, and Failed conditions") + ginkgo.By("Checking if the TranJob has Suspended and Failed conditions as well as failed JobsStatus") gomega.Eventually(func(g gomega.Gomega) { gotTrainJob := &trainer.TrainJob{} g.Expect(k8sClient.Get(ctx, trainJobKey, gotTrainJob)).Should(gomega.Succeed()) @@ -544,6 +623,185 @@ var _ = ginkgo.Describe("TrainJob controller", ginkgo.Ordered, func() { Message: jobsetconsts.FailedJobsMessage, }, }, util.IgnoreConditions)) + g.Expect(gotTrainJob.Status.JobsStatus).Should(gomega.BeComparableTo([]trainer.JobStatus{ + { + Name: constants.DatasetInitializer, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.ModelInitializer, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.Node, + Ready: 0, + Succeeded: 0, + Failed: 1, + Active: 0, + Suspended: 0, + }, + })) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + }) + + ginkgo.It("Should synchronize JobsStatus from JobSet ReplicatedJobsStatus", func() { + ginkgo.By("Creating TrainingRuntime and suspended TrainJob") + gomega.Expect(k8sClient.Create(ctx, trainingRuntime)).Should(gomega.Succeed()) + gomega.Eventually(func(g gomega.Gomega) { + g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(trainingRuntime), trainingRuntime)).Should(gomega.Succeed()) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + gomega.Expect(k8sClient.Create(ctx, trainJob)).Should(gomega.Succeed()) + + ginkgo.By("Checking if JobSet and PodGroup are created") + gomega.Eventually(func(g gomega.Gomega) { + g.Expect(k8sClient.Get(ctx, trainJobKey, &jobsetv1alpha2.JobSet{})).Should(gomega.Succeed()) + g.Expect(k8sClient.Get(ctx, trainJobKey, &schedulerpluginsv1alpha1.PodGroup{})).Should(gomega.Succeed()) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + + ginkgo.By("Unsuspending the TrainJob") + gomega.Eventually(func(g gomega.Gomega) { + gotTrainJob := &trainer.TrainJob{} + g.Expect(k8sClient.Get(ctx, trainJobKey, gotTrainJob)).Should(gomega.Succeed()) + gotTrainJob.Spec.Suspend = ptr.To(false) + g.Expect(k8sClient.Update(ctx, gotTrainJob)).Should(gomega.Succeed()) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + + ginkgo.By("Updating JobSet ReplicatedJobsStatus to simulate running jobs") + gomega.Eventually(func(g gomega.Gomega) { + jobSet := &jobsetv1alpha2.JobSet{} + g.Expect(k8sClient.Get(ctx, trainJobKey, jobSet)).Should(gomega.Succeed()) + jobSet.Status.ReplicatedJobsStatus = []jobsetv1alpha2.ReplicatedJobStatus{ + { + Name: constants.DatasetInitializer, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.ModelInitializer, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.Node, + Ready: 1, + Succeeded: 0, + Failed: 0, + Active: 1, + Suspended: 0, + }, + } + g.Expect(k8sClient.Status().Update(ctx, jobSet)).Should(gomega.Succeed()) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + + ginkgo.By("Verifying JobsStatus synchronization in TrainJob") + gomega.Eventually(func(g gomega.Gomega) { + gotTrainJob := &trainer.TrainJob{} + g.Expect(k8sClient.Get(ctx, trainJobKey, gotTrainJob)).Should(gomega.Succeed()) + g.Expect(gotTrainJob.Status.JobsStatus).Should(gomega.BeComparableTo([]trainer.JobStatus{ + { + Name: constants.DatasetInitializer, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.ModelInitializer, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.Node, + Ready: 1, + Succeeded: 0, + Failed: 0, + Active: 1, + Suspended: 0, + }, + })) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + + ginkgo.By("Updating JobSet ReplicatedJobsStatus to simulate some failed jobs") + gomega.Eventually(func(g gomega.Gomega) { + jobSet := &jobsetv1alpha2.JobSet{} + g.Expect(k8sClient.Get(ctx, trainJobKey, jobSet)).Should(gomega.Succeed()) + jobSet.Status.ReplicatedJobsStatus = []jobsetv1alpha2.ReplicatedJobStatus{ + { + Name: constants.DatasetInitializer, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.ModelInitializer, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.Node, + Ready: 0, + Succeeded: 0, + Failed: 1, + Active: 0, + Suspended: 0, + }, + } + g.Expect(k8sClient.Status().Update(ctx, jobSet)).Should(gomega.Succeed()) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + + ginkgo.By("Verifying updated JobsStatus reflects failed jobs") + gomega.Eventually(func(g gomega.Gomega) { + gotTrainJob := &trainer.TrainJob{} + g.Expect(k8sClient.Get(ctx, trainJobKey, gotTrainJob)).Should(gomega.Succeed()) + g.Expect(gotTrainJob.Status.JobsStatus).Should(gomega.BeComparableTo([]trainer.JobStatus{ + { + Name: constants.DatasetInitializer, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.ModelInitializer, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.Node, + Ready: 0, + Succeeded: 0, + Failed: 1, + Active: 0, + Suspended: 0, + }, + })) }, util.Timeout, util.Interval).Should(gomega.Succeed()) }) @@ -889,7 +1147,7 @@ alpha-node-0-1.alpha slots=8 g.Expect(k8sClient.Get(ctx, secKey, &corev1.Secret{})).Should(gomega.Succeed()) }) - ginkgo.By("Checking if TrainJob has Suspended and Created conditions") + ginkgo.By("Checking if TrainJob has Suspended=True condition") gomega.Eventually(func(g gomega.Gomega) { gotTrainJob := &trainer.TrainJob{} g.Expect(k8sClient.Get(ctx, trainJobKey, gotTrainJob)).Should(gomega.Succeed()) @@ -903,7 +1161,7 @@ alpha-node-0-1.alpha slots=8 }, util.IgnoreConditions)) }, util.Timeout, util.Interval).Should(gomega.Succeed()) - ginkgo.By("Checking if the TrainJob has Resumed and Created conditions after unsuspended") + ginkgo.By("Checking if the TrainJob has Suspended=False [Resumed] condition after unsuspended") gomega.Eventually(func(g gomega.Gomega) { gotTrainJob := &trainer.TrainJob{} g.Expect(k8sClient.Get(ctx, trainJobKey, gotTrainJob)).Should(gomega.Succeed()) @@ -920,7 +1178,7 @@ alpha-node-0-1.alpha slots=8 }, util.IgnoreConditions)) }, util.Timeout, util.Interval).Should(gomega.Succeed()) - ginkgo.By("Updating the JobSet condition with Completed") + ginkgo.By("Updating the JobSet conditions and ReplicatedJobsStatus with successful completion") gomega.Eventually(func(g gomega.Gomega) { jobSet := &jobsetv1alpha2.JobSet{} g.Expect(k8sClient.Get(ctx, trainJobKey, jobSet)).Should(gomega.Succeed()) @@ -930,10 +1188,36 @@ alpha-node-0-1.alpha slots=8 Message: jobsetconsts.AllJobsCompletedMessage, Status: metav1.ConditionTrue, }) + jobSet.Status.ReplicatedJobsStatus = []jobsetv1alpha2.ReplicatedJobStatus{ + { + Name: constants.DatasetInitializer, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.ModelInitializer, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.Node, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + } g.Expect(k8sClient.Status().Update(ctx, jobSet)).Should(gomega.Succeed()) }, util.Timeout, util.Interval).Should(gomega.Succeed()) - ginkgo.By("Checking if the TranJob has Resumed, Created, and Completed conditions") + ginkgo.By("Checking if the TranJob has Suspended=False and Complete=True conditions as well as succeeded JobsStatus") gomega.Eventually(func(g gomega.Gomega) { gotTrainJob := &trainer.TrainJob{} g.Expect(k8sClient.Get(ctx, trainJobKey, gotTrainJob)).Should(gomega.Succeed()) @@ -951,6 +1235,32 @@ alpha-node-0-1.alpha slots=8 Message: jobsetconsts.AllJobsCompletedMessage, }, }, util.IgnoreConditions)) + g.Expect(gotTrainJob.Status.JobsStatus).Should(gomega.BeComparableTo([]trainer.JobStatus{ + { + Name: constants.DatasetInitializer, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.ModelInitializer, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.Node, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + })) }, util.Timeout, util.Interval).Should(gomega.Succeed()) }) @@ -977,7 +1287,7 @@ alpha-node-0-1.alpha slots=8 g.Expect(k8sClient.Update(ctx, gotTrainJob)).Should(gomega.Succeed()) }, util.Timeout, util.Interval).Should(gomega.Succeed()) - ginkgo.By("Waiting for TrainJob Created=True and Suspended=False condition") + ginkgo.By("Waiting for TrainJob Suspended=False condition") gomega.Eventually(func(g gomega.Gomega) { gotTrainJob := &trainer.TrainJob{} g.Expect(k8sClient.Get(ctx, trainJobKey, gotTrainJob)).Should(gomega.Succeed()) @@ -991,7 +1301,7 @@ alpha-node-0-1.alpha slots=8 }, util.IgnoreConditions)) }, util.Timeout, util.Interval).Should(gomega.Succeed()) - ginkgo.By("Updating the JobSet condition with Failed") + ginkgo.By("Updating the JobSet Failed=True condition and ReplicatedJobsStatus with failed jobs") gomega.Eventually(func(g gomega.Gomega) { jobSet := &jobsetv1alpha2.JobSet{} g.Expect(k8sClient.Get(ctx, trainJobKey, jobSet)).Should(gomega.Succeed()) @@ -1001,10 +1311,44 @@ alpha-node-0-1.alpha slots=8 Message: jobsetconsts.FailedJobsMessage, Status: metav1.ConditionTrue, }) + jobSet.Status.ReplicatedJobsStatus = []jobsetv1alpha2.ReplicatedJobStatus{ + { + Name: constants.DatasetInitializer, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.ModelInitializer, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.Launcher, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.Node, + Ready: 0, + Succeeded: 0, + Failed: 1, + Active: 0, + Suspended: 0, + }, + } g.Expect(k8sClient.Status().Update(ctx, jobSet)).Should(gomega.Succeed()) }, util.Timeout, util.Interval).Should(gomega.Succeed()) - ginkgo.By("Checking if the TranJob has Resumed, Created, and Failed conditions") + ginkgo.By("Checking if the TranJob has Suspended=False [Resumed] and Failed=True conditions as well as failed JobsStatus") gomega.Eventually(func(g gomega.Gomega) { gotTrainJob := &trainer.TrainJob{} g.Expect(k8sClient.Get(ctx, trainJobKey, gotTrainJob)).Should(gomega.Succeed()) @@ -1022,6 +1366,40 @@ alpha-node-0-1.alpha slots=8 Message: jobsetconsts.FailedJobsMessage, }, }, util.IgnoreConditions)) + g.Expect(gotTrainJob.Status.JobsStatus).Should(gomega.BeComparableTo([]trainer.JobStatus{ + { + Name: constants.DatasetInitializer, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.ModelInitializer, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.Launcher, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.Node, + Ready: 0, + Succeeded: 0, + Failed: 1, + Active: 0, + Suspended: 0, + }, + })) }, util.Timeout, util.Interval).Should(gomega.Succeed()) }) }) From bfe42343c2d6937277d0d1fc89cdbfae194d61da Mon Sep 17 00:00:00 2001 From: Antonin Stefanutti Date: Mon, 1 Sep 2025 11:02:59 +0200 Subject: [PATCH 3/8] Update e2e tests Signed-off-by: Antonin Stefanutti --- test/e2e/e2e_test.go | 86 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 83 insertions(+), 3 deletions(-) diff --git a/test/e2e/e2e_test.go b/test/e2e/e2e_test.go index c2fce0b209..8abcee8558 100644 --- a/test/e2e/e2e_test.go +++ b/test/e2e/e2e_test.go @@ -9,6 +9,7 @@ import ( jobsetconsts "sigs.k8s.io/jobset/pkg/constants" trainer "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1" + "github.com/kubeflow/trainer/v2/pkg/constants" testingutil "github.com/kubeflow/trainer/v2/pkg/util/testing" "github.com/kubeflow/trainer/v2/test/util" ) @@ -16,7 +17,6 @@ import ( const ( torchRuntime = "torch-distributed" deepSpeedRuntime = "deepspeed-distributed" - mlxRuntime = "mlx-distributed" ) var _ = ginkgo.Describe("TrainJob e2e", func() { @@ -57,8 +57,28 @@ var _ = ginkgo.Describe("TrainJob e2e", func() { gomega.Expect(k8sClient.Create(ctx, trainJob)).Should(gomega.Succeed()) }) - // Wait for TrainJob to be in Succeeded status. - ginkgo.By("Wait for TrainJob to be in Succeeded status", func() { + // Wait for jobs to become active + ginkgo.By("Wait for TrainJob jobs to become active", func() { + gomega.Eventually(func(g gomega.Gomega) { + gotTrainJob := &trainer.TrainJob{} + g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(trainJob), gotTrainJob)).Should(gomega.Succeed()) + + expectedJobsStatus := []trainer.JobStatus{ + { + Name: constants.Node, + Ready: 0, + Succeeded: 0, + Failed: 0, + Active: 1, + Suspended: 0, + }, + } + g.Expect(gotTrainJob.Status.JobsStatus).Should(gomega.BeComparableTo(expectedJobsStatus)) + }, util.TimeoutE2E, util.Interval).Should(gomega.Succeed()) + }) + + // Wait for TrainJob to be in Succeeded status with all jobs succeeded. + ginkgo.By("Wait for TrainJob to be in Succeeded status with all jobs succeeded", func() { gomega.Eventually(func(g gomega.Gomega) { gotTrainJob := &trainer.TrainJob{} g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(trainJob), gotTrainJob)).Should(gomega.Succeed()) @@ -70,6 +90,18 @@ var _ = ginkgo.Describe("TrainJob e2e", func() { Message: jobsetconsts.AllJobsCompletedMessage, }, }, util.IgnoreConditions)) + + expectedJobsStatus := []trainer.JobStatus{ + { + Name: constants.Node, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + } + g.Expect(gotTrainJob.Status.JobsStatus).Should(gomega.BeComparableTo(expectedJobsStatus)) }, util.TimeoutE2E, util.Interval).Should(gomega.Succeed()) }) }) @@ -87,6 +119,34 @@ var _ = ginkgo.Describe("TrainJob e2e", func() { gomega.Expect(k8sClient.Create(ctx, trainJob)).Should(gomega.Succeed()) }) + // Wait for jobs to become active + ginkgo.By("Wait for TrainJob jobs to become active", func() { + gomega.Eventually(func(g gomega.Gomega) { + gotTrainJob := &trainer.TrainJob{} + g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(trainJob), gotTrainJob)).Should(gomega.Succeed()) + + expectedJobsStatus := []trainer.JobStatus{ + { + Name: constants.Launcher, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.Node, + Ready: 0, + Succeeded: 0, + Failed: 0, + Active: 1, + Suspended: 0, + }, + } + g.Expect(gotTrainJob.Status.JobsStatus).Should(gomega.BeComparableTo(expectedJobsStatus)) + }, util.TimeoutE2E, util.Interval).Should(gomega.Succeed()) + }) + // Wait for TrainJob to be in Succeeded status. ginkgo.By("Wait for TrainJob to be in Succeeded status", func() { gomega.Eventually(func(g gomega.Gomega) { @@ -100,6 +160,26 @@ var _ = ginkgo.Describe("TrainJob e2e", func() { Message: jobsetconsts.AllJobsCompletedMessage, }, }, util.IgnoreConditions)) + + expectedJobsStatus := []trainer.JobStatus{ + { + Name: constants.Launcher, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.Node, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + } + g.Expect(gotTrainJob.Status.JobsStatus).Should(gomega.BeComparableTo(expectedJobsStatus)) }, util.TimeoutE2E, util.Interval).Should(gomega.Succeed()) }) }) From a1287d3a8a7a1437618eec2ec8ee21791b47533a Mon Sep 17 00:00:00 2001 From: Antonin Stefanutti Date: Mon, 1 Sep 2025 11:04:28 +0200 Subject: [PATCH 4/8] Remove extra check Signed-off-by: Antonin Stefanutti --- pkg/controller/trainjob_controller.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/pkg/controller/trainjob_controller.go b/pkg/controller/trainjob_controller.go index 7061d29d0e..b9f5e3ad3f 100644 --- a/pkg/controller/trainjob_controller.go +++ b/pkg/controller/trainjob_controller.go @@ -276,9 +276,6 @@ func setJobsStatus(ctx context.Context, runtime jobruntimes.Runtime, trainJob *t if err != nil { return err } - if statuses == nil { - return nil - } trainJob.Status.JobsStatus = statuses return nil } From efc5e9d054f97dbef8183f4af734fdef3a9ea41b Mon Sep 17 00:00:00 2001 From: Antonin Stefanutti Date: Mon, 1 Sep 2025 12:03:26 +0200 Subject: [PATCH 5/8] Sort JobsStatus in e2e tests Signed-off-by: Antonin Stefanutti --- test/e2e/e2e_test.go | 24 ++++++++---------------- test/util/constants.go | 5 +++++ 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/test/e2e/e2e_test.go b/test/e2e/e2e_test.go index 8abcee8558..a2d3d3a33b 100644 --- a/test/e2e/e2e_test.go +++ b/test/e2e/e2e_test.go @@ -62,8 +62,7 @@ var _ = ginkgo.Describe("TrainJob e2e", func() { gomega.Eventually(func(g gomega.Gomega) { gotTrainJob := &trainer.TrainJob{} g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(trainJob), gotTrainJob)).Should(gomega.Succeed()) - - expectedJobsStatus := []trainer.JobStatus{ + g.Expect(gotTrainJob.Status.JobsStatus).Should(gomega.BeComparableTo([]trainer.JobStatus{ { Name: constants.Node, Ready: 0, @@ -72,8 +71,7 @@ var _ = ginkgo.Describe("TrainJob e2e", func() { Active: 1, Suspended: 0, }, - } - g.Expect(gotTrainJob.Status.JobsStatus).Should(gomega.BeComparableTo(expectedJobsStatus)) + }, util.SortJobsStatus)) }, util.TimeoutE2E, util.Interval).Should(gomega.Succeed()) }) @@ -90,8 +88,7 @@ var _ = ginkgo.Describe("TrainJob e2e", func() { Message: jobsetconsts.AllJobsCompletedMessage, }, }, util.IgnoreConditions)) - - expectedJobsStatus := []trainer.JobStatus{ + g.Expect(gotTrainJob.Status.JobsStatus).Should(gomega.BeComparableTo([]trainer.JobStatus{ { Name: constants.Node, Ready: 0, @@ -100,8 +97,7 @@ var _ = ginkgo.Describe("TrainJob e2e", func() { Active: 0, Suspended: 0, }, - } - g.Expect(gotTrainJob.Status.JobsStatus).Should(gomega.BeComparableTo(expectedJobsStatus)) + }, util.SortJobsStatus)) }, util.TimeoutE2E, util.Interval).Should(gomega.Succeed()) }) }) @@ -124,8 +120,7 @@ var _ = ginkgo.Describe("TrainJob e2e", func() { gomega.Eventually(func(g gomega.Gomega) { gotTrainJob := &trainer.TrainJob{} g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(trainJob), gotTrainJob)).Should(gomega.Succeed()) - - expectedJobsStatus := []trainer.JobStatus{ + g.Expect(gotTrainJob.Status.JobsStatus).Should(gomega.BeComparableTo([]trainer.JobStatus{ { Name: constants.Launcher, Ready: 0, @@ -142,8 +137,7 @@ var _ = ginkgo.Describe("TrainJob e2e", func() { Active: 1, Suspended: 0, }, - } - g.Expect(gotTrainJob.Status.JobsStatus).Should(gomega.BeComparableTo(expectedJobsStatus)) + }, util.SortJobsStatus)) }, util.TimeoutE2E, util.Interval).Should(gomega.Succeed()) }) @@ -160,8 +154,7 @@ var _ = ginkgo.Describe("TrainJob e2e", func() { Message: jobsetconsts.AllJobsCompletedMessage, }, }, util.IgnoreConditions)) - - expectedJobsStatus := []trainer.JobStatus{ + g.Expect(gotTrainJob.Status.JobsStatus).Should(gomega.BeComparableTo([]trainer.JobStatus{ { Name: constants.Launcher, Ready: 0, @@ -178,8 +171,7 @@ var _ = ginkgo.Describe("TrainJob e2e", func() { Active: 0, Suspended: 0, }, - } - g.Expect(gotTrainJob.Status.JobsStatus).Should(gomega.BeComparableTo(expectedJobsStatus)) + }, util.SortJobsStatus)) }, util.TimeoutE2E, util.Interval).Should(gomega.Succeed()) }) }) diff --git a/test/util/constants.go b/test/util/constants.go index da7d144fdb..860fa42c04 100644 --- a/test/util/constants.go +++ b/test/util/constants.go @@ -22,6 +22,8 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + trainer "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1" ) const ( @@ -39,4 +41,7 @@ var ( IgnoreConditions = cmp.Options{ cmpopts.IgnoreFields(metav1.Condition{}, "LastTransitionTime", "ObservedGeneration"), } + SortJobsStatus = cmp.Options{ + cmpopts.SortSlices(func(a, b trainer.JobStatus) bool { return a.Name < b.Name }), + } ) From ab46fa61c1afada4e74f12d1b66f52437843ae3d Mon Sep 17 00:00:00 2001 From: Antonin Stefanutti Date: Mon, 1 Sep 2025 17:53:18 +0200 Subject: [PATCH 6/8] Fix e2e test for MPI job Signed-off-by: Antonin Stefanutti --- pkg/controller/trainjob_controller.go | 9 ------ test/e2e/e2e_test.go | 6 ++-- .../controller/trainjob_controller_test.go | 32 ++++++++++++++----- 3 files changed, 27 insertions(+), 20 deletions(-) diff --git a/pkg/controller/trainjob_controller.go b/pkg/controller/trainjob_controller.go index b9f5e3ad3f..6ec51ed8ff 100644 --- a/pkg/controller/trainjob_controller.go +++ b/pkg/controller/trainjob_controller.go @@ -104,10 +104,6 @@ func (r *TrainJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c log := ctrl.LoggerFrom(ctx).WithValues("trainJob", klog.KObj(&trainJob)) ctx = ctrl.LoggerInto(ctx, log) log.V(2).Info("Reconciling TrainJob") - if isTrainJobFinished(&trainJob) { - log.V(5).Info("TrainJob has already been finished") - return ctrl.Result{}, nil - } var err error // Keep track of the origin TrainJob status @@ -280,11 +276,6 @@ func setJobsStatus(ctx context.Context, runtime jobruntimes.Runtime, trainJob *t return nil } -func isTrainJobFinished(trainJob *trainer.TrainJob) bool { - return meta.IsStatusConditionTrue(trainJob.Status.Conditions, trainer.TrainJobComplete) || - meta.IsStatusConditionTrue(trainJob.Status.Conditions, trainer.TrainJobFailed) -} - func (r *TrainJobReconciler) SetupWithManager(mgr ctrl.Manager, options controller.Options) error { b := builder.TypedControllerManagedBy[reconcile.Request](mgr). Named("trainjob_controller"). diff --git a/test/e2e/e2e_test.go b/test/e2e/e2e_test.go index a2d3d3a33b..012d6a6ab2 100644 --- a/test/e2e/e2e_test.go +++ b/test/e2e/e2e_test.go @@ -124,9 +124,9 @@ var _ = ginkgo.Describe("TrainJob e2e", func() { { Name: constants.Launcher, Ready: 0, - Succeeded: 1, + Succeeded: 0, Failed: 0, - Active: 0, + Active: 1, Suspended: 0, }, { @@ -166,7 +166,7 @@ var _ = ginkgo.Describe("TrainJob e2e", func() { { Name: constants.Node, Ready: 0, - Succeeded: 1, + Succeeded: 0, Failed: 0, Active: 0, Suspended: 0, diff --git a/test/integration/controller/trainjob_controller_test.go b/test/integration/controller/trainjob_controller_test.go index dda0f3774f..62dc36dbcd 100644 --- a/test/integration/controller/trainjob_controller_test.go +++ b/test/integration/controller/trainjob_controller_test.go @@ -1206,13 +1206,21 @@ alpha-node-0-1.alpha slots=8 Suspended: 0, }, { - Name: constants.Node, + Name: constants.Launcher, Ready: 0, Succeeded: 1, Failed: 0, Active: 0, Suspended: 0, }, + { + Name: constants.Node, + Ready: 0, + Succeeded: 0, + Failed: 0, + Active: 0, + Suspended: 0, + }, } g.Expect(k8sClient.Status().Update(ctx, jobSet)).Should(gomega.Succeed()) }, util.Timeout, util.Interval).Should(gomega.Succeed()) @@ -1253,13 +1261,21 @@ alpha-node-0-1.alpha slots=8 Suspended: 0, }, { - Name: constants.Node, + Name: constants.Launcher, Ready: 0, Succeeded: 1, Failed: 0, Active: 0, Suspended: 0, }, + { + Name: constants.Node, + Ready: 0, + Succeeded: 0, + Failed: 0, + Active: 0, + Suspended: 0, + }, })) }, util.Timeout, util.Interval).Should(gomega.Succeed()) }) @@ -1331,8 +1347,8 @@ alpha-node-0-1.alpha slots=8 { Name: constants.Launcher, Ready: 0, - Succeeded: 1, - Failed: 0, + Succeeded: 0, + Failed: 1, Active: 0, Suspended: 0, }, @@ -1340,7 +1356,7 @@ alpha-node-0-1.alpha slots=8 Name: constants.Node, Ready: 0, Succeeded: 0, - Failed: 1, + Failed: 0, Active: 0, Suspended: 0, }, @@ -1386,8 +1402,8 @@ alpha-node-0-1.alpha slots=8 { Name: constants.Launcher, Ready: 0, - Succeeded: 1, - Failed: 0, + Succeeded: 0, + Failed: 1, Active: 0, Suspended: 0, }, @@ -1395,7 +1411,7 @@ alpha-node-0-1.alpha slots=8 Name: constants.Node, Ready: 0, Succeeded: 0, - Failed: 1, + Failed: 0, Active: 0, Suspended: 0, }, From fc55beb61e35af486fdecae01d7b7b2c7d50683e Mon Sep 17 00:00:00 2001 From: Antonin Stefanutti Date: Tue, 2 Sep 2025 09:03:46 +0200 Subject: [PATCH 7/8] Fail-fast when multiple terminal condition and JobsStatus plugins exist Signed-off-by: Antonin Stefanutti --- pkg/runtime/framework/core/framework.go | 29 ++++++++++---------- pkg/runtime/framework/core/framework_test.go | 18 ++++++------ 2 files changed, 24 insertions(+), 23 deletions(-) diff --git a/pkg/runtime/framework/core/framework.go b/pkg/runtime/framework/core/framework.go index 1712986404..8470babfe0 100644 --- a/pkg/runtime/framework/core/framework.go +++ b/pkg/runtime/framework/core/framework.go @@ -46,8 +46,8 @@ type Framework struct { watchExtensionPlugins []framework.WatchExtensionPlugin podNetworkPlugins []framework.PodNetworkPlugin componentBuilderPlugins []framework.ComponentBuilderPlugin - terminalConditionPlugins []framework.TerminalConditionPlugin - jobsStatusPlugins []framework.JobsStatusPlugin + terminalConditionPlugin framework.TerminalConditionPlugin + jobsStatusPlugin framework.JobsStatusPlugin } func New(ctx context.Context, c client.Client, r fwkplugins.Registry, indexer client.FieldIndexer) (*Framework, error) { @@ -84,10 +84,16 @@ func New(ctx context.Context, c client.Client, r fwkplugins.Registry, indexer cl f.componentBuilderPlugins = append(f.componentBuilderPlugins, p) } if p, ok := plugin.(framework.TerminalConditionPlugin); ok { - f.terminalConditionPlugins = append(f.terminalConditionPlugins, p) + if f.terminalConditionPlugin != nil { + return nil, errorTooManyTerminalConditionPlugin + } + f.terminalConditionPlugin = p } if p, ok := plugin.(framework.JobsStatusPlugin); ok { - f.jobsStatusPlugins = append(f.jobsStatusPlugins, p) + if f.jobsStatusPlugin != nil { + return nil, errorTooManyJobsStatusPlugin + } + f.jobsStatusPlugin = p } } f.plugins = plugins @@ -149,22 +155,15 @@ func (f *Framework) RunComponentBuilderPlugins(ctx context.Context, info *runtim } func (f *Framework) RunTerminalConditionPlugins(ctx context.Context, trainJob *trainer.TrainJob) (*metav1.Condition, error) { - // TODO (tenzen-y): Once we provide the Configuration API, we should validate which plugin should have terminalCondition execution points. - if len(f.terminalConditionPlugins) > 1 { - return nil, errorTooManyTerminalConditionPlugin - } - if len(f.terminalConditionPlugins) != 0 { - return f.terminalConditionPlugins[0].TerminalCondition(ctx, trainJob) + if f.terminalConditionPlugin != nil { + return f.terminalConditionPlugin.TerminalCondition(ctx, trainJob) } return nil, nil } func (f *Framework) RunJobsStatusPlugins(ctx context.Context, trainJob *trainer.TrainJob) ([]trainer.JobStatus, error) { - if len(f.jobsStatusPlugins) > 1 { - return nil, errorTooManyJobsStatusPlugin - } - if len(f.jobsStatusPlugins) != 0 { - return f.jobsStatusPlugins[0].JobsStatus(ctx, trainJob) + if f.jobsStatusPlugin != nil { + return f.jobsStatusPlugin.JobsStatus(ctx, trainJob) } return nil, nil } diff --git a/pkg/runtime/framework/core/framework_test.go b/pkg/runtime/framework/core/framework_test.go index 1884a8ad9e..3c4d7d7cf2 100644 --- a/pkg/runtime/framework/core/framework_test.go +++ b/pkg/runtime/framework/core/framework_test.go @@ -111,12 +111,8 @@ func TestNew(t *testing.T) { &jobset.JobSet{}, &mpi.MPI{}, }, - terminalConditionPlugins: []framework.TerminalConditionPlugin{ - &jobset.JobSet{}, - }, - jobsStatusPlugins: []framework.JobsStatusPlugin{ - &jobset.JobSet{}, - }, + terminalConditionPlugin: &jobset.JobSet{}, + jobsStatusPlugin: &jobset.JobSet{}, }, }, "indexer key for trainingRuntime and runtimeClass is an empty": { @@ -1537,7 +1533,10 @@ func TestTerminalConditionPlugins(t *testing.T) { fwk, err := New(ctx, c, tc.registry, testingutil.AsIndex(clientBuilder)) if err != nil { - t.Fatal(err) + if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 { + t.Errorf("Unexpected error (-want,+got):\n%s", diff) + } + return } gotCond, gotErr := fwk.RunTerminalConditionPlugins(ctx, tc.trainJob) @@ -1691,7 +1690,10 @@ func TestJobsStatusPlugins(t *testing.T) { fwk, err := New(ctx, c, tc.registry, testingutil.AsIndex(clientBuilder)) if err != nil { - t.Fatal(err) + if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 { + t.Errorf("Unexpected error (-want,+got):\n%s", diff) + } + return } gotStatuses, gotErr := fwk.RunJobsStatusPlugins(ctx, tc.trainJob) From 7d34b6ea2ef9bfdf29a0d9e84e2837aec4b119c4 Mon Sep 17 00:00:00 2001 From: Antonin Stefanutti Date: Tue, 7 Oct 2025 14:48:29 +0200 Subject: [PATCH 8/8] Fold TerminalCondition and JobsStatus plugins Signed-off-by: Antonin Stefanutti --- pkg/controller/trainjob_controller.go | 24 +- pkg/runtime/core/clustertrainingruntime.go | 9 +- pkg/runtime/core/trainingruntime.go | 8 +- pkg/runtime/framework/core/framework.go | 36 +-- pkg/runtime/framework/core/framework_test.go | 228 ++++++++---------- pkg/runtime/framework/interface.go | 10 +- .../framework/plugins/jobset/jobset.go | 22 +- pkg/runtime/interface.go | 4 +- 8 files changed, 128 insertions(+), 213 deletions(-) diff --git a/pkg/controller/trainjob_controller.go b/pkg/controller/trainjob_controller.go index 6ec51ed8ff..c4c2f61622 100644 --- a/pkg/controller/trainjob_controller.go +++ b/pkg/controller/trainjob_controller.go @@ -134,12 +134,9 @@ func (r *TrainJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c } setSuspendedCondition(&trainJob) - if terminalCondErr := setTerminalCondition(ctx, runtime, &trainJob); terminalCondErr != nil { - err = errors.Join(err, terminalCondErr) - } - if jobsStatusErr := setJobsStatus(ctx, runtime, &trainJob); jobsStatusErr != nil { - err = errors.Join(err, jobsStatusErr) + if statusErr := setTrainJobStatus(ctx, runtime, &trainJob); statusErr != nil { + err = errors.Join(err, statusErr) } if !equality.Semantic.DeepEqual(&trainJob.Status, originStatus) { @@ -256,23 +253,14 @@ func removeFailedCondition(trainJob *trainer.TrainJob) { meta.RemoveStatusCondition(&trainJob.Status.Conditions, trainer.TrainJobFailed) } -func setTerminalCondition(ctx context.Context, runtime jobruntimes.Runtime, trainJob *trainer.TrainJob) error { - terminalCond, err := runtime.TerminalCondition(ctx, trainJob) +func setTrainJobStatus(ctx context.Context, runtime jobruntimes.Runtime, trainJob *trainer.TrainJob) error { + status, err := runtime.TrainJobStatus(ctx, trainJob) if err != nil { return err } - if terminalCond != nil { - meta.SetStatusCondition(&trainJob.Status.Conditions, *terminalCond) - } - return nil -} - -func setJobsStatus(ctx context.Context, runtime jobruntimes.Runtime, trainJob *trainer.TrainJob) error { - statuses, err := runtime.JobsStatus(ctx, trainJob) - if err != nil { - return err + if status != nil { + trainJob.Status = *status } - trainJob.Status.JobsStatus = statuses return nil } diff --git a/pkg/runtime/core/clustertrainingruntime.go b/pkg/runtime/core/clustertrainingruntime.go index 201ab72f82..dbcb18034e 100644 --- a/pkg/runtime/core/clustertrainingruntime.go +++ b/pkg/runtime/core/clustertrainingruntime.go @@ -21,7 +21,6 @@ import ( "errors" "fmt" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/util/validation/field" "sigs.k8s.io/controller-runtime/pkg/client" @@ -71,12 +70,8 @@ func (r *ClusterTrainingRuntime) RuntimeInfo( return r.TrainingRuntime.RuntimeInfo(trainJob, runtimeTemplateSpec, mlPolicy, podGroupPolicy) } -func (r *ClusterTrainingRuntime) TerminalCondition(ctx context.Context, trainJob *trainer.TrainJob) (*metav1.Condition, error) { - return r.TrainingRuntime.TerminalCondition(ctx, trainJob) -} - -func (r *ClusterTrainingRuntime) JobsStatus(ctx context.Context, trainJob *trainer.TrainJob) ([]trainer.JobStatus, error) { - return r.TrainingRuntime.JobsStatus(ctx, trainJob) +func (r *ClusterTrainingRuntime) TrainJobStatus(ctx context.Context, trainJob *trainer.TrainJob) (*trainer.TrainJobStatus, error) { + return r.TrainingRuntime.TrainJobStatus(ctx, trainJob) } func (r *ClusterTrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBuilder { diff --git a/pkg/runtime/core/trainingruntime.go b/pkg/runtime/core/trainingruntime.go index c90e732a66..521e929cf6 100644 --- a/pkg/runtime/core/trainingruntime.go +++ b/pkg/runtime/core/trainingruntime.go @@ -246,12 +246,8 @@ func syncPodSets(info *runtime.Info) { } } -func (r *TrainingRuntime) TerminalCondition(ctx context.Context, trainJob *trainer.TrainJob) (*metav1.Condition, error) { - return r.framework.RunTerminalConditionPlugins(ctx, trainJob) -} - -func (r *TrainingRuntime) JobsStatus(ctx context.Context, trainJob *trainer.TrainJob) ([]trainer.JobStatus, error) { - return r.framework.RunJobsStatusPlugins(ctx, trainJob) +func (r *TrainingRuntime) TrainJobStatus(ctx context.Context, trainJob *trainer.TrainJob) (*trainer.TrainJobStatus, error) { + return r.framework.RunTrainJobStatusPlugin(ctx, trainJob) } func (r *TrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBuilder { diff --git a/pkg/runtime/framework/core/framework.go b/pkg/runtime/framework/core/framework.go index 8470babfe0..dc654d3aea 100644 --- a/pkg/runtime/framework/core/framework.go +++ b/pkg/runtime/framework/core/framework.go @@ -20,7 +20,6 @@ import ( "context" "errors" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/util/validation/field" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/webhook/admission" @@ -32,10 +31,7 @@ import ( index "github.com/kubeflow/trainer/v2/pkg/runtime/indexer" ) -var ( - errorTooManyTerminalConditionPlugin = errors.New("too many TerminalCondition plugins are registered") - errorTooManyJobsStatusPlugin = errors.New("too many JobsStatus plugins are registered") -) +var errorTooManyTrainJobStatusPlugin = errors.New("too many TrainJobStatus plugins are registered") type Framework struct { registry fwkplugins.Registry @@ -46,8 +42,7 @@ type Framework struct { watchExtensionPlugins []framework.WatchExtensionPlugin podNetworkPlugins []framework.PodNetworkPlugin componentBuilderPlugins []framework.ComponentBuilderPlugin - terminalConditionPlugin framework.TerminalConditionPlugin - jobsStatusPlugin framework.JobsStatusPlugin + trainJobStatusPlugin framework.TrainJobStatusPlugin } func New(ctx context.Context, c client.Client, r fwkplugins.Registry, indexer client.FieldIndexer) (*Framework, error) { @@ -83,17 +78,11 @@ func New(ctx context.Context, c client.Client, r fwkplugins.Registry, indexer cl if p, ok := plugin.(framework.ComponentBuilderPlugin); ok { f.componentBuilderPlugins = append(f.componentBuilderPlugins, p) } - if p, ok := plugin.(framework.TerminalConditionPlugin); ok { - if f.terminalConditionPlugin != nil { - return nil, errorTooManyTerminalConditionPlugin - } - f.terminalConditionPlugin = p - } - if p, ok := plugin.(framework.JobsStatusPlugin); ok { - if f.jobsStatusPlugin != nil { - return nil, errorTooManyJobsStatusPlugin + if p, ok := plugin.(framework.TrainJobStatusPlugin); ok { + if f.trainJobStatusPlugin != nil { + return nil, errorTooManyTrainJobStatusPlugin } - f.jobsStatusPlugin = p + f.trainJobStatusPlugin = p } } f.plugins = plugins @@ -154,16 +143,9 @@ func (f *Framework) RunComponentBuilderPlugins(ctx context.Context, info *runtim return objs, nil } -func (f *Framework) RunTerminalConditionPlugins(ctx context.Context, trainJob *trainer.TrainJob) (*metav1.Condition, error) { - if f.terminalConditionPlugin != nil { - return f.terminalConditionPlugin.TerminalCondition(ctx, trainJob) - } - return nil, nil -} - -func (f *Framework) RunJobsStatusPlugins(ctx context.Context, trainJob *trainer.TrainJob) ([]trainer.JobStatus, error) { - if f.jobsStatusPlugin != nil { - return f.jobsStatusPlugin.JobsStatus(ctx, trainJob) +func (f *Framework) RunTrainJobStatusPlugin(ctx context.Context, trainJob *trainer.TrainJob) (*trainer.TrainJobStatus, error) { + if f.trainJobStatusPlugin != nil { + return f.trainJobStatusPlugin.Status(ctx, trainJob) } return nil, nil } diff --git a/pkg/runtime/framework/core/framework_test.go b/pkg/runtime/framework/core/framework_test.go index 3c4d7d7cf2..f25f0798b3 100644 --- a/pkg/runtime/framework/core/framework_test.go +++ b/pkg/runtime/framework/core/framework_test.go @@ -19,6 +19,7 @@ package core import ( "context" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" @@ -111,8 +112,7 @@ func TestNew(t *testing.T) { &jobset.JobSet{}, &mpi.MPI{}, }, - terminalConditionPlugin: &jobset.JobSet{}, - jobsStatusPlugin: &jobset.JobSet{}, + trainJobStatusPlugin: &jobset.JobSet{}, }, }, "indexer key for trainingRuntime and runtimeClass is an empty": { @@ -1439,28 +1439,33 @@ func TestWatchExtensionPlugins(t *testing.T) { } } -type fakeTerminalConditionPlugin struct{} +type fakeTrainJobStatusPlugin struct{} -var _ framework.TerminalConditionPlugin = (*fakeTerminalConditionPlugin)(nil) +var _ framework.TrainJobStatusPlugin = (*fakeTrainJobStatusPlugin)(nil) -func newFakeTerminalConditionPlugin(context.Context, client.Client, client.FieldIndexer) (framework.Plugin, error) { - return &fakeTerminalConditionPlugin{}, nil +func newFakeJobsStatusPlugin(context.Context, client.Client, client.FieldIndexer) (framework.Plugin, error) { + return &fakeTrainJobStatusPlugin{}, nil } -const fakeTerminalConditionPluginName = "fake" +const fakeJobsStatusPluginName = "fake-train-job-status" -func (f fakeTerminalConditionPlugin) Name() string { return fakeTerminalConditionPluginName } -func (f fakeTerminalConditionPlugin) TerminalCondition(context.Context, *trainer.TrainJob) (*metav1.Condition, error) { - return nil, nil +func (f fakeTrainJobStatusPlugin) Name() string { return fakeJobsStatusPluginName } +func (f fakeTrainJobStatusPlugin) Status(context.Context, *trainer.TrainJob) (*trainer.TrainJobStatus, error) { + return &trainer.TrainJobStatus{ + JobsStatus: []trainer.JobStatus{ + {Name: "fake-job", Ready: 1, Succeeded: 0, Failed: 0, Active: 1, Suspended: 0}, + }, + }, nil } -func TestTerminalConditionPlugins(t *testing.T) { +func TestTrainJobStatusPlugins(t *testing.T) { + lastTransitionTime := metav1.Time{Time: time.Now()}.Rfc3339Copy() cases := map[string]struct { - registry fwkplugins.Registry - trainJob *trainer.TrainJob - jobSet *jobsetv1alpha2.JobSet - wantCondition *metav1.Condition - wantError error + registry fwkplugins.Registry + trainJob *trainer.TrainJob + jobSet *jobsetv1alpha2.JobSet + wantStatus *trainer.TrainJobStatus + wantError error }{ "jobSet has not been finalized, yet": { registry: fwkplugins.NewRegistry(), @@ -1474,6 +1479,7 @@ func TestTerminalConditionPlugins(t *testing.T) { Status: metav1.ConditionFalse, }). Obj(), + wantStatus: &trainer.TrainJobStatus{}, }, "succeeded to obtain completed terminal condition": { registry: fwkplugins.NewRegistry(), @@ -1481,17 +1487,23 @@ func TestTerminalConditionPlugins(t *testing.T) { Obj(), jobSet: testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "testing"). Conditions(metav1.Condition{ - Type: string(jobsetv1alpha2.JobSetCompleted), - Reason: jobsetconsts.AllJobsCompletedReason, - Message: jobsetconsts.AllJobsCompletedMessage, - Status: metav1.ConditionTrue, + Type: string(jobsetv1alpha2.JobSetCompleted), + LastTransitionTime: lastTransitionTime, + Message: jobsetconsts.AllJobsCompletedMessage, + Reason: jobsetconsts.AllJobsCompletedReason, + Status: metav1.ConditionTrue, }). Obj(), - wantCondition: &metav1.Condition{ - Type: trainer.TrainJobComplete, - Reason: jobsetconsts.AllJobsCompletedReason, - Message: jobsetconsts.AllJobsCompletedMessage, - Status: metav1.ConditionTrue, + wantStatus: &trainer.TrainJobStatus{ + Conditions: []metav1.Condition{ + { + Type: trainer.TrainJobComplete, + LastTransitionTime: lastTransitionTime, + Message: jobsetconsts.AllJobsCompletedMessage, + Reason: jobsetconsts.AllJobsCompletedReason, + Status: metav1.ConditionTrue, + }, + }, }, }, "succeeded to obtain failed terminal condition": { @@ -1500,89 +1512,39 @@ func TestTerminalConditionPlugins(t *testing.T) { Obj(), jobSet: testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "testing"). Conditions(metav1.Condition{ - Type: string(jobsetv1alpha2.JobSetFailed), - Reason: jobsetconsts.FailedJobsReason, - Message: jobsetconsts.FailedJobsMessage, - Status: metav1.ConditionTrue, + Type: string(jobsetv1alpha2.JobSetFailed), + LastTransitionTime: lastTransitionTime, + Message: jobsetconsts.FailedJobsMessage, + Reason: jobsetconsts.FailedJobsReason, + Status: metav1.ConditionTrue, }). Obj(), - wantCondition: &metav1.Condition{ - Type: trainer.TrainJobFailed, - Reason: jobsetconsts.FailedJobsReason, - Message: jobsetconsts.FailedJobsMessage, - Status: metav1.ConditionTrue, + wantStatus: &trainer.TrainJobStatus{ + Conditions: []metav1.Condition{ + { + Type: trainer.TrainJobFailed, + LastTransitionTime: lastTransitionTime, + Message: jobsetconsts.FailedJobsMessage, + Reason: jobsetconsts.FailedJobsReason, + Status: metav1.ConditionTrue, + }, + }, }, }, - "failed to obtain any terminal condition due to multiple terminalCondition plugin": { + "failed to obtain TrainJob status due to multiple trainJobStatus plugin": { registry: fwkplugins.Registry{ - jobset.Name: jobset.New, - fakeTerminalConditionPluginName: newFakeTerminalConditionPlugin, + jobset.Name: jobset.New, + fakeJobsStatusPluginName: newFakeJobsStatusPlugin, }, - wantError: errorTooManyTerminalConditionPlugin, + wantError: errorTooManyTrainJobStatusPlugin, }, - } - for name, tc := range cases { - t.Run(name, func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - t.Cleanup(cancel) - clientBuilder := testingutil.NewClientBuilder() - if tc.jobSet != nil { - clientBuilder = clientBuilder.WithObjects(tc.jobSet) - } - c := clientBuilder.Build() - - fwk, err := New(ctx, c, tc.registry, testingutil.AsIndex(clientBuilder)) - if err != nil { - if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 { - t.Errorf("Unexpected error (-want,+got):\n%s", diff) - } - return - } - - gotCond, gotErr := fwk.RunTerminalConditionPlugins(ctx, tc.trainJob) - if diff := cmp.Diff(tc.wantError, gotErr, cmpopts.EquateErrors()); len(diff) != 0 { - t.Errorf("Unexpected error (-want,+got):\n%s", diff) - } - - if diff := cmp.Diff(tc.wantCondition, gotCond); len(diff) != 0 { - t.Errorf("Unexpected terminal condition (-want,+got):\n%s", diff) - } - }) - } -} - -type fakeJobsStatusPlugin struct{} - -var _ framework.JobsStatusPlugin = (*fakeJobsStatusPlugin)(nil) - -func newFakeJobsStatusPlugin(context.Context, client.Client, client.FieldIndexer) (framework.Plugin, error) { - return &fakeJobsStatusPlugin{}, nil -} - -const fakeJobsStatusPluginName = "fake-jobs-status" - -func (f fakeJobsStatusPlugin) Name() string { return fakeJobsStatusPluginName } -func (f fakeJobsStatusPlugin) JobsStatus(context.Context, *trainer.TrainJob) ([]trainer.JobStatus, error) { - return []trainer.JobStatus{ - {Name: "fake-job", Ready: 1, Succeeded: 0, Failed: 0, Active: 1, Suspended: 0}, - }, nil -} - -func TestJobsStatusPlugins(t *testing.T) { - cases := map[string]struct { - registry fwkplugins.Registry - trainJob *trainer.TrainJob - jobSet *jobsetv1alpha2.JobSet - wantStatuses []trainer.JobStatus - wantError error - }{ "JobSet with empty replicated jobs status": { registry: fwkplugins.NewRegistry(), trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "testing"). Obj(), jobSet: testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "testing"). Obj(), - wantStatuses: nil, + wantStatus: &trainer.TrainJobStatus{}, }, "succeeded to obtain JobsStatus from JobSet with multiple replicated jobs": { registry: fwkplugins.NewRegistry(), @@ -1616,30 +1578,32 @@ func TestJobsStatusPlugins(t *testing.T) { }, }). Obj(), - wantStatuses: []trainer.JobStatus{ - { - Name: constants.DatasetInitializer, - Ready: 1, - Succeeded: 1, - Failed: 0, - Active: 0, - Suspended: 0, - }, - { - Name: constants.ModelInitializer, - Ready: 1, - Succeeded: 1, - Failed: 0, - Active: 0, - Suspended: 0, - }, - { - Name: constants.Node, - Ready: 2, - Succeeded: 0, - Failed: 0, - Active: 2, - Suspended: 0, + wantStatus: &trainer.TrainJobStatus{ + JobsStatus: []trainer.JobStatus{ + { + Name: constants.DatasetInitializer, + Ready: 1, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.ModelInitializer, + Ready: 1, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.Node, + Ready: 2, + Succeeded: 0, + Failed: 0, + Active: 2, + Suspended: 0, + }, }, }, }, @@ -1659,14 +1623,16 @@ func TestJobsStatusPlugins(t *testing.T) { }, }). Obj(), - wantStatuses: []trainer.JobStatus{ - { - Name: constants.Node, - Ready: 0, - Succeeded: 0, - Failed: 1, - Active: 0, - Suspended: 0, + wantStatus: &trainer.TrainJobStatus{ + JobsStatus: []trainer.JobStatus{ + { + Name: constants.Node, + Ready: 0, + Succeeded: 0, + Failed: 1, + Active: 0, + Suspended: 0, + }, }, }, }, @@ -1675,7 +1641,7 @@ func TestJobsStatusPlugins(t *testing.T) { jobset.Name: jobset.New, fakeJobsStatusPluginName: newFakeJobsStatusPlugin, }, - wantError: errorTooManyJobsStatusPlugin, + wantError: errorTooManyTrainJobStatusPlugin, }, } for name, tc := range cases { @@ -1696,13 +1662,13 @@ func TestJobsStatusPlugins(t *testing.T) { return } - gotStatuses, gotErr := fwk.RunJobsStatusPlugins(ctx, tc.trainJob) + gotStatus, gotErr := fwk.RunTrainJobStatusPlugin(ctx, tc.trainJob) if diff := cmp.Diff(tc.wantError, gotErr, cmpopts.EquateErrors()); len(diff) != 0 { t.Errorf("Unexpected error (-want,+got):\n%s", diff) } - if diff := cmp.Diff(tc.wantStatuses, gotStatuses); len(diff) != 0 { - t.Errorf("Unexpected jobs status (-want,+got):\n%s", diff) + if diff := cmp.Diff(tc.wantStatus, gotStatus); len(diff) != 0 { + t.Errorf("Unexpected TrainJob status (-want,+got):\n%s", diff) } }) } diff --git a/pkg/runtime/framework/interface.go b/pkg/runtime/framework/interface.go index b6ab61712a..703379c1f1 100644 --- a/pkg/runtime/framework/interface.go +++ b/pkg/runtime/framework/interface.go @@ -19,7 +19,6 @@ package framework import ( "context" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/util/validation/field" "sigs.k8s.io/controller-runtime/pkg/webhook/admission" @@ -61,12 +60,7 @@ type ComponentBuilderPlugin interface { Build(ctx context.Context, info *runtime.Info, trainJob *trainer.TrainJob) ([]any, error) } -type TerminalConditionPlugin interface { +type TrainJobStatusPlugin interface { Plugin - TerminalCondition(ctx context.Context, trainJob *trainer.TrainJob) (*metav1.Condition, error) -} - -type JobsStatusPlugin interface { - Plugin - JobsStatus(ctx context.Context, trainJob *trainer.TrainJob) ([]trainer.JobStatus, error) + Status(ctx context.Context, trainJob *trainer.TrainJob) (*trainer.TrainJobStatus, error) } diff --git a/pkg/runtime/framework/plugins/jobset/jobset.go b/pkg/runtime/framework/plugins/jobset/jobset.go index d737efc603..c550eec27a 100644 --- a/pkg/runtime/framework/plugins/jobset/jobset.go +++ b/pkg/runtime/framework/plugins/jobset/jobset.go @@ -62,8 +62,7 @@ type JobSet struct { var _ framework.WatchExtensionPlugin = (*JobSet)(nil) var _ framework.PodNetworkPlugin = (*JobSet)(nil) var _ framework.ComponentBuilderPlugin = (*JobSet)(nil) -var _ framework.TerminalConditionPlugin = (*JobSet)(nil) -var _ framework.JobsStatusPlugin = (*JobSet)(nil) +var _ framework.TrainJobStatusPlugin = (*JobSet)(nil) var _ framework.CustomValidationPlugin = (*JobSet)(nil) const Name = constants.JobSetKind @@ -290,27 +289,22 @@ func (j *JobSet) Build(ctx context.Context, info *runtime.Info, trainJob *traine return []any{jobSet}, nil } -func (j *JobSet) TerminalCondition(ctx context.Context, trainJob *trainer.TrainJob) (*metav1.Condition, error) { +func (j *JobSet) Status(ctx context.Context, trainJob *trainer.TrainJob) (*trainer.TrainJobStatus, error) { jobSet := &jobsetv1alpha2.JobSet{} if err := j.client.Get(ctx, client.ObjectKeyFromObject(trainJob), jobSet); err != nil { return nil, err } + status := trainJob.Status.DeepCopy() + if completed := meta.FindStatusCondition(jobSet.Status.Conditions, string(jobsetv1alpha2.JobSetCompleted)); completed != nil && completed.Status == metav1.ConditionTrue { completed.Type = trainer.TrainJobComplete - return completed, nil + meta.SetStatusCondition(&status.Conditions, *completed) } if failed := meta.FindStatusCondition(jobSet.Status.Conditions, string(jobsetv1alpha2.JobSetFailed)); failed != nil && failed.Status == metav1.ConditionTrue { failed.Type = trainer.TrainJobFailed - return failed, nil + meta.SetStatusCondition(&status.Conditions, *failed) } - return nil, nil -} -func (j *JobSet) JobsStatus(ctx context.Context, trainJob *trainer.TrainJob) ([]trainer.JobStatus, error) { - jobSet := &jobsetv1alpha2.JobSet{} - if err := j.client.Get(ctx, client.ObjectKeyFromObject(trainJob), jobSet); err != nil { - return nil, err - } var statuses []trainer.JobStatus for _, status := range jobSet.Status.ReplicatedJobsStatus { statuses = append(statuses, trainer.JobStatus{ @@ -322,5 +316,7 @@ func (j *JobSet) JobsStatus(ctx context.Context, trainJob *trainer.TrainJob) ([] Suspended: status.Suspended, }) } - return statuses, nil + status.JobsStatus = statuses + + return status, nil } diff --git a/pkg/runtime/interface.go b/pkg/runtime/interface.go index 6150a8c821..3e831b8fa6 100644 --- a/pkg/runtime/interface.go +++ b/pkg/runtime/interface.go @@ -19,7 +19,6 @@ package runtime import ( "context" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/util/validation/field" "sigs.k8s.io/controller-runtime/pkg/builder" "sigs.k8s.io/controller-runtime/pkg/cache" @@ -38,8 +37,7 @@ type Runtime interface { NewObjects(ctx context.Context, trainJob *trainer.TrainJob) ([]any, error) RuntimeInfo(trainJob *trainer.TrainJob, runtimeTemplateSpec any, mlPolicy *trainer.MLPolicy, podGroupPolicy *trainer.PodGroupPolicy) (*Info, error) - TerminalCondition(ctx context.Context, trainJob *trainer.TrainJob) (*metav1.Condition, error) - JobsStatus(ctx context.Context, trainJob *trainer.TrainJob) ([]trainer.JobStatus, error) + TrainJobStatus(ctx context.Context, trainJob *trainer.TrainJob) (*trainer.TrainJobStatus, error) EventHandlerRegistrars() []ReconcilerBuilder ValidateObjects(ctx context.Context, old, new *trainer.TrainJob) (admission.Warnings, field.ErrorList) }