Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion api/openapi-spec/swagger.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,9 @@ spec:
- message: ManagedBy value is immutable
rule: self == oldSelf
podSpecOverrides:
description: Custom overrides for the training runtime.
description: |-
Custom overrides for the training runtime.
When multiple overrides apply to the same targetJob, later entries in the slice override earlier field values.
items:
description: PodSpecOverride represents the custom overrides that
will be applied for the TrainJob's resources.
Expand Down
4 changes: 3 additions & 1 deletion manifests/base/crds/trainer.kubeflow.org_trainjobs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,9 @@ spec:
- message: ManagedBy value is immutable
rule: self == oldSelf
podSpecOverrides:
description: Custom overrides for the training runtime.
description: |-
Custom overrides for the training runtime.
When multiple overrides apply to the same targetJob, later entries in the slice override earlier field values.
items:
description: PodSpecOverride represents the custom overrides that
will be applied for the TrainJob's resources.
Expand Down
1 change: 1 addition & 0 deletions pkg/apis/trainer/v1alpha1/trainjob_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ type TrainJobSpec struct {
Annotations map[string]string `json:"annotations,omitempty"`

// Custom overrides for the training runtime.
// When multiple overrides apply to the same targetJob, later entries in the slice override earlier field values.
// +listType=atomic
PodSpecOverrides []PodSpecOverride `json:"podSpecOverrides,omitempty"`

Expand Down
2 changes: 1 addition & 1 deletion pkg/apis/trainer/v1alpha1/zz_generated.openapi.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

75 changes: 75 additions & 0 deletions pkg/runtime/core/trainingruntime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,81 @@ func TestTrainingRuntimeNewObjects(t *testing.T) {
Obj(),
},
},
"succeeded to build JobSet with TrainJob's PodSpecOverrides containing duplicate TargetJobs": {
trainingRuntime: testingutil.MakeTrainingRuntimeWrapper(metav1.NamespaceDefault, "test-runtime").RuntimeSpec(
testingutil.MakeTrainingRuntimeSpecWrapper(testingutil.MakeTrainingRuntimeWrapper(metav1.NamespaceDefault, "test-runtime").Spec).
Container(constants.Node, constants.Node, "test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests).
Obj(),
).Obj(),
trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job").
UID("uid").
RuntimeRef(trainer.SchemeGroupVersion.WithKind(trainer.TrainingRuntimeKind), "test-runtime").
PodSpecOverrides([]trainer.PodSpecOverride{
{
TargetJobs: []trainer.PodSpecOverrideTargetJob{{Name: constants.Node}},
NodeSelector: map[string]string{
"node.kubernetes.io/instance-type": "p5.48xlarge",
},
},
{
TargetJobs: []trainer.PodSpecOverrideTargetJob{{Name: constants.Node}},
ServiceAccountName: ptr.To("test-sa"),
},
}).
Obj(),
wantObjs: []runtime.Object{
testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test-job").
ControllerReference(trainer.SchemeGroupVersion.WithKind(trainer.TrainJobKind), "test-job", "uid").
Replicas(1, constants.DatasetInitializer, constants.ModelInitializer, constants.Node).
Parallelism(1, constants.DatasetInitializer, constants.ModelInitializer, constants.Node).
Completions(1, constants.DatasetInitializer, constants.ModelInitializer, constants.Node).
Container(constants.Node, constants.Node, "test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests).
NodeSelector(constants.Node,
map[string]string{
"node.kubernetes.io/instance-type": "p5.48xlarge",
}).
ServiceAccountName(constants.Node, "test-sa").
Obj(),
},
},
"succeeded to build JobSet with TrainJob's PodSpecOverrides targeting the same job with different values": {
trainingRuntime: testingutil.MakeTrainingRuntimeWrapper(metav1.NamespaceDefault, "test-runtime").RuntimeSpec(
testingutil.MakeTrainingRuntimeSpecWrapper(testingutil.MakeTrainingRuntimeWrapper(metav1.NamespaceDefault, "test-runtime").Spec).
Container(constants.Node, constants.Node, "test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests).
Obj(),
).Obj(),
trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job").
UID("uid").
RuntimeRef(trainer.SchemeGroupVersion.WithKind(trainer.TrainingRuntimeKind), "test-runtime").
PodSpecOverrides([]trainer.PodSpecOverride{
{
TargetJobs: []trainer.PodSpecOverrideTargetJob{{Name: constants.Node}},
NodeSelector: map[string]string{
"node.kubernetes.io/instance-type": "p5.48xlarge",
},
},
{
TargetJobs: []trainer.PodSpecOverrideTargetJob{{Name: constants.Node}},
NodeSelector: map[string]string{
"node.kubernetes.io/instance-type": "p5en.48xlarge",
},
},
}).
Obj(),
wantObjs: []runtime.Object{
testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test-job").
ControllerReference(trainer.SchemeGroupVersion.WithKind(trainer.TrainJobKind), "test-job", "uid").
Replicas(1, constants.DatasetInitializer, constants.ModelInitializer, constants.Node).
Parallelism(1, constants.DatasetInitializer, constants.ModelInitializer, constants.Node).
Completions(1, constants.DatasetInitializer, constants.ModelInitializer, constants.Node).
Container(constants.Node, constants.Node, "test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests).
NodeSelector(constants.Node,
map[string]string{
"node.kubernetes.io/instance-type": "p5en.48xlarge",
}).
Obj(),
},
},
"succeeded to build JobSet with dataset and model initializer from the TrainJob.": {
trainingRuntime: testingutil.MakeTrainingRuntimeWrapper(metav1.NamespaceDefault, "test-runtime").RuntimeSpec(
testingutil.MakeTrainingRuntimeSpecWrapper(testingutil.MakeTrainingRuntimeWrapper(metav1.NamespaceDefault, "test-runtime").Spec).
Expand Down
9 changes: 0 additions & 9 deletions pkg/runtime/framework/plugins/jobset/jobset.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,16 +121,7 @@ func (j *JobSet) Validate(ctx context.Context, info *runtime.Info, oldObj, newOb
allErrs = append(allErrs, j.checkPodSpecOverridesImmutability(ctx, oldObj, newObj)...)

// TODO (andreyvelich): Validate Volumes, VolumeMounts, and Tolerations.
targetJobNames := sets.New[string]()
for _, podSpecOverride := range newObj.Spec.PodSpecOverrides {
// Validate that there are no duplicate target job names within the same PodSpecOverride
for _, targetJob := range podSpecOverride.TargetJobs {
if targetJobNames.Has(targetJob.Name) {
allErrs = append(allErrs, field.Duplicate(podSpecOverridePath, targetJob.Name))
}
targetJobNames.Insert(targetJob.Name)
}

for _, targetJob := range podSpecOverride.TargetJobs {
containers, ok := rJobContainerNames[targetJob.Name]
if !ok {
Expand Down
6 changes: 3 additions & 3 deletions test/integration/webhooks/trainjob_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -351,9 +351,9 @@ var _ = ginkgo.Describe("TrainJob marker validations and defaulting", ginkgo.Ord
Obj()
},
testingutil.BeInvalidError()),
ginkgo.Entry("Should fail in creating trainJob with podSpecOverrides have duplicated targetJob",
ginkgo.Entry("Should succeed to create trainJob with podSpecOverrides containing duplicate targetJob",
func() *trainer.TrainJob {
return testingutil.MakeTrainJobWrapper(ns.Name, "invalid-pod-spec-overrides").
return testingutil.MakeTrainJobWrapper(ns.Name, "duplicated-podspecoverrides-target-jobs").
RuntimeRef(trainer.GroupVersion.WithKind(trainer.TrainingRuntimeKind), "testing").
PodSpecOverrides([]trainer.PodSpecOverride{
{
Expand All @@ -367,7 +367,7 @@ var _ = ginkgo.Describe("TrainJob marker validations and defaulting", ginkgo.Ord
}).
Obj()
},
testingutil.BeForbiddenError()),
gomega.Succeed()),
)
ginkgo.DescribeTable("Defaulting TrainJob on creation", func(trainJob func() *trainer.TrainJob, wantTrainJob func() *trainer.TrainJob) {
created := trainJob()
Expand Down
Loading