diff --git a/README.md b/README.md index f5272161..3d136efd 100644 --- a/README.md +++ b/README.md @@ -71,6 +71,37 @@ TrainerClient().wait_for_job_status(job_id) print("\n".join(TrainerClient().get_job_logs(name=job_id))) ``` +### Run a custom command with CommandTrainer + +CommandTrainer runs an arbitrary command inside the runtime’s launcher (torchrun/mpirun/python) while preserving package installation, env vars, and resources. + +```python +from kubeflow.trainer import TrainerClient +from kubeflow.trainer.types import types + +client = TrainerClient() +rt = client.get_runtime("torch") # or "mpi", "plainml" + +trainer = types.CommandTrainer( + command=["python"], + args=["train.py", "--epochs", "2"], + packages_to_install=["numpy"], + pip_index_urls=["https://pypi.org/simple"], + num_nodes=2, + resources_per_node={"gpu": "1"}, + env={"FOO": "bar"}, +) + +job = client.train(runtime=rt, trainer=trainer) +print("Job:", job) +``` + +Notes: + +- Launcher is runtime-aware (torch → torchrun, mpi → mpirun, plain → python). +- Packages are installed before the command; MPI installs use `--user`. +- Ensure your script exists in the container (image/ConfigMap/volume/init). + ## Supported Kubeflow Projects | Project | Status | Version Support | Description | diff --git a/kubeflow/trainer/__init__.py b/kubeflow/trainer/__init__.py index 61a46b26..908d7ad1 100644 --- a/kubeflow/trainer/__init__.py +++ b/kubeflow/trainer/__init__.py @@ -27,6 +27,7 @@ from kubeflow.trainer.types.types import ( BuiltinTrainer, CustomTrainer, + CommandTrainer, DataFormat, DataType, HuggingFaceDatasetInitializer, @@ -43,6 +44,7 @@ __all__ = [ "BuiltinTrainer", "CustomTrainer", + "CommandTrainer", "DataFormat", "DATASET_PATH", "DataType", diff --git a/kubeflow/trainer/api/trainer_client.py b/kubeflow/trainer/api/trainer_client.py index 6b564c90..18df8a1e 100644 --- a/kubeflow/trainer/api/trainer_client.py +++ b/kubeflow/trainer/api/trainer_client.py @@ -95,7 +95,9 @@ def train( self, runtime: Optional[types.Runtime] = None, initializer: Optional[types.Initializer] = None, - trainer: Optional[Union[types.CustomTrainer, types.BuiltinTrainer]] = None, + trainer: Optional[ + Union[types.CustomTrainer, types.BuiltinTrainer, types.CommandTrainer] + ] = None, ) -> str: """Create a TrainJob. You can configure the TrainJob using one of these trainers: @@ -103,13 +105,15 @@ def train( training process. - BuiltinTrainer: Uses a predefined trainer with built-in post-training logic, requiring only parameter configuration. + - CommandTrainer: Executes an arbitrary command inside the runtime's launcher while + preserving environment and resource settings. Args: runtime: Optional reference to one of the existing runtimes. Defaults to the torch-distributed runtime if not provided. initializer: Optional configuration for the dataset and model initializers. - trainer: Optional configuration for a CustomTrainer or BuiltinTrainer. If not specified, - the TrainJob will use the runtime's default values. + trainer: Optional configuration for a CustomTrainer, BuiltinTrainer, or CommandTrainer. + If not specified, the TrainJob will use the runtime's default values. Returns: The unique name of the TrainJob that has been generated. diff --git a/kubeflow/trainer/backends/base.py b/kubeflow/trainer/backends/base.py index 0316b7b6..220b6fb1 100644 --- a/kubeflow/trainer/backends/base.py +++ b/kubeflow/trainer/backends/base.py @@ -38,7 +38,9 @@ def train( self, runtime: Optional[types.Runtime] = None, initializer: Optional[types.Initializer] = None, - trainer: Optional[Union[types.CustomTrainer, types.BuiltinTrainer]] = None, + trainer: Optional[ + Union[types.CustomTrainer, types.CommandTrainer, types.BuiltinTrainer] + ] = None, ) -> str: raise NotImplementedError() diff --git a/kubeflow/trainer/backends/kubernetes/backend.py b/kubeflow/trainer/backends/kubernetes/backend.py index 4310182b..eef64cc8 100644 --- a/kubeflow/trainer/backends/kubernetes/backend.py +++ b/kubeflow/trainer/backends/kubernetes/backend.py @@ -182,7 +182,7 @@ def train( self, runtime: Optional[types.Runtime] = None, initializer: Optional[types.Initializer] = None, - trainer: Optional[Union[types.CustomTrainer, types.BuiltinTrainer]] = None, + trainer: Optional[Union[types.CustomTrainer, types.CommandTrainer, types.BuiltinTrainer]] = None, ) -> str: if runtime is None: runtime = self.get_runtime(constants.TORCH_RUNTIME) @@ -201,6 +201,12 @@ def train( raise ValueError(f"CustomTrainer can't be used with {runtime} runtime") trainer_crd = utils.get_trainer_crd_from_custom_trainer(runtime, trainer) + # If users choose to use a command trainer to run custom command. + elif isinstance(trainer, types.CommandTrainer): + if runtime.trainer.trainer_type != types.TrainerType.CUSTOM_TRAINER: + raise ValueError(f"CommandTrainer can't be used with {runtime} runtime") + trainer_crd = utils.get_trainer_crd_from_command_trainer(runtime, trainer) + # If users choose to use a builtin trainer for post-training. elif isinstance(trainer, types.BuiltinTrainer): if runtime.trainer.trainer_type != types.TrainerType.BUILTIN_TRAINER: @@ -212,7 +218,7 @@ def train( else: raise ValueError( f"The trainer type {type(trainer)} is not supported. " - "Please use CustomTrainer or BuiltinTrainer." + "Please use CustomTrainer, CommandTrainer or BuiltinTrainer." ) train_job = models.TrainerV1alpha1TrainJob( diff --git a/kubeflow/trainer/backends/kubernetes/backend_test.py b/kubeflow/trainer/backends/kubernetes/backend_test.py index 700739a6..78bc14ea 100644 --- a/kubeflow/trainer/backends/kubernetes/backend_test.py +++ b/kubeflow/trainer/backends/kubernetes/backend_test.py @@ -19,17 +19,17 @@ It tests KubernetesBackend's behavior across job listing, resource creation etc """ -from dataclasses import asdict import datetime import multiprocessing import random import string +import uuid +from dataclasses import asdict from typing import Optional from unittest.mock import Mock, patch -import uuid -from kubeflow_trainer_api import models import pytest +from kubeflow_trainer_api import models from kubeflow.trainer.backends.kubernetes.backend import KubernetesBackend from kubeflow.trainer.backends.kubernetes.types import KubernetesBackendConfig @@ -821,6 +821,23 @@ def test_train(kubernetes_backend, test_case): print("test execution complete") +def test_train_routes_command_trainer(kubernetes_backend): + """Ensure CommandTrainer is routed to its CRD builder in backend.train.""" + runtime = kubernetes_backend.get_runtime(TORCH_RUNTIME) + cmd_trainer = types.CommandTrainer(command=["python"], args=["train.py"]) + + fake_crd = models.TrainerV1alpha1Trainer() + + with patch( + "kubeflow.trainer.utils.utils.get_trainer_crd_from_command_trainer", + return_value=fake_crd, + ) as mocked_builder: + job_name = kubernetes_backend.train(runtime=runtime, trainer=cmd_trainer) + + mocked_builder.assert_called_once() + assert isinstance(job_name, str) and len(job_name) > 0 + + @pytest.mark.parametrize( "test_case", [ diff --git a/kubeflow/trainer/types/types.py b/kubeflow/trainer/types/types.py index 7485ba65..14a16823 100644 --- a/kubeflow/trainer/types/types.py +++ b/kubeflow/trainer/types/types.py @@ -240,3 +240,32 @@ class Initializer: dataset: Optional[HuggingFaceDatasetInitializer] = None model: Optional[HuggingFaceModelInitializer] = None + + +@dataclass +class CommandTrainer: + """Command Trainer configuration. + + If "command" is set, it becomes the container entrypoint and "args" are passed as container args. + If "command" is not set, defaults are chosen by runtime framework (e.g., torch→torchrun, + mpi→mpirun, torch-tune→tune run, otherwise python), and "args" are passed as-is. + + Args: + command (Optional[List[str]]): The command to execute (e.g., ["python"]). + args (Optional[List[str]]): Positional arguments for the command. + packages_to_install (Optional[List[str]]): Python packages to install. + pip_index_urls (List[str]): Index and extra index URLs; first is index-url. + pip_extra_args (Optional[List[str]]): Extra pip flags (e.g., ["--no-cache-dir"]). + num_nodes (Optional[int]): Number of nodes for training. + resources_per_node (Optional[Dict]): Resources per node. + env (Optional[Dict[str, str]]): Environment variables. + """ + + command: Optional[list[str]] = None + args: Optional[list[str]] = None + packages_to_install: Optional[list[str]] = None + pip_index_urls: list[str] = field(default_factory=lambda: list(constants.DEFAULT_PIP_INDEX_URLS)) + pip_extra_args: Optional[list[str]] = None + num_nodes: Optional[int] = None + resources_per_node: Optional[dict] = None + env: Optional[dict[str, str]] = None diff --git a/kubeflow/trainer/types/types_test.py b/kubeflow/trainer/types/types_test.py new file mode 100644 index 00000000..c1d43be4 --- /dev/null +++ b/kubeflow/trainer/types/types_test.py @@ -0,0 +1,12 @@ +from kubeflow.trainer.types import types + + +class TestCommandTrainerType: + def test_command_trainer_dataclass_minimal(self): + trainer = types.CommandTrainer(command=["python"], args=["train.py"]) + + assert trainer.command == ["python"] + assert trainer.args == ["train.py"] + assert trainer.pip_index_urls and isinstance(trainer.pip_index_urls, list) + assert trainer.packages_to_install is None + assert trainer.env is None diff --git a/kubeflow/trainer/utils/utils.py b/kubeflow/trainer/utils/utils.py index 663109cb..e9d9aa50 100644 --- a/kubeflow/trainer/utils/utils.py +++ b/kubeflow/trainer/utils/utils.py @@ -259,6 +259,7 @@ def get_script_for_python_packages( packages_to_install: list[str], pip_index_urls: list[str], is_mpi: bool, + pip_extra_args: Optional[list[str]] = None, ) -> str: """ Get init script to install Python packages from the given pip index URLs. @@ -281,14 +282,16 @@ def get_script_for_python_packages( """ ) - script_for_python_packages = ( - header_script - + "PIP_DISABLE_PIP_VERSION_CHECK=1 python -m pip install --quiet " - + "--no-warn-script-location {} {}\n".format( - " ".join(options), - packages_str, - ) + extra_args = " ".join(pip_extra_args or []) + options_args = " ".join(options) + + base_cmd = ( + f"PIP_DISABLE_PIP_VERSION_CHECK=1 python -m pip install --quiet --no-warn-script-location {options_args} {packages_str}" # noqa: E501 ) + if extra_args: + base_cmd = f"{base_cmd} {extra_args}" + + script_for_python_packages = f"{header_script}{base_cmd}\n" return script_for_python_packages @@ -365,6 +368,78 @@ def get_command_using_train_func( return command +def get_command_using_user_command( + runtime: types.Runtime, + command: list[str], + command_args: Optional[list[str]], + pip_index_urls: list[str], + packages_to_install: Optional[list[str]], + pip_extra_args: Optional[list[str]] = None, +) -> list[str]: + """ + Build a runtime-aware command to execute an arbitrary user command with args. + Preserves the runtime launcher (torchrun/mpirun/python) and prepends optional + pip installs using provided index URLs. + """ + if not runtime.trainer: + raise ValueError(f"Runtime must have a trainer: {runtime}") + + base = list(runtime.trainer.command) + is_mpi = base and base[0] == "mpirun" + + install = "" + if packages_to_install: + install = get_script_for_python_packages( + packages_to_install=packages_to_install, + pip_index_urls=pip_index_urls, + is_mpi=is_mpi, + pip_extra_args=pip_extra_args, + ) + + cmd_line = " ".join([*(command or []), *(((command_args) or []))]) + final_script = "{}{}".format(install, cmd_line) + + if not base: + return ["bash", "-c", final_script] + + base[-1] = final_script + return base + + +def get_trainer_crd_from_command_trainer( + runtime: types.Runtime, + trainer: types.CommandTrainer, +) -> models.TrainerV1alpha1Trainer: + """ + Build Trainer CRD for CommandTrainer, preserving env/resources and using the + runtime-aware user command assembly helper. + """ + trainer_crd = models.TrainerV1alpha1Trainer() + + if trainer.num_nodes: + trainer_crd.num_nodes = trainer.num_nodes + + if trainer.resources_per_node: + trainer_crd.resources_per_node = get_resources_per_node(trainer.resources_per_node) + + # Always produce a bash-wrapped command to ensure shell interpolation (e.g. ${VAR}) + # and to preserve runtime launcher behavior consistently. + trainer_crd.command = get_command_using_user_command( + runtime=runtime, + command=list(trainer.command or []), + command_args=trainer.args, + pip_index_urls=trainer.pip_index_urls, + packages_to_install=trainer.packages_to_install, + pip_extra_args=trainer.pip_extra_args, + ) + + if trainer.env: + trainer_crd.env = [ + models.IoK8sApiCoreV1EnvVar(name=k, value=v) for k, v in trainer.env.items() + ] + + return trainer_crd + def get_trainer_crd_from_custom_trainer( runtime: types.Runtime, trainer: types.CustomTrainer, diff --git a/kubeflow/trainer/utils/utils_test.py b/kubeflow/trainer/utils/utils_test.py index 5a1a821a..c220163c 100644 --- a/kubeflow/trainer/utils/utils_test.py +++ b/kubeflow/trainer/utils/utils_test.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import textwrap + import pytest from kubeflow.trainer.constants import constants @@ -255,3 +257,217 @@ def test_get_command_using_train_func(test_case: TestCase): except Exception as e: assert type(e) is test_case.expected_error print("test execution complete") + +def _build_plain_runtime() -> types.Runtime: + trainer = types.RuntimeTrainer( + trainer_type=types.TrainerType.CUSTOM_TRAINER, + framework="plainml", + num_nodes=1, + ) + trainer.set_command(constants.DEFAULT_COMMAND) + return types.Runtime(name="test-runtime", trainer=trainer) + +def _build_mpi_runtime() -> types.Runtime: + trainer = types.RuntimeTrainer( + trainer_type=types.TrainerType.CUSTOM_TRAINER, + framework="mpi", + num_nodes=2, + ) + trainer.set_command(constants.MPI_COMMAND) + return types.Runtime(name="mpi-runtime", trainer=trainer) + + +class TestGetCommandUsingUserCommand: + def test_plain_with_installs(self): + runtime = _build_plain_runtime() + command = ["python"] + command_args = ["train.py", "--epochs", "2"] + pip_index_urls = [ + "https://pypi.org/simple", + "https://private.repo.com/simple", + ] + packages_to_install = ["torch", "numpy"] + + result = utils.get_command_using_user_command( + runtime=runtime, + command=command, + command_args=command_args, + pip_index_urls=pip_index_urls, + packages_to_install=packages_to_install, + ) + + expected = textwrap.dedent( + """bash +-c + +if ! [ -x "$(command -v pip)" ]; then + python -m ensurepip || python -m ensurepip --user || apt-get install python-pip +fi + +PIP_DISABLE_PIP_VERSION_CHECK=1 python -m pip install --quiet --no-warn-script-location --index-url https://pypi.org/simple --extra-index-url https://private.repo.com/simple torch numpy +python train.py --epochs 2""" + ) + assert "\n".join(result) == expected + + def test_plain_with_installs_and_pip_extra_args(self): + runtime = _build_plain_runtime() + command = ["python"] + command_args = ["train.py"] + pip_index_urls = [ + "https://pypi.org/simple", + ] + packages_to_install = ["torch"] + + result = utils.get_command_using_user_command( + runtime=runtime, + command=command, + command_args=command_args, + pip_index_urls=pip_index_urls, + packages_to_install=packages_to_install, + pip_extra_args=["--no-cache-dir", "--find-links", "/wheels"], + ) + + joined = "\n".join(result) + assert "--no-warn-script-location --index-url https://pypi.org/simple torch --no-cache-dir --find-links /wheels" in joined + + +class TestGetTrainerCRDFromCommandTrainer: + def test_plain_runtime_builds_crd_with_env_and_resources(self): + runtime = _build_plain_runtime() + trainer = types.CommandTrainer( + command=["python"], + args=["main.py", "--epochs", "3"], + packages_to_install=["numpy"], + pip_index_urls=["https://pypi.org/simple"], + num_nodes=2, + resources_per_node={"gpu": "1"}, + env={"FOO": "bar"}, + ) + + crd = utils.get_trainer_crd_from_command_trainer(runtime, trainer) + expected_command = utils.get_command_using_user_command( + runtime=runtime, + command=trainer.command, + command_args=trainer.args, + pip_index_urls=trainer.pip_index_urls, + packages_to_install=trainer.packages_to_install, + ) + + assert crd.num_nodes == 2 + assert crd.command == expected_command + assert any(ev.name == "FOO" and ev.value == "bar" for ev in (crd.env or [])) + assert crd.resources_per_node is not None + + def test_mpi_runtime_builds_crd_uses_user_flag(self): + runtime = _build_mpi_runtime() + trainer = types.CommandTrainer( + command=["python"], + args=["train.py"], + packages_to_install=["torch"], + pip_index_urls=["https://pypi.org/simple"], + num_nodes=4, + ) + + crd = utils.get_trainer_crd_from_command_trainer(runtime, trainer) + + expected_command = utils.get_command_using_user_command( + runtime=runtime, + command=trainer.command, + command_args=trainer.args, + pip_index_urls=trainer.pip_index_urls, + packages_to_install=trainer.packages_to_install, + ) + + assert crd.num_nodes == 4 + assert crd.command == expected_command + + def test_defaults_to_bash_when_command_missing(self): + runtime = _build_plain_runtime() + trainer = types.CommandTrainer( + args=["-lc", "echo hello"], + ) + crd = utils.get_trainer_crd_from_command_trainer(runtime, trainer) + expected_command = utils.get_command_using_user_command( + runtime=runtime, + command=[], + command_args=["-lc", "echo hello"], + pip_index_urls=trainer.pip_index_urls, + packages_to_install=trainer.packages_to_install, + ) + assert crd.command == expected_command + + def test_always_bash_wrapped_even_without_installs(self): + runtime = _build_plain_runtime() + trainer = types.CommandTrainer( + command=["python"], + args=["main.py"], + ) + crd = utils.get_trainer_crd_from_command_trainer(runtime, trainer) + # Should wrap into bash -c preserving python main.py + assert crd.command == ["bash", "-c", "python main.py"] + + def test_preserves_prefix_plain(self): + runtime = _build_plain_runtime() + + result = utils.get_command_using_user_command( + runtime=runtime, + command=["python"], + command_args=["main.py"], + pip_index_urls=[constants.DEFAULT_PIP_INDEX_URLS[0]], + packages_to_install=None, + ) + + expected = "bash\n-c\npython main.py" + assert "\n".join(result) == expected + + def test_preserves_prefix_mpi_and_user_flag(self): + runtime = _build_mpi_runtime() + + result = utils.get_command_using_user_command( + runtime=runtime, + command=["python"], + command_args=["train.py"], + pip_index_urls=["https://pypi.org/simple"], + packages_to_install=["torch"], + ) + + expected = textwrap.dedent( + """mpirun +--hostfile +/etc/mpi/hostfile +bash +-c + +if ! [ -x "$(command -v pip)" ]; then + python -m ensurepip || python -m ensurepip --user || apt-get install python-pip +fi + +PIP_DISABLE_PIP_VERSION_CHECK=1 python -m pip install --quiet --no-warn-script-location --index-url https://pypi.org/simple --user torch +python train.py""" + ) + assert "\n".join(result) == expected + + def test_fallback_when_no_runtime_command(self): + trainer = types.RuntimeTrainer( + trainer_type=types.TrainerType.CUSTOM_TRAINER, + framework="plainml", + num_nodes=1, + ) + # Explicitly set launcher for this runtime + trainer.set_command(constants.DEFAULT_COMMAND) + runtime = types.Runtime(name="with-launcher", trainer=trainer) + + result = utils.get_command_using_user_command( + runtime=runtime, + command=["echo"], + command_args=["hello"], + pip_index_urls=[constants.DEFAULT_PIP_INDEX_URLS[0]], + packages_to_install=None, + ) + + expected = textwrap.dedent( + """bash +-c +echo hello""" + ) + assert "\n".join(result) == expected \ No newline at end of file