-
Notifications
You must be signed in to change notification settings - Fork 41
Description
What happened?
The SDK currently has a hardcoded mapping of Docker image names to trainer configurations in ALL_TRAINERS. When using custom Docker images that aren't in this predefined mapping, the SDK falls back to DEFAULT_TRAINER which uses torchrun instead of mpirun, causing MPI training jobs to fail.
Root Cause:
In sdk/python/kubeflow/trainer/utils/utils.py:
# Extract image name from the container image to get appropriate Trainer.
image_name = trainer_container.image.split(":")[0]
trainer = types.ALL_TRAINERS.get(image_name, types.DEFAULT_TRAINER)The ALL_TRAINERS dictionary in types.py only includes predefined image names like:
ghcr.io/kubeflow/trainer/deepspeed-runtimeghcr.io/kubeflow/trainer/mlx-runtimepytorch/pytorch
When using custom images like custom-registry.example.com/my-deepspeed-runtime, the SDK falls back to DEFAULT_TRAINER which uses torchrun entrypoint instead of mpirun.
Current Behavior:
- Custom images are not recognized in the predefined mapping
- SDK falls back to
DEFAULT_TRAINERwithtorchrunentrypoint - MPI training commands are dropped from the generated TrainJob
- Distributed training fails because
mpirunis missing
Additional Context:
There's already a TODO comment in the code suggesting this should be addressed:
# TODO (andreyvelich): We should allow user to overrides the default image names.Proposed Solutions:
- Add configuration support: Allow users to specify custom image-to-trainer mappings
- Image pattern matching: Support wildcard/regex patterns for image names
- Runtime-based detection: Use runtime metadata to determine trainer type instead of relying solely on image names
- Fallback logic: Improve fallback behavior for unrecognized images
This affects users who need to use custom Docker images for their training workloads, particularly in enterprise environments where standard images may not be available or suitable.
What did you expect to happen?
- Custom images should be properly detected for MPI training
- The SDK should support configurable image-to-trainer mappings
- MPI entrypoints should be preserved for custom DeepSpeed images
- The generated TrainJob should include the correct
mpiruncommand regardless of image name
Environment
Kubeflow Python SDK version:
```bash
$ pip show kubeflow
Name: kubeflow
Version: 0.1.0
Location: /Users/subramk/source/github.com/jskswamy/101/kubeflow-trainer-example/.venv/lib/python3.12/site-packages
Requires: kubeflow-trainer-api, kubernetes, pydantic
Required-by:
"kubeflow @ git+https://github.com/kubeflow/sdk.git@main#subdirectory=python",
Impacted by this bug?
Give it a 👍 We prioritize the issues with most 👍