Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ spec:
Defaults to OpenMPI.
enum:
- OpenMPI
- ""
type: string
numProcPerNode:
default: 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ spec:
Defaults to OpenMPI.
enum:
- OpenMPI
- ""
type: string
numProcPerNode:
default: 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ spec:
Defaults to OpenMPI.
enum:
- OpenMPI
- ""
type: string
numProcPerNode:
default: 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ spec:
Defaults to OpenMPI.
enum:
- OpenMPI
- ""
type: string
numProcPerNode:
default: 1
Expand Down
4 changes: 2 additions & 2 deletions pkg/apis/trainer/v1alpha1/trainingruntime_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,9 @@ type MPIMLPolicySource struct {
// mpiImplementation is the name of the MPI implementation to create the appropriate hostfile.
// Defaults to OpenMPI.
// +kubebuilder:default=OpenMPI
// +kubebuilder:validation:Enum=OpenMPI
// +kubebuilder:validation:Enum=OpenMPI;""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to allow an empty input, ""?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want to keep the linter I think we have, otherwise if forces to change the field to a non-pointer.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see. Thank you for describing that.

// +optional
MPIImplementation MPIImplementation `json:"mpiImplementation,omitempty"`
MPIImplementation *MPIImplementation `json:"mpiImplementation,omitempty"`

// sshAuthMountPath is the directory where SSH keys are mounted.
// Defaults to /root/.ssh.
Expand Down
5 changes: 5 additions & 0 deletions pkg/apis/trainer/v1alpha1/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions pkg/runtime/framework/plugins/mpi/mpi.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ func (m *MPI) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) er
WithName(constants.MPIHostfileVolumeName).
WithMountPath(constants.MPIHostfileDir),
)
switch info.RuntimePolicy.MLPolicySource.MPI.MPIImplementation {
switch *info.RuntimePolicy.MLPolicySource.MPI.MPIImplementation {
case trainer.MPIImplementationOpenMPI:
apply.UpsertEnvVars(
&info.TemplateSpec.PodSets[psIdx].Containers[cIdx].Env,
Expand Down Expand Up @@ -304,7 +304,7 @@ func (m *MPI) buildHostFileConfigMap(info *runtime.Info, trainJob *trainer.Train
if !isNode(runLauncherAsNode, ps) {
continue
}
switch info.RuntimePolicy.MLPolicySource.MPI.MPIImplementation {
switch *info.RuntimePolicy.MLPolicySource.MPI.MPIImplementation {
case trainer.MPIImplementationOpenMPI:
for e := range ps.Endpoints {
hostFile.WriteString(fmt.Sprintf("%s slots=%d\n", e, slots))
Expand Down
2 changes: 1 addition & 1 deletion pkg/util/testing/wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -1239,7 +1239,7 @@ func (m *MLPolicySourceWrapper) MPIPolicy(numProcPerNode *int32, MPImplementatio
m.MPI = &trainer.MPIMLPolicySource{}
}
m.MPI.NumProcPerNode = numProcPerNode
m.MPI.MPIImplementation = MPImplementation
m.MPI.MPIImplementation = &MPImplementation
m.MPI.SSHAuthMountPath = sshAuthMountPath
m.MPI.RunLauncherAsNode = runLauncherAsNode
return m
Expand Down
2 changes: 1 addition & 1 deletion test/integration/webhooks/trainingruntime_webhook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ var _ = ginkgo.Describe("TrainingRuntime marker validations and defaulting", gin
WithMLPolicy(
testingutil.MakeMLPolicyWrapper().
WithMLPolicySource(*testingutil.MakeMLPolicySourceWrapper().
MPIPolicy(ptr.To[int32](1), "", ptr.To("/usr/dir"), ptr.To(false)).
MPIPolicy(ptr.To[int32](1), trainer.MPIImplementationOpenMPI, ptr.To("/usr/dir"), ptr.To(false)).
Obj(),
).
Obj(),
Expand Down
Loading