Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 7 additions & 15 deletions pkg/controller/trainjob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,6 @@ func (r *TrainJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c
log := ctrl.LoggerFrom(ctx).WithValues("trainJob", klog.KObj(&trainJob))
ctx = ctrl.LoggerInto(ctx, log)
log.V(2).Info("Reconciling TrainJob")
if isTrainJobFinished(&trainJob) {
log.V(5).Info("TrainJob has already been finished")
return ctrl.Result{}, nil
}

var err error
// Keep track of the origin TrainJob status
Expand Down Expand Up @@ -138,8 +134,9 @@ func (r *TrainJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c
}

setSuspendedCondition(&trainJob)
if terminalCondErr := setTerminalCondition(ctx, runtime, &trainJob); terminalCondErr != nil {
err = errors.Join(err, terminalCondErr)

if statusErr := setTrainJobStatus(ctx, runtime, &trainJob); statusErr != nil {
err = errors.Join(err, statusErr)
}

if !equality.Semantic.DeepEqual(&trainJob.Status, originStatus) {
Expand Down Expand Up @@ -256,22 +253,17 @@ func removeFailedCondition(trainJob *trainer.TrainJob) {
meta.RemoveStatusCondition(&trainJob.Status.Conditions, trainer.TrainJobFailed)
}

func setTerminalCondition(ctx context.Context, runtime jobruntimes.Runtime, trainJob *trainer.TrainJob) error {
terminalCond, err := runtime.TerminalCondition(ctx, trainJob)
func setTrainJobStatus(ctx context.Context, runtime jobruntimes.Runtime, trainJob *trainer.TrainJob) error {
status, err := runtime.TrainJobStatus(ctx, trainJob)
if err != nil {
return err
}
if terminalCond != nil {
meta.SetStatusCondition(&trainJob.Status.Conditions, *terminalCond)
if status != nil {
trainJob.Status = *status
}
return nil
}

func isTrainJobFinished(trainJob *trainer.TrainJob) bool {
return meta.IsStatusConditionTrue(trainJob.Status.Conditions, trainer.TrainJobComplete) ||
meta.IsStatusConditionTrue(trainJob.Status.Conditions, trainer.TrainJobFailed)
}

func (r *TrainJobReconciler) SetupWithManager(mgr ctrl.Manager, options controller.Options) error {
b := builder.TypedControllerManagedBy[reconcile.Request](mgr).
Named("trainjob_controller").
Expand Down
5 changes: 2 additions & 3 deletions pkg/runtime/core/clustertrainingruntime.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (
"errors"
"fmt"

metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/util/validation/field"
"sigs.k8s.io/controller-runtime/pkg/client"
Expand Down Expand Up @@ -71,8 +70,8 @@ func (r *ClusterTrainingRuntime) RuntimeInfo(
return r.TrainingRuntime.RuntimeInfo(trainJob, runtimeTemplateSpec, mlPolicy, podGroupPolicy)
}

func (r *ClusterTrainingRuntime) TerminalCondition(ctx context.Context, trainJob *trainer.TrainJob) (*metav1.Condition, error) {
return r.TrainingRuntime.TerminalCondition(ctx, trainJob)
func (r *ClusterTrainingRuntime) TrainJobStatus(ctx context.Context, trainJob *trainer.TrainJob) (*trainer.TrainJobStatus, error) {
return r.TrainingRuntime.TrainJobStatus(ctx, trainJob)
}

func (r *ClusterTrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBuilder {
Expand Down
4 changes: 2 additions & 2 deletions pkg/runtime/core/trainingruntime.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,8 @@ func syncPodSets(info *runtime.Info) {
}
}

func (r *TrainingRuntime) TerminalCondition(ctx context.Context, trainJob *trainer.TrainJob) (*metav1.Condition, error) {
return r.framework.RunTerminalConditionPlugins(ctx, trainJob)
func (r *TrainingRuntime) TrainJobStatus(ctx context.Context, trainJob *trainer.TrainJob) (*trainer.TrainJobStatus, error) {
return r.framework.RunTrainJobStatusPlugin(ctx, trainJob)
}

func (r *TrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBuilder {
Expand Down
22 changes: 10 additions & 12 deletions pkg/runtime/framework/core/framework.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import (
"context"
"errors"

metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/validation/field"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"
Expand All @@ -32,7 +31,7 @@ import (
index "github.com/kubeflow/trainer/v2/pkg/runtime/indexer"
)

var errorTooManyTerminalConditionPlugin = errors.New("too many TerminalCondition plugins are registered")
var errorTooManyTrainJobStatusPlugin = errors.New("too many TrainJobStatus plugins are registered")

type Framework struct {
registry fwkplugins.Registry
Expand All @@ -43,7 +42,7 @@ type Framework struct {
watchExtensionPlugins []framework.WatchExtensionPlugin
podNetworkPlugins []framework.PodNetworkPlugin
componentBuilderPlugins []framework.ComponentBuilderPlugin
terminalConditionPlugins []framework.TerminalConditionPlugin
trainJobStatusPlugin framework.TrainJobStatusPlugin
}

func New(ctx context.Context, c client.Client, r fwkplugins.Registry, indexer client.FieldIndexer) (*Framework, error) {
Expand Down Expand Up @@ -79,8 +78,11 @@ func New(ctx context.Context, c client.Client, r fwkplugins.Registry, indexer cl
if p, ok := plugin.(framework.ComponentBuilderPlugin); ok {
f.componentBuilderPlugins = append(f.componentBuilderPlugins, p)
}
if p, ok := plugin.(framework.TerminalConditionPlugin); ok {
f.terminalConditionPlugins = append(f.terminalConditionPlugins, p)
if p, ok := plugin.(framework.TrainJobStatusPlugin); ok {
if f.trainJobStatusPlugin != nil {
return nil, errorTooManyTrainJobStatusPlugin
}
f.trainJobStatusPlugin = p
}
}
f.plugins = plugins
Expand Down Expand Up @@ -141,13 +143,9 @@ func (f *Framework) RunComponentBuilderPlugins(ctx context.Context, info *runtim
return objs, nil
}

func (f *Framework) RunTerminalConditionPlugins(ctx context.Context, trainJob *trainer.TrainJob) (*metav1.Condition, error) {
// TODO (tenzen-y): Once we provide the Configuration API, we should validate which plugin should have terminalCondition execution points.
if len(f.terminalConditionPlugins) > 1 {
Comment on lines -145 to -146
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason why we can remove the verification of the number of plugins?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we discussed here, users won't be able to register more than one plugin for terminal condition: #2802 (comment)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the check is now moved at initialization, rather than at (first) execution.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, right. The new way makes more sense.
Thank you!

return nil, errorTooManyTerminalConditionPlugin
}
if len(f.terminalConditionPlugins) != 0 {
return f.terminalConditionPlugins[0].TerminalCondition(ctx, trainJob)
func (f *Framework) RunTrainJobStatusPlugin(ctx context.Context, trainJob *trainer.TrainJob) (*trainer.TrainJobStatus, error) {
if f.trainJobStatusPlugin != nil {
return f.trainJobStatusPlugin.Status(ctx, trainJob)
}
return nil, nil
}
Expand Down
Loading
Loading