diff --git a/pkg/controller/trainjob_controller.go b/pkg/controller/trainjob_controller.go index 78cb8c28df..c4c2f61622 100644 --- a/pkg/controller/trainjob_controller.go +++ b/pkg/controller/trainjob_controller.go @@ -104,10 +104,6 @@ func (r *TrainJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c log := ctrl.LoggerFrom(ctx).WithValues("trainJob", klog.KObj(&trainJob)) ctx = ctrl.LoggerInto(ctx, log) log.V(2).Info("Reconciling TrainJob") - if isTrainJobFinished(&trainJob) { - log.V(5).Info("TrainJob has already been finished") - return ctrl.Result{}, nil - } var err error // Keep track of the origin TrainJob status @@ -138,8 +134,9 @@ func (r *TrainJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c } setSuspendedCondition(&trainJob) - if terminalCondErr := setTerminalCondition(ctx, runtime, &trainJob); terminalCondErr != nil { - err = errors.Join(err, terminalCondErr) + + if statusErr := setTrainJobStatus(ctx, runtime, &trainJob); statusErr != nil { + err = errors.Join(err, statusErr) } if !equality.Semantic.DeepEqual(&trainJob.Status, originStatus) { @@ -256,22 +253,17 @@ func removeFailedCondition(trainJob *trainer.TrainJob) { meta.RemoveStatusCondition(&trainJob.Status.Conditions, trainer.TrainJobFailed) } -func setTerminalCondition(ctx context.Context, runtime jobruntimes.Runtime, trainJob *trainer.TrainJob) error { - terminalCond, err := runtime.TerminalCondition(ctx, trainJob) +func setTrainJobStatus(ctx context.Context, runtime jobruntimes.Runtime, trainJob *trainer.TrainJob) error { + status, err := runtime.TrainJobStatus(ctx, trainJob) if err != nil { return err } - if terminalCond != nil { - meta.SetStatusCondition(&trainJob.Status.Conditions, *terminalCond) + if status != nil { + trainJob.Status = *status } return nil } -func isTrainJobFinished(trainJob *trainer.TrainJob) bool { - return meta.IsStatusConditionTrue(trainJob.Status.Conditions, trainer.TrainJobComplete) || - meta.IsStatusConditionTrue(trainJob.Status.Conditions, trainer.TrainJobFailed) -} - func (r *TrainJobReconciler) SetupWithManager(mgr ctrl.Manager, options controller.Options) error { b := builder.TypedControllerManagedBy[reconcile.Request](mgr). Named("trainjob_controller"). diff --git a/pkg/runtime/core/clustertrainingruntime.go b/pkg/runtime/core/clustertrainingruntime.go index 16214203e5..dbcb18034e 100644 --- a/pkg/runtime/core/clustertrainingruntime.go +++ b/pkg/runtime/core/clustertrainingruntime.go @@ -21,7 +21,6 @@ import ( "errors" "fmt" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/util/validation/field" "sigs.k8s.io/controller-runtime/pkg/client" @@ -71,8 +70,8 @@ func (r *ClusterTrainingRuntime) RuntimeInfo( return r.TrainingRuntime.RuntimeInfo(trainJob, runtimeTemplateSpec, mlPolicy, podGroupPolicy) } -func (r *ClusterTrainingRuntime) TerminalCondition(ctx context.Context, trainJob *trainer.TrainJob) (*metav1.Condition, error) { - return r.TrainingRuntime.TerminalCondition(ctx, trainJob) +func (r *ClusterTrainingRuntime) TrainJobStatus(ctx context.Context, trainJob *trainer.TrainJob) (*trainer.TrainJobStatus, error) { + return r.TrainingRuntime.TrainJobStatus(ctx, trainJob) } func (r *ClusterTrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBuilder { diff --git a/pkg/runtime/core/trainingruntime.go b/pkg/runtime/core/trainingruntime.go index c25de2936d..521e929cf6 100644 --- a/pkg/runtime/core/trainingruntime.go +++ b/pkg/runtime/core/trainingruntime.go @@ -246,8 +246,8 @@ func syncPodSets(info *runtime.Info) { } } -func (r *TrainingRuntime) TerminalCondition(ctx context.Context, trainJob *trainer.TrainJob) (*metav1.Condition, error) { - return r.framework.RunTerminalConditionPlugins(ctx, trainJob) +func (r *TrainingRuntime) TrainJobStatus(ctx context.Context, trainJob *trainer.TrainJob) (*trainer.TrainJobStatus, error) { + return r.framework.RunTrainJobStatusPlugin(ctx, trainJob) } func (r *TrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBuilder { diff --git a/pkg/runtime/framework/core/framework.go b/pkg/runtime/framework/core/framework.go index a7567615a5..dc654d3aea 100644 --- a/pkg/runtime/framework/core/framework.go +++ b/pkg/runtime/framework/core/framework.go @@ -20,7 +20,6 @@ import ( "context" "errors" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/util/validation/field" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/webhook/admission" @@ -32,7 +31,7 @@ import ( index "github.com/kubeflow/trainer/v2/pkg/runtime/indexer" ) -var errorTooManyTerminalConditionPlugin = errors.New("too many TerminalCondition plugins are registered") +var errorTooManyTrainJobStatusPlugin = errors.New("too many TrainJobStatus plugins are registered") type Framework struct { registry fwkplugins.Registry @@ -43,7 +42,7 @@ type Framework struct { watchExtensionPlugins []framework.WatchExtensionPlugin podNetworkPlugins []framework.PodNetworkPlugin componentBuilderPlugins []framework.ComponentBuilderPlugin - terminalConditionPlugins []framework.TerminalConditionPlugin + trainJobStatusPlugin framework.TrainJobStatusPlugin } func New(ctx context.Context, c client.Client, r fwkplugins.Registry, indexer client.FieldIndexer) (*Framework, error) { @@ -79,8 +78,11 @@ func New(ctx context.Context, c client.Client, r fwkplugins.Registry, indexer cl if p, ok := plugin.(framework.ComponentBuilderPlugin); ok { f.componentBuilderPlugins = append(f.componentBuilderPlugins, p) } - if p, ok := plugin.(framework.TerminalConditionPlugin); ok { - f.terminalConditionPlugins = append(f.terminalConditionPlugins, p) + if p, ok := plugin.(framework.TrainJobStatusPlugin); ok { + if f.trainJobStatusPlugin != nil { + return nil, errorTooManyTrainJobStatusPlugin + } + f.trainJobStatusPlugin = p } } f.plugins = plugins @@ -141,13 +143,9 @@ func (f *Framework) RunComponentBuilderPlugins(ctx context.Context, info *runtim return objs, nil } -func (f *Framework) RunTerminalConditionPlugins(ctx context.Context, trainJob *trainer.TrainJob) (*metav1.Condition, error) { - // TODO (tenzen-y): Once we provide the Configuration API, we should validate which plugin should have terminalCondition execution points. - if len(f.terminalConditionPlugins) > 1 { - return nil, errorTooManyTerminalConditionPlugin - } - if len(f.terminalConditionPlugins) != 0 { - return f.terminalConditionPlugins[0].TerminalCondition(ctx, trainJob) +func (f *Framework) RunTrainJobStatusPlugin(ctx context.Context, trainJob *trainer.TrainJob) (*trainer.TrainJobStatus, error) { + if f.trainJobStatusPlugin != nil { + return f.trainJobStatusPlugin.Status(ctx, trainJob) } return nil, nil } diff --git a/pkg/runtime/framework/core/framework_test.go b/pkg/runtime/framework/core/framework_test.go index 04262483f8..f25f0798b3 100644 --- a/pkg/runtime/framework/core/framework_test.go +++ b/pkg/runtime/framework/core/framework_test.go @@ -19,6 +19,7 @@ package core import ( "context" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" @@ -111,9 +112,7 @@ func TestNew(t *testing.T) { &jobset.JobSet{}, &mpi.MPI{}, }, - terminalConditionPlugins: []framework.TerminalConditionPlugin{ - &jobset.JobSet{}, - }, + trainJobStatusPlugin: &jobset.JobSet{}, }, }, "indexer key for trainingRuntime and runtimeClass is an empty": { @@ -136,7 +135,7 @@ func TestNew(t *testing.T) { cmpopts.IgnoreUnexported(coscheduling.CoScheduling{}, volcano.Volcano{}, mpi.MPI{}, plainml.PlainML{}, torch.Torch{}, jobset.JobSet{}), cmpopts.IgnoreFields(coscheduling.CoScheduling{}, "client"), cmpopts.IgnoreFields(volcano.Volcano{}, "client"), - cmpopts.IgnoreFields(jobset.JobSet{}, "client"), + cmpopts.IgnoreFields(jobset.JobSet{}, "client", "restMapper", "scheme", "logger"), cmpopts.IgnoreTypes(apiruntime.Scheme{}, meta.DefaultRESTMapper{}, fwkplugins.Registry{}), cmpopts.SortMaps(func(a, b string) bool { return a < b }), cmpopts.SortSlices(func(a, b framework.Plugin) bool { return a.Name() < b.Name() }), @@ -1440,28 +1439,33 @@ func TestWatchExtensionPlugins(t *testing.T) { } } -type fakeTerminalConditionPlugin struct{} +type fakeTrainJobStatusPlugin struct{} -var _ framework.TerminalConditionPlugin = (*fakeTerminalConditionPlugin)(nil) +var _ framework.TrainJobStatusPlugin = (*fakeTrainJobStatusPlugin)(nil) -func newFakeTerminalConditionPlugin(context.Context, client.Client, client.FieldIndexer) (framework.Plugin, error) { - return &fakeTerminalConditionPlugin{}, nil +func newFakeJobsStatusPlugin(context.Context, client.Client, client.FieldIndexer) (framework.Plugin, error) { + return &fakeTrainJobStatusPlugin{}, nil } -const fakeTerminalConditionPluginName = "fake" +const fakeJobsStatusPluginName = "fake-train-job-status" -func (f fakeTerminalConditionPlugin) Name() string { return fakeTerminalConditionPluginName } -func (f fakeTerminalConditionPlugin) TerminalCondition(context.Context, *trainer.TrainJob) (*metav1.Condition, error) { - return nil, nil +func (f fakeTrainJobStatusPlugin) Name() string { return fakeJobsStatusPluginName } +func (f fakeTrainJobStatusPlugin) Status(context.Context, *trainer.TrainJob) (*trainer.TrainJobStatus, error) { + return &trainer.TrainJobStatus{ + JobsStatus: []trainer.JobStatus{ + {Name: "fake-job", Ready: 1, Succeeded: 0, Failed: 0, Active: 1, Suspended: 0}, + }, + }, nil } -func TestTerminalConditionPlugins(t *testing.T) { +func TestTrainJobStatusPlugins(t *testing.T) { + lastTransitionTime := metav1.Time{Time: time.Now()}.Rfc3339Copy() cases := map[string]struct { - registry fwkplugins.Registry - trainJob *trainer.TrainJob - jobSet *jobsetv1alpha2.JobSet - wantCondition *metav1.Condition - wantError error + registry fwkplugins.Registry + trainJob *trainer.TrainJob + jobSet *jobsetv1alpha2.JobSet + wantStatus *trainer.TrainJobStatus + wantError error }{ "jobSet has not been finalized, yet": { registry: fwkplugins.NewRegistry(), @@ -1475,6 +1479,7 @@ func TestTerminalConditionPlugins(t *testing.T) { Status: metav1.ConditionFalse, }). Obj(), + wantStatus: &trainer.TrainJobStatus{}, }, "succeeded to obtain completed terminal condition": { registry: fwkplugins.NewRegistry(), @@ -1482,17 +1487,23 @@ func TestTerminalConditionPlugins(t *testing.T) { Obj(), jobSet: testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "testing"). Conditions(metav1.Condition{ - Type: string(jobsetv1alpha2.JobSetCompleted), - Reason: jobsetconsts.AllJobsCompletedReason, - Message: jobsetconsts.AllJobsCompletedMessage, - Status: metav1.ConditionTrue, + Type: string(jobsetv1alpha2.JobSetCompleted), + LastTransitionTime: lastTransitionTime, + Message: jobsetconsts.AllJobsCompletedMessage, + Reason: jobsetconsts.AllJobsCompletedReason, + Status: metav1.ConditionTrue, }). Obj(), - wantCondition: &metav1.Condition{ - Type: trainer.TrainJobComplete, - Reason: jobsetconsts.AllJobsCompletedReason, - Message: jobsetconsts.AllJobsCompletedMessage, - Status: metav1.ConditionTrue, + wantStatus: &trainer.TrainJobStatus{ + Conditions: []metav1.Condition{ + { + Type: trainer.TrainJobComplete, + LastTransitionTime: lastTransitionTime, + Message: jobsetconsts.AllJobsCompletedMessage, + Reason: jobsetconsts.AllJobsCompletedReason, + Status: metav1.ConditionTrue, + }, + }, }, }, "succeeded to obtain failed terminal condition": { @@ -1501,25 +1512,136 @@ func TestTerminalConditionPlugins(t *testing.T) { Obj(), jobSet: testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "testing"). Conditions(metav1.Condition{ - Type: string(jobsetv1alpha2.JobSetFailed), - Reason: jobsetconsts.FailedJobsReason, - Message: jobsetconsts.FailedJobsMessage, - Status: metav1.ConditionTrue, + Type: string(jobsetv1alpha2.JobSetFailed), + LastTransitionTime: lastTransitionTime, + Message: jobsetconsts.FailedJobsMessage, + Reason: jobsetconsts.FailedJobsReason, + Status: metav1.ConditionTrue, }). Obj(), - wantCondition: &metav1.Condition{ - Type: trainer.TrainJobFailed, - Reason: jobsetconsts.FailedJobsReason, - Message: jobsetconsts.FailedJobsMessage, - Status: metav1.ConditionTrue, + wantStatus: &trainer.TrainJobStatus{ + Conditions: []metav1.Condition{ + { + Type: trainer.TrainJobFailed, + LastTransitionTime: lastTransitionTime, + Message: jobsetconsts.FailedJobsMessage, + Reason: jobsetconsts.FailedJobsReason, + Status: metav1.ConditionTrue, + }, + }, }, }, - "failed to obtain any terminal condition due to multiple terminalCondition plugin": { + "failed to obtain TrainJob status due to multiple trainJobStatus plugin": { registry: fwkplugins.Registry{ - jobset.Name: jobset.New, - fakeTerminalConditionPluginName: newFakeTerminalConditionPlugin, + jobset.Name: jobset.New, + fakeJobsStatusPluginName: newFakeJobsStatusPlugin, }, - wantError: errorTooManyTerminalConditionPlugin, + wantError: errorTooManyTrainJobStatusPlugin, + }, + "JobSet with empty replicated jobs status": { + registry: fwkplugins.NewRegistry(), + trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "testing"). + Obj(), + jobSet: testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "testing"). + Obj(), + wantStatus: &trainer.TrainJobStatus{}, + }, + "succeeded to obtain JobsStatus from JobSet with multiple replicated jobs": { + registry: fwkplugins.NewRegistry(), + trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "testing"). + Obj(), + jobSet: testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "testing"). + ReplicatedJobsStatuses([]jobsetv1alpha2.ReplicatedJobStatus{ + { + Name: constants.DatasetInitializer, + Ready: 1, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.ModelInitializer, + Ready: 1, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.Node, + Ready: 2, + Succeeded: 0, + Failed: 0, + Active: 2, + Suspended: 0, + }, + }). + Obj(), + wantStatus: &trainer.TrainJobStatus{ + JobsStatus: []trainer.JobStatus{ + { + Name: constants.DatasetInitializer, + Ready: 1, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.ModelInitializer, + Ready: 1, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.Node, + Ready: 2, + Succeeded: 0, + Failed: 0, + Active: 2, + Suspended: 0, + }, + }, + }, + }, + "succeeded to obtain JobsStatus from JobSet with failed job": { + registry: fwkplugins.NewRegistry(), + trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "testing"). + Obj(), + jobSet: testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "testing"). + ReplicatedJobsStatuses([]jobsetv1alpha2.ReplicatedJobStatus{ + { + Name: constants.Node, + Ready: 0, + Succeeded: 0, + Failed: 1, + Active: 0, + Suspended: 0, + }, + }). + Obj(), + wantStatus: &trainer.TrainJobStatus{ + JobsStatus: []trainer.JobStatus{ + { + Name: constants.Node, + Ready: 0, + Succeeded: 0, + Failed: 1, + Active: 0, + Suspended: 0, + }, + }, + }, + }, + "failed to obtain JobsStatus due to multiple JobsStatusPlugins": { + registry: fwkplugins.Registry{ + jobset.Name: jobset.New, + fakeJobsStatusPluginName: newFakeJobsStatusPlugin, + }, + wantError: errorTooManyTrainJobStatusPlugin, }, } for name, tc := range cases { @@ -1534,16 +1656,19 @@ func TestTerminalConditionPlugins(t *testing.T) { fwk, err := New(ctx, c, tc.registry, testingutil.AsIndex(clientBuilder)) if err != nil { - t.Fatal(err) + if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 { + t.Errorf("Unexpected error (-want,+got):\n%s", diff) + } + return } - gotCond, gotErr := fwk.RunTerminalConditionPlugins(ctx, tc.trainJob) + gotStatus, gotErr := fwk.RunTrainJobStatusPlugin(ctx, tc.trainJob) if diff := cmp.Diff(tc.wantError, gotErr, cmpopts.EquateErrors()); len(diff) != 0 { t.Errorf("Unexpected error (-want,+got):\n%s", diff) } - if diff := cmp.Diff(tc.wantCondition, gotCond); len(diff) != 0 { - t.Errorf("Unexpected terminal condition (-want,+got):\n%s", diff) + if diff := cmp.Diff(tc.wantStatus, gotStatus); len(diff) != 0 { + t.Errorf("Unexpected TrainJob status (-want,+got):\n%s", diff) } }) } diff --git a/pkg/runtime/framework/interface.go b/pkg/runtime/framework/interface.go index 676548ebe7..703379c1f1 100644 --- a/pkg/runtime/framework/interface.go +++ b/pkg/runtime/framework/interface.go @@ -19,7 +19,6 @@ package framework import ( "context" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/util/validation/field" "sigs.k8s.io/controller-runtime/pkg/webhook/admission" @@ -61,7 +60,7 @@ type ComponentBuilderPlugin interface { Build(ctx context.Context, info *runtime.Info, trainJob *trainer.TrainJob) ([]any, error) } -type TerminalConditionPlugin interface { +type TrainJobStatusPlugin interface { Plugin - TerminalCondition(ctx context.Context, trainJob *trainer.TrainJob) (*metav1.Condition, error) + Status(ctx context.Context, trainJob *trainer.TrainJob) (*trainer.TrainJobStatus, error) } diff --git a/pkg/runtime/framework/plugins/jobset/jobset.go b/pkg/runtime/framework/plugins/jobset/jobset.go index e829188fd8..c550eec27a 100644 --- a/pkg/runtime/framework/plugins/jobset/jobset.go +++ b/pkg/runtime/framework/plugins/jobset/jobset.go @@ -62,7 +62,7 @@ type JobSet struct { var _ framework.WatchExtensionPlugin = (*JobSet)(nil) var _ framework.PodNetworkPlugin = (*JobSet)(nil) var _ framework.ComponentBuilderPlugin = (*JobSet)(nil) -var _ framework.TerminalConditionPlugin = (*JobSet)(nil) +var _ framework.TrainJobStatusPlugin = (*JobSet)(nil) var _ framework.CustomValidationPlugin = (*JobSet)(nil) const Name = constants.JobSetKind @@ -289,18 +289,34 @@ func (j *JobSet) Build(ctx context.Context, info *runtime.Info, trainJob *traine return []any{jobSet}, nil } -func (j *JobSet) TerminalCondition(ctx context.Context, trainJob *trainer.TrainJob) (*metav1.Condition, error) { +func (j *JobSet) Status(ctx context.Context, trainJob *trainer.TrainJob) (*trainer.TrainJobStatus, error) { jobSet := &jobsetv1alpha2.JobSet{} if err := j.client.Get(ctx, client.ObjectKeyFromObject(trainJob), jobSet); err != nil { return nil, err } + status := trainJob.Status.DeepCopy() + if completed := meta.FindStatusCondition(jobSet.Status.Conditions, string(jobsetv1alpha2.JobSetCompleted)); completed != nil && completed.Status == metav1.ConditionTrue { completed.Type = trainer.TrainJobComplete - return completed, nil + meta.SetStatusCondition(&status.Conditions, *completed) } if failed := meta.FindStatusCondition(jobSet.Status.Conditions, string(jobsetv1alpha2.JobSetFailed)); failed != nil && failed.Status == metav1.ConditionTrue { failed.Type = trainer.TrainJobFailed - return failed, nil + meta.SetStatusCondition(&status.Conditions, *failed) + } + + var statuses []trainer.JobStatus + for _, status := range jobSet.Status.ReplicatedJobsStatus { + statuses = append(statuses, trainer.JobStatus{ + Name: status.Name, + Ready: status.Ready, + Succeeded: status.Succeeded, + Failed: status.Failed, + Active: status.Active, + Suspended: status.Suspended, + }) } - return nil, nil + status.JobsStatus = statuses + + return status, nil } diff --git a/pkg/runtime/interface.go b/pkg/runtime/interface.go index 6f49bace31..3e831b8fa6 100644 --- a/pkg/runtime/interface.go +++ b/pkg/runtime/interface.go @@ -19,7 +19,6 @@ package runtime import ( "context" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/util/validation/field" "sigs.k8s.io/controller-runtime/pkg/builder" "sigs.k8s.io/controller-runtime/pkg/cache" @@ -38,7 +37,7 @@ type Runtime interface { NewObjects(ctx context.Context, trainJob *trainer.TrainJob) ([]any, error) RuntimeInfo(trainJob *trainer.TrainJob, runtimeTemplateSpec any, mlPolicy *trainer.MLPolicy, podGroupPolicy *trainer.PodGroupPolicy) (*Info, error) - TerminalCondition(ctx context.Context, trainJob *trainer.TrainJob) (*metav1.Condition, error) + TrainJobStatus(ctx context.Context, trainJob *trainer.TrainJob) (*trainer.TrainJobStatus, error) EventHandlerRegistrars() []ReconcilerBuilder ValidateObjects(ctx context.Context, old, new *trainer.TrainJob) (admission.Warnings, field.ErrorList) } diff --git a/pkg/util/testing/wrapper.go b/pkg/util/testing/wrapper.go index d497ecc37d..c4c3a552f9 100644 --- a/pkg/util/testing/wrapper.go +++ b/pkg/util/testing/wrapper.go @@ -503,6 +503,11 @@ func (j *JobSetWrapper) DependsOn(rJobName string, dependsOn ...jobsetv1alpha2.D return j } +func (j *JobSetWrapper) ReplicatedJobsStatuses(statuses []jobsetv1alpha2.ReplicatedJobStatus) *JobSetWrapper { + j.Status.ReplicatedJobsStatus = statuses + return j +} + func (j *JobSetWrapper) Obj() *jobsetv1alpha2.JobSet { return &j.JobSet } diff --git a/test/e2e/e2e_test.go b/test/e2e/e2e_test.go index c2fce0b209..012d6a6ab2 100644 --- a/test/e2e/e2e_test.go +++ b/test/e2e/e2e_test.go @@ -9,6 +9,7 @@ import ( jobsetconsts "sigs.k8s.io/jobset/pkg/constants" trainer "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1" + "github.com/kubeflow/trainer/v2/pkg/constants" testingutil "github.com/kubeflow/trainer/v2/pkg/util/testing" "github.com/kubeflow/trainer/v2/test/util" ) @@ -16,7 +17,6 @@ import ( const ( torchRuntime = "torch-distributed" deepSpeedRuntime = "deepspeed-distributed" - mlxRuntime = "mlx-distributed" ) var _ = ginkgo.Describe("TrainJob e2e", func() { @@ -57,8 +57,26 @@ var _ = ginkgo.Describe("TrainJob e2e", func() { gomega.Expect(k8sClient.Create(ctx, trainJob)).Should(gomega.Succeed()) }) - // Wait for TrainJob to be in Succeeded status. - ginkgo.By("Wait for TrainJob to be in Succeeded status", func() { + // Wait for jobs to become active + ginkgo.By("Wait for TrainJob jobs to become active", func() { + gomega.Eventually(func(g gomega.Gomega) { + gotTrainJob := &trainer.TrainJob{} + g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(trainJob), gotTrainJob)).Should(gomega.Succeed()) + g.Expect(gotTrainJob.Status.JobsStatus).Should(gomega.BeComparableTo([]trainer.JobStatus{ + { + Name: constants.Node, + Ready: 0, + Succeeded: 0, + Failed: 0, + Active: 1, + Suspended: 0, + }, + }, util.SortJobsStatus)) + }, util.TimeoutE2E, util.Interval).Should(gomega.Succeed()) + }) + + // Wait for TrainJob to be in Succeeded status with all jobs succeeded. + ginkgo.By("Wait for TrainJob to be in Succeeded status with all jobs succeeded", func() { gomega.Eventually(func(g gomega.Gomega) { gotTrainJob := &trainer.TrainJob{} g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(trainJob), gotTrainJob)).Should(gomega.Succeed()) @@ -70,6 +88,16 @@ var _ = ginkgo.Describe("TrainJob e2e", func() { Message: jobsetconsts.AllJobsCompletedMessage, }, }, util.IgnoreConditions)) + g.Expect(gotTrainJob.Status.JobsStatus).Should(gomega.BeComparableTo([]trainer.JobStatus{ + { + Name: constants.Node, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + }, util.SortJobsStatus)) }, util.TimeoutE2E, util.Interval).Should(gomega.Succeed()) }) }) @@ -87,6 +115,32 @@ var _ = ginkgo.Describe("TrainJob e2e", func() { gomega.Expect(k8sClient.Create(ctx, trainJob)).Should(gomega.Succeed()) }) + // Wait for jobs to become active + ginkgo.By("Wait for TrainJob jobs to become active", func() { + gomega.Eventually(func(g gomega.Gomega) { + gotTrainJob := &trainer.TrainJob{} + g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(trainJob), gotTrainJob)).Should(gomega.Succeed()) + g.Expect(gotTrainJob.Status.JobsStatus).Should(gomega.BeComparableTo([]trainer.JobStatus{ + { + Name: constants.Launcher, + Ready: 0, + Succeeded: 0, + Failed: 0, + Active: 1, + Suspended: 0, + }, + { + Name: constants.Node, + Ready: 0, + Succeeded: 0, + Failed: 0, + Active: 1, + Suspended: 0, + }, + }, util.SortJobsStatus)) + }, util.TimeoutE2E, util.Interval).Should(gomega.Succeed()) + }) + // Wait for TrainJob to be in Succeeded status. ginkgo.By("Wait for TrainJob to be in Succeeded status", func() { gomega.Eventually(func(g gomega.Gomega) { @@ -100,6 +154,24 @@ var _ = ginkgo.Describe("TrainJob e2e", func() { Message: jobsetconsts.AllJobsCompletedMessage, }, }, util.IgnoreConditions)) + g.Expect(gotTrainJob.Status.JobsStatus).Should(gomega.BeComparableTo([]trainer.JobStatus{ + { + Name: constants.Launcher, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.Node, + Ready: 0, + Succeeded: 0, + Failed: 0, + Active: 0, + Suspended: 0, + }, + }, util.SortJobsStatus)) }, util.TimeoutE2E, util.Interval).Should(gomega.Succeed()) }) }) diff --git a/test/integration/controller/trainjob_controller_test.go b/test/integration/controller/trainjob_controller_test.go index 9be21beb47..62dc36dbcd 100644 --- a/test/integration/controller/trainjob_controller_test.go +++ b/test/integration/controller/trainjob_controller_test.go @@ -412,7 +412,7 @@ var _ = ginkgo.Describe("TrainJob controller", ginkgo.Ordered, func() { g.Expect(k8sClient.Get(ctx, trainJobKey, &schedulerpluginsv1alpha1.PodGroup{})).Should(gomega.Succeed()) }, util.Timeout, util.Interval).Should(gomega.Succeed()) - ginkgo.By("Checking if TrainJob has Suspended and Created conditions") + ginkgo.By("Checking if TrainJob has Suspended=True condition") gomega.Eventually(func(g gomega.Gomega) { gotTrainJob := &trainer.TrainJob{} g.Expect(k8sClient.Get(ctx, trainJobKey, gotTrainJob)).Should(gomega.Succeed()) @@ -426,7 +426,7 @@ var _ = ginkgo.Describe("TrainJob controller", ginkgo.Ordered, func() { }, util.IgnoreConditions)) }, util.Timeout, util.Interval).Should(gomega.Succeed()) - ginkgo.By("Checking if the TrainJob has Resumed and Created conditions after unsuspended") + ginkgo.By("Checking if the TrainJob has Suspended=False [Resumed] condition after unsuspended") gomega.Eventually(func(g gomega.Gomega) { gotTrainJob := &trainer.TrainJob{} g.Expect(k8sClient.Get(ctx, trainJobKey, gotTrainJob)).Should(gomega.Succeed()) @@ -443,7 +443,7 @@ var _ = ginkgo.Describe("TrainJob controller", ginkgo.Ordered, func() { }, util.IgnoreConditions)) }, util.Timeout, util.Interval).Should(gomega.Succeed()) - ginkgo.By("Updating the JobSet condition with Completed") + ginkgo.By("Updating the JobSet conditions and ReplicatedJobsStatus with successful completion") gomega.Eventually(func(g gomega.Gomega) { jobSet := &jobsetv1alpha2.JobSet{} g.Expect(k8sClient.Get(ctx, trainJobKey, jobSet)).Should(gomega.Succeed()) @@ -453,10 +453,37 @@ var _ = ginkgo.Describe("TrainJob controller", ginkgo.Ordered, func() { Message: jobsetconsts.AllJobsCompletedMessage, Status: metav1.ConditionTrue, }) + jobSet.Status.ReplicatedJobsStatus = []jobsetv1alpha2.ReplicatedJobStatus{ + { + Name: constants.DatasetInitializer, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.ModelInitializer, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.Node, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + } + g.Expect(k8sClient.Status().Update(ctx, jobSet)).Should(gomega.Succeed()) }, util.Timeout, util.Interval).Should(gomega.Succeed()) - ginkgo.By("Checking if the TranJob has Resumed, Created, and Completed conditions") + ginkgo.By("Checking if the TranJob has Suspended and Complete conditions as well as Succeeded JobsStatus") gomega.Eventually(func(g gomega.Gomega) { gotTrainJob := &trainer.TrainJob{} g.Expect(k8sClient.Get(ctx, trainJobKey, gotTrainJob)).Should(gomega.Succeed()) @@ -474,6 +501,32 @@ var _ = ginkgo.Describe("TrainJob controller", ginkgo.Ordered, func() { Message: jobsetconsts.AllJobsCompletedMessage, }, }, util.IgnoreConditions)) + g.Expect(gotTrainJob.Status.JobsStatus).Should(gomega.BeComparableTo([]trainer.JobStatus{ + { + Name: constants.DatasetInitializer, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.ModelInitializer, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.Node, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + })) }, util.Timeout, util.Interval).Should(gomega.Succeed()) }) @@ -499,7 +552,7 @@ var _ = ginkgo.Describe("TrainJob controller", ginkgo.Ordered, func() { g.Expect(k8sClient.Update(ctx, gotTrainJob)).Should(gomega.Succeed()) }, util.Timeout, util.Interval).Should(gomega.Succeed()) - ginkgo.By("Waiting for TrainJob Created=True and Suspended=False condition") + ginkgo.By("Waiting for TrainJob Suspended=False condition") gomega.Eventually(func(g gomega.Gomega) { gotTrainJob := &trainer.TrainJob{} g.Expect(k8sClient.Get(ctx, trainJobKey, gotTrainJob)).Should(gomega.Succeed()) @@ -513,7 +566,7 @@ var _ = ginkgo.Describe("TrainJob controller", ginkgo.Ordered, func() { }, util.IgnoreConditions)) }, util.Timeout, util.Interval).Should(gomega.Succeed()) - ginkgo.By("Updating the JobSet condition with Failed") + ginkgo.By("Updating the JobSet conditions and ReplicatedJobsStatus with failed jobs") gomega.Eventually(func(g gomega.Gomega) { jobSet := &jobsetv1alpha2.JobSet{} g.Expect(k8sClient.Get(ctx, trainJobKey, jobSet)).Should(gomega.Succeed()) @@ -523,10 +576,36 @@ var _ = ginkgo.Describe("TrainJob controller", ginkgo.Ordered, func() { Message: jobsetconsts.FailedJobsMessage, Status: metav1.ConditionTrue, }) + jobSet.Status.ReplicatedJobsStatus = []jobsetv1alpha2.ReplicatedJobStatus{ + { + Name: constants.DatasetInitializer, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.ModelInitializer, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.Node, + Ready: 0, + Succeeded: 0, + Failed: 1, + Active: 0, + Suspended: 0, + }, + } g.Expect(k8sClient.Status().Update(ctx, jobSet)).Should(gomega.Succeed()) }, util.Timeout, util.Interval).Should(gomega.Succeed()) - ginkgo.By("Checking if the TranJob has Resumed, Created, and Failed conditions") + ginkgo.By("Checking if the TranJob has Suspended and Failed conditions as well as failed JobsStatus") gomega.Eventually(func(g gomega.Gomega) { gotTrainJob := &trainer.TrainJob{} g.Expect(k8sClient.Get(ctx, trainJobKey, gotTrainJob)).Should(gomega.Succeed()) @@ -544,6 +623,185 @@ var _ = ginkgo.Describe("TrainJob controller", ginkgo.Ordered, func() { Message: jobsetconsts.FailedJobsMessage, }, }, util.IgnoreConditions)) + g.Expect(gotTrainJob.Status.JobsStatus).Should(gomega.BeComparableTo([]trainer.JobStatus{ + { + Name: constants.DatasetInitializer, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.ModelInitializer, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.Node, + Ready: 0, + Succeeded: 0, + Failed: 1, + Active: 0, + Suspended: 0, + }, + })) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + }) + + ginkgo.It("Should synchronize JobsStatus from JobSet ReplicatedJobsStatus", func() { + ginkgo.By("Creating TrainingRuntime and suspended TrainJob") + gomega.Expect(k8sClient.Create(ctx, trainingRuntime)).Should(gomega.Succeed()) + gomega.Eventually(func(g gomega.Gomega) { + g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(trainingRuntime), trainingRuntime)).Should(gomega.Succeed()) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + gomega.Expect(k8sClient.Create(ctx, trainJob)).Should(gomega.Succeed()) + + ginkgo.By("Checking if JobSet and PodGroup are created") + gomega.Eventually(func(g gomega.Gomega) { + g.Expect(k8sClient.Get(ctx, trainJobKey, &jobsetv1alpha2.JobSet{})).Should(gomega.Succeed()) + g.Expect(k8sClient.Get(ctx, trainJobKey, &schedulerpluginsv1alpha1.PodGroup{})).Should(gomega.Succeed()) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + + ginkgo.By("Unsuspending the TrainJob") + gomega.Eventually(func(g gomega.Gomega) { + gotTrainJob := &trainer.TrainJob{} + g.Expect(k8sClient.Get(ctx, trainJobKey, gotTrainJob)).Should(gomega.Succeed()) + gotTrainJob.Spec.Suspend = ptr.To(false) + g.Expect(k8sClient.Update(ctx, gotTrainJob)).Should(gomega.Succeed()) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + + ginkgo.By("Updating JobSet ReplicatedJobsStatus to simulate running jobs") + gomega.Eventually(func(g gomega.Gomega) { + jobSet := &jobsetv1alpha2.JobSet{} + g.Expect(k8sClient.Get(ctx, trainJobKey, jobSet)).Should(gomega.Succeed()) + jobSet.Status.ReplicatedJobsStatus = []jobsetv1alpha2.ReplicatedJobStatus{ + { + Name: constants.DatasetInitializer, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.ModelInitializer, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.Node, + Ready: 1, + Succeeded: 0, + Failed: 0, + Active: 1, + Suspended: 0, + }, + } + g.Expect(k8sClient.Status().Update(ctx, jobSet)).Should(gomega.Succeed()) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + + ginkgo.By("Verifying JobsStatus synchronization in TrainJob") + gomega.Eventually(func(g gomega.Gomega) { + gotTrainJob := &trainer.TrainJob{} + g.Expect(k8sClient.Get(ctx, trainJobKey, gotTrainJob)).Should(gomega.Succeed()) + g.Expect(gotTrainJob.Status.JobsStatus).Should(gomega.BeComparableTo([]trainer.JobStatus{ + { + Name: constants.DatasetInitializer, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.ModelInitializer, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.Node, + Ready: 1, + Succeeded: 0, + Failed: 0, + Active: 1, + Suspended: 0, + }, + })) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + + ginkgo.By("Updating JobSet ReplicatedJobsStatus to simulate some failed jobs") + gomega.Eventually(func(g gomega.Gomega) { + jobSet := &jobsetv1alpha2.JobSet{} + g.Expect(k8sClient.Get(ctx, trainJobKey, jobSet)).Should(gomega.Succeed()) + jobSet.Status.ReplicatedJobsStatus = []jobsetv1alpha2.ReplicatedJobStatus{ + { + Name: constants.DatasetInitializer, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.ModelInitializer, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.Node, + Ready: 0, + Succeeded: 0, + Failed: 1, + Active: 0, + Suspended: 0, + }, + } + g.Expect(k8sClient.Status().Update(ctx, jobSet)).Should(gomega.Succeed()) + }, util.Timeout, util.Interval).Should(gomega.Succeed()) + + ginkgo.By("Verifying updated JobsStatus reflects failed jobs") + gomega.Eventually(func(g gomega.Gomega) { + gotTrainJob := &trainer.TrainJob{} + g.Expect(k8sClient.Get(ctx, trainJobKey, gotTrainJob)).Should(gomega.Succeed()) + g.Expect(gotTrainJob.Status.JobsStatus).Should(gomega.BeComparableTo([]trainer.JobStatus{ + { + Name: constants.DatasetInitializer, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.ModelInitializer, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.Node, + Ready: 0, + Succeeded: 0, + Failed: 1, + Active: 0, + Suspended: 0, + }, + })) }, util.Timeout, util.Interval).Should(gomega.Succeed()) }) @@ -889,7 +1147,7 @@ alpha-node-0-1.alpha slots=8 g.Expect(k8sClient.Get(ctx, secKey, &corev1.Secret{})).Should(gomega.Succeed()) }) - ginkgo.By("Checking if TrainJob has Suspended and Created conditions") + ginkgo.By("Checking if TrainJob has Suspended=True condition") gomega.Eventually(func(g gomega.Gomega) { gotTrainJob := &trainer.TrainJob{} g.Expect(k8sClient.Get(ctx, trainJobKey, gotTrainJob)).Should(gomega.Succeed()) @@ -903,7 +1161,7 @@ alpha-node-0-1.alpha slots=8 }, util.IgnoreConditions)) }, util.Timeout, util.Interval).Should(gomega.Succeed()) - ginkgo.By("Checking if the TrainJob has Resumed and Created conditions after unsuspended") + ginkgo.By("Checking if the TrainJob has Suspended=False [Resumed] condition after unsuspended") gomega.Eventually(func(g gomega.Gomega) { gotTrainJob := &trainer.TrainJob{} g.Expect(k8sClient.Get(ctx, trainJobKey, gotTrainJob)).Should(gomega.Succeed()) @@ -920,7 +1178,7 @@ alpha-node-0-1.alpha slots=8 }, util.IgnoreConditions)) }, util.Timeout, util.Interval).Should(gomega.Succeed()) - ginkgo.By("Updating the JobSet condition with Completed") + ginkgo.By("Updating the JobSet conditions and ReplicatedJobsStatus with successful completion") gomega.Eventually(func(g gomega.Gomega) { jobSet := &jobsetv1alpha2.JobSet{} g.Expect(k8sClient.Get(ctx, trainJobKey, jobSet)).Should(gomega.Succeed()) @@ -930,10 +1188,44 @@ alpha-node-0-1.alpha slots=8 Message: jobsetconsts.AllJobsCompletedMessage, Status: metav1.ConditionTrue, }) + jobSet.Status.ReplicatedJobsStatus = []jobsetv1alpha2.ReplicatedJobStatus{ + { + Name: constants.DatasetInitializer, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.ModelInitializer, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.Launcher, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.Node, + Ready: 0, + Succeeded: 0, + Failed: 0, + Active: 0, + Suspended: 0, + }, + } g.Expect(k8sClient.Status().Update(ctx, jobSet)).Should(gomega.Succeed()) }, util.Timeout, util.Interval).Should(gomega.Succeed()) - ginkgo.By("Checking if the TranJob has Resumed, Created, and Completed conditions") + ginkgo.By("Checking if the TranJob has Suspended=False and Complete=True conditions as well as succeeded JobsStatus") gomega.Eventually(func(g gomega.Gomega) { gotTrainJob := &trainer.TrainJob{} g.Expect(k8sClient.Get(ctx, trainJobKey, gotTrainJob)).Should(gomega.Succeed()) @@ -951,6 +1243,40 @@ alpha-node-0-1.alpha slots=8 Message: jobsetconsts.AllJobsCompletedMessage, }, }, util.IgnoreConditions)) + g.Expect(gotTrainJob.Status.JobsStatus).Should(gomega.BeComparableTo([]trainer.JobStatus{ + { + Name: constants.DatasetInitializer, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.ModelInitializer, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.Launcher, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.Node, + Ready: 0, + Succeeded: 0, + Failed: 0, + Active: 0, + Suspended: 0, + }, + })) }, util.Timeout, util.Interval).Should(gomega.Succeed()) }) @@ -977,7 +1303,7 @@ alpha-node-0-1.alpha slots=8 g.Expect(k8sClient.Update(ctx, gotTrainJob)).Should(gomega.Succeed()) }, util.Timeout, util.Interval).Should(gomega.Succeed()) - ginkgo.By("Waiting for TrainJob Created=True and Suspended=False condition") + ginkgo.By("Waiting for TrainJob Suspended=False condition") gomega.Eventually(func(g gomega.Gomega) { gotTrainJob := &trainer.TrainJob{} g.Expect(k8sClient.Get(ctx, trainJobKey, gotTrainJob)).Should(gomega.Succeed()) @@ -991,7 +1317,7 @@ alpha-node-0-1.alpha slots=8 }, util.IgnoreConditions)) }, util.Timeout, util.Interval).Should(gomega.Succeed()) - ginkgo.By("Updating the JobSet condition with Failed") + ginkgo.By("Updating the JobSet Failed=True condition and ReplicatedJobsStatus with failed jobs") gomega.Eventually(func(g gomega.Gomega) { jobSet := &jobsetv1alpha2.JobSet{} g.Expect(k8sClient.Get(ctx, trainJobKey, jobSet)).Should(gomega.Succeed()) @@ -1001,10 +1327,44 @@ alpha-node-0-1.alpha slots=8 Message: jobsetconsts.FailedJobsMessage, Status: metav1.ConditionTrue, }) + jobSet.Status.ReplicatedJobsStatus = []jobsetv1alpha2.ReplicatedJobStatus{ + { + Name: constants.DatasetInitializer, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.ModelInitializer, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.Launcher, + Ready: 0, + Succeeded: 0, + Failed: 1, + Active: 0, + Suspended: 0, + }, + { + Name: constants.Node, + Ready: 0, + Succeeded: 0, + Failed: 0, + Active: 0, + Suspended: 0, + }, + } g.Expect(k8sClient.Status().Update(ctx, jobSet)).Should(gomega.Succeed()) }, util.Timeout, util.Interval).Should(gomega.Succeed()) - ginkgo.By("Checking if the TranJob has Resumed, Created, and Failed conditions") + ginkgo.By("Checking if the TranJob has Suspended=False [Resumed] and Failed=True conditions as well as failed JobsStatus") gomega.Eventually(func(g gomega.Gomega) { gotTrainJob := &trainer.TrainJob{} g.Expect(k8sClient.Get(ctx, trainJobKey, gotTrainJob)).Should(gomega.Succeed()) @@ -1022,6 +1382,40 @@ alpha-node-0-1.alpha slots=8 Message: jobsetconsts.FailedJobsMessage, }, }, util.IgnoreConditions)) + g.Expect(gotTrainJob.Status.JobsStatus).Should(gomega.BeComparableTo([]trainer.JobStatus{ + { + Name: constants.DatasetInitializer, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.ModelInitializer, + Ready: 0, + Succeeded: 1, + Failed: 0, + Active: 0, + Suspended: 0, + }, + { + Name: constants.Launcher, + Ready: 0, + Succeeded: 0, + Failed: 1, + Active: 0, + Suspended: 0, + }, + { + Name: constants.Node, + Ready: 0, + Succeeded: 0, + Failed: 0, + Active: 0, + Suspended: 0, + }, + })) }, util.Timeout, util.Interval).Should(gomega.Succeed()) }) }) diff --git a/test/util/constants.go b/test/util/constants.go index da7d144fdb..860fa42c04 100644 --- a/test/util/constants.go +++ b/test/util/constants.go @@ -22,6 +22,8 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + trainer "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1" ) const ( @@ -39,4 +41,7 @@ var ( IgnoreConditions = cmp.Options{ cmpopts.IgnoreFields(metav1.Condition{}, "LastTransitionTime", "ObservedGeneration"), } + SortJobsStatus = cmp.Options{ + cmpopts.SortSlices(func(a, b trainer.JobStatus) bool { return a.Name < b.Name }), + } )