Skip to content

Support custom Docker images for MPI training runtimes #29

@jskswamy

Description

@jskswamy

What happened?

The SDK currently has a hardcoded mapping of Docker image names to trainer configurations in ALL_TRAINERS. When using custom Docker images that aren't in this predefined mapping, the SDK falls back to DEFAULT_TRAINER which uses torchrun instead of mpirun, causing MPI training jobs to fail.

Root Cause:
In sdk/python/kubeflow/trainer/utils/utils.py:

# Extract image name from the container image to get appropriate Trainer.
image_name = trainer_container.image.split(":")[0]
trainer = types.ALL_TRAINERS.get(image_name, types.DEFAULT_TRAINER)

The ALL_TRAINERS dictionary in types.py only includes predefined image names like:

  • ghcr.io/kubeflow/trainer/deepspeed-runtime
  • ghcr.io/kubeflow/trainer/mlx-runtime
  • pytorch/pytorch

When using custom images like custom-registry.example.com/my-deepspeed-runtime, the SDK falls back to DEFAULT_TRAINER which uses torchrun entrypoint instead of mpirun.

Current Behavior:

  • Custom images are not recognized in the predefined mapping
  • SDK falls back to DEFAULT_TRAINER with torchrun entrypoint
  • MPI training commands are dropped from the generated TrainJob
  • Distributed training fails because mpirun is missing

Additional Context:
There's already a TODO comment in the code suggesting this should be addressed:

# TODO (andreyvelich): We should allow user to overrides the default image names.

Proposed Solutions:

  1. Add configuration support: Allow users to specify custom image-to-trainer mappings
  2. Image pattern matching: Support wildcard/regex patterns for image names
  3. Runtime-based detection: Use runtime metadata to determine trainer type instead of relying solely on image names
  4. Fallback logic: Improve fallback behavior for unrecognized images

This affects users who need to use custom Docker images for their training workloads, particularly in enterprise environments where standard images may not be available or suitable.

What did you expect to happen?

  • Custom images should be properly detected for MPI training
  • The SDK should support configurable image-to-trainer mappings
  • MPI entrypoints should be preserved for custom DeepSpeed images
  • The generated TrainJob should include the correct mpirun command regardless of image name

Environment

Kubeflow Python SDK version:
```bash
$ pip show kubeflow

Name: kubeflow
Version: 0.1.0
Location: /Users/subramk/source/github.com/jskswamy/101/kubeflow-trainer-example/.venv/lib/python3.12/site-packages
Requires: kubeflow-trainer-api, kubernetes, pydantic
Required-by:

 "kubeflow @ git+https://github.com/kubeflow/sdk.git@main#subdirectory=python",

Impacted by this bug?

Give it a 👍 We prioritize the issues with most 👍

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions