Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,37 @@ TrainerClient().wait_for_job_status(job_id)
print("\n".join(TrainerClient().get_job_logs(name=job_id)))
```

### Run a custom command with CommandTrainer

CommandTrainer runs an arbitrary command inside the runtime’s launcher (torchrun/mpirun/python) while preserving package installation, env vars, and resources.

```python
from kubeflow.trainer import TrainerClient
from kubeflow.trainer.types import types

client = TrainerClient()
rt = client.get_runtime("torch") # or "mpi", "plainml"

trainer = types.CommandTrainer(
Copy link
Member

@andreyvelich andreyvelich Sep 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an interesting idea and somewhat aligned with what we discussed today at the Kubeflow SDK call with KubernetesTrainer proposed by @szaher: https://youtu.be/mv8GoWdefck?t=832

Since we distinguish the runtime trainers between CustomTrainer and BuiltinTrainer, I am wondering if we want to introduce CustomTrainerContainer() type which give users control to configure image, container, args instead of passing the training function.

Would that be helpful for integration between KFP and Trainer ?

Thoughts @kubeflow/kubeflow-sdk-team @mprahl @franciscojavierarceo @ederign @rudeigerc?

command=["python"],
args=["train.py", "--epochs", "2"],
packages_to_install=["numpy"],
pip_index_urls=["https://pypi.org/simple"],
num_nodes=2,
resources_per_node={"gpu": "1"},
env={"FOO": "bar"},
)

job = client.train(runtime=rt, trainer=trainer)
print("Job:", job)
```

Notes:

- Launcher is runtime-aware (torch → torchrun, mpi → mpirun, plain → python).
- Packages are installed before the command; MPI installs use `--user`.
- Ensure your script exists in the container (image/ConfigMap/volume/init).

## Supported Kubeflow Projects

| Project | Status | Version Support | Description |
Expand Down
2 changes: 2 additions & 0 deletions kubeflow/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from kubeflow.trainer.types.types import (
BuiltinTrainer,
CustomTrainer,
CommandTrainer,
DataFormat,
DataType,
HuggingFaceDatasetInitializer,
Expand All @@ -43,6 +44,7 @@
__all__ = [
"BuiltinTrainer",
"CustomTrainer",
"CommandTrainer",
"DataFormat",
"DATASET_PATH",
"DataType",
Expand Down
10 changes: 7 additions & 3 deletions kubeflow/trainer/api/trainer_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,21 +95,25 @@ def train(
self,
runtime: Optional[types.Runtime] = None,
initializer: Optional[types.Initializer] = None,
trainer: Optional[Union[types.CustomTrainer, types.BuiltinTrainer]] = None,
trainer: Optional[
Union[types.CustomTrainer, types.BuiltinTrainer, types.CommandTrainer]
] = None,
) -> str:
"""Create a TrainJob. You can configure the TrainJob using one of these trainers:

- CustomTrainer: Runs training with a user-defined function that fully encapsulates the
training process.
- BuiltinTrainer: Uses a predefined trainer with built-in post-training logic, requiring
only parameter configuration.
- CommandTrainer: Executes an arbitrary command inside the runtime's launcher while
preserving environment and resource settings.

Args:
runtime: Optional reference to one of the existing runtimes. Defaults to the
torch-distributed runtime if not provided.
initializer: Optional configuration for the dataset and model initializers.
trainer: Optional configuration for a CustomTrainer or BuiltinTrainer. If not specified,
the TrainJob will use the runtime's default values.
trainer: Optional configuration for a CustomTrainer, BuiltinTrainer, or CommandTrainer.
If not specified, the TrainJob will use the runtime's default values.

Returns:
The unique name of the TrainJob that has been generated.
Expand Down
4 changes: 3 additions & 1 deletion kubeflow/trainer/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ def train(
self,
runtime: Optional[types.Runtime] = None,
initializer: Optional[types.Initializer] = None,
trainer: Optional[Union[types.CustomTrainer, types.BuiltinTrainer]] = None,
trainer: Optional[
Union[types.CustomTrainer, types.CommandTrainer, types.BuiltinTrainer]
] = None,
) -> str:
raise NotImplementedError()

Expand Down
10 changes: 8 additions & 2 deletions kubeflow/trainer/backends/kubernetes/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def train(
self,
runtime: Optional[types.Runtime] = None,
initializer: Optional[types.Initializer] = None,
trainer: Optional[Union[types.CustomTrainer, types.BuiltinTrainer]] = None,
trainer: Optional[Union[types.CustomTrainer, types.CommandTrainer, types.BuiltinTrainer]] = None,
) -> str:
if runtime is None:
runtime = self.get_runtime(constants.TORCH_RUNTIME)
Expand All @@ -201,6 +201,12 @@ def train(
raise ValueError(f"CustomTrainer can't be used with {runtime} runtime")
trainer_crd = utils.get_trainer_crd_from_custom_trainer(runtime, trainer)

# If users choose to use a command trainer to run custom command.
elif isinstance(trainer, types.CommandTrainer):
if runtime.trainer.trainer_type != types.TrainerType.CUSTOM_TRAINER:
raise ValueError(f"CommandTrainer can't be used with {runtime} runtime")
trainer_crd = utils.get_trainer_crd_from_command_trainer(runtime, trainer)

# If users choose to use a builtin trainer for post-training.
elif isinstance(trainer, types.BuiltinTrainer):
if runtime.trainer.trainer_type != types.TrainerType.BUILTIN_TRAINER:
Expand All @@ -212,7 +218,7 @@ def train(
else:
raise ValueError(
f"The trainer type {type(trainer)} is not supported. "
"Please use CustomTrainer or BuiltinTrainer."
"Please use CustomTrainer, CommandTrainer or BuiltinTrainer."
)

train_job = models.TrainerV1alpha1TrainJob(
Expand Down
23 changes: 20 additions & 3 deletions kubeflow/trainer/backends/kubernetes/backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,17 @@
It tests KubernetesBackend's behavior across job listing, resource creation etc
"""

from dataclasses import asdict
import datetime
import multiprocessing
import random
import string
import uuid
from dataclasses import asdict
from typing import Optional
from unittest.mock import Mock, patch
import uuid

from kubeflow_trainer_api import models
import pytest
from kubeflow_trainer_api import models

from kubeflow.trainer.backends.kubernetes.backend import KubernetesBackend
from kubeflow.trainer.backends.kubernetes.types import KubernetesBackendConfig
Expand Down Expand Up @@ -821,6 +821,23 @@ def test_train(kubernetes_backend, test_case):
print("test execution complete")


def test_train_routes_command_trainer(kubernetes_backend):
"""Ensure CommandTrainer is routed to its CRD builder in backend.train."""
runtime = kubernetes_backend.get_runtime(TORCH_RUNTIME)
cmd_trainer = types.CommandTrainer(command=["python"], args=["train.py"])

fake_crd = models.TrainerV1alpha1Trainer()

with patch(
"kubeflow.trainer.utils.utils.get_trainer_crd_from_command_trainer",
return_value=fake_crd,
) as mocked_builder:
job_name = kubernetes_backend.train(runtime=runtime, trainer=cmd_trainer)

mocked_builder.assert_called_once()
assert isinstance(job_name, str) and len(job_name) > 0


@pytest.mark.parametrize(
"test_case",
[
Expand Down
29 changes: 29 additions & 0 deletions kubeflow/trainer/types/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,3 +240,32 @@ class Initializer:

dataset: Optional[HuggingFaceDatasetInitializer] = None
model: Optional[HuggingFaceModelInitializer] = None


@dataclass
class CommandTrainer:
"""Command Trainer configuration.

If "command" is set, it becomes the container entrypoint and "args" are passed as container args.
If "command" is not set, defaults are chosen by runtime framework (e.g., torch→torchrun,
mpi→mpirun, torch-tune→tune run, otherwise python), and "args" are passed as-is.

Args:
command (Optional[List[str]]): The command to execute (e.g., ["python"]).
args (Optional[List[str]]): Positional arguments for the command.
packages_to_install (Optional[List[str]]): Python packages to install.
pip_index_urls (List[str]): Index and extra index URLs; first is index-url.
pip_extra_args (Optional[List[str]]): Extra pip flags (e.g., ["--no-cache-dir"]).
num_nodes (Optional[int]): Number of nodes for training.
resources_per_node (Optional[Dict]): Resources per node.
env (Optional[Dict[str, str]]): Environment variables.
"""

command: Optional[list[str]] = None
args: Optional[list[str]] = None
packages_to_install: Optional[list[str]] = None
pip_index_urls: list[str] = field(default_factory=lambda: list(constants.DEFAULT_PIP_INDEX_URLS))
pip_extra_args: Optional[list[str]] = None
num_nodes: Optional[int] = None
resources_per_node: Optional[dict] = None
env: Optional[dict[str, str]] = None
12 changes: 12 additions & 0 deletions kubeflow/trainer/types/types_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from kubeflow.trainer.types import types


class TestCommandTrainerType:
def test_command_trainer_dataclass_minimal(self):
trainer = types.CommandTrainer(command=["python"], args=["train.py"])

assert trainer.command == ["python"]
assert trainer.args == ["train.py"]
assert trainer.pip_index_urls and isinstance(trainer.pip_index_urls, list)
assert trainer.packages_to_install is None
assert trainer.env is None
89 changes: 82 additions & 7 deletions kubeflow/trainer/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ def get_script_for_python_packages(
packages_to_install: list[str],
pip_index_urls: list[str],
is_mpi: bool,
pip_extra_args: Optional[list[str]] = None,
) -> str:
"""
Get init script to install Python packages from the given pip index URLs.
Expand All @@ -281,14 +282,16 @@ def get_script_for_python_packages(
"""
)

script_for_python_packages = (
header_script
+ "PIP_DISABLE_PIP_VERSION_CHECK=1 python -m pip install --quiet "
+ "--no-warn-script-location {} {}\n".format(
" ".join(options),
packages_str,
)
extra_args = " ".join(pip_extra_args or [])
options_args = " ".join(options)

base_cmd = (
f"PIP_DISABLE_PIP_VERSION_CHECK=1 python -m pip install --quiet --no-warn-script-location {options_args} {packages_str}" # noqa: E501
)
if extra_args:
base_cmd = f"{base_cmd} {extra_args}"

script_for_python_packages = f"{header_script}{base_cmd}\n"

return script_for_python_packages

Expand Down Expand Up @@ -365,6 +368,78 @@ def get_command_using_train_func(
return command


def get_command_using_user_command(
runtime: types.Runtime,
command: list[str],
command_args: Optional[list[str]],
pip_index_urls: list[str],
packages_to_install: Optional[list[str]],
pip_extra_args: Optional[list[str]] = None,
) -> list[str]:
"""
Build a runtime-aware command to execute an arbitrary user command with args.
Preserves the runtime launcher (torchrun/mpirun/python) and prepends optional
pip installs using provided index URLs.
"""
if not runtime.trainer:
raise ValueError(f"Runtime must have a trainer: {runtime}")

base = list(runtime.trainer.command)
is_mpi = base and base[0] == "mpirun"

install = ""
if packages_to_install:
install = get_script_for_python_packages(
packages_to_install=packages_to_install,
pip_index_urls=pip_index_urls,
is_mpi=is_mpi,
pip_extra_args=pip_extra_args,
)

cmd_line = " ".join([*(command or []), *(((command_args) or []))])
final_script = "{}{}".format(install, cmd_line)

if not base:
return ["bash", "-c", final_script]

base[-1] = final_script
return base


def get_trainer_crd_from_command_trainer(
runtime: types.Runtime,
trainer: types.CommandTrainer,
) -> models.TrainerV1alpha1Trainer:
"""
Build Trainer CRD for CommandTrainer, preserving env/resources and using the
runtime-aware user command assembly helper.
"""
trainer_crd = models.TrainerV1alpha1Trainer()

if trainer.num_nodes:
trainer_crd.num_nodes = trainer.num_nodes

if trainer.resources_per_node:
trainer_crd.resources_per_node = get_resources_per_node(trainer.resources_per_node)

# Always produce a bash-wrapped command to ensure shell interpolation (e.g. ${VAR})
# and to preserve runtime launcher behavior consistently.
trainer_crd.command = get_command_using_user_command(
runtime=runtime,
command=list(trainer.command or []),
command_args=trainer.args,
pip_index_urls=trainer.pip_index_urls,
packages_to_install=trainer.packages_to_install,
pip_extra_args=trainer.pip_extra_args,
)

if trainer.env:
trainer_crd.env = [
models.IoK8sApiCoreV1EnvVar(name=k, value=v) for k, v in trainer.env.items()
]

return trainer_crd

def get_trainer_crd_from_custom_trainer(
runtime: types.Runtime,
trainer: types.CustomTrainer,
Expand Down
Loading
Loading