Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions pkg/controller/trainjob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,6 @@ func (r *TrainJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c
log := ctrl.LoggerFrom(ctx).WithValues("trainJob", klog.KObj(&trainJob))
ctx = ctrl.LoggerInto(ctx, log)
log.V(2).Info("Reconciling TrainJob")
if isTrainJobFinished(&trainJob) {
log.V(5).Info("TrainJob has already been finished")
return ctrl.Result{}, nil
}

var err error
// Keep track of the origin TrainJob status
Expand Down Expand Up @@ -142,6 +138,10 @@ func (r *TrainJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c
err = errors.Join(err, terminalCondErr)
}

if jobsStatusErr := setJobsStatus(ctx, runtime, &trainJob); jobsStatusErr != nil {
err = errors.Join(err, jobsStatusErr)
}

if !equality.Semantic.DeepEqual(&trainJob.Status, originStatus) {
return ctrl.Result{}, errors.Join(err, r.client.Status().Update(ctx, &trainJob))
}
Expand Down Expand Up @@ -267,9 +267,13 @@ func setTerminalCondition(ctx context.Context, runtime jobruntimes.Runtime, trai
return nil
}

func isTrainJobFinished(trainJob *trainer.TrainJob) bool {
return meta.IsStatusConditionTrue(trainJob.Status.Conditions, trainer.TrainJobComplete) ||
meta.IsStatusConditionTrue(trainJob.Status.Conditions, trainer.TrainJobFailed)
func setJobsStatus(ctx context.Context, runtime jobruntimes.Runtime, trainJob *trainer.TrainJob) error {
statuses, err := runtime.JobsStatus(ctx, trainJob)
if err != nil {
return err
}
trainJob.Status.JobsStatus = statuses
return nil
}

func (r *TrainJobReconciler) SetupWithManager(mgr ctrl.Manager, options controller.Options) error {
Expand Down
4 changes: 4 additions & 0 deletions pkg/runtime/core/clustertrainingruntime.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ func (r *ClusterTrainingRuntime) TerminalCondition(ctx context.Context, trainJob
return r.TrainingRuntime.TerminalCondition(ctx, trainJob)
}

func (r *ClusterTrainingRuntime) JobsStatus(ctx context.Context, trainJob *trainer.TrainJob) ([]trainer.JobStatus, error) {
return r.TrainingRuntime.JobsStatus(ctx, trainJob)
}

func (r *ClusterTrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBuilder {
return nil
}
Expand Down
4 changes: 4 additions & 0 deletions pkg/runtime/core/trainingruntime.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,10 @@ func (r *TrainingRuntime) TerminalCondition(ctx context.Context, trainJob *train
return r.framework.RunTerminalConditionPlugins(ctx, trainJob)
}

func (r *TrainingRuntime) JobsStatus(ctx context.Context, trainJob *trainer.TrainJob) ([]trainer.JobStatus, error) {
return r.framework.RunJobsStatusPlugins(ctx, trainJob)
}

func (r *TrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBuilder {
var builders []runtime.ReconcilerBuilder
for _, ex := range r.framework.WatchExtensionPlugins() {
Expand Down
32 changes: 24 additions & 8 deletions pkg/runtime/framework/core/framework.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ import (
index "github.com/kubeflow/trainer/v2/pkg/runtime/indexer"
)

var errorTooManyTerminalConditionPlugin = errors.New("too many TerminalCondition plugins are registered")
var (
errorTooManyTerminalConditionPlugin = errors.New("too many TerminalCondition plugins are registered")
errorTooManyJobsStatusPlugin = errors.New("too many JobsStatus plugins are registered")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we can only register one plugin for TerminalCondition and JobsStatus, do we need these errors ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's actually the error that's returned at init time in case more than one plugin is being registered.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But it won't be possible since we initialize the Framework here:

f := &Framework{
registry: r,
}

So this condition will never be executed:

if p, ok := plugin.(framework.TerminalConditionPlugin); ok {
if f.terminalConditionPlugin != nil {
return nil, errorTooManyTerminalConditionPlugin
}

Am I missing something ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right and that's where the errors are returned.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@astefanutti Can you explain the flow when f.terminalConditionPlugin != nil when we execute framework's New() function ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tenzen-y I am wondering if there is any need to distinguish plugins that we define in the Registry and plugins that are part of those "registered" plugins (e.g. TerminalCondition plugin is part of JobSet plugin).

I could not catch the reason why you want to do that.
I think the current @astefanutti 's implementations are a straightforward way.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tenzen-y I was asking what does "Plugin" mean in the Extension Framework from your point of view ?
And how are we going to allow users to extend the framework with their own plugins ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, got it. The trainer provides the 2 level abstraction layers which is Runtime Framework and Plugin.
But, sometimes, the plugin is only for a specific Runtime Framework like the JobSet plugin.

I guess that we might want to have mechanisms to represent the tie between the Runtime Framework and the Plugin, but it would be better to avoid adding any validation relationship to allow them to reuse the implemented Plugin wherever the Runtime Framework.

Anyway, we can consider that as an another enhancement.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good.

)

type Framework struct {
registry fwkplugins.Registry
Expand All @@ -43,7 +46,8 @@ type Framework struct {
watchExtensionPlugins []framework.WatchExtensionPlugin
podNetworkPlugins []framework.PodNetworkPlugin
componentBuilderPlugins []framework.ComponentBuilderPlugin
terminalConditionPlugins []framework.TerminalConditionPlugin
terminalConditionPlugin framework.TerminalConditionPlugin
jobsStatusPlugin framework.JobsStatusPlugin
}

func New(ctx context.Context, c client.Client, r fwkplugins.Registry, indexer client.FieldIndexer) (*Framework, error) {
Expand Down Expand Up @@ -80,7 +84,16 @@ func New(ctx context.Context, c client.Client, r fwkplugins.Registry, indexer cl
f.componentBuilderPlugins = append(f.componentBuilderPlugins, p)
}
if p, ok := plugin.(framework.TerminalConditionPlugin); ok {
f.terminalConditionPlugins = append(f.terminalConditionPlugins, p)
if f.terminalConditionPlugin != nil {
return nil, errorTooManyTerminalConditionPlugin
}
f.terminalConditionPlugin = p
}
if p, ok := plugin.(framework.JobsStatusPlugin); ok {
if f.jobsStatusPlugin != nil {
return nil, errorTooManyJobsStatusPlugin
}
f.jobsStatusPlugin = p
}
}
f.plugins = plugins
Expand Down Expand Up @@ -142,12 +155,15 @@ func (f *Framework) RunComponentBuilderPlugins(ctx context.Context, info *runtim
}

func (f *Framework) RunTerminalConditionPlugins(ctx context.Context, trainJob *trainer.TrainJob) (*metav1.Condition, error) {
// TODO (tenzen-y): Once we provide the Configuration API, we should validate which plugin should have terminalCondition execution points.
if len(f.terminalConditionPlugins) > 1 {
Comment on lines -145 to -146
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason why we can remove the verification of the number of plugins?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we discussed here, users won't be able to register more than one plugin for terminal condition: #2802 (comment)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the check is now moved at initialization, rather than at (first) execution.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, right. The new way makes more sense.
Thank you!

return nil, errorTooManyTerminalConditionPlugin
if f.terminalConditionPlugin != nil {
return f.terminalConditionPlugin.TerminalCondition(ctx, trainJob)
}
if len(f.terminalConditionPlugins) != 0 {
return f.terminalConditionPlugins[0].TerminalCondition(ctx, trainJob)
return nil, nil
}

func (f *Framework) RunJobsStatusPlugins(ctx context.Context, trainJob *trainer.TrainJob) ([]trainer.JobStatus, error) {
if f.jobsStatusPlugin != nil {
return f.jobsStatusPlugin.JobsStatus(ctx, trainJob)
}
return nil, nil
}
Expand Down
169 changes: 164 additions & 5 deletions pkg/runtime/framework/core/framework_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,8 @@ func TestNew(t *testing.T) {
&jobset.JobSet{},
&mpi.MPI{},
},
terminalConditionPlugins: []framework.TerminalConditionPlugin{
&jobset.JobSet{},
},
terminalConditionPlugin: &jobset.JobSet{},
jobsStatusPlugin: &jobset.JobSet{},
},
},
"indexer key for trainingRuntime and runtimeClass is an empty": {
Expand All @@ -136,7 +135,7 @@ func TestNew(t *testing.T) {
cmpopts.IgnoreUnexported(coscheduling.CoScheduling{}, volcano.Volcano{}, mpi.MPI{}, plainml.PlainML{}, torch.Torch{}, jobset.JobSet{}),
cmpopts.IgnoreFields(coscheduling.CoScheduling{}, "client"),
cmpopts.IgnoreFields(volcano.Volcano{}, "client"),
cmpopts.IgnoreFields(jobset.JobSet{}, "client"),
cmpopts.IgnoreFields(jobset.JobSet{}, "client", "restMapper", "scheme", "logger"),
cmpopts.IgnoreTypes(apiruntime.Scheme{}, meta.DefaultRESTMapper{}, fwkplugins.Registry{}),
cmpopts.SortMaps(func(a, b string) bool { return a < b }),
cmpopts.SortSlices(func(a, b framework.Plugin) bool { return a.Name() < b.Name() }),
Expand Down Expand Up @@ -1534,7 +1533,10 @@ func TestTerminalConditionPlugins(t *testing.T) {

fwk, err := New(ctx, c, tc.registry, testingutil.AsIndex(clientBuilder))
if err != nil {
t.Fatal(err)
if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 {
t.Errorf("Unexpected error (-want,+got):\n%s", diff)
}
return
}

gotCond, gotErr := fwk.RunTerminalConditionPlugins(ctx, tc.trainJob)
Expand All @@ -1549,6 +1551,163 @@ func TestTerminalConditionPlugins(t *testing.T) {
}
}

type fakeJobsStatusPlugin struct{}

var _ framework.JobsStatusPlugin = (*fakeJobsStatusPlugin)(nil)

func newFakeJobsStatusPlugin(context.Context, client.Client, client.FieldIndexer) (framework.Plugin, error) {
return &fakeJobsStatusPlugin{}, nil
}

const fakeJobsStatusPluginName = "fake-jobs-status"

func (f fakeJobsStatusPlugin) Name() string { return fakeJobsStatusPluginName }
func (f fakeJobsStatusPlugin) JobsStatus(context.Context, *trainer.TrainJob) ([]trainer.JobStatus, error) {
return []trainer.JobStatus{
{Name: "fake-job", Ready: 1, Succeeded: 0, Failed: 0, Active: 1, Suspended: 0},
}, nil
}

func TestJobsStatusPlugins(t *testing.T) {
cases := map[string]struct {
registry fwkplugins.Registry
trainJob *trainer.TrainJob
jobSet *jobsetv1alpha2.JobSet
wantStatuses []trainer.JobStatus
wantError error
}{
"JobSet with empty replicated jobs status": {
registry: fwkplugins.NewRegistry(),
trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "testing").
Obj(),
jobSet: testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "testing").
Obj(),
wantStatuses: nil,
},
"succeeded to obtain JobsStatus from JobSet with multiple replicated jobs": {
registry: fwkplugins.NewRegistry(),
trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "testing").
Obj(),
jobSet: testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "testing").
ReplicatedJobsStatuses([]jobsetv1alpha2.ReplicatedJobStatus{
{
Name: constants.DatasetInitializer,
Ready: 1,
Succeeded: 1,
Failed: 0,
Active: 0,
Suspended: 0,
},
{
Name: constants.ModelInitializer,
Ready: 1,
Succeeded: 1,
Failed: 0,
Active: 0,
Suspended: 0,
},
{
Name: constants.Node,
Ready: 2,
Succeeded: 0,
Failed: 0,
Active: 2,
Suspended: 0,
},
}).
Obj(),
wantStatuses: []trainer.JobStatus{
{
Name: constants.DatasetInitializer,
Ready: 1,
Succeeded: 1,
Failed: 0,
Active: 0,
Suspended: 0,
},
{
Name: constants.ModelInitializer,
Ready: 1,
Succeeded: 1,
Failed: 0,
Active: 0,
Suspended: 0,
},
{
Name: constants.Node,
Ready: 2,
Succeeded: 0,
Failed: 0,
Active: 2,
Suspended: 0,
},
},
},
"succeeded to obtain JobsStatus from JobSet with failed job": {
registry: fwkplugins.NewRegistry(),
trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "testing").
Obj(),
jobSet: testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "testing").
ReplicatedJobsStatuses([]jobsetv1alpha2.ReplicatedJobStatus{
{
Name: constants.Node,
Ready: 0,
Succeeded: 0,
Failed: 1,
Active: 0,
Suspended: 0,
},
}).
Obj(),
wantStatuses: []trainer.JobStatus{
{
Name: constants.Node,
Ready: 0,
Succeeded: 0,
Failed: 1,
Active: 0,
Suspended: 0,
},
},
},
"failed to obtain JobsStatus due to multiple JobsStatusPlugins": {
registry: fwkplugins.Registry{
jobset.Name: jobset.New,
fakeJobsStatusPluginName: newFakeJobsStatusPlugin,
},
wantError: errorTooManyJobsStatusPlugin,
},
}
for name, tc := range cases {
t.Run(name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
clientBuilder := testingutil.NewClientBuilder()
if tc.jobSet != nil {
clientBuilder = clientBuilder.WithObjects(tc.jobSet)
}
c := clientBuilder.Build()

fwk, err := New(ctx, c, tc.registry, testingutil.AsIndex(clientBuilder))
if err != nil {
if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 {
t.Errorf("Unexpected error (-want,+got):\n%s", diff)
}
return
}

gotStatuses, gotErr := fwk.RunJobsStatusPlugins(ctx, tc.trainJob)
if diff := cmp.Diff(tc.wantError, gotErr, cmpopts.EquateErrors()); len(diff) != 0 {
t.Errorf("Unexpected error (-want,+got):\n%s", diff)
}

if diff := cmp.Diff(tc.wantStatuses, gotStatuses); len(diff) != 0 {
t.Errorf("Unexpected jobs status (-want,+got):\n%s", diff)
}
})
}
}

func TestPodNetworkPlugins(t *testing.T) {
cases := map[string]struct {
registry fwkplugins.Registry
Expand Down
5 changes: 5 additions & 0 deletions pkg/runtime/framework/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,8 @@ type TerminalConditionPlugin interface {
Plugin
TerminalCondition(ctx context.Context, trainJob *trainer.TrainJob) (*metav1.Condition, error)
}

type JobsStatusPlugin interface {
Plugin
JobsStatus(ctx context.Context, trainJob *trainer.TrainJob) ([]trainer.JobStatus, error)
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto related to consolidation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated with the two plugins together. PTAL.

20 changes: 20 additions & 0 deletions pkg/runtime/framework/plugins/jobset/jobset.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ var _ framework.WatchExtensionPlugin = (*JobSet)(nil)
var _ framework.PodNetworkPlugin = (*JobSet)(nil)
var _ framework.ComponentBuilderPlugin = (*JobSet)(nil)
var _ framework.TerminalConditionPlugin = (*JobSet)(nil)
var _ framework.JobsStatusPlugin = (*JobSet)(nil)
var _ framework.CustomValidationPlugin = (*JobSet)(nil)

const Name = constants.JobSetKind
Expand Down Expand Up @@ -304,3 +305,22 @@ func (j *JobSet) TerminalCondition(ctx context.Context, trainJob *trainer.TrainJ
}
return nil, nil
}

func (j *JobSet) JobsStatus(ctx context.Context, trainJob *trainer.TrainJob) ([]trainer.JobStatus, error) {
jobSet := &jobsetv1alpha2.JobSet{}
if err := j.client.Get(ctx, client.ObjectKeyFromObject(trainJob), jobSet); err != nil {
return nil, err
}
var statuses []trainer.JobStatus
for _, status := range jobSet.Status.ReplicatedJobsStatus {
statuses = append(statuses, trainer.JobStatus{
Name: status.Name,
Ready: status.Ready,
Succeeded: status.Succeeded,
Failed: status.Failed,
Active: status.Active,
Suspended: status.Suspended,
Comment on lines +311 to +316
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering whether we will introduce a breaking change after this issue: kubernetes-sigs/jobset#723 ?
I proposed that we don't use Succeeded as Job condition in rJob.
Any thoughts @astefanutti @tenzen-y @kannon92 @kubeflow/kubeflow-trainer-team ?

Copy link
Contributor Author

@astefanutti astefanutti Sep 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree this would be better to be consistent across projects, thought I'd suggest not to hold this PR as it'll be a breaking change anyway and have it addressed separately so we can clearly "release-note" it as a breaking change. WDYT?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we discussed offline in Slack, we will use Succeeded condition for now.
We will change this API to Complete in a future version, once JobSet migrates its APIs.

})
}
return statuses, nil
}
1 change: 1 addition & 0 deletions pkg/runtime/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ type Runtime interface {
NewObjects(ctx context.Context, trainJob *trainer.TrainJob) ([]any, error)
RuntimeInfo(trainJob *trainer.TrainJob, runtimeTemplateSpec any, mlPolicy *trainer.MLPolicy, podGroupPolicy *trainer.PodGroupPolicy) (*Info, error)
TerminalCondition(ctx context.Context, trainJob *trainer.TrainJob) (*metav1.Condition, error)
JobsStatus(ctx context.Context, trainJob *trainer.TrainJob) ([]trainer.JobStatus, error)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
TerminalCondition(ctx context.Context, trainJob *trainer.TrainJob) (*metav1.Condition, error)
JobsStatus(ctx context.Context, trainJob *trainer.TrainJob) ([]trainer.JobStatus, error)
UnderlyingJobState(ctx context.Context, trainJob *trainer.TrainJob) (*metav1.Condition, []trainer.JobStatus, error)

I think that TerminalCondition and JobsStatus responsibilities are almost same which is just propagate the Job state from underlying one to TrainJob.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense for me to consolidate those two plugins together.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've folded the two plugins together. PTAL.

EventHandlerRegistrars() []ReconcilerBuilder
ValidateObjects(ctx context.Context, old, new *trainer.TrainJob) (admission.Warnings, field.ErrorList)
}
5 changes: 5 additions & 0 deletions pkg/util/testing/wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,11 @@ func (j *JobSetWrapper) DependsOn(rJobName string, dependsOn ...jobsetv1alpha2.D
return j
}

func (j *JobSetWrapper) ReplicatedJobsStatuses(statuses []jobsetv1alpha2.ReplicatedJobStatus) *JobSetWrapper {
j.Status.ReplicatedJobsStatus = statuses
return j
}

func (j *JobSetWrapper) Obj() *jobsetv1alpha2.JobSet {
return &j.JobSet
}
Expand Down
Loading
Loading