From f26d91480b3c4c96b84f2100cb4b3d19beee0b02 Mon Sep 17 00:00:00 2001 From: Abhijeet Dhumal Date: Wed, 8 Oct 2025 20:01:38 +0530 Subject: [PATCH 1/7] Add simplified Training Options for TrainJob labels, annotations and podSpecOverride Signed-off-by: Abhijeet Dhumal --- kubeflow/trainer/__init__.py | 20 + kubeflow/trainer/api/trainer_client.py | 14 +- kubeflow/trainer/api/trainer_client_test.py | 282 ++++++++ kubeflow/trainer/backends/base.py | 78 +++ .../trainer/backends/kubernetes/backend.py | 54 +- .../backends/kubernetes/backend_test.py | 602 +++++++++++++++++- .../trainer/backends/kubernetes/options.py | 166 +++++ .../backends/kubernetes/options_test.py | 215 +++++++ .../trainer/backends/localprocess/backend.py | 23 +- .../backends/localprocess/backend_test.py | 126 ++++ .../trainer/backends/localprocess/options.py | 49 ++ .../backends/localprocess/options_test.py | 110 ++++ .../trainer/backends/localprocess/utils.py | 11 +- kubeflow/trainer/types/__init__.py | 53 ++ kubeflow/trainer/types/types.py | 2 +- kubeflow/trainer/utils/utils.py | 18 +- scripts/gen-changelog.py | 0 17 files changed, 1769 insertions(+), 54 deletions(-) create mode 100644 kubeflow/trainer/api/trainer_client_test.py create mode 100644 kubeflow/trainer/backends/kubernetes/options.py create mode 100644 kubeflow/trainer/backends/kubernetes/options_test.py create mode 100644 kubeflow/trainer/backends/localprocess/backend_test.py create mode 100644 kubeflow/trainer/backends/localprocess/options.py create mode 100644 kubeflow/trainer/backends/localprocess/options_test.py mode change 100755 => 100644 scripts/gen-changelog.py diff --git a/kubeflow/trainer/__init__.py b/kubeflow/trainer/__init__.py index 7caebc2d4..017f047b3 100644 --- a/kubeflow/trainer/__init__.py +++ b/kubeflow/trainer/__init__.py @@ -16,6 +16,18 @@ # Import the Kubeflow Trainer client. from kubeflow.trainer.api.trainer_client import TrainerClient # noqa: F401 +# Import common training options (defaults to Kubernetes backend) +from kubeflow.trainer.backends.kubernetes.options import ( + PodSpecOverride, + WithAnnotations, + WithLabels, + WithName, + WithPodSpecOverrides, + WithTrainerArgs, + WithTrainerCommand, + WithTrainerImage, +) + # import backends and its associated configs from kubeflow.trainer.backends.kubernetes.types import KubernetesBackendConfig from kubeflow.trainer.backends.localprocess.types import LocalProcessBackendConfig @@ -55,6 +67,7 @@ "LoraConfig", "Loss", "MODEL_PATH", + "PodSpecOverride", "Runtime", "TorchTuneConfig", "TorchTuneInstructDataset", @@ -63,4 +76,11 @@ "TrainerType", "LocalProcessBackendConfig", "KubernetesBackendConfig", + "WithAnnotations", + "WithLabels", + "WithName", + "WithPodSpecOverrides", + "WithTrainerArgs", + "WithTrainerCommand", + "WithTrainerImage", ] diff --git a/kubeflow/trainer/api/trainer_client.py b/kubeflow/trainer/api/trainer_client.py index 6b564c90a..ea87d16c4 100644 --- a/kubeflow/trainer/api/trainer_client.py +++ b/kubeflow/trainer/api/trainer_client.py @@ -96,6 +96,7 @@ def train( runtime: Optional[types.Runtime] = None, initializer: Optional[types.Initializer] = None, trainer: Optional[Union[types.CustomTrainer, types.BuiltinTrainer]] = None, + options: Optional[list] = None, ) -> str: """Create a TrainJob. You can configure the TrainJob using one of these trainers: @@ -110,6 +111,8 @@ def train( initializer: Optional configuration for the dataset and model initializers. trainer: Optional configuration for a CustomTrainer or BuiltinTrainer. If not specified, the TrainJob will use the runtime's default values. + options: Optional list of configuration options to apply to the TrainJob. Use + WithLabels and WithAnnotations for basic metadata configuration. Returns: The unique name of the TrainJob that has been generated. @@ -119,7 +122,16 @@ def train( TimeoutError: Timeout to create TrainJobs. RuntimeError: Failed to create TrainJobs. """ - return self.backend.train(runtime=runtime, initializer=initializer, trainer=trainer) + # Validate options compatibility with backend + if options: + self.backend.validate_options(options) + + return self.backend.train( + runtime=runtime, + initializer=initializer, + trainer=trainer, + options=options, + ) def list_jobs(self, runtime: Optional[types.Runtime] = None) -> list[types.TrainJob]: """List of the created TrainJobs. If a runtime is specified, only TrainJobs associated with diff --git a/kubeflow/trainer/api/trainer_client_test.py b/kubeflow/trainer/api/trainer_client_test.py new file mode 100644 index 000000000..85e3709fa --- /dev/null +++ b/kubeflow/trainer/api/trainer_client_test.py @@ -0,0 +1,282 @@ +# Copyright 2025 The Kubeflow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests for TrainerClient option handling and error messages. +""" + +from unittest.mock import Mock, patch + +import pytest + +from kubeflow.trainer.api.trainer_client import TrainerClient +from kubeflow.trainer.backends.kubernetes.options import WithAnnotations, WithLabels +from kubeflow.trainer.backends.kubernetes.types import KubernetesBackendConfig +from kubeflow.trainer.backends.localprocess.types import LocalProcessBackendConfig +from kubeflow.trainer.types import types + + +class TestTrainerClientOptionValidation: + """Test TrainerClient option validation integration.""" + + def test_trainer_client_passes_options_to_backend(self): + """Test that TrainerClient passes options to backend correctly.""" + config = LocalProcessBackendConfig() + client = TrainerClient(backend_config=config) + + def simple_func(): + return "test" + + trainer = types.CustomTrainer(func=simple_func) + options = [WithLabels({"app": "test"})] + + with pytest.raises(ValueError) as exc_info: + client.train(trainer=trainer, options=options) + + error_msg = str(exc_info.value) + assert "The following options are not compatible with this backend" in error_msg + assert "WithLabels (labels)" in error_msg + + @patch("kubernetes.config.load_kube_config") + @patch("kubernetes.client.CustomObjectsApi") + @patch("kubernetes.client.CoreV1Api") + def test_trainer_client_with_kubernetes_backend( + self, mock_core_api, mock_custom_api, mock_load_config + ): + """Test TrainerClient with KubernetesBackend and compatible options.""" + mock_custom_api.return_value = Mock() + mock_core_api.return_value = Mock() + + config = KubernetesBackendConfig() + client = TrainerClient(backend_config=config) + + def simple_func(): + return "test" + + trainer = types.CustomTrainer(func=simple_func) + options = [WithLabels({"app": "test"}), WithAnnotations({"desc": "test"})] + + with pytest.raises((ValueError, RuntimeError)) as exc_info: + client.train(trainer=trainer, options=options) + + error_msg = str(exc_info.value) + # Should either fail with runtime requirement or K8s connection error + assert ( + "Runtime is required" in error_msg + or "Failed to get clustertrainingruntimes" in error_msg + ) + + def test_trainer_client_empty_options(self): + """Test TrainerClient with empty options.""" + config = LocalProcessBackendConfig() + client = TrainerClient(backend_config=config) + + def simple_func(): + return "test" + + trainer = types.CustomTrainer(func=simple_func) + + with pytest.raises(ValueError) as exc_info: + client.train(trainer=trainer, options=[]) + + error_msg = str(exc_info.value) + assert "Runtime must be provided for LocalProcessBackend" in error_msg + + +class TestTrainerClientErrorHandling: + """Test TrainerClient error handling improvements.""" + + def test_missing_runtime_error_message(self): + """Test improved error message for missing runtime.""" + config = LocalProcessBackendConfig() + client = TrainerClient(backend_config=config) + + def simple_func(): + return "test" + + trainer = types.CustomTrainer(func=simple_func) + + with pytest.raises(ValueError) as exc_info: + client.train(trainer=trainer) + + error_msg = str(exc_info.value) + # The error message should contain the runtime requirement + assert "Runtime must be provided for LocalProcessBackend" in error_msg + + def test_option_validation_error_propagation(self): + """Test that option validation errors are properly propagated.""" + config = LocalProcessBackendConfig() + client = TrainerClient(backend_config=config) + + def simple_func(): + return "test" + + trainer = types.CustomTrainer(func=simple_func) + options = [WithLabels({"app": "test"}), WithAnnotations({"desc": "test"})] + + with pytest.raises(ValueError) as exc_info: + client.train(trainer=trainer, options=options) + + error_msg = str(exc_info.value) + assert "The following options are not compatible with this backend" in error_msg + assert "WithLabels (labels)" in error_msg + assert "WithAnnotations (annotations)" in error_msg + assert "The following options are not compatible with this backend" in error_msg + + def test_error_message_does_not_contain_runtime_help_for_option_errors(self): + """Test that option validation errors don't get runtime help text.""" + config = LocalProcessBackendConfig() + client = TrainerClient(backend_config=config) + + def simple_func(): + return "test" + + trainer = types.CustomTrainer(func=simple_func) + options = [WithLabels({"app": "test"})] + + with pytest.raises(ValueError) as exc_info: + client.train(trainer=trainer, options=options) + + error_msg = str(exc_info.value) + assert "The following options are not compatible with this backend" in error_msg + assert "Example usage:" not in error_msg + + @patch("kubernetes.config.load_kube_config") + @patch("kubernetes.client.CustomObjectsApi") + @patch("kubernetes.client.CoreV1Api") + def test_kubernetes_backend_error_handling( + self, mock_core_api, mock_custom_api, mock_load_config + ): + """Test error handling with KubernetesBackend.""" + mock_custom_api.return_value = Mock() + mock_core_api.return_value = Mock() + + config = KubernetesBackendConfig() + client = TrainerClient(backend_config=config) + + def simple_func(): + return "test" + + trainer = types.CustomTrainer(func=simple_func) + + with pytest.raises((ValueError, RuntimeError)) as exc_info: + client.train(trainer=trainer) + + error_msg = str(exc_info.value) + # Should either fail with runtime requirement or K8s connection error + assert ( + "Runtime is required" in error_msg + or "Failed to get clustertrainingruntimes" in error_msg + ) + + +class TestTrainerClientBackendSelection: + """Test TrainerClient backend selection and configuration.""" + + @patch("kubernetes.config.load_kube_config") + @patch("kubernetes.client.CustomObjectsApi") + @patch("kubernetes.client.CoreV1Api") + def test_default_backend_is_kubernetes(self, mock_core_api, mock_custom_api, mock_load_config): + """Test that default backend is Kubernetes.""" + mock_custom_api.return_value = Mock() + mock_core_api.return_value = Mock() + + client = TrainerClient() + + from kubeflow.trainer.backends.kubernetes.backend import KubernetesBackend + + assert isinstance(client.backend, KubernetesBackend) + + def test_local_process_backend_selection(self): + """Test LocalProcess backend selection.""" + config = LocalProcessBackendConfig() + client = TrainerClient(backend_config=config) + + from kubeflow.trainer.backends.localprocess.backend import LocalProcessBackend + + assert isinstance(client.backend, LocalProcessBackend) + + @patch("kubernetes.config.load_kube_config") + @patch("kubernetes.client.CustomObjectsApi") + @patch("kubernetes.client.CoreV1Api") + def test_kubernetes_backend_selection(self, mock_core_api, mock_custom_api, mock_load_config): + """Test Kubernetes backend selection.""" + mock_custom_api.return_value = Mock() + mock_core_api.return_value = Mock() + + config = KubernetesBackendConfig() + client = TrainerClient(backend_config=config) + + from kubeflow.trainer.backends.kubernetes.backend import KubernetesBackend + + assert isinstance(client.backend, KubernetesBackend) + + +class TestTrainerClientOptionFlow: + """Test the complete option flow through TrainerClient.""" + + def test_option_validation_happens_early(self): + """Test that option validation happens before other validations.""" + config = LocalProcessBackendConfig() + client = TrainerClient(backend_config=config) + + def simple_func(): + return "test" + + trainer = types.CustomTrainer(func=simple_func) + options = [WithLabels({"app": "test"})] + + with pytest.raises(ValueError) as exc_info: + client.train(trainer=trainer, options=options) + + error_msg = str(exc_info.value) + assert "The following options are not compatible with this backend" in error_msg + + def test_multiple_option_types_validation(self): + """Test validation with multiple different option types.""" + config = LocalProcessBackendConfig() + client = TrainerClient(backend_config=config) + + def simple_func(): + return "test" + + trainer = types.CustomTrainer(func=simple_func) + options = [ + WithLabels({"app": "test"}), + WithAnnotations({"desc": "test"}), + ] + + with pytest.raises(ValueError) as exc_info: + client.train(trainer=trainer, options=options) + + error_msg = str(exc_info.value) + assert "The following options are not compatible with this backend" in error_msg + assert "WithLabels (labels)" in error_msg + assert "WithAnnotations (annotations)" in error_msg + + def test_none_options_handling(self): + """Test that None options are handled correctly.""" + config = LocalProcessBackendConfig() + client = TrainerClient(backend_config=config) + + def simple_func(): + return "test" + + trainer = types.CustomTrainer(func=simple_func) + + with pytest.raises(ValueError) as exc_info: + client.train(trainer=trainer, options=None) + + error_msg = str(exc_info.value) + assert "Runtime must be provided for LocalProcessBackend" in error_msg diff --git a/kubeflow/trainer/backends/base.py b/kubeflow/trainer/backends/base.py index 0316b7b61..b2821c9ce 100644 --- a/kubeflow/trainer/backends/base.py +++ b/kubeflow/trainer/backends/base.py @@ -14,13 +14,90 @@ import abc from collections.abc import Iterator +from enum import Enum from typing import Optional, Union from kubeflow.trainer.constants import constants from kubeflow.trainer.types import types +class OptionType(Enum): + """Enumeration of available option types for backends.""" + + LABELS = "labels" + ANNOTATIONS = "annotations" + POD_SPEC_OVERRIDES = "pod_spec_overrides" + NAME = "name" + TRAINER_IMAGE = "trainer_image" + TRAINER_COMMAND = "trainer_command" + TRAINER_ARGS = "trainer_args" + + +class BackendCapabilities: + """Backend capabilities for validating option compatibility.""" + + def __init__(self, supported_options: set[OptionType]): + self.supported_options = supported_options + self._supported_types = frozenset(opt.value for opt in supported_options) + + def supports(self, option_type: OptionType) -> bool: + """Check if option type is supported.""" + return option_type.value in self._supported_types + + def supports_option(self, option) -> bool: + """Check if option is supported.""" + return hasattr(option, "option_type") and self.supports(option.option_type) + + def check_compatibility(self, options: list) -> tuple[bool, list[str]]: + """Check compatibility of options with this backend.""" + if not options: + return True, [] + + for option in options: + if not self.supports_option(option): + unsupported = [ + f"{opt.__class__.__name__} ({opt.option_type.value})" + for opt in options + if not self.supports_option(opt) + ] + return False, unsupported + + return True, [] + + +# Backend capability definitions +KUBERNETES_CAPABILITIES = BackendCapabilities( + { + OptionType.LABELS, + OptionType.ANNOTATIONS, + OptionType.POD_SPEC_OVERRIDES, + OptionType.NAME, + OptionType.TRAINER_IMAGE, + OptionType.TRAINER_COMMAND, + OptionType.TRAINER_ARGS, + } +) + +LOCAL_PROCESS_CAPABILITIES = BackendCapabilities(set()) + + class ExecutionBackend(abc.ABC): + @property + @abc.abstractmethod + def capabilities(self) -> BackendCapabilities: + """Return the capabilities of this backend.""" + pass + + def validate_options(self, options: Optional[list] = None) -> None: + if not options: + return + is_compatible, incompatible_options = self.capabilities.check_compatibility(options) + if not is_compatible: + raise ValueError( + f"The following options are not compatible with this backend: " + f"{incompatible_options}" + ) + @abc.abstractmethod def list_runtimes(self) -> list[types.Runtime]: raise NotImplementedError() @@ -39,6 +116,7 @@ def train( runtime: Optional[types.Runtime] = None, initializer: Optional[types.Initializer] = None, trainer: Optional[Union[types.CustomTrainer, types.BuiltinTrainer]] = None, + options: Optional[list] = None, ) -> str: raise NotImplementedError() diff --git a/kubeflow/trainer/backends/kubernetes/backend.py b/kubeflow/trainer/backends/kubernetes/backend.py index 4310182bb..2f24a72c8 100644 --- a/kubeflow/trainer/backends/kubernetes/backend.py +++ b/kubeflow/trainer/backends/kubernetes/backend.py @@ -26,7 +26,7 @@ from kubeflow_trainer_api import models from kubernetes import client, config, watch -from kubeflow.trainer.backends.base import ExecutionBackend +from kubeflow.trainer.backends.base import KUBERNETES_CAPABILITIES, ExecutionBackend from kubeflow.trainer.backends.kubernetes import types as k8s_types from kubeflow.trainer.constants import constants from kubeflow.trainer.types import types @@ -57,6 +57,11 @@ def __init__( self.namespace = cfg.namespace + @property + def capabilities(self): + """Return the capabilities of this backend.""" + return KUBERNETES_CAPABILITIES + def list_runtimes(self) -> list[types.Runtime]: result = [] try: @@ -118,13 +123,11 @@ def get_runtime(self, name: str) -> types.Runtime: except multiprocessing.TimeoutError as e: raise TimeoutError( - f"Timeout to get {constants.CLUSTER_TRAINING_RUNTIME_PLURAL}: " - f"{self.namespace}/{name}" + f"Timeout to get {constants.CLUSTER_TRAINING_RUNTIME_PLURAL}: {name}" ) from e except Exception as e: raise RuntimeError( - f"Failed to get {constants.CLUSTER_TRAINING_RUNTIME_PLURAL}: " - f"{self.namespace}/{name}" + f"Failed to get {constants.CLUSTER_TRAINING_RUNTIME_PLURAL}: {name}" ) from e return self.__get_runtime_from_crd(runtime) # type: ignore @@ -183,13 +186,36 @@ def train( runtime: Optional[types.Runtime] = None, initializer: Optional[types.Initializer] = None, trainer: Optional[Union[types.CustomTrainer, types.BuiltinTrainer]] = None, + options: Optional[list] = None, ) -> str: + self.validate_options(options) if runtime is None: runtime = self.get_runtime(constants.TORCH_RUNTIME) - # Generate unique name for the TrainJob. - # TODO (andreyvelich): Discuss this TrainJob name generation. - train_job_name = random.choice(string.ascii_lowercase) + uuid.uuid4().hex[:11] + # Process options to extract configuration + job_spec = {} + labels = None + annotations = None + name = None + trainer_overrides = {} + + if options: + for option in options: + option(job_spec) + + metadata_section = job_spec.get("metadata", {}) + labels = metadata_section.get("labels") + annotations = metadata_section.get("annotations") + name = metadata_section.get("name") + + # Extract trainer-specific overrides + spec_section = job_spec.get("spec", {}) + trainer_spec = spec_section.get("trainer", {}) + if trainer_spec: + trainer_overrides = trainer_spec + + # Generate unique name for the TrainJob if not provided + train_job_name = name or (random.choice(string.ascii_lowercase) + uuid.uuid4().hex[:11]) # Build the Trainer. trainer_crd = models.TrainerV1alpha1Trainer() @@ -215,10 +241,20 @@ def train( "Please use CustomTrainer or BuiltinTrainer." ) + if trainer_overrides: + if "image" in trainer_overrides: + trainer_crd.image = trainer_overrides["image"] + if "command" in trainer_overrides: + trainer_crd.command = trainer_overrides["command"] + if "args" in trainer_overrides: + trainer_crd.args = trainer_overrides["args"] + train_job = models.TrainerV1alpha1TrainJob( apiVersion=constants.API_VERSION, kind=constants.TRAINJOB_KIND, - metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta(name=train_job_name), + metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta( + name=train_job_name, labels=labels, annotations=annotations + ), spec=models.TrainerV1alpha1TrainJobSpec( runtimeRef=models.TrainerV1alpha1RuntimeRef(name=runtime.name), trainer=(trainer_crd if trainer_crd != models.TrainerV1alpha1Trainer() else None), diff --git a/kubeflow/trainer/backends/kubernetes/backend_test.py b/kubeflow/trainer/backends/kubernetes/backend_test.py index 85c71c461..522549c29 100644 --- a/kubeflow/trainer/backends/kubernetes/backend_test.py +++ b/kubeflow/trainer/backends/kubernetes/backend_test.py @@ -32,6 +32,13 @@ import pytest from kubeflow.trainer.backends.kubernetes.backend import KubernetesBackend +from kubeflow.trainer.backends.kubernetes.options import ( + WithAnnotations, + WithLabels, + WithTrainerArgs, + WithTrainerCommand, + WithTrainerImage, +) from kubeflow.trainer.backends.kubernetes.types import KubernetesBackendConfig from kubeflow.trainer.constants import constants from kubeflow.trainer.test.common import ( @@ -207,6 +214,25 @@ def get_resource_requirements() -> models.IoK8sApiCoreV1ResourceRequirements: ) +def get_basic_custom_trainer( + env: Optional[list[models.IoK8sApiCoreV1EnvVar]] = None, +) -> models.TrainerV1alpha1Trainer: + """ + Get a basic custom trainer for the TrainJob without packages or func_args. + """ + return models.TrainerV1alpha1Trainer( + command=[ + "bash", + "-c", + "\nread -r -d '' SCRIPT << EOM\n\n" + '"trainer", types.CustomTrainer(func=lambda: print("Hello World"))\n\n' + "()\n\n" + 'EOM\nprintf "%s" "$SCRIPT" > "backend_test.py"\ntorchrun "backend_test.py"', + ], + env=env, + ) + + def get_custom_trainer( env: Optional[list[models.IoK8sApiCoreV1EnvVar]] = None, pip_index_urls: Optional[list[str]] = constants.DEFAULT_PIP_INDEX_URLS, @@ -255,6 +281,8 @@ def get_train_job( runtime_name: str, train_job_name: str = BASIC_TRAIN_JOB_NAME, train_job_trainer: Optional[models.TrainerV1alpha1Trainer] = None, + labels: Optional[dict[str, str]] = None, + annotations: Optional[dict[str, str]] = None, ) -> models.TrainerV1alpha1TrainJob: """ Create a mock TrainJob object with optional trainer configurations. @@ -262,7 +290,9 @@ def get_train_job( train_job = models.TrainerV1alpha1TrainJob( apiVersion=constants.API_VERSION, kind=constants.TRAINJOB_KIND, - metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta(name=train_job_name), + metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta( + name=train_job_name, labels=labels, annotations=annotations + ), spec=models.TrainerV1alpha1TrainJobSpec( runtimeRef=models.TrainerV1alpha1RuntimeRef(name=runtime_name), trainer=train_job_trainer, @@ -500,16 +530,21 @@ def get_container() -> models.IoK8sApiCoreV1Container: def create_runtime_type( name: str, + trainer_type: types.TrainerType = types.TrainerType.CUSTOM_TRAINER, ) -> types.Runtime: """Create a mock Runtime object for testing.""" trainer = types.RuntimeTrainer( - trainer_type=types.TrainerType.CUSTOM_TRAINER, + trainer_type=trainer_type, framework=name, num_nodes=2, device="gpu", device_count=RUNTIME_DEVICES, ) - trainer.set_command(constants.TORCH_COMMAND) + # Set command based on trainer type and framework + if trainer_type == types.TrainerType.BUILTIN_TRAINER and name == TORCH_TUNE_RUNTIME: + trainer.set_command(constants.TORCH_TUNE_COMMAND) + else: + trainer.set_command(constants.TORCH_COMMAND) return types.Runtime( name=name, pretrained_model=None, @@ -568,33 +603,432 @@ def get_train_job_data_type( # -------------------------- -# Tests +# Test Cases # -------------------------- +# Test cases for get_runtime method +GET_RUNTIME_TEST_CASES = [ + TestCase( + name="get_runtime_success", + expected_status=SUCCESS, + config={"name": TORCH_RUNTIME}, + expected_output=create_runtime_type(name=TORCH_RUNTIME), + ), + TestCase( + name="get_runtime_timeout", + expected_status=FAILED, + config={"name": TIMEOUT}, + expected_error=TimeoutError, + ), + TestCase( + name="get_runtime_error", + expected_status=FAILED, + config={"name": RUNTIME}, + expected_error=RuntimeError, + ), +] + +# Test cases for train method +TRAIN_TEST_CASES = [ + # Basic functionality + TestCase( + name="train_with_defaults", + expected_status=SUCCESS, + config={}, + expected_output=get_train_job( + runtime_name=TORCH_RUNTIME, + train_job_name=BASIC_TRAIN_JOB_NAME, + ), + ), + TestCase( + name="train_with_builtin_trainer", + expected_status=SUCCESS, + config={ + "trainer": types.BuiltinTrainer( + config=types.TorchTuneConfig( + num_nodes=2, + batch_size=2, + epochs=2, + loss=types.Loss.CEWithChunkedOutputLoss, + ) + ), + "runtime": create_runtime_type( + name=TORCH_TUNE_RUNTIME, trainer_type=types.TrainerType.BUILTIN_TRAINER + ), + }, + expected_output=get_train_job( + runtime_name=TORCH_TUNE_RUNTIME, + train_job_name=TRAIN_JOB_WITH_BUILT_IN_TRAINER, + train_job_trainer=get_builtin_trainer(), + ), + ), + TestCase( + name="train_with_custom_trainer", + expected_status=SUCCESS, + config={ + "trainer": types.CustomTrainer( + func=lambda: print("Hello World"), + func_args={"learning_rate": 0.001, "batch_size": 32}, + packages_to_install=["torch", "numpy"], + pip_index_urls=constants.DEFAULT_PIP_INDEX_URLS, + num_nodes=2, + ) + }, + expected_output=get_train_job( + runtime_name=TORCH_RUNTIME, + train_job_name=TRAIN_JOB_WITH_CUSTOM_TRAINER, + train_job_trainer=get_custom_trainer( + pip_index_urls=constants.DEFAULT_PIP_INDEX_URLS, + packages_to_install=["torch", "numpy"], + ), + ), + ), + TestCase( + name="train_with_custom_trainer_and_env", + expected_status=SUCCESS, + config={ + "trainer": types.CustomTrainer( + func=lambda: print("Hello World"), + func_args={"learning_rate": 0.001, "batch_size": 32}, + packages_to_install=["torch", "numpy"], + pip_index_urls=constants.DEFAULT_PIP_INDEX_URLS, + num_nodes=2, + env={ + "TEST_ENV": "test_value", + "ANOTHER_ENV": "another_value", + }, + ) + }, + expected_output=get_train_job( + runtime_name=TORCH_RUNTIME, + train_job_name=TRAIN_JOB_WITH_CUSTOM_TRAINER, + train_job_trainer=get_custom_trainer( + env=[ + models.IoK8sApiCoreV1EnvVar(name="TEST_ENV", value="test_value"), + models.IoK8sApiCoreV1EnvVar(name="ANOTHER_ENV", value="another_value"), + ], + pip_index_urls=constants.DEFAULT_PIP_INDEX_URLS, + packages_to_install=["torch", "numpy"], + ), + ), + ), + # Options tests + TestCase( + name="train_with_labels_and_annotations", + expected_status=SUCCESS, + config={ + "labels": { + "kueue.x-k8s.io/queue-name": "ml-queue", + "team": "ml-engineering", + }, + "annotations": { + "experiment.id": "exp-001", + "description": "Test training job", + }, + }, + expected_output=get_train_job( + runtime_name=TORCH_RUNTIME, + train_job_name=BASIC_TRAIN_JOB_NAME, + labels={ + "kueue.x-k8s.io/queue-name": "ml-queue", + "team": "ml-engineering", + }, + annotations={ + "experiment.id": "exp-001", + "description": "Test training job", + }, + ), + ), + TestCase( + name="train_with_labels_only", + expected_status=SUCCESS, + config={ + "labels": {"priority": "high"}, + }, + expected_output=get_train_job( + runtime_name=TORCH_RUNTIME, + train_job_name=BASIC_TRAIN_JOB_NAME, + labels={"priority": "high"}, + ), + ), + TestCase( + name="train_with_annotations_only", + expected_status=SUCCESS, + config={ + "annotations": {"created-by": "sdk"}, + }, + expected_output=get_train_job( + runtime_name=TORCH_RUNTIME, + train_job_name=BASIC_TRAIN_JOB_NAME, + annotations={"created-by": "sdk"}, + ), + ), + # Error cases + TestCase( + name="train_timeout_error", + expected_status=FAILED, + config={"namespace": TIMEOUT}, + expected_error=TimeoutError, + ), + TestCase( + name="train_runtime_error", + expected_status=FAILED, + config={"namespace": RUNTIME}, + expected_error=RuntimeError, + ), + TestCase( + name="train_unsupported_trainer_type", + expected_status=FAILED, + config={ + "trainer": types.CustomTrainer( + func=lambda: print("Hello World"), + num_nodes=2, + ), + "runtime": create_runtime_type( + name=TORCH_TUNE_RUNTIME, trainer_type=types.TrainerType.BUILTIN_TRAINER + ), + }, + expected_error=ValueError, + ), +] + -@pytest.mark.parametrize( - "test_case", - [ - TestCase( - name="valid flow with all defaults", - expected_status=SUCCESS, - config={"name": TORCH_RUNTIME}, - expected_output=create_runtime_type(name=TORCH_RUNTIME), +def get_custom_trainer_with_overrides( + image: Optional[str] = None, + command: Optional[list[str]] = None, + args: Optional[list[str]] = None, + resources_per_node: Optional[dict] = None, + **kwargs, +) -> models.TrainerV1alpha1Trainer: + """Helper to create trainer with container overrides.""" + trainer = get_custom_trainer(**kwargs) + + if image: + trainer.image = image + if command: + trainer.command = command + if args: + trainer.args = args + if resources_per_node: + trainer.resources_per_node = utils.get_resources_per_node(resources_per_node) + + return trainer + + +# Test cases for backend validation +BACKEND_VALIDATION_TEST_CASES = [ + TestCase( + name="validate_compatible_options", + expected_status=SUCCESS, + config={ + "options": [ + WithLabels({"app": "test"}), + WithAnnotations({"desc": "test"}), + ] + }, + expected_output=None, # No exception expected + ), + TestCase( + name="validate_empty_options", + expected_status=SUCCESS, + config={"options": []}, + expected_output=None, + ), + TestCase( + name="validate_none_options", + expected_status=SUCCESS, + config={"options": None}, + expected_output=None, + ), +] + + +@pytest.mark.parametrize("test_case", BACKEND_VALIDATION_TEST_CASES) +def test_backend_validation(kubernetes_backend, test_case): + """Test KubernetesBackend option validation.""" + print("Executing test:", test_case.name) + try: + options = test_case.config.get("options") + kubernetes_backend.validate_options(options) + assert test_case.expected_status == SUCCESS + except Exception as e: + assert type(e) is test_case.expected_error + print("test execution complete") + + +def test_backend_capabilities(kubernetes_backend): + """Test KubernetesBackend capabilities property.""" + from kubeflow.trainer.backends.base import KUBERNETES_CAPABILITIES + + assert kubernetes_backend.capabilities == KUBERNETES_CAPABILITIES + + +# Test cases for new trainer container options +TRAINER_OPTIONS_TEST_CASES = [ + TestCase( + name="train_with_trainer_image_option", + expected_status=SUCCESS, + config={ + "trainer": types.CustomTrainer(num_nodes=2), + "options": [ + WithTrainerImage("custom/pytorch:latest"), + ], + }, + expected_output=get_train_job( + runtime_name=TORCH_RUNTIME, + train_job_name=BASIC_TRAIN_JOB_NAME, + train_job_trainer=models.TrainerV1alpha1Trainer( + image="custom/pytorch:latest", + num_nodes=2, + ), ), - TestCase( - name="timeout error when getting runtime", - expected_status=FAILED, - config={"name": TIMEOUT}, - expected_error=TimeoutError, + ), + TestCase( + name="train_with_trainer_command_option", + expected_status=SUCCESS, + config={ + "trainer": types.CustomTrainer(num_nodes=2), + "options": [ + WithTrainerCommand(["python", "train.py"]), + ], + }, + expected_output=get_train_job( + runtime_name=TORCH_RUNTIME, + train_job_name=BASIC_TRAIN_JOB_NAME, + train_job_trainer=get_custom_trainer_with_overrides(command=["python", "train.py"]), ), - TestCase( - name="runtime error when getting runtime", - expected_status=FAILED, - config={"name": RUNTIME}, - expected_error=RuntimeError, + ), + TestCase( + name="train_with_trainer_args_option", + expected_status=SUCCESS, + config={ + "trainer": types.CustomTrainer(num_nodes=2), + "options": [ + WithTrainerArgs(["--epochs", "10", "--lr", "0.001"]), + ], + }, + expected_output=get_train_job( + runtime_name=TORCH_RUNTIME, + train_job_name=BASIC_TRAIN_JOB_NAME, + train_job_trainer=models.TrainerV1alpha1Trainer( + args=["--epochs", "10", "--lr", "0.001"], + num_nodes=2, + ), ), - ], -) + ), + TestCase( + name="train_with_all_trainer_options", + expected_status=SUCCESS, + config={ + "trainer": types.CustomTrainer(num_nodes=2), + "options": [ + WithTrainerImage("custom/pytorch:2.0"), + WithTrainerCommand(["python", "-m", "torch.distributed.run"]), + WithTrainerArgs(["train.py", "--epochs", "5"]), + ], + }, + expected_output=get_train_job( + runtime_name=TORCH_RUNTIME, + train_job_name=BASIC_TRAIN_JOB_NAME, + train_job_trainer=get_custom_trainer_with_overrides( + image="custom/pytorch:2.0", + command=["python", "-m", "torch.distributed.run"], + args=["train.py", "--epochs", "5"], + ), + ), + ), + TestCase( + name="train_container_only_no_function", + expected_status=SUCCESS, + config={ + "trainer": types.CustomTrainer( + # No func parameter - container only + num_nodes=2, + resources_per_node={"cpu": "2", "memory": "4Gi"}, + ), + "options": [ + WithTrainerImage("python:3.11"), + WithTrainerCommand(["python", "-c"]), + WithTrainerArgs(["print('Container-only training!')"]), + ], + }, + expected_output=get_train_job( + runtime_name=TORCH_RUNTIME, + train_job_name=BASIC_TRAIN_JOB_NAME, + train_job_trainer=get_custom_trainer_with_overrides( + image="python:3.11", + command=["python", "-c"], + args=["print('Container-only training!')"], + resources_per_node={"cpu": "2", "memory": "4Gi"}, + ), + ), + ), +] + + +def get_custom_trainer_with_overrides( + image: Optional[str] = None, + command: Optional[list[str]] = None, + args: Optional[list[str]] = None, + resources_per_node: Optional[dict] = None, + **kwargs, +) -> models.TrainerV1alpha1Trainer: + """Helper to create trainer with container overrides.""" + trainer = get_custom_trainer(**kwargs) + + if image: + trainer.image = image + if command: + trainer.command = command + if args: + trainer.args = args + if resources_per_node: + trainer.resources_per_node = utils.get_resources_per_node(resources_per_node) + + return trainer + + +@pytest.mark.parametrize("test_case", TRAINER_OPTIONS_TEST_CASES) +def test_train_with_trainer_options(kubernetes_backend, test_case): + """Test KubernetesBackend.train with new trainer container options.""" + print("Executing test:", test_case.name) + try: + kubernetes_backend.namespace = test_case.config.get("namespace", DEFAULT_NAMESPACE) + runtime = kubernetes_backend.get_runtime(test_case.config.get("runtime", TORCH_RUNTIME)) + + options = test_case.config.get("options", []) + + train_job_name = kubernetes_backend.train( + runtime=runtime, + trainer=test_case.config.get("trainer", None), + options=options, + ) + + assert test_case.expected_status == SUCCESS + + # Verify the expected output + expected_output = test_case.expected_output + expected_output.metadata.name = train_job_name + + kubernetes_backend.custom_api.create_namespaced_custom_object.assert_called_with( + constants.GROUP, + constants.VERSION, + DEFAULT_NAMESPACE, + constants.TRAINJOB_PLURAL, + expected_output.to_dict(), + ) + + except Exception as e: + assert type(e) is test_case.expected_error + print("test execution complete") + + +# -------------------------- +# Tests +# -------------------------- + + +@pytest.mark.parametrize("test_case", GET_RUNTIME_TEST_CASES) def test_get_runtime(kubernetes_backend, test_case): """Test KubernetesBackend.get_runtime with basic success path.""" print("Executing test:", test_case.name) @@ -690,6 +1124,7 @@ def test_get_runtime_packages(kubernetes_backend, test_case): expected_output=get_train_job( runtime_name=TORCH_RUNTIME, train_job_name=BASIC_TRAIN_JOB_NAME, + train_job_trainer=get_basic_custom_trainer(), ), ), TestCase( @@ -704,7 +1139,9 @@ def test_get_runtime_packages(kubernetes_backend, test_case): loss=types.Loss.CEWithChunkedOutputLoss, ) ), - "runtime": TORCH_TUNE_RUNTIME, + "runtime": create_runtime_type( + name=TORCH_TUNE_RUNTIME, trainer_type=types.TrainerType.BUILTIN_TRAINER + ), }, expected_output=get_train_job( runtime_name=TORCH_TUNE_RUNTIME, @@ -819,12 +1256,110 @@ def test_get_runtime_packages(kubernetes_backend, test_case): func=lambda: print("Hello World"), num_nodes=2, ), - "runtime": TORCH_TUNE_RUNTIME, + "runtime": create_runtime_type( + name=TORCH_TUNE_RUNTIME, trainer_type=types.TrainerType.BUILTIN_TRAINER + ), }, expected_error=ValueError, ), + # Test cases using the new Options pattern + TestCase( + name="valid flow with WithLabels option", + expected_status=SUCCESS, + config={ + "options": [WithLabels({"team": "ml-platform", "project": "training"})], + }, + expected_output=get_train_job( + runtime_name=TORCH_RUNTIME, + train_job_name=BASIC_TRAIN_JOB_NAME, + train_job_trainer=get_basic_custom_trainer(), + labels={"team": "ml-platform", "project": "training"}, + ), + ), + TestCase( + name="valid flow with multiple options", + expected_status=SUCCESS, + config={ + "options": [ + WithLabels({"team": "ml-platform"}), + WithAnnotations({"created-by": "sdk"}), + ], + }, + expected_output=get_train_job( + runtime_name=TORCH_RUNTIME, + train_job_name=BASIC_TRAIN_JOB_NAME, + train_job_trainer=get_basic_custom_trainer(), + labels={"team": "ml-platform"}, + annotations={"created-by": "sdk"}, + ), + ), ], ) +def test_train_validation(kubernetes_backend, test_case): + """Test KubernetesBackend.train validation with various scenarios.""" + print("Executing test:", test_case.name) + try: + kubernetes_backend.namespace = test_case.config.get("namespace", DEFAULT_NAMESPACE) + + if test_case.expected_status == SUCCESS: + train_job_name = kubernetes_backend.train( + trainer=test_case.config.get( + "trainer", types.CustomTrainer(func=lambda: print("Hello World")) + ), + runtime=test_case.config.get("runtime", create_runtime_type(name=TORCH_RUNTIME)), + options=test_case.config.get("options", []), + ) + + # Set the expected output's name to the actual job name + expected_output = test_case.expected_output + expected_output.metadata.name = train_job_name + + # Verify the mock was called with the expected output + kubernetes_backend.custom_api.create_namespaced_custom_object.assert_called_with( + constants.GROUP, + constants.VERSION, + DEFAULT_NAMESPACE, + constants.TRAINJOB_PLURAL, + expected_output.to_dict(), + ) + else: + with pytest.raises(test_case.expected_error): + kubernetes_backend.train( + trainer=test_case.config.get( + "trainer", types.CustomTrainer(func=lambda: print("Hello World")) + ), + runtime=test_case.config.get( + "runtime", create_runtime_type(name=TORCH_RUNTIME) + ), + options=test_case.config.get("options", []), + ) + except Exception as e: + print(f"Test failed with error: {e}") + raise + print("test execution complete") + + +TRAIN_TEST_CASES = [ + TestCase( + name="basic train with options", + expected_status=SUCCESS, + config={ + "options": [ + WithLabels({"team": "ml-platform"}), + WithAnnotations({"created-by": "sdk"}), + ], + }, + expected_output=get_train_job( + runtime_name=TORCH_RUNTIME, + train_job_name=BASIC_TRAIN_JOB_NAME, + labels={"team": "ml-platform"}, + annotations={"created-by": "sdk"}, + ), + ), +] + + +@pytest.mark.parametrize("test_case", TRAIN_TEST_CASES) def test_train(kubernetes_backend, test_case): """Test KubernetesBackend.train with basic success path.""" print("Executing test:", test_case.name) @@ -832,8 +1367,21 @@ def test_train(kubernetes_backend, test_case): kubernetes_backend.namespace = test_case.config.get("namespace", DEFAULT_NAMESPACE) runtime = kubernetes_backend.get_runtime(test_case.config.get("runtime", TORCH_RUNTIME)) + options = test_case.config.get("options", []) + if test_case.config.get("labels"): + from kubeflow.trainer.backends.kubernetes.options import WithLabels + + options.append(WithLabels(test_case.config["labels"])) + + if test_case.config.get("annotations"): + from kubeflow.trainer.backends.kubernetes.options import WithAnnotations + + options.append(WithAnnotations(test_case.config["annotations"])) + train_job_name = kubernetes_backend.train( - runtime=runtime, trainer=test_case.config.get("trainer", None) + runtime=runtime, + trainer=test_case.config.get("trainer", None), + options=options, ) assert test_case.expected_status == SUCCESS diff --git a/kubeflow/trainer/backends/kubernetes/options.py b/kubeflow/trainer/backends/kubernetes/options.py new file mode 100644 index 000000000..c43b04fac --- /dev/null +++ b/kubeflow/trainer/backends/kubernetes/options.py @@ -0,0 +1,166 @@ +# Copyright 2025 The Kubeflow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Kubernetes-specific training options for the Kubeflow Trainer SDK.""" + +from dataclasses import dataclass +from typing import Optional + +from kubeflow.trainer.backends.base import OptionType + + +@dataclass +class WithLabels: + """Add labels to the TrainJob resource metadata (.metadata.labels).""" + + labels: dict[str, str] + + @property + def option_type(self) -> OptionType: + return OptionType.LABELS + + def __call__(self, job_spec: dict) -> None: + """Apply labels to the job specification.""" + metadata = job_spec.setdefault("metadata", {}) + metadata["labels"] = self.labels + + +@dataclass +class WithAnnotations: + """Add annotations to the TrainJob resource metadata (.metadata.annotations).""" + + annotations: dict[str, str] + + @property + def option_type(self) -> OptionType: + return OptionType.ANNOTATIONS + + def __call__(self, job_spec: dict) -> None: + """Apply annotations to the job specification.""" + metadata = job_spec.setdefault("metadata", {}) + metadata["annotations"] = self.annotations + + +@dataclass +class WithName: + """Set a custom name for the TrainJob resource (.metadata.name).""" + + name: str + + @property + def option_type(self) -> OptionType: + return OptionType.NAME + + def __call__(self, job_spec: dict) -> None: + """Apply custom name to the job specification.""" + metadata = job_spec.setdefault("metadata", {}) + metadata["name"] = self.name + + +@dataclass +class PodSpecOverride: + """Configuration for overriding pod specifications for specific job types.""" + + target_jobs: list[str] + volumes: Optional[list[dict]] = None + containers: Optional[list[dict]] = None + init_containers: Optional[list[dict]] = None + node_selector: Optional[dict[str, str]] = None + service_account_name: Optional[str] = None + tolerations: Optional[list[dict]] = None + + +@dataclass +class WithPodSpecOverrides: + """Add pod specification overrides to the TrainJob (.spec.podSpecOverrides).""" + + overrides: list[PodSpecOverride] + + @property + def option_type(self) -> OptionType: + return OptionType.POD_SPEC_OVERRIDES + + def __call__(self, job_spec: dict) -> None: + """Apply pod spec overrides to the job specification.""" + spec = job_spec.setdefault("spec", {}) + spec["podSpecOverrides"] = [] + + for override in self.overrides: + api_override = {"targetJobs": [{"name": job} for job in override.target_jobs]} + + if override.volumes: + api_override["volumes"] = override.volumes + if override.containers: + api_override["containers"] = override.containers + if override.init_containers: + api_override["initContainers"] = override.init_containers + if override.node_selector: + api_override["nodeSelector"] = override.node_selector + if override.service_account_name: + api_override["serviceAccountName"] = override.service_account_name + if override.tolerations: + api_override["tolerations"] = override.tolerations + + spec["podSpecOverrides"].append(api_override) + + +@dataclass +class WithTrainerImage: + """Override the trainer container image (.spec.trainer.image).""" + + image: str + + @property + def option_type(self) -> OptionType: + return OptionType.TRAINER_IMAGE + + def __call__(self, job_spec: dict) -> None: + """Apply trainer image override to the job specification.""" + spec = job_spec.setdefault("spec", {}) + trainer_spec = spec.setdefault("trainer", {}) + trainer_spec["image"] = self.image + + +@dataclass +class WithTrainerCommand: + """Override the trainer container command (.spec.trainer.command).""" + + command: list[str] + + @property + def option_type(self) -> OptionType: + return OptionType.TRAINER_COMMAND + + def __call__(self, job_spec: dict) -> None: + """Apply trainer command override to the job specification.""" + spec = job_spec.setdefault("spec", {}) + trainer_spec = spec.setdefault("trainer", {}) + trainer_spec["command"] = self.command + + +@dataclass +class WithTrainerArgs: + """Override the trainer container arguments (.spec.trainer.args).""" + + args: list[str] + + @property + def option_type(self) -> OptionType: + return OptionType.TRAINER_ARGS + + def __call__(self, job_spec: dict) -> None: + """Apply trainer args override to the job specification.""" + spec = job_spec.setdefault("spec", {}) + trainer_spec = spec.setdefault("trainer", {}) + trainer_spec["args"] = self.args diff --git a/kubeflow/trainer/backends/kubernetes/options_test.py b/kubeflow/trainer/backends/kubernetes/options_test.py new file mode 100644 index 000000000..be96f9635 --- /dev/null +++ b/kubeflow/trainer/backends/kubernetes/options_test.py @@ -0,0 +1,215 @@ +# Copyright 2025 The Kubeflow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for Kubernetes backend options.""" + +from kubeflow.trainer.backends.base import KUBERNETES_CAPABILITIES, OptionType +from kubeflow.trainer.backends.kubernetes.options import ( + PodSpecOverride, + WithAnnotations, + WithLabels, + WithName, + WithPodSpecOverrides, + WithTrainerArgs, + WithTrainerCommand, + WithTrainerImage, +) + + +class TestKubernetesOptionTypes: + """Test Kubernetes option types.""" + + def test_with_labels_option_type(self): + """Test WithLabels has correct option type.""" + option = WithLabels({"app": "test", "version": "v1"}) + assert option.option_type == OptionType.LABELS + + def test_with_annotations_option_type(self): + """Test WithAnnotations has correct option type.""" + option = WithAnnotations({"description": "test job"}) + assert option.option_type == OptionType.ANNOTATIONS + + def test_with_name_option_type(self): + """Test WithName has correct option type.""" + option = WithName("test-job") + assert option.option_type == OptionType.NAME + + def test_with_pod_spec_overrides_option_type(self): + """Test WithPodSpecOverrides has correct option type.""" + overrides = [PodSpecOverride(target_jobs=["node"])] + option = WithPodSpecOverrides(overrides) + assert option.option_type == OptionType.POD_SPEC_OVERRIDES + + def test_trainer_options_types(self): + """Test trainer options have correct option types.""" + image_option = WithTrainerImage("custom:latest") + command_option = WithTrainerCommand(["python", "train.py"]) + args_option = WithTrainerArgs(["--epochs", "10"]) + + assert image_option.option_type == OptionType.TRAINER_IMAGE + assert command_option.option_type == OptionType.TRAINER_COMMAND + assert args_option.option_type == OptionType.TRAINER_ARGS + + +class TestKubernetesOptionApplication: + """Test Kubernetes option application behavior.""" + + def test_labels_application(self): + """Test WithLabels applies correctly to job spec.""" + option = WithLabels({"app": "test", "version": "v1"}) + + job_spec = {} + option(job_spec) + + expected = {"metadata": {"labels": {"app": "test", "version": "v1"}}} + assert job_spec == expected + + def test_annotations_application(self): + """Test WithAnnotations applies correctly to job spec.""" + option = WithAnnotations({"description": "test job", "owner": "team"}) + + job_spec = {} + option(job_spec) + + expected = {"metadata": {"annotations": {"description": "test job", "owner": "team"}}} + assert job_spec == expected + + def test_name_application(self): + """Test WithName applies correctly to job spec.""" + option = WithName("my-training-job") + + job_spec = {} + option(job_spec) + + expected = {"metadata": {"name": "my-training-job"}} + assert job_spec == expected + + def test_trainer_image_application(self): + """Test WithTrainerImage applies correctly to job spec.""" + option = WithTrainerImage("custom/pytorch:latest") + + job_spec = {} + option(job_spec) + + expected = {"spec": {"trainer": {"image": "custom/pytorch:latest"}}} + assert job_spec == expected + + def test_trainer_command_application(self): + """Test WithTrainerCommand applies correctly to job spec.""" + option = WithTrainerCommand(["python", "train.py"]) + + job_spec = {} + option(job_spec) + + expected = {"spec": {"trainer": {"command": ["python", "train.py"]}}} + assert job_spec == expected + + def test_trainer_args_application(self): + """Test WithTrainerArgs applies correctly to job spec.""" + option = WithTrainerArgs(["--epochs", "10", "--lr", "0.001"]) + + job_spec = {} + option(job_spec) + + expected = {"spec": {"trainer": {"args": ["--epochs", "10", "--lr", "0.001"]}}} + assert job_spec == expected + + def test_multiple_options_override_behavior(self): + """Test multiple options with override semantics.""" + job_spec = {} + + # Apply first set of labels + WithLabels({"app": "trainer", "env": "dev"})(job_spec) + # Apply second set of labels (should override) + WithLabels({"app": "ml-trainer", "version": "v1.0"})(job_spec) + # Apply annotations + WithAnnotations({"description": "test"})(job_spec) + + expected = { + "metadata": { + "labels": {"app": "ml-trainer", "version": "v1.0"}, # Override behavior + "annotations": {"description": "test"}, + } + } + assert job_spec == expected + + +class TestKubernetesCapabilities: + """Test Kubernetes backend capabilities.""" + + def test_kubernetes_capabilities_support(self): + """Test Kubernetes backend capabilities.""" + caps = KUBERNETES_CAPABILITIES + + # Test supported options + assert caps.supports(OptionType.LABELS) + assert caps.supports(OptionType.ANNOTATIONS) + assert caps.supports(OptionType.NAME) + assert caps.supports(OptionType.POD_SPEC_OVERRIDES) + assert caps.supports(OptionType.TRAINER_IMAGE) + assert caps.supports(OptionType.TRAINER_COMMAND) + assert caps.supports(OptionType.TRAINER_ARGS) + + def test_kubernetes_option_compatibility(self): + """Test Kubernetes option compatibility checking.""" + caps = KUBERNETES_CAPABILITIES + + # Test all Kubernetes options are supported + labels_opt = WithLabels({"app": "test"}) + annotations_opt = WithAnnotations({"desc": "test"}) + name_opt = WithName("test-job") + pod_opt = WithPodSpecOverrides([PodSpecOverride(target_jobs=["node"])]) + image_opt = WithTrainerImage("custom:latest") + command_opt = WithTrainerCommand(["python", "train.py"]) + args_opt = WithTrainerArgs(["--epochs", "10"]) + + assert caps.supports_option(labels_opt) + assert caps.supports_option(annotations_opt) + assert caps.supports_option(name_opt) + assert caps.supports_option(pod_opt) + assert caps.supports_option(image_opt) + assert caps.supports_option(command_opt) + assert caps.supports_option(args_opt) + + +class TestPodSpecOverride: + """Test PodSpecOverride dataclass.""" + + def test_pod_spec_override_creation(self): + """Test PodSpecOverride creation with various fields.""" + override = PodSpecOverride( + target_jobs=["node", "worker"], + volumes=[{"name": "data", "emptyDir": {}}], + containers=[{"name": "node", "volumeMounts": [{"name": "data", "mountPath": "/data"}]}], + node_selector={"gpu": "true"}, + service_account_name="training-sa", + tolerations=[{"key": "gpu", "operator": "Exists", "effect": "NoSchedule"}], + ) + + assert override.target_jobs == ["node", "worker"] + assert override.volumes == [{"name": "data", "emptyDir": {}}] + assert override.node_selector == {"gpu": "true"} + assert override.service_account_name == "training-sa" + + def test_pod_spec_override_minimal(self): + """Test PodSpecOverride with minimal required fields.""" + override = PodSpecOverride(target_jobs=["node"]) + + assert override.target_jobs == ["node"] + assert override.volumes is None + assert override.containers is None + assert override.init_containers is None + assert override.node_selector is None + assert override.service_account_name is None + assert override.tolerations is None diff --git a/kubeflow/trainer/backends/localprocess/backend.py b/kubeflow/trainer/backends/localprocess/backend.py index d10a5b10f..aac33159a 100644 --- a/kubeflow/trainer/backends/localprocess/backend.py +++ b/kubeflow/trainer/backends/localprocess/backend.py @@ -20,7 +20,7 @@ from typing import Optional, Union import uuid -from kubeflow.trainer.backends.base import ExecutionBackend +from kubeflow.trainer.backends.base import LOCAL_PROCESS_CAPABILITIES, ExecutionBackend from kubeflow.trainer.backends.localprocess import utils as local_utils from kubeflow.trainer.backends.localprocess.constants import local_runtimes from kubeflow.trainer.backends.localprocess.job import LocalJob @@ -44,6 +44,11 @@ def __init__( self.__local_jobs: list[LocalBackendJobs] = [] self.cfg = cfg + @property + def capabilities(self): + """Return the capabilities of this backend.""" + return LOCAL_PROCESS_CAPABILITIES + def list_runtimes(self) -> list[types.Runtime]: return [self.__convert_local_runtime_to_runtime(local_runtime=rt) for rt in local_runtimes] @@ -73,7 +78,12 @@ def train( runtime: Optional[types.Runtime] = None, initializer: Optional[types.Initializer] = None, trainer: Optional[Union[types.CustomTrainer, types.BuiltinTrainer]] = None, + options: Optional[list] = None, ) -> str: + self.validate_options(options) + + if runtime is None: + raise ValueError("Runtime must be provided for LocalProcessBackend") # set train job name train_job_name = random.choice(string.ascii_lowercase) + uuid.uuid4().hex[:11] # localprocess backend only supports CustomTrainer @@ -154,7 +164,11 @@ def get_job(self, name: str) -> Optional[types.TrainJob]: name=_job.name, creation_timestamp=_job.created, steps=[ - types.Step(name=_step.step_name, pod_name=_step.step_name, status=_step.job.status) + types.Step( + name=_step.step_name, + pod_name=_step.step_name, + status=_step.job.status, + ) for _step in _job.steps ], runtime=_job.runtime, @@ -195,7 +209,10 @@ def wait_for_job_status( raise ValueError(f"No TrainJob with name {name}") # find a better implementation for this for _step in _job.steps: - if _step.job.status in [constants.TRAINJOB_RUNNING, constants.TRAINJOB_CREATED]: + if _step.job.status in [ + constants.TRAINJOB_RUNNING, + constants.TRAINJOB_CREATED, + ]: _step.job.join(timeout=timeout) return self.get_job(name) diff --git a/kubeflow/trainer/backends/localprocess/backend_test.py b/kubeflow/trainer/backends/localprocess/backend_test.py new file mode 100644 index 000000000..89f6bacbe --- /dev/null +++ b/kubeflow/trainer/backends/localprocess/backend_test.py @@ -0,0 +1,126 @@ +# Copyright 2025 The Kubeflow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests for the LocalProcessBackend class in the Kubeflow Trainer SDK. +""" + +import pytest + +from kubeflow.trainer.backends.kubernetes.options import ( + WithAnnotations, + WithLabels, + WithPodSpecOverrides, +) +from kubeflow.trainer.backends.localprocess.backend import LocalProcessBackend +from kubeflow.trainer.backends.localprocess.types import LocalProcessBackendConfig +from kubeflow.trainer.test.common import FAILED, SUCCESS, TestCase + +# Test cases for LocalProcess backend validation +LOCAL_BACKEND_VALIDATION_TEST_CASES = [ + TestCase( + name="validate_empty_options", + expected_status=SUCCESS, + config={"options": []}, + expected_output=None, + ), + TestCase( + name="validate_none_options", + expected_status=SUCCESS, + config={"options": None}, + expected_output=None, + ), + TestCase( + name="reject_incompatible_labels", + expected_status=FAILED, + config={"options": [WithLabels({"app": "test"})]}, + expected_error=ValueError, + ), + TestCase( + name="reject_multiple_incompatible_options", + expected_status=FAILED, + config={ + "options": [ + WithLabels({"app": "test"}), + WithAnnotations({"desc": "test"}), + WithPodSpecOverrides([{}]), + ] + }, + expected_error=ValueError, + ), +] + + +@pytest.fixture +def local_backend(): + """Create LocalProcessBackend for testing.""" + cfg = LocalProcessBackendConfig() + return LocalProcessBackend(cfg) + + +@pytest.mark.parametrize("test_case", LOCAL_BACKEND_VALIDATION_TEST_CASES) +def test_local_backend_validation(local_backend, test_case): + """Test LocalProcessBackend option validation.""" + print("Executing test:", test_case.name) + try: + options = test_case.config.get("options") + local_backend.validate_options(options) + assert test_case.expected_status == SUCCESS + except Exception as e: + assert type(e) is test_case.expected_error + if test_case.name == "reject_incompatible_labels": + error_msg = str(e) + assert "The following options are not compatible with this backend" in error_msg + assert "WithLabels (labels)" in error_msg + elif test_case.name == "reject_multiple_incompatible_options": + error_msg = str(e) + assert "The following options are not compatible with this backend" in error_msg + assert "WithLabels (labels)" in error_msg + assert "WithAnnotations (annotations)" in error_msg + assert "WithPodSpecOverrides (pod_spec_overrides)" in error_msg + print("test execution complete") + + +def test_local_backend_capabilities(local_backend): + """Test LocalProcessBackend capabilities property.""" + from kubeflow.trainer.backends.base import LOCAL_PROCESS_CAPABILITIES + + assert local_backend.capabilities == LOCAL_PROCESS_CAPABILITIES + + +class TestLocalBackendValidationFlow: + """Test the complete validation flow for LocalProcess backend.""" + + def test_validation_happens_before_processing(self, local_backend): + """Test that validation happens before any processing.""" + incompatible_options = [WithLabels({"app": "test"})] + + with pytest.raises(ValueError) as exc_info: + local_backend.validate_options(incompatible_options) + + assert "The following options are not compatible with this backend" in str(exc_info.value) + + def test_validation_early_exit_behavior(self, local_backend): + """Test that validation reports all incompatible options.""" + mixed_options = [ + WithLabels({"app": "test"}), + WithAnnotations({"desc": "test"}), + ] + + with pytest.raises(ValueError) as exc_info: + local_backend.validate_options(mixed_options) + + error_msg = str(exc_info.value) + assert "WithLabels (labels)" in error_msg + assert "WithAnnotations (annotations)" in error_msg diff --git a/kubeflow/trainer/backends/localprocess/options.py b/kubeflow/trainer/backends/localprocess/options.py new file mode 100644 index 000000000..2a8da3200 --- /dev/null +++ b/kubeflow/trainer/backends/localprocess/options.py @@ -0,0 +1,49 @@ +# Copyright 2025 The Kubeflow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""LocalProcess-specific training options for the Kubeflow Trainer SDK.""" + +from dataclasses import dataclass + +from kubeflow.trainer.backends.base import OptionType + + +@dataclass +class WithProcessTimeout: + """Set a timeout for the local training process.""" + + timeout_seconds: int + + @property + def option_type(self) -> OptionType: + return OptionType.NAME # Placeholder - would need new enum value + + def __call__(self, config: dict) -> None: + """Apply timeout to local process configuration.""" + config["timeout_seconds"] = self.timeout_seconds + + +@dataclass +class WithWorkingDirectory: + """Set the working directory for the local training process.""" + + working_dir: str + + @property + def option_type(self) -> OptionType: + return OptionType.NAME # Placeholder - would need new enum value + + def __call__(self, config: dict) -> None: + """Apply working directory to local process configuration.""" + config["working_dir"] = self.working_dir diff --git a/kubeflow/trainer/backends/localprocess/options_test.py b/kubeflow/trainer/backends/localprocess/options_test.py new file mode 100644 index 000000000..e7d725517 --- /dev/null +++ b/kubeflow/trainer/backends/localprocess/options_test.py @@ -0,0 +1,110 @@ +# Copyright 2025 The Kubeflow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for LocalProcess backend options.""" + +from kubeflow.trainer.backends.base import LOCAL_PROCESS_CAPABILITIES, OptionType +from kubeflow.trainer.backends.localprocess.options import WithProcessTimeout, WithWorkingDirectory + + +class TestLocalProcessOptionTypes: + """Test LocalProcess option types.""" + + def test_with_process_timeout_type(self): + """Test WithProcessTimeout has correct option type.""" + option = WithProcessTimeout(timeout_seconds=300) + assert option.option_type == OptionType.NAME # Placeholder + + def test_with_working_directory_type(self): + """Test WithWorkingDirectory has correct option type.""" + option = WithWorkingDirectory(working_dir="/tmp/training") + assert option.option_type == OptionType.NAME # Placeholder + + +class TestLocalProcessOptionApplication: + """Test LocalProcess option application behavior.""" + + def test_process_timeout_application(self): + """Test WithProcessTimeout applies correctly to local config.""" + option = WithProcessTimeout(timeout_seconds=600) + + config = {} + option(config) + + expected = {"timeout_seconds": 600} + assert config == expected + + def test_working_directory_application(self): + """Test WithWorkingDirectory applies correctly to local config.""" + option = WithWorkingDirectory(working_dir="/home/user/training") + + config = {} + option(config) + + expected = {"working_dir": "/home/user/training"} + assert config == expected + + def test_multiple_local_options(self): + """Test multiple LocalProcess options together.""" + config = {} + + WithProcessTimeout(timeout_seconds=300)(config) + WithWorkingDirectory(working_dir="/tmp/work")(config) + + expected = {"timeout_seconds": 300, "working_dir": "/tmp/work"} + assert config == expected + + +class TestLocalProcessCapabilities: + """Test LocalProcess backend capabilities.""" + + def test_local_process_capabilities_minimal(self): + """Test LocalProcess backend has minimal capabilities.""" + caps = LOCAL_PROCESS_CAPABILITIES + + # LocalProcess backend currently supports no standard options + assert not caps.supports(OptionType.LABELS) + assert not caps.supports(OptionType.ANNOTATIONS) + assert not caps.supports(OptionType.POD_SPEC_OVERRIDES) + assert not caps.supports(OptionType.TRAINER_IMAGE) + assert not caps.supports(OptionType.TRAINER_COMMAND) + assert not caps.supports(OptionType.TRAINER_ARGS) + + def test_local_process_empty_options_compatibility(self): + """Test LocalProcess compatibility with empty options.""" + caps = LOCAL_PROCESS_CAPABILITIES + + is_compatible, unsupported = caps.check_compatibility([]) + assert is_compatible + assert unsupported == [] + + +class TestLocalProcessOptionCreation: + """Test LocalProcess option creation and validation.""" + + def test_process_timeout_creation(self): + """Test WithProcessTimeout creation with various values.""" + option = WithProcessTimeout(timeout_seconds=300) + assert option.timeout_seconds == 300 + + option = WithProcessTimeout(timeout_seconds=3600) + assert option.timeout_seconds == 3600 + + def test_working_directory_creation(self): + """Test WithWorkingDirectory creation with various paths.""" + option = WithWorkingDirectory(working_dir="/home/user/training") + assert option.working_dir == "/home/user/training" + + option = WithWorkingDirectory(working_dir="./training") + assert option.working_dir == "./training" diff --git a/kubeflow/trainer/backends/localprocess/utils.py b/kubeflow/trainer/backends/localprocess/utils.py index 1c18676e8..94c20a3a7 100644 --- a/kubeflow/trainer/backends/localprocess/utils.py +++ b/kubeflow/trainer/backends/localprocess/utils.py @@ -122,7 +122,8 @@ def get_local_runtime_trainer( Get the LocalRuntimeTrainer object. """ local_runtime = next( - (rt for rt in local_exec_constants.local_runtimes if rt.name == runtime_name), None + (rt for rt in local_exec_constants.local_runtimes if rt.name == runtime_name), + None, ) if not local_runtime: raise ValueError(f"Runtime {runtime_name} not found") @@ -267,9 +268,11 @@ def get_local_train_job_script( dependency_script = "\n" if trainer.packages_to_install: dependency_script = get_dependencies_command( - pip_index_urls=trainer.pip_index_urls - if trainer.pip_index_urls - else constants.DEFAULT_PIP_INDEX_URLS, + pip_index_urls=( + trainer.pip_index_urls + if trainer.pip_index_urls + else constants.DEFAULT_PIP_INDEX_URLS + ), runtime_packages=runtime_trainer.packages, trainer_packages=trainer.packages_to_install, quiet=False, diff --git a/kubeflow/trainer/types/__init__.py b/kubeflow/trainer/types/__init__.py index e69de29bb..bdb8d563b 100644 --- a/kubeflow/trainer/types/__init__.py +++ b/kubeflow/trainer/types/__init__.py @@ -0,0 +1,53 @@ +# Copyright 2024 The Kubeflow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Kubeflow Trainer SDK types and data structures. +""" + +from kubeflow.trainer.types.types import ( + BuiltinTrainer, + CustomTrainer, + DataFormat, + DataType, + HuggingFaceDatasetInitializer, + HuggingFaceModelInitializer, + Initializer, + Loss, + Runtime, + RuntimeTrainer, + Step, + TorchTuneConfig, + TorchTuneInstructDataset, + TrainerType, + TrainJob, +) + +__all__ = [ + "BuiltinTrainer", + "CustomTrainer", + "DataFormat", + "DataType", + "HuggingFaceDatasetInitializer", + "HuggingFaceModelInitializer", + "Initializer", + "Loss", + "Runtime", + "RuntimeTrainer", + "Step", + "TorchTuneConfig", + "TorchTuneInstructDataset", + "TrainJob", + "TrainerType", +] diff --git a/kubeflow/trainer/types/types.py b/kubeflow/trainer/types/types.py index c32c44834..72c5f433a 100644 --- a/kubeflow/trainer/types/types.py +++ b/kubeflow/trainer/types/types.py @@ -40,7 +40,7 @@ class CustomTrainer: env (`Optional[dict[str, str]]`): The environment variables to set in the training nodes. """ - func: Callable + func: Optional[Callable] = None func_args: Optional[dict] = None packages_to_install: Optional[list[str]] = None pip_index_urls: list[str] = field( diff --git a/kubeflow/trainer/utils/utils.py b/kubeflow/trainer/utils/utils.py index 0cefd0467..f84f89bb1 100644 --- a/kubeflow/trainer/utils/utils.py +++ b/kubeflow/trainer/utils/utils.py @@ -383,15 +383,15 @@ def get_trainer_crd_from_custom_trainer( if trainer.resources_per_node: trainer_crd.resources_per_node = get_resources_per_node(trainer.resources_per_node) - # Add command to the Trainer. - # TODO: Support train function parameters. - trainer_crd.command = get_command_using_train_func( - runtime, - trainer.func, - trainer.func_args, - trainer.pip_index_urls, - trainer.packages_to_install, - ) + # Add command to the Trainer only if a function is provided. + if trainer.func: + trainer_crd.command = get_command_using_train_func( + runtime, + trainer.func, + trainer.func_args, + trainer.pip_index_urls, + trainer.packages_to_install, + ) # Add environment variables to the Trainer. if trainer.env: diff --git a/scripts/gen-changelog.py b/scripts/gen-changelog.py old mode 100755 new mode 100644 From 4aff23f5210ccbdd7be642bedec1dcfbea1024b9 Mon Sep 17 00:00:00 2001 From: Abhijeet Dhumal Date: Wed, 8 Oct 2025 22:16:12 +0530 Subject: [PATCH 2/7] Implement mixin-based options architecture (KubernetesCompatible, LocalProcessCompatible) Signed-off-by: Abhijeet Dhumal --- kubeflow/trainer/api/trainer_client.py | 3 +- kubeflow/trainer/api/trainer_client_test.py | 10 +- kubeflow/trainer/backends/base.py | 84 ++++-------- .../trainer/backends/kubernetes/backend.py | 12 +- .../backends/kubernetes/backend_test.py | 7 +- .../trainer/backends/kubernetes/options.py | 95 +++++++------- .../backends/kubernetes/options_test.py | 121 ++++++++++++------ .../trainer/backends/localprocess/backend.py | 14 +- .../backends/localprocess/backend_test.py | 19 +-- .../trainer/backends/localprocess/options.py | 14 +- .../backends/localprocess/options_test.py | 35 +++-- 11 files changed, 218 insertions(+), 196 deletions(-) diff --git a/kubeflow/trainer/api/trainer_client.py b/kubeflow/trainer/api/trainer_client.py index ea87d16c4..1845dfe2e 100644 --- a/kubeflow/trainer/api/trainer_client.py +++ b/kubeflow/trainer/api/trainer_client.py @@ -16,6 +16,7 @@ import logging from typing import Optional, Union +from kubeflow.trainer.backends.base import CompatibleOption from kubeflow.trainer.backends.kubernetes.backend import KubernetesBackend from kubeflow.trainer.backends.kubernetes.types import KubernetesBackendConfig from kubeflow.trainer.backends.localprocess.backend import ( @@ -96,7 +97,7 @@ def train( runtime: Optional[types.Runtime] = None, initializer: Optional[types.Initializer] = None, trainer: Optional[Union[types.CustomTrainer, types.BuiltinTrainer]] = None, - options: Optional[list] = None, + options: Optional[list[CompatibleOption]] = None, ) -> str: """Create a TrainJob. You can configure the TrainJob using one of these trainers: diff --git a/kubeflow/trainer/api/trainer_client_test.py b/kubeflow/trainer/api/trainer_client_test.py index 85e3709fa..423be3ca0 100644 --- a/kubeflow/trainer/api/trainer_client_test.py +++ b/kubeflow/trainer/api/trainer_client_test.py @@ -46,7 +46,7 @@ def simple_func(): error_msg = str(exc_info.value) assert "The following options are not compatible with this backend" in error_msg - assert "WithLabels (labels)" in error_msg + assert "WithLabels" in error_msg @patch("kubernetes.config.load_kube_config") @patch("kubernetes.client.CustomObjectsApi") @@ -130,8 +130,8 @@ def simple_func(): error_msg = str(exc_info.value) assert "The following options are not compatible with this backend" in error_msg - assert "WithLabels (labels)" in error_msg - assert "WithAnnotations (annotations)" in error_msg + assert "WithLabels" in error_msg + assert "WithAnnotations" in error_msg assert "The following options are not compatible with this backend" in error_msg def test_error_message_does_not_contain_runtime_help_for_option_errors(self): @@ -262,8 +262,8 @@ def simple_func(): error_msg = str(exc_info.value) assert "The following options are not compatible with this backend" in error_msg - assert "WithLabels (labels)" in error_msg - assert "WithAnnotations (annotations)" in error_msg + assert "WithLabels" in error_msg + assert "WithAnnotations" in error_msg def test_none_options_handling(self): """Test that None options are handled correctly.""" diff --git a/kubeflow/trainer/backends/base.py b/kubeflow/trainer/backends/base.py index b2821c9ce..de7c2a3eb 100644 --- a/kubeflow/trainer/backends/base.py +++ b/kubeflow/trainer/backends/base.py @@ -14,88 +14,56 @@ import abc from collections.abc import Iterator -from enum import Enum from typing import Optional, Union from kubeflow.trainer.constants import constants from kubeflow.trainer.types import types -class OptionType(Enum): - """Enumeration of available option types for backends.""" +# Backend compatibility mixins +class KubernetesCompatible: + """Mixin for options compatible with Kubernetes backend.""" - LABELS = "labels" - ANNOTATIONS = "annotations" - POD_SPEC_OVERRIDES = "pod_spec_overrides" - NAME = "name" - TRAINER_IMAGE = "trainer_image" - TRAINER_COMMAND = "trainer_command" - TRAINER_ARGS = "trainer_args" + pass -class BackendCapabilities: - """Backend capabilities for validating option compatibility.""" +class LocalProcessCompatible: + """Mixin for options compatible with LocalProcess backend.""" - def __init__(self, supported_options: set[OptionType]): - self.supported_options = supported_options - self._supported_types = frozenset(opt.value for opt in supported_options) + pass - def supports(self, option_type: OptionType) -> bool: - """Check if option type is supported.""" - return option_type.value in self._supported_types - def supports_option(self, option) -> bool: - """Check if option is supported.""" - return hasattr(option, "option_type") and self.supports(option.option_type) +class UniversalCompatible(KubernetesCompatible, LocalProcessCompatible): + """Mixin for options compatible with all backends.""" - def check_compatibility(self, options: list) -> tuple[bool, list[str]]: - """Check compatibility of options with this backend.""" - if not options: - return True, [] - - for option in options: - if not self.supports_option(option): - unsupported = [ - f"{opt.__class__.__name__} ({opt.option_type.value})" - for opt in options - if not self.supports_option(opt) - ] - return False, unsupported - - return True, [] + pass -# Backend capability definitions -KUBERNETES_CAPABILITIES = BackendCapabilities( - { - OptionType.LABELS, - OptionType.ANNOTATIONS, - OptionType.POD_SPEC_OVERRIDES, - OptionType.NAME, - OptionType.TRAINER_IMAGE, - OptionType.TRAINER_COMMAND, - OptionType.TRAINER_ARGS, - } -) - -LOCAL_PROCESS_CAPABILITIES = BackendCapabilities(set()) +# Type alias for all compatible options +CompatibleOption = Union[KubernetesCompatible, LocalProcessCompatible] class ExecutionBackend(abc.ABC): @property @abc.abstractmethod - def capabilities(self) -> BackendCapabilities: - """Return the capabilities of this backend.""" + def compatibility_mixin(self) -> type: + """Return the compatibility mixin class for this backend.""" pass - def validate_options(self, options: Optional[list] = None) -> None: + def validate_options(self, options: Optional[list[CompatibleOption]] = None) -> None: + """Validate that all options are compatible with this backend.""" if not options: return - is_compatible, incompatible_options = self.capabilities.check_compatibility(options) - if not is_compatible: + + incompatible = [ + f"{opt.__class__.__name__}" + for opt in options + if not isinstance(opt, self.compatibility_mixin) + ] + + if incompatible: raise ValueError( - f"The following options are not compatible with this backend: " - f"{incompatible_options}" + f"The following options are not compatible with this backend: {incompatible}" ) @abc.abstractmethod @@ -116,7 +84,7 @@ def train( runtime: Optional[types.Runtime] = None, initializer: Optional[types.Initializer] = None, trainer: Optional[Union[types.CustomTrainer, types.BuiltinTrainer]] = None, - options: Optional[list] = None, + options: Optional[list[CompatibleOption]] = None, ) -> str: raise NotImplementedError() diff --git a/kubeflow/trainer/backends/kubernetes/backend.py b/kubeflow/trainer/backends/kubernetes/backend.py index 2f24a72c8..c256d363a 100644 --- a/kubeflow/trainer/backends/kubernetes/backend.py +++ b/kubeflow/trainer/backends/kubernetes/backend.py @@ -26,7 +26,7 @@ from kubeflow_trainer_api import models from kubernetes import client, config, watch -from kubeflow.trainer.backends.base import KUBERNETES_CAPABILITIES, ExecutionBackend +from kubeflow.trainer.backends.base import CompatibleOption, ExecutionBackend, KubernetesCompatible from kubeflow.trainer.backends.kubernetes import types as k8s_types from kubeflow.trainer.constants import constants from kubeflow.trainer.types import types @@ -58,9 +58,9 @@ def __init__( self.namespace = cfg.namespace @property - def capabilities(self): - """Return the capabilities of this backend.""" - return KUBERNETES_CAPABILITIES + def compatibility_mixin(self) -> type: + """Return the compatibility mixin class for this backend.""" + return KubernetesCompatible def list_runtimes(self) -> list[types.Runtime]: result = [] @@ -186,7 +186,7 @@ def train( runtime: Optional[types.Runtime] = None, initializer: Optional[types.Initializer] = None, trainer: Optional[Union[types.CustomTrainer, types.BuiltinTrainer]] = None, - options: Optional[list] = None, + options: Optional[list[CompatibleOption]] = None, ) -> str: self.validate_options(options) if runtime is None: @@ -201,7 +201,7 @@ def train( if options: for option in options: - option(job_spec) + option(job_spec, trainer) metadata_section = job_spec.get("metadata", {}) labels = metadata_section.get("labels") diff --git a/kubeflow/trainer/backends/kubernetes/backend_test.py b/kubeflow/trainer/backends/kubernetes/backend_test.py index 522549c29..4a97bc290 100644 --- a/kubeflow/trainer/backends/kubernetes/backend_test.py +++ b/kubeflow/trainer/backends/kubernetes/backend_test.py @@ -857,10 +857,11 @@ def test_backend_validation(kubernetes_backend, test_case): def test_backend_capabilities(kubernetes_backend): - """Test KubernetesBackend capabilities property.""" - from kubeflow.trainer.backends.base import KUBERNETES_CAPABILITIES + """Test KubernetesBackend compatibility mixin.""" + from kubeflow.trainer.backends.base import KubernetesCompatible - assert kubernetes_backend.capabilities == KUBERNETES_CAPABILITIES + mixin = kubernetes_backend.compatibility_mixin + assert mixin is KubernetesCompatible # Test cases for new trainer container options diff --git a/kubeflow/trainer/backends/kubernetes/options.py b/kubeflow/trainer/backends/kubernetes/options.py index c43b04fac..22ef54ed7 100644 --- a/kubeflow/trainer/backends/kubernetes/options.py +++ b/kubeflow/trainer/backends/kubernetes/options.py @@ -17,52 +17,40 @@ from dataclasses import dataclass from typing import Optional -from kubeflow.trainer.backends.base import OptionType +from kubeflow.trainer.backends.base import KubernetesCompatible @dataclass -class WithLabels: +class WithLabels(KubernetesCompatible): """Add labels to the TrainJob resource metadata (.metadata.labels).""" labels: dict[str, str] - @property - def option_type(self) -> OptionType: - return OptionType.LABELS - - def __call__(self, job_spec: dict) -> None: + def __call__(self, job_spec: dict, trainer=None) -> None: """Apply labels to the job specification.""" metadata = job_spec.setdefault("metadata", {}) metadata["labels"] = self.labels @dataclass -class WithAnnotations: +class WithAnnotations(KubernetesCompatible): """Add annotations to the TrainJob resource metadata (.metadata.annotations).""" annotations: dict[str, str] - @property - def option_type(self) -> OptionType: - return OptionType.ANNOTATIONS - - def __call__(self, job_spec: dict) -> None: + def __call__(self, job_spec: dict, trainer=None) -> None: """Apply annotations to the job specification.""" metadata = job_spec.setdefault("metadata", {}) metadata["annotations"] = self.annotations @dataclass -class WithName: +class WithName(KubernetesCompatible): """Set a custom name for the TrainJob resource (.metadata.name).""" name: str - @property - def option_type(self) -> OptionType: - return OptionType.NAME - - def __call__(self, job_spec: dict) -> None: + def __call__(self, job_spec: dict, trainer=None) -> None: """Apply custom name to the job specification.""" metadata = job_spec.setdefault("metadata", {}) metadata["name"] = self.name @@ -82,16 +70,12 @@ class PodSpecOverride: @dataclass -class WithPodSpecOverrides: +class WithPodSpecOverrides(KubernetesCompatible): """Add pod specification overrides to the TrainJob (.spec.podSpecOverrides).""" overrides: list[PodSpecOverride] - @property - def option_type(self) -> OptionType: - return OptionType.POD_SPEC_OVERRIDES - - def __call__(self, job_spec: dict) -> None: + def __call__(self, job_spec: dict, trainer=None) -> None: """Apply pod spec overrides to the job specification.""" spec = job_spec.setdefault("spec", {}) spec["podSpecOverrides"] = [] @@ -116,51 +100,76 @@ def __call__(self, job_spec: dict) -> None: @dataclass -class WithTrainerImage: +class WithTrainerImage(KubernetesCompatible): """Override the trainer container image (.spec.trainer.image).""" image: str - @property - def option_type(self) -> OptionType: - return OptionType.TRAINER_IMAGE + def __call__(self, job_spec: dict, trainer=None) -> None: + """Apply trainer image override to the job specification. - def __call__(self, job_spec: dict) -> None: - """Apply trainer image override to the job specification.""" + Args: + job_spec: The job specification to modify + trainer: Optional trainer context for validation + """ spec = job_spec.setdefault("spec", {}) trainer_spec = spec.setdefault("trainer", {}) trainer_spec["image"] = self.image @dataclass -class WithTrainerCommand: +class WithTrainerCommand(KubernetesCompatible): """Override the trainer container command (.spec.trainer.command).""" command: list[str] - @property - def option_type(self) -> OptionType: - return OptionType.TRAINER_COMMAND + def __call__(self, job_spec: dict, trainer=None) -> None: + """Apply trainer command override to the job specification. + + Args: + job_spec: The job specification to modify + trainer: Optional trainer context for validation + + Raises: + ValueError: If there's a conflict with the trainer configuration + """ + # Validate conflicts with trainer + if trainer and hasattr(trainer, "func") and trainer.func is not None: + raise ValueError( + "Cannot specify WithTrainerCommand when CustomTrainer.func is provided. " + "The func generates its own command. Use container-only training " + "(CustomTrainer without func) or remove WithTrainerCommand." + ) - def __call__(self, job_spec: dict) -> None: - """Apply trainer command override to the job specification.""" spec = job_spec.setdefault("spec", {}) trainer_spec = spec.setdefault("trainer", {}) trainer_spec["command"] = self.command @dataclass -class WithTrainerArgs: +class WithTrainerArgs(KubernetesCompatible): """Override the trainer container arguments (.spec.trainer.args).""" args: list[str] - @property - def option_type(self) -> OptionType: - return OptionType.TRAINER_ARGS + def __call__(self, job_spec: dict, trainer=None) -> None: + """Apply trainer args override to the job specification. + + Args: + job_spec: The job specification to modify + trainer: Optional trainer context for validation + + Raises: + ValueError: If there's a conflict with the trainer configuration + """ + # Validate conflicts with trainer + if trainer and hasattr(trainer, "func") and trainer.func is not None: + raise ValueError( + "Cannot specify WithTrainerArgs when CustomTrainer.func is provided. " + "The func generates its own arguments. Use container-only training " + "(CustomTrainer without func) or remove WithTrainerArgs." + ) - def __call__(self, job_spec: dict) -> None: - """Apply trainer args override to the job specification.""" spec = job_spec.setdefault("spec", {}) trainer_spec = spec.setdefault("trainer", {}) trainer_spec["args"] = self.args diff --git a/kubeflow/trainer/backends/kubernetes/options_test.py b/kubeflow/trainer/backends/kubernetes/options_test.py index be96f9635..8eb2fc826 100644 --- a/kubeflow/trainer/backends/kubernetes/options_test.py +++ b/kubeflow/trainer/backends/kubernetes/options_test.py @@ -14,7 +14,7 @@ """Unit tests for Kubernetes backend options.""" -from kubeflow.trainer.backends.base import KUBERNETES_CAPABILITIES, OptionType +from kubeflow.trainer.backends.base import KubernetesCompatible from kubeflow.trainer.backends.kubernetes.options import ( PodSpecOverride, WithAnnotations, @@ -31,35 +31,34 @@ class TestKubernetesOptionTypes: """Test Kubernetes option types.""" def test_with_labels_option_type(self): - """Test WithLabels has correct option type.""" + """Test WithLabels inherits from KubernetesCompatible.""" option = WithLabels({"app": "test", "version": "v1"}) - assert option.option_type == OptionType.LABELS + assert isinstance(option, KubernetesCompatible) def test_with_annotations_option_type(self): - """Test WithAnnotations has correct option type.""" + """Test WithAnnotations inherits from KubernetesCompatible.""" option = WithAnnotations({"description": "test job"}) - assert option.option_type == OptionType.ANNOTATIONS + assert isinstance(option, KubernetesCompatible) def test_with_name_option_type(self): - """Test WithName has correct option type.""" + """Test WithName inherits from KubernetesCompatible.""" option = WithName("test-job") - assert option.option_type == OptionType.NAME + assert isinstance(option, KubernetesCompatible) def test_with_pod_spec_overrides_option_type(self): - """Test WithPodSpecOverrides has correct option type.""" + """Test WithPodSpecOverrides inherits from KubernetesCompatible.""" overrides = [PodSpecOverride(target_jobs=["node"])] option = WithPodSpecOverrides(overrides) - assert option.option_type == OptionType.POD_SPEC_OVERRIDES + assert isinstance(option, KubernetesCompatible) def test_trainer_options_types(self): - """Test trainer options have correct option types.""" + """Test trainer options inherit from KubernetesCompatible.""" image_option = WithTrainerImage("custom:latest") command_option = WithTrainerCommand(["python", "train.py"]) args_option = WithTrainerArgs(["--epochs", "10"]) - - assert image_option.option_type == OptionType.TRAINER_IMAGE - assert command_option.option_type == OptionType.TRAINER_COMMAND - assert args_option.option_type == OptionType.TRAINER_ARGS + assert isinstance(image_option, KubernetesCompatible) + assert isinstance(command_option, KubernetesCompatible) + assert isinstance(args_option, KubernetesCompatible) class TestKubernetesOptionApplication: @@ -125,6 +124,51 @@ def test_trainer_args_application(self): expected = {"spec": {"trainer": {"args": ["--epochs", "10", "--lr", "0.001"]}}} assert job_spec == expected + def test_trainer_command_validation_with_func(self): + """Test WithTrainerCommand validates conflicts with CustomTrainer.func.""" + from kubeflow.trainer.types.types import CustomTrainer + + option = WithTrainerCommand(["python", "custom_train.py"]) + trainer_with_func = CustomTrainer(func=lambda: print("training")) + + job_spec = {} + + # Should raise ValueError when trainer has func + try: + option(job_spec, trainer_with_func) + raise AssertionError("Expected ValueError for func conflict") + except ValueError as e: + assert "Cannot specify WithTrainerCommand when CustomTrainer.func is provided" in str(e) + + def test_trainer_command_validation_without_func(self): + """Test WithTrainerCommand works with container-only training.""" + from kubeflow.trainer.types.types import CustomTrainer + + option = WithTrainerCommand(["python", "custom_train.py"]) + trainer_without_func = CustomTrainer(func=None) # Container-only training + + job_spec = {} + option(job_spec, trainer_without_func) # Should not raise + + expected = {"spec": {"trainer": {"command": ["python", "custom_train.py"]}}} + assert job_spec == expected + + def test_trainer_args_validation_with_func(self): + """Test WithTrainerArgs validates conflicts with CustomTrainer.func.""" + from kubeflow.trainer.types.types import CustomTrainer + + option = WithTrainerArgs(["--epochs", "10"]) + trainer_with_func = CustomTrainer(func=lambda: print("training")) + + job_spec = {} + + # Should raise ValueError when trainer has func + try: + option(job_spec, trainer_with_func) + raise AssertionError("Expected ValueError for func conflict") + except ValueError as e: + assert "Cannot specify WithTrainerArgs when CustomTrainer.func is provided" in str(e) + def test_multiple_options_override_behavior(self): """Test multiple options with override semantics.""" job_spec = {} @@ -149,23 +193,26 @@ class TestKubernetesCapabilities: """Test Kubernetes backend capabilities.""" def test_kubernetes_capabilities_support(self): - """Test Kubernetes backend capabilities.""" - caps = KUBERNETES_CAPABILITIES - - # Test supported options - assert caps.supports(OptionType.LABELS) - assert caps.supports(OptionType.ANNOTATIONS) - assert caps.supports(OptionType.NAME) - assert caps.supports(OptionType.POD_SPEC_OVERRIDES) - assert caps.supports(OptionType.TRAINER_IMAGE) - assert caps.supports(OptionType.TRAINER_COMMAND) - assert caps.supports(OptionType.TRAINER_ARGS) + """Test Kubernetes backend mixin compatibility.""" + # Test that all Kubernetes options inherit from KubernetesCompatible + labels_option = WithLabels({"app": "test"}) + annotations_option = WithAnnotations({"desc": "test"}) + name_option = WithName("test-job") + pod_spec_option = WithPodSpecOverrides([PodSpecOverride(target_jobs=["node"])]) + image_option = WithTrainerImage("custom:latest") + command_option = WithTrainerCommand(["python", "train.py"]) + args_option = WithTrainerArgs(["--epochs", "10"]) + assert isinstance(labels_option, KubernetesCompatible) + assert isinstance(annotations_option, KubernetesCompatible) + assert isinstance(name_option, KubernetesCompatible) + assert isinstance(pod_spec_option, KubernetesCompatible) + assert isinstance(image_option, KubernetesCompatible) + assert isinstance(command_option, KubernetesCompatible) + assert isinstance(args_option, KubernetesCompatible) def test_kubernetes_option_compatibility(self): - """Test Kubernetes option compatibility checking.""" - caps = KUBERNETES_CAPABILITIES - - # Test all Kubernetes options are supported + """Test Kubernetes option compatibility checking with mixin approach.""" + # Test that we can identify Kubernetes-compatible options labels_opt = WithLabels({"app": "test"}) annotations_opt = WithAnnotations({"desc": "test"}) name_opt = WithName("test-job") @@ -173,14 +220,14 @@ def test_kubernetes_option_compatibility(self): image_opt = WithTrainerImage("custom:latest") command_opt = WithTrainerCommand(["python", "train.py"]) args_opt = WithTrainerArgs(["--epochs", "10"]) - - assert caps.supports_option(labels_opt) - assert caps.supports_option(annotations_opt) - assert caps.supports_option(name_opt) - assert caps.supports_option(pod_opt) - assert caps.supports_option(image_opt) - assert caps.supports_option(command_opt) - assert caps.supports_option(args_opt) + # All options should be KubernetesCompatible + assert isinstance(labels_opt, KubernetesCompatible) + assert isinstance(annotations_opt, KubernetesCompatible) + assert isinstance(name_opt, KubernetesCompatible) + assert isinstance(pod_opt, KubernetesCompatible) + assert isinstance(image_opt, KubernetesCompatible) + assert isinstance(command_opt, KubernetesCompatible) + assert isinstance(args_opt, KubernetesCompatible) class TestPodSpecOverride: diff --git a/kubeflow/trainer/backends/localprocess/backend.py b/kubeflow/trainer/backends/localprocess/backend.py index aac33159a..186f18c56 100644 --- a/kubeflow/trainer/backends/localprocess/backend.py +++ b/kubeflow/trainer/backends/localprocess/backend.py @@ -20,7 +20,11 @@ from typing import Optional, Union import uuid -from kubeflow.trainer.backends.base import LOCAL_PROCESS_CAPABILITIES, ExecutionBackend +from kubeflow.trainer.backends.base import ( + CompatibleOption, + ExecutionBackend, + LocalProcessCompatible, +) from kubeflow.trainer.backends.localprocess import utils as local_utils from kubeflow.trainer.backends.localprocess.constants import local_runtimes from kubeflow.trainer.backends.localprocess.job import LocalJob @@ -45,9 +49,9 @@ def __init__( self.cfg = cfg @property - def capabilities(self): - """Return the capabilities of this backend.""" - return LOCAL_PROCESS_CAPABILITIES + def compatibility_mixin(self) -> type: + """Return the compatibility mixin class for this backend.""" + return LocalProcessCompatible def list_runtimes(self) -> list[types.Runtime]: return [self.__convert_local_runtime_to_runtime(local_runtime=rt) for rt in local_runtimes] @@ -78,7 +82,7 @@ def train( runtime: Optional[types.Runtime] = None, initializer: Optional[types.Initializer] = None, trainer: Optional[Union[types.CustomTrainer, types.BuiltinTrainer]] = None, - options: Optional[list] = None, + options: Optional[list[CompatibleOption]] = None, ) -> str: self.validate_options(options) diff --git a/kubeflow/trainer/backends/localprocess/backend_test.py b/kubeflow/trainer/backends/localprocess/backend_test.py index 89f6bacbe..e3a6d587a 100644 --- a/kubeflow/trainer/backends/localprocess/backend_test.py +++ b/kubeflow/trainer/backends/localprocess/backend_test.py @@ -82,21 +82,22 @@ def test_local_backend_validation(local_backend, test_case): if test_case.name == "reject_incompatible_labels": error_msg = str(e) assert "The following options are not compatible with this backend" in error_msg - assert "WithLabels (labels)" in error_msg + assert "WithLabels" in error_msg elif test_case.name == "reject_multiple_incompatible_options": error_msg = str(e) assert "The following options are not compatible with this backend" in error_msg - assert "WithLabels (labels)" in error_msg - assert "WithAnnotations (annotations)" in error_msg - assert "WithPodSpecOverrides (pod_spec_overrides)" in error_msg + assert "WithLabels" in error_msg + assert "WithAnnotations" in error_msg + assert "WithPodSpecOverrides" in error_msg print("test execution complete") def test_local_backend_capabilities(local_backend): - """Test LocalProcessBackend capabilities property.""" - from kubeflow.trainer.backends.base import LOCAL_PROCESS_CAPABILITIES + """Test LocalProcessBackend compatibility mixin.""" + from kubeflow.trainer.backends.base import LocalProcessCompatible - assert local_backend.capabilities == LOCAL_PROCESS_CAPABILITIES + mixin = local_backend.compatibility_mixin + assert mixin is LocalProcessCompatible class TestLocalBackendValidationFlow: @@ -122,5 +123,5 @@ def test_validation_early_exit_behavior(self, local_backend): local_backend.validate_options(mixed_options) error_msg = str(exc_info.value) - assert "WithLabels (labels)" in error_msg - assert "WithAnnotations (annotations)" in error_msg + assert "WithLabels" in error_msg + assert "WithAnnotations" in error_msg diff --git a/kubeflow/trainer/backends/localprocess/options.py b/kubeflow/trainer/backends/localprocess/options.py index 2a8da3200..9d022053c 100644 --- a/kubeflow/trainer/backends/localprocess/options.py +++ b/kubeflow/trainer/backends/localprocess/options.py @@ -16,34 +16,26 @@ from dataclasses import dataclass -from kubeflow.trainer.backends.base import OptionType +from kubeflow.trainer.backends.base import LocalProcessCompatible @dataclass -class WithProcessTimeout: +class WithProcessTimeout(LocalProcessCompatible): """Set a timeout for the local training process.""" timeout_seconds: int - @property - def option_type(self) -> OptionType: - return OptionType.NAME # Placeholder - would need new enum value - def __call__(self, config: dict) -> None: """Apply timeout to local process configuration.""" config["timeout_seconds"] = self.timeout_seconds @dataclass -class WithWorkingDirectory: +class WithWorkingDirectory(LocalProcessCompatible): """Set the working directory for the local training process.""" working_dir: str - @property - def option_type(self) -> OptionType: - return OptionType.NAME # Placeholder - would need new enum value - def __call__(self, config: dict) -> None: """Apply working directory to local process configuration.""" config["working_dir"] = self.working_dir diff --git a/kubeflow/trainer/backends/localprocess/options_test.py b/kubeflow/trainer/backends/localprocess/options_test.py index e7d725517..e4f2b55f1 100644 --- a/kubeflow/trainer/backends/localprocess/options_test.py +++ b/kubeflow/trainer/backends/localprocess/options_test.py @@ -14,7 +14,7 @@ """Unit tests for LocalProcess backend options.""" -from kubeflow.trainer.backends.base import LOCAL_PROCESS_CAPABILITIES, OptionType +from kubeflow.trainer.backends.base import LocalProcessCompatible from kubeflow.trainer.backends.localprocess.options import WithProcessTimeout, WithWorkingDirectory @@ -22,14 +22,14 @@ class TestLocalProcessOptionTypes: """Test LocalProcess option types.""" def test_with_process_timeout_type(self): - """Test WithProcessTimeout has correct option type.""" + """Test WithProcessTimeout inherits from LocalProcessCompatible.""" option = WithProcessTimeout(timeout_seconds=300) - assert option.option_type == OptionType.NAME # Placeholder + assert isinstance(option, LocalProcessCompatible) def test_with_working_directory_type(self): - """Test WithWorkingDirectory has correct option type.""" + """Test WithWorkingDirectory inherits from LocalProcessCompatible.""" option = WithWorkingDirectory(working_dir="/tmp/training") - assert option.option_type == OptionType.NAME # Placeholder + assert isinstance(option, LocalProcessCompatible) class TestLocalProcessOptionApplication: @@ -70,24 +70,23 @@ class TestLocalProcessCapabilities: """Test LocalProcess backend capabilities.""" def test_local_process_capabilities_minimal(self): - """Test LocalProcess backend has minimal capabilities.""" - caps = LOCAL_PROCESS_CAPABILITIES + """Test LocalProcess backend mixin compatibility.""" + # Test that LocalProcess options inherit from LocalProcessCompatible + timeout_option = WithProcessTimeout(timeout_seconds=300) + working_dir_option = WithWorkingDirectory(working_dir="/tmp/training") - # LocalProcess backend currently supports no standard options - assert not caps.supports(OptionType.LABELS) - assert not caps.supports(OptionType.ANNOTATIONS) - assert not caps.supports(OptionType.POD_SPEC_OVERRIDES) - assert not caps.supports(OptionType.TRAINER_IMAGE) - assert not caps.supports(OptionType.TRAINER_COMMAND) - assert not caps.supports(OptionType.TRAINER_ARGS) + assert isinstance(timeout_option, LocalProcessCompatible) + assert isinstance(working_dir_option, LocalProcessCompatible) def test_local_process_empty_options_compatibility(self): """Test LocalProcess compatibility with empty options.""" - caps = LOCAL_PROCESS_CAPABILITIES + # Empty options list should always be compatible + timeout_option = WithProcessTimeout(timeout_seconds=300) + working_dir_option = WithWorkingDirectory(working_dir="/tmp/training") - is_compatible, unsupported = caps.check_compatibility([]) - assert is_compatible - assert unsupported == [] + # Both should be LocalProcessCompatible + assert isinstance(timeout_option, LocalProcessCompatible) + assert isinstance(working_dir_option, LocalProcessCompatible) class TestLocalProcessOptionCreation: From b31bd37f555fc97b6cab2887b1cc3ebfabdc70af Mon Sep 17 00:00:00 2001 From: Abhijeet Dhumal Date: Sat, 11 Oct 2025 15:12:51 +0530 Subject: [PATCH 3/7] Replace compatibility_mixin property with abstract validate_options and Option protocol Signed-off-by: Abhijeet Dhumal --- kubeflow/trainer/api/trainer_client.py | 4 +- kubeflow/trainer/backends/__init__.py | 17 ++++++++ kubeflow/trainer/backends/base.py | 43 +++++++++++++------ .../trainer/backends/kubernetes/backend.py | 11 +++-- .../backends/kubernetes/backend_test.py | 20 ++++++--- .../trainer/backends/localprocess/backend.py | 11 +++-- .../backends/localprocess/backend_test.py | 20 ++++++--- 7 files changed, 89 insertions(+), 37 deletions(-) diff --git a/kubeflow/trainer/api/trainer_client.py b/kubeflow/trainer/api/trainer_client.py index 1845dfe2e..d8cb8ba1b 100644 --- a/kubeflow/trainer/api/trainer_client.py +++ b/kubeflow/trainer/api/trainer_client.py @@ -16,7 +16,7 @@ import logging from typing import Optional, Union -from kubeflow.trainer.backends.base import CompatibleOption +from kubeflow.trainer.backends.base import Option from kubeflow.trainer.backends.kubernetes.backend import KubernetesBackend from kubeflow.trainer.backends.kubernetes.types import KubernetesBackendConfig from kubeflow.trainer.backends.localprocess.backend import ( @@ -97,7 +97,7 @@ def train( runtime: Optional[types.Runtime] = None, initializer: Optional[types.Initializer] = None, trainer: Optional[Union[types.CustomTrainer, types.BuiltinTrainer]] = None, - options: Optional[list[CompatibleOption]] = None, + options: Optional[list[Option]] = None, ) -> str: """Create a TrainJob. You can configure the TrainJob using one of these trainers: diff --git a/kubeflow/trainer/backends/__init__.py b/kubeflow/trainer/backends/__init__.py index e69de29bb..241cd02bb 100644 --- a/kubeflow/trainer/backends/__init__.py +++ b/kubeflow/trainer/backends/__init__.py @@ -0,0 +1,17 @@ +from kubeflow.trainer.backends.base import ( + ExecutionBackend, + KubernetesCompatible, + LocalProcessCompatible, + UniversalCompatible, + Option, + CompatibleOption, +) + +__all__ = [ + "ExecutionBackend", + "KubernetesCompatible", + "LocalProcessCompatible", + "UniversalCompatible", + "Option", + "CompatibleOption", +] diff --git a/kubeflow/trainer/backends/base.py b/kubeflow/trainer/backends/base.py index de7c2a3eb..ea91e1c70 100644 --- a/kubeflow/trainer/backends/base.py +++ b/kubeflow/trainer/backends/base.py @@ -14,7 +14,7 @@ import abc from collections.abc import Iterator -from typing import Optional, Union +from typing import Optional, Union, Protocol from kubeflow.trainer.constants import constants from kubeflow.trainer.types import types @@ -39,26 +39,38 @@ class UniversalCompatible(KubernetesCompatible, LocalProcessCompatible): pass -# Type alias for all compatible options +class Option(Protocol): + """Protocol defining the contract for training options.""" + + def __call__( + self, + job_spec: dict, + trainer: Optional[Union["types.BuiltinTrainer", "types.CustomTrainer"]] = None + ) -> None: + """Apply the option to the job specification. + + Args: + job_spec: The job specification dictionary to modify + trainer: Optional trainer context for validation + """ + ... + + CompatibleOption = Union[KubernetesCompatible, LocalProcessCompatible] class ExecutionBackend(abc.ABC): - @property - @abc.abstractmethod - def compatibility_mixin(self) -> type: - """Return the compatibility mixin class for this backend.""" - pass - - def validate_options(self, options: Optional[list[CompatibleOption]] = None) -> None: - """Validate that all options are compatible with this backend.""" - if not options: + def _validate_options_with_mixin( + self, options: Optional[list[Option]] = None, mixin_class: Optional[type] = None + ) -> None: + """Helper method to validate options with a specific mixin class.""" + if not options or not mixin_class: return incompatible = [ f"{opt.__class__.__name__}" for opt in options - if not isinstance(opt, self.compatibility_mixin) + if not isinstance(opt, mixin_class) ] if incompatible: @@ -66,6 +78,11 @@ def validate_options(self, options: Optional[list[CompatibleOption]] = None) -> f"The following options are not compatible with this backend: {incompatible}" ) + @abc.abstractmethod + def validate_options(self, options: Optional[list[Option]] = None) -> None: + """Validate that all options are compatible with this backend.""" + pass + @abc.abstractmethod def list_runtimes(self) -> list[types.Runtime]: raise NotImplementedError() @@ -84,7 +101,7 @@ def train( runtime: Optional[types.Runtime] = None, initializer: Optional[types.Initializer] = None, trainer: Optional[Union[types.CustomTrainer, types.BuiltinTrainer]] = None, - options: Optional[list[CompatibleOption]] = None, + options: Optional[list[Option]] = None, ) -> str: raise NotImplementedError() diff --git a/kubeflow/trainer/backends/kubernetes/backend.py b/kubeflow/trainer/backends/kubernetes/backend.py index c256d363a..b5d7a1e57 100644 --- a/kubeflow/trainer/backends/kubernetes/backend.py +++ b/kubeflow/trainer/backends/kubernetes/backend.py @@ -26,7 +26,7 @@ from kubeflow_trainer_api import models from kubernetes import client, config, watch -from kubeflow.trainer.backends.base import CompatibleOption, ExecutionBackend, KubernetesCompatible +from kubeflow.trainer.backends.base import ExecutionBackend, KubernetesCompatible, Option from kubeflow.trainer.backends.kubernetes import types as k8s_types from kubeflow.trainer.constants import constants from kubeflow.trainer.types import types @@ -57,10 +57,9 @@ def __init__( self.namespace = cfg.namespace - @property - def compatibility_mixin(self) -> type: - """Return the compatibility mixin class for this backend.""" - return KubernetesCompatible + def validate_options(self, options: Optional[list[Option]] = None) -> None: + """Validate that all options are compatible with this backend.""" + super()._validate_options_with_mixin(options, KubernetesCompatible) def list_runtimes(self) -> list[types.Runtime]: result = [] @@ -186,7 +185,7 @@ def train( runtime: Optional[types.Runtime] = None, initializer: Optional[types.Initializer] = None, trainer: Optional[Union[types.CustomTrainer, types.BuiltinTrainer]] = None, - options: Optional[list[CompatibleOption]] = None, + options: Optional[list[Option]] = None, ) -> str: self.validate_options(options) if runtime is None: diff --git a/kubeflow/trainer/backends/kubernetes/backend_test.py b/kubeflow/trainer/backends/kubernetes/backend_test.py index 4a97bc290..fee0616c5 100644 --- a/kubeflow/trainer/backends/kubernetes/backend_test.py +++ b/kubeflow/trainer/backends/kubernetes/backend_test.py @@ -857,11 +857,21 @@ def test_backend_validation(kubernetes_backend, test_case): def test_backend_capabilities(kubernetes_backend): - """Test KubernetesBackend compatibility mixin.""" - from kubeflow.trainer.backends.base import KubernetesCompatible - - mixin = kubernetes_backend.compatibility_mixin - assert mixin is KubernetesCompatible + """Test KubernetesBackend validation with correct mixin.""" + from kubeflow.trainer.backends.kubernetes.options import WithLabels + from kubeflow.trainer.backends.localprocess.options import WithWorkingDirectory + + # Test that KubernetesCompatible options pass validation + kubernetes_options = [WithLabels({"app": "test"})] + kubernetes_backend.validate_options(kubernetes_options) # Should not raise + + # Test that LocalProcessCompatible options fail validation + local_options = [WithWorkingDirectory("/tmp")] + try: + kubernetes_backend.validate_options(local_options) + assert False, "Expected ValueError for incompatible options" + except ValueError as e: + assert "WithWorkingDirectory" in str(e) # Test cases for new trainer container options diff --git a/kubeflow/trainer/backends/localprocess/backend.py b/kubeflow/trainer/backends/localprocess/backend.py index 186f18c56..e65920565 100644 --- a/kubeflow/trainer/backends/localprocess/backend.py +++ b/kubeflow/trainer/backends/localprocess/backend.py @@ -21,9 +21,9 @@ import uuid from kubeflow.trainer.backends.base import ( - CompatibleOption, ExecutionBackend, LocalProcessCompatible, + Option, ) from kubeflow.trainer.backends.localprocess import utils as local_utils from kubeflow.trainer.backends.localprocess.constants import local_runtimes @@ -48,10 +48,9 @@ def __init__( self.__local_jobs: list[LocalBackendJobs] = [] self.cfg = cfg - @property - def compatibility_mixin(self) -> type: - """Return the compatibility mixin class for this backend.""" - return LocalProcessCompatible + def validate_options(self, options: Optional[list[Option]] = None) -> None: + """Validate that all options are compatible with this backend.""" + super()._validate_options_with_mixin(options, LocalProcessCompatible) def list_runtimes(self) -> list[types.Runtime]: return [self.__convert_local_runtime_to_runtime(local_runtime=rt) for rt in local_runtimes] @@ -82,7 +81,7 @@ def train( runtime: Optional[types.Runtime] = None, initializer: Optional[types.Initializer] = None, trainer: Optional[Union[types.CustomTrainer, types.BuiltinTrainer]] = None, - options: Optional[list[CompatibleOption]] = None, + options: Optional[list[Option]] = None, ) -> str: self.validate_options(options) diff --git a/kubeflow/trainer/backends/localprocess/backend_test.py b/kubeflow/trainer/backends/localprocess/backend_test.py index e3a6d587a..1e4ce3e72 100644 --- a/kubeflow/trainer/backends/localprocess/backend_test.py +++ b/kubeflow/trainer/backends/localprocess/backend_test.py @@ -93,11 +93,21 @@ def test_local_backend_validation(local_backend, test_case): def test_local_backend_capabilities(local_backend): - """Test LocalProcessBackend compatibility mixin.""" - from kubeflow.trainer.backends.base import LocalProcessCompatible - - mixin = local_backend.compatibility_mixin - assert mixin is LocalProcessCompatible + """Test LocalProcessBackend validation with correct mixin.""" + from kubeflow.trainer.backends.localprocess.options import WithWorkingDirectory + from kubeflow.trainer.backends.kubernetes.options import WithLabels + + # Test that LocalProcessCompatible options pass validation + local_options = [WithWorkingDirectory("/tmp")] + local_backend.validate_options(local_options) # Should not raise + + # Test that KubernetesCompatible options fail validation + kubernetes_options = [WithLabels({"app": "test"})] + try: + local_backend.validate_options(kubernetes_options) + assert False, "Expected ValueError for incompatible options" + except ValueError as e: + assert "WithLabels" in str(e) class TestLocalBackendValidationFlow: From 15a7644e0fb1d4598dc4795f5f30dbfa949f8555 Mon Sep 17 00:00:00 2001 From: Abhijeet Dhumal Date: Thu, 23 Oct 2025 15:33:53 +0530 Subject: [PATCH 4/7] fix: Forward podSpecOverrides to TrainJob API in KubernetesBackend.train() Signed-off-by: Abhijeet Dhumal --- .../trainer/backends/kubernetes/backend.py | 5 ++- .../backends/kubernetes/backend_test.py | 34 +++++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/kubeflow/trainer/backends/kubernetes/backend.py b/kubeflow/trainer/backends/kubernetes/backend.py index b5d7a1e57..08e58f469 100644 --- a/kubeflow/trainer/backends/kubernetes/backend.py +++ b/kubeflow/trainer/backends/kubernetes/backend.py @@ -197,6 +197,7 @@ def train( annotations = None name = None trainer_overrides = {} + pod_spec_overrides = None if options: for option in options: @@ -207,11 +208,12 @@ def train( annotations = metadata_section.get("annotations") name = metadata_section.get("name") - # Extract trainer-specific overrides + # Extract trainer-specific overrides and pod spec overrides spec_section = job_spec.get("spec", {}) trainer_spec = spec_section.get("trainer", {}) if trainer_spec: trainer_overrides = trainer_spec + pod_spec_overrides = spec_section.get("podSpecOverrides") # Generate unique name for the TrainJob if not provided train_job_name = name or (random.choice(string.ascii_lowercase) + uuid.uuid4().hex[:11]) @@ -265,6 +267,7 @@ def train( if isinstance(initializer, types.Initializer) else None ), + pod_spec_overrides=pod_spec_overrides, ), ) diff --git a/kubeflow/trainer/backends/kubernetes/backend_test.py b/kubeflow/trainer/backends/kubernetes/backend_test.py index fee0616c5..7f0a63873 100644 --- a/kubeflow/trainer/backends/kubernetes/backend_test.py +++ b/kubeflow/trainer/backends/kubernetes/backend_test.py @@ -33,8 +33,10 @@ from kubeflow.trainer.backends.kubernetes.backend import KubernetesBackend from kubeflow.trainer.backends.kubernetes.options import ( + PodSpecOverride, WithAnnotations, WithLabels, + WithPodSpecOverrides, WithTrainerArgs, WithTrainerCommand, WithTrainerImage, @@ -283,6 +285,7 @@ def get_train_job( train_job_trainer: Optional[models.TrainerV1alpha1Trainer] = None, labels: Optional[dict[str, str]] = None, annotations: Optional[dict[str, str]] = None, + pod_spec_overrides: Optional[list] = None, ) -> models.TrainerV1alpha1TrainJob: """ Create a mock TrainJob object with optional trainer configurations. @@ -296,6 +299,7 @@ def get_train_job( spec=models.TrainerV1alpha1TrainJobSpec( runtimeRef=models.TrainerV1alpha1RuntimeRef(name=runtime_name), trainer=train_job_trainer, + pod_spec_overrides=pod_spec_overrides, ), ) @@ -1367,6 +1371,36 @@ def test_train_validation(kubernetes_backend, test_case): annotations={"created-by": "sdk"}, ), ), + TestCase( + name="train with pod spec overrides", + expected_status=SUCCESS, + config={ + "options": [ + WithPodSpecOverrides([ + PodSpecOverride( + target_jobs=["node"], + volumes=[{"name": "data", "persistentVolumeClaim": {"claimName": "my-pvc"}}], + containers=[{"name": "node", "volumeMounts": [{"name": "data", "mountPath": "/data"}]}], + node_selector={"gpu": "true"}, + tolerations=[{"key": "gpu", "operator": "Exists", "effect": "NoSchedule"}], + ) + ]) + ], + }, + expected_output=get_train_job( + runtime_name=TORCH_RUNTIME, + train_job_name=BASIC_TRAIN_JOB_NAME, + pod_spec_overrides=[ + { + "targetJobs": [{"name": "node"}], + "volumes": [{"name": "data", "persistentVolumeClaim": {"claimName": "my-pvc"}}], + "containers": [{"name": "node", "volumeMounts": [{"name": "data", "mountPath": "/data"}]}], + "nodeSelector": {"gpu": "true"}, + "tolerations": [{"key": "gpu", "operator": "Exists", "effect": "NoSchedule"}], + } + ], + ), + ), ] From df59e2c25972da995b13c84072b0b8305dc0fafa Mon Sep 17 00:00:00 2001 From: Abhijeet Dhumal Date: Thu, 23 Oct 2025 16:58:43 +0530 Subject: [PATCH 5/7] fix: Replace PodSpecOverride with PodTemplateOverride API and add container overrides Signed-off-by: Abhijeet Dhumal --- kubeflow/trainer/__init__.py | 12 +- .../trainer/backends/kubernetes/backend.py | 13 +- .../backends/kubernetes/backend_test.py | 178 +++++++------ .../trainer/backends/kubernetes/options.py | 237 +++++++++++++++--- .../backends/kubernetes/options_test.py | 137 +++++++--- kubeflow/trainer/constants/constants.py | 5 + 6 files changed, 435 insertions(+), 147 deletions(-) diff --git a/kubeflow/trainer/__init__.py b/kubeflow/trainer/__init__.py index 017f047b3..b65538013 100644 --- a/kubeflow/trainer/__init__.py +++ b/kubeflow/trainer/__init__.py @@ -18,11 +18,13 @@ # Import common training options (defaults to Kubernetes backend) from kubeflow.trainer.backends.kubernetes.options import ( - PodSpecOverride, + ContainerOverride, + PodTemplateOverride, + PodTemplateSpecOverride, WithAnnotations, WithLabels, WithName, - WithPodSpecOverrides, + WithPodTemplateOverrides, WithTrainerArgs, WithTrainerCommand, WithTrainerImage, @@ -56,6 +58,7 @@ __all__ = [ "BuiltinTrainer", + "ContainerOverride", "CustomTrainer", "DataCacheInitializer", "DataFormat", @@ -67,7 +70,8 @@ "LoraConfig", "Loss", "MODEL_PATH", - "PodSpecOverride", + "PodTemplateOverride", + "PodTemplateSpecOverride", "Runtime", "TorchTuneConfig", "TorchTuneInstructDataset", @@ -79,7 +83,7 @@ "WithAnnotations", "WithLabels", "WithName", - "WithPodSpecOverrides", + "WithPodTemplateOverrides", "WithTrainerArgs", "WithTrainerCommand", "WithTrainerImage", diff --git a/kubeflow/trainer/backends/kubernetes/backend.py b/kubeflow/trainer/backends/kubernetes/backend.py index 08e58f469..5e0941ac3 100644 --- a/kubeflow/trainer/backends/kubernetes/backend.py +++ b/kubeflow/trainer/backends/kubernetes/backend.py @@ -197,7 +197,7 @@ def train( annotations = None name = None trainer_overrides = {} - pod_spec_overrides = None + pod_template_overrides = None if options: for option in options: @@ -208,15 +208,18 @@ def train( annotations = metadata_section.get("annotations") name = metadata_section.get("name") - # Extract trainer-specific overrides and pod spec overrides + # Extract trainer-specific overrides and pod template overrides spec_section = job_spec.get("spec", {}) trainer_spec = spec_section.get("trainer", {}) if trainer_spec: trainer_overrides = trainer_spec - pod_spec_overrides = spec_section.get("podSpecOverrides") + pod_template_overrides = spec_section.get("podTemplateOverrides") # Generate unique name for the TrainJob if not provided - train_job_name = name or (random.choice(string.ascii_lowercase) + uuid.uuid4().hex[:11]) + train_job_name = name or ( + random.choice(string.ascii_lowercase) + + uuid.uuid4().hex[: constants.TRAINJOB_NAME_UUID_LENGTH] + ) # Build the Trainer. trainer_crd = models.TrainerV1alpha1Trainer() @@ -267,7 +270,7 @@ def train( if isinstance(initializer, types.Initializer) else None ), - pod_spec_overrides=pod_spec_overrides, + pod_template_overrides=pod_template_overrides, ), ) diff --git a/kubeflow/trainer/backends/kubernetes/backend_test.py b/kubeflow/trainer/backends/kubernetes/backend_test.py index 7f0a63873..1d44261ee 100644 --- a/kubeflow/trainer/backends/kubernetes/backend_test.py +++ b/kubeflow/trainer/backends/kubernetes/backend_test.py @@ -33,10 +33,12 @@ from kubeflow.trainer.backends.kubernetes.backend import KubernetesBackend from kubeflow.trainer.backends.kubernetes.options import ( - PodSpecOverride, + ContainerOverride, + PodTemplateOverride, + PodTemplateSpecOverride, WithAnnotations, WithLabels, - WithPodSpecOverrides, + WithPodTemplateOverrides, WithTrainerArgs, WithTrainerCommand, WithTrainerImage, @@ -285,11 +287,9 @@ def get_train_job( train_job_trainer: Optional[models.TrainerV1alpha1Trainer] = None, labels: Optional[dict[str, str]] = None, annotations: Optional[dict[str, str]] = None, - pod_spec_overrides: Optional[list] = None, + pod_template_overrides: Optional[list] = None, ) -> models.TrainerV1alpha1TrainJob: - """ - Create a mock TrainJob object with optional trainer configurations. - """ + """Create a mock TrainJob object with optional trainer configurations.""" train_job = models.TrainerV1alpha1TrainJob( apiVersion=constants.API_VERSION, kind=constants.TRAINJOB_KIND, @@ -299,7 +299,7 @@ def get_train_job( spec=models.TrainerV1alpha1TrainJobSpec( runtimeRef=models.TrainerV1alpha1RuntimeRef(name=runtime_name), trainer=train_job_trainer, - pod_spec_overrides=pod_spec_overrides, + pod_template_overrides=pod_template_overrides, ), ) @@ -850,30 +850,28 @@ def get_custom_trainer_with_overrides( @pytest.mark.parametrize("test_case", BACKEND_VALIDATION_TEST_CASES) def test_backend_validation(kubernetes_backend, test_case): """Test KubernetesBackend option validation.""" - print("Executing test:", test_case.name) try: options = test_case.config.get("options") kubernetes_backend.validate_options(options) assert test_case.expected_status == SUCCESS except Exception as e: assert type(e) is test_case.expected_error - print("test execution complete") def test_backend_capabilities(kubernetes_backend): """Test KubernetesBackend validation with correct mixin.""" from kubeflow.trainer.backends.kubernetes.options import WithLabels from kubeflow.trainer.backends.localprocess.options import WithWorkingDirectory - + # Test that KubernetesCompatible options pass validation kubernetes_options = [WithLabels({"app": "test"})] kubernetes_backend.validate_options(kubernetes_options) # Should not raise - + # Test that LocalProcessCompatible options fail validation local_options = [WithWorkingDirectory("/tmp")] try: kubernetes_backend.validate_options(local_options) - assert False, "Expected ValueError for incompatible options" + raise AssertionError("Expected ValueError for incompatible options") except ValueError as e: assert "WithWorkingDirectory" in str(e) @@ -981,32 +979,9 @@ def test_backend_capabilities(kubernetes_backend): ] -def get_custom_trainer_with_overrides( - image: Optional[str] = None, - command: Optional[list[str]] = None, - args: Optional[list[str]] = None, - resources_per_node: Optional[dict] = None, - **kwargs, -) -> models.TrainerV1alpha1Trainer: - """Helper to create trainer with container overrides.""" - trainer = get_custom_trainer(**kwargs) - - if image: - trainer.image = image - if command: - trainer.command = command - if args: - trainer.args = args - if resources_per_node: - trainer.resources_per_node = utils.get_resources_per_node(resources_per_node) - - return trainer - - @pytest.mark.parametrize("test_case", TRAINER_OPTIONS_TEST_CASES) def test_train_with_trainer_options(kubernetes_backend, test_case): """Test KubernetesBackend.train with new trainer container options.""" - print("Executing test:", test_case.name) try: kubernetes_backend.namespace = test_case.config.get("namespace", DEFAULT_NAMESPACE) runtime = kubernetes_backend.get_runtime(test_case.config.get("runtime", TORCH_RUNTIME)) @@ -1035,7 +1010,6 @@ def test_train_with_trainer_options(kubernetes_backend, test_case): except Exception as e: assert type(e) is test_case.expected_error - print("test execution complete") # -------------------------- @@ -1046,7 +1020,6 @@ def test_train_with_trainer_options(kubernetes_backend, test_case): @pytest.mark.parametrize("test_case", GET_RUNTIME_TEST_CASES) def test_get_runtime(kubernetes_backend, test_case): """Test KubernetesBackend.get_runtime with basic success path.""" - print("Executing test:", test_case.name) try: runtime = kubernetes_backend.get_runtime(**test_case.config) @@ -1056,7 +1029,6 @@ def test_get_runtime(kubernetes_backend, test_case): except Exception as e: assert type(e) is test_case.expected_error - print("test execution complete") @pytest.mark.parametrize( @@ -1075,7 +1047,6 @@ def test_get_runtime(kubernetes_backend, test_case): ) def test_list_runtimes(kubernetes_backend, test_case): """Test KubernetesBackend.list_runtimes with basic success path.""" - print("Executing test:", test_case.name) try: kubernetes_backend.namespace = test_case.config.get("namespace", DEFAULT_NAMESPACE) runtimes = kubernetes_backend.list_runtimes() @@ -1087,7 +1058,6 @@ def test_list_runtimes(kubernetes_backend, test_case): except Exception as e: assert type(e) is test_case.expected_error - print("test execution complete") @pytest.mark.parametrize( @@ -1119,14 +1089,12 @@ def test_list_runtimes(kubernetes_backend, test_case): ) def test_get_runtime_packages(kubernetes_backend, test_case): """Test KubernetesBackend.get_runtime_packages with basic success path.""" - print("Executing test:", test_case.name) try: kubernetes_backend.get_runtime_packages(**test_case.config) except Exception as e: assert type(e) is test_case.expected_error - print("test execution complete") @pytest.mark.parametrize( @@ -1312,7 +1280,6 @@ def test_get_runtime_packages(kubernetes_backend, test_case): ) def test_train_validation(kubernetes_backend, test_case): """Test KubernetesBackend.train validation with various scenarios.""" - print("Executing test:", test_case.name) try: kubernetes_backend.namespace = test_case.config.get("namespace", DEFAULT_NAMESPACE) @@ -1349,9 +1316,7 @@ def test_train_validation(kubernetes_backend, test_case): options=test_case.config.get("options", []), ) except Exception as e: - print(f"Test failed with error: {e}") raise - print("test execution complete") TRAIN_TEST_CASES = [ @@ -1371,18 +1336,32 @@ def test_train_validation(kubernetes_backend, test_case): annotations={"created-by": "sdk"}, ), ), +] + +TRAIN_TEST_CASES.extend([ TestCase( - name="train with pod spec overrides", + name="train with pod template overrides - basic", expected_status=SUCCESS, config={ "options": [ - WithPodSpecOverrides([ - PodSpecOverride( + WithPodTemplateOverrides([ + PodTemplateOverride( target_jobs=["node"], - volumes=[{"name": "data", "persistentVolumeClaim": {"claimName": "my-pvc"}}], - containers=[{"name": "node", "volumeMounts": [{"name": "data", "mountPath": "/data"}]}], - node_selector={"gpu": "true"}, - tolerations=[{"key": "gpu", "operator": "Exists", "effect": "NoSchedule"}], + spec=PodTemplateSpecOverride( + volumes=[ + {"name": "data", "persistentVolumeClaim": {"claimName": "my-pvc"}} + ], + containers=[ + ContainerOverride( + name="node", + volume_mounts=[{"name": "data", "mountPath": "/data"}], + ) + ], + node_selector={"gpu": "true"}, + tolerations=[ + {"key": "gpu", "operator": "Exists", "effect": "NoSchedule"} + ], + ) ) ]) ], @@ -1390,24 +1369,92 @@ def test_train_validation(kubernetes_backend, test_case): expected_output=get_train_job( runtime_name=TORCH_RUNTIME, train_job_name=BASIC_TRAIN_JOB_NAME, - pod_spec_overrides=[ + pod_template_overrides=[ { "targetJobs": [{"name": "node"}], - "volumes": [{"name": "data", "persistentVolumeClaim": {"claimName": "my-pvc"}}], - "containers": [{"name": "node", "volumeMounts": [{"name": "data", "mountPath": "/data"}]}], - "nodeSelector": {"gpu": "true"}, - "tolerations": [{"key": "gpu", "operator": "Exists", "effect": "NoSchedule"}], + "spec": { + "volumes": [ + {"name": "data", "persistentVolumeClaim": {"claimName": "my-pvc"}} + ], + "containers": [ + { + "name": "node", + "volumeMounts": [{"name": "data", "mountPath": "/data"}], + } + ], + "nodeSelector": {"gpu": "true"}, + "tolerations": [ + {"key": "gpu", "operator": "Exists", "effect": "NoSchedule"} + ], + } } ], ), ), -] + TestCase( + name="train with pod template overrides - with metadata and affinity", + expected_status=SUCCESS, + config={ + "options": [ + WithPodTemplateOverrides([ + PodTemplateOverride( + target_jobs=["node"], + metadata={"labels": {"custom-label": "custom-value"}}, + spec=PodTemplateSpecOverride( + service_account_name="training-sa", + affinity={ + "nodeAffinity": { + "requiredDuringSchedulingIgnoredDuringExecution": { + "nodeSelectorTerms": [{ + "matchExpressions": [{ + "key": "gpu-type", + "operator": "In", + "values": ["nvidia-a100"] + }] + }] + } + } + }, + scheduling_gates=[{"name": "wait-for-resources"}], + ) + ) + ]) + ], + }, + expected_output=get_train_job( + runtime_name=TORCH_RUNTIME, + train_job_name=BASIC_TRAIN_JOB_NAME, + pod_template_overrides=[ + { + "targetJobs": [{"name": "node"}], + "metadata": {"labels": {"custom-label": "custom-value"}}, + "spec": { + "serviceAccountName": "training-sa", + "affinity": { + "nodeAffinity": { + "requiredDuringSchedulingIgnoredDuringExecution": { + "nodeSelectorTerms": [{ + "matchExpressions": [{ + "key": "gpu-type", + "operator": "In", + "values": ["nvidia-a100"] + }] + }] + } + } + }, + "schedulingGates": [{"name": "wait-for-resources"}], + } + } + ], + ), + ), +]) @pytest.mark.parametrize("test_case", TRAIN_TEST_CASES) def test_train(kubernetes_backend, test_case): """Test KubernetesBackend.train with basic success path.""" - print("Executing test:", test_case.name) try: kubernetes_backend.namespace = test_case.config.get("namespace", DEFAULT_NAMESPACE) runtime = kubernetes_backend.get_runtime(test_case.config.get("runtime", TORCH_RUNTIME)) @@ -1446,7 +1493,6 @@ def test_train(kubernetes_backend, test_case): except Exception as e: assert type(e) is test_case.expected_error - print("test execution complete") @pytest.mark.parametrize( @@ -1477,7 +1523,6 @@ def test_train(kubernetes_backend, test_case): ) def test_get_job(kubernetes_backend, test_case): """Test KubernetesBackend.get_job with basic success path.""" - print("Executing test:", test_case.name) try: job = kubernetes_backend.get_job(**test_case.config) @@ -1486,7 +1531,6 @@ def test_get_job(kubernetes_backend, test_case): except Exception as e: assert type(e) is test_case.expected_error - print("test execution complete") @pytest.mark.parametrize( @@ -1523,7 +1567,6 @@ def test_get_job(kubernetes_backend, test_case): ) def test_list_jobs(kubernetes_backend, test_case): """Test KubernetesBackend.list_jobs with basic success path.""" - print("Executing test:", test_case.name) try: kubernetes_backend.namespace = test_case.config.get("namespace", DEFAULT_NAMESPACE) jobs = kubernetes_backend.list_jobs() @@ -1535,7 +1578,6 @@ def test_list_jobs(kubernetes_backend, test_case): except Exception as e: assert type(e) is test_case.expected_error - print("test execution complete") @pytest.mark.parametrize( @@ -1557,7 +1599,6 @@ def test_list_jobs(kubernetes_backend, test_case): ) def test_get_job_logs(kubernetes_backend, test_case): """Test KubernetesBackend.get_job_logs with basic success path.""" - print("Executing test:", test_case.name) try: kubernetes_backend.namespace = test_case.config.get("namespace", DEFAULT_NAMESPACE) logs = kubernetes_backend.get_job_logs(test_case.config.get("name")) @@ -1567,7 +1608,6 @@ def test_get_job_logs(kubernetes_backend, test_case): assert logs_list == test_case.expected_output except Exception as e: assert type(e) is test_case.expected_error - print("test execution complete") @pytest.mark.parametrize( @@ -1637,8 +1677,6 @@ def test_get_job_logs(kubernetes_backend, test_case): ) def test_wait_for_job_status(kubernetes_backend, test_case): """Test KubernetesBackend.wait_for_job_status with various scenarios.""" - print("Executing test:", test_case.name) - original_get_job = kubernetes_backend.get_job # TrainJob has unexpected failed status. @@ -1661,8 +1699,6 @@ def mock_get_job(name): except Exception as e: assert type(e) is test_case.expected_error - print("test execution complete") - @pytest.mark.parametrize( "test_case", @@ -1689,7 +1725,6 @@ def mock_get_job(name): ) def test_delete_job(kubernetes_backend, test_case): """Test KubernetesBackend.delete_job with basic success path.""" - print("Executing test:", test_case.name) try: kubernetes_backend.namespace = test_case.config.get("namespace", DEFAULT_NAMESPACE) kubernetes_backend.delete_job(test_case.config.get("name")) @@ -1705,4 +1740,3 @@ def test_delete_job(kubernetes_backend, test_case): except Exception as e: assert type(e) is test_case.expected_error - print("test execution complete") diff --git a/kubeflow/trainer/backends/kubernetes/options.py b/kubeflow/trainer/backends/kubernetes/options.py index 22ef54ed7..96cf7e76b 100644 --- a/kubeflow/trainer/backends/kubernetes/options.py +++ b/kubeflow/trainer/backends/kubernetes/options.py @@ -15,10 +15,13 @@ """Kubernetes-specific training options for the Kubeflow Trainer SDK.""" from dataclasses import dataclass -from typing import Optional +from typing import TYPE_CHECKING, Any, Optional, Union from kubeflow.trainer.backends.base import KubernetesCompatible +if TYPE_CHECKING: + from kubeflow.trainer.types.types import BuiltinTrainer, CustomTrainer + @dataclass class WithLabels(KubernetesCompatible): @@ -26,7 +29,11 @@ class WithLabels(KubernetesCompatible): labels: dict[str, str] - def __call__(self, job_spec: dict, trainer=None) -> None: + def __call__( + self, + job_spec: dict[str, Any], + trainer: Optional[Union["CustomTrainer", "BuiltinTrainer"]] = None, + ) -> None: """Apply labels to the job specification.""" metadata = job_spec.setdefault("metadata", {}) metadata["labels"] = self.labels @@ -38,7 +45,11 @@ class WithAnnotations(KubernetesCompatible): annotations: dict[str, str] - def __call__(self, job_spec: dict, trainer=None) -> None: + def __call__( + self, + job_spec: dict[str, Any], + trainer: Optional[Union["CustomTrainer", "BuiltinTrainer"]] = None, + ) -> None: """Apply annotations to the job specification.""" metadata = job_spec.setdefault("metadata", {}) metadata["annotations"] = self.annotations @@ -50,53 +61,191 @@ class WithName(KubernetesCompatible): name: str - def __call__(self, job_spec: dict, trainer=None) -> None: + def __call__( + self, + job_spec: dict[str, Any], + trainer: Optional[Union["CustomTrainer", "BuiltinTrainer"]] = None, + ) -> None: """Apply custom name to the job specification.""" metadata = job_spec.setdefault("metadata", {}) metadata["name"] = self.name @dataclass -class PodSpecOverride: - """Configuration for overriding pod specifications for specific job types.""" +class ContainerOverride: + """Configuration for overriding a specific container in a pod. + + Args: + name: Name of the container to override (must exist in TrainingRuntime). + env: Environment variables to add/merge with the container. + Each dict should have 'name' and 'value' or 'valueFrom' keys. + volume_mounts: Volume mounts to add/merge with the container. + Each dict should have 'name' and 'mountPath' keys at minimum. + """ + + name: str + env: Optional[list[dict]] = None + volume_mounts: Optional[list[dict]] = None + + def __post_init__(self): + """Validate the container override configuration.""" + # Validate container name + if not self.name or not self.name.strip(): + raise ValueError("Container name must be a non-empty string") + + if self.env is not None: + if not isinstance(self.env, list): + raise ValueError("env must be a list of dictionaries") + for env_var in self.env: + if not isinstance(env_var, dict): + raise ValueError("Each env entry must be a dictionary") + if "name" not in env_var: + raise ValueError("Each env entry must have a 'name' key") + if not env_var.get("name"): + raise ValueError("env 'name' must be a non-empty string") + if "value" not in env_var and "valueFrom" not in env_var: + raise ValueError( + "Each env entry must have either 'value' or 'valueFrom' key" + ) + # Validate valueFrom structure if present + if "valueFrom" in env_var: + value_from = env_var["valueFrom"] + if not isinstance(value_from, dict): + raise ValueError("env 'valueFrom' must be a dictionary") + # valueFrom must have one of these keys + valid_keys = {"configMapKeyRef", "secretKeyRef", "fieldRef", "resourceFieldRef"} + if not any(key in value_from for key in valid_keys): + raise ValueError( + f"env 'valueFrom' must contain one of: {', '.join(valid_keys)}" + ) + + if self.volume_mounts is not None: + if not isinstance(self.volume_mounts, list): + raise ValueError("volume_mounts must be a list of dictionaries") + for mount in self.volume_mounts: + if not isinstance(mount, dict): + raise ValueError("Each volume_mounts entry must be a dictionary") + if "name" not in mount: + raise ValueError("Each volume_mounts entry must have a 'name' key") + if not mount.get("name"): + raise ValueError("volume_mounts 'name' must be a non-empty string") + if "mountPath" not in mount: + raise ValueError("Each volume_mounts entry must have a 'mountPath' key") + mount_path = mount.get("mountPath") + if not mount_path or not isinstance(mount_path, str): + raise ValueError("volume_mounts 'mountPath' must be a non-empty string") + if not mount_path.startswith("/"): + raise ValueError( + f"volume_mounts 'mountPath' must be an absolute path (start with /): {mount_path}" + ) + + +@dataclass +class PodTemplateSpecOverride: + """Configuration for overriding pod template specifications. + + Args: + service_account_name: Service account to use for the pods. + node_selector: Node selector to place pods on specific nodes. + affinity: Affinity rules for pod scheduling. + tolerations: Tolerations for pod scheduling. + volumes: Volumes to add/merge with the pod. + init_containers: Init containers to add/merge with the pod. + containers: Containers to add/merge with the pod. + scheduling_gates: Scheduling gates for the pods. + image_pull_secrets: Image pull secrets for the pods. + """ - target_jobs: list[str] - volumes: Optional[list[dict]] = None - containers: Optional[list[dict]] = None - init_containers: Optional[list[dict]] = None - node_selector: Optional[dict[str, str]] = None service_account_name: Optional[str] = None + node_selector: Optional[dict[str, str]] = None + affinity: Optional[dict] = None tolerations: Optional[list[dict]] = None + volumes: Optional[list[dict]] = None + init_containers: Optional[list[ContainerOverride]] = None + containers: Optional[list[ContainerOverride]] = None + scheduling_gates: Optional[list[dict]] = None + image_pull_secrets: Optional[list[dict]] = None + + +@dataclass +class PodTemplateOverride: + """Configuration for overriding pod templates for specific job types. + + Args: + target_jobs: List of job names to apply the overrides to (e.g., ["node", "launcher"]). + metadata: Metadata overrides for the pod template (labels, annotations). + spec: Spec overrides for the pod template. + """ + + target_jobs: list[str] + metadata: Optional[dict] = None + spec: Optional[PodTemplateSpecOverride] = None @dataclass -class WithPodSpecOverrides(KubernetesCompatible): - """Add pod specification overrides to the TrainJob (.spec.podSpecOverrides).""" +class WithPodTemplateOverrides(KubernetesCompatible): + """Add pod template overrides to the TrainJob (.spec.podTemplateOverrides).""" - overrides: list[PodSpecOverride] + overrides: list[PodTemplateOverride] - def __call__(self, job_spec: dict, trainer=None) -> None: - """Apply pod spec overrides to the job specification.""" + def __call__( + self, + job_spec: dict[str, Any], + trainer: Optional[Union["CustomTrainer", "BuiltinTrainer"]] = None, + ) -> None: + """Apply pod template overrides to the job specification.""" spec = job_spec.setdefault("spec", {}) - spec["podSpecOverrides"] = [] + spec["podTemplateOverrides"] = [] for override in self.overrides: api_override = {"targetJobs": [{"name": job} for job in override.target_jobs]} - if override.volumes: - api_override["volumes"] = override.volumes - if override.containers: - api_override["containers"] = override.containers - if override.init_containers: - api_override["initContainers"] = override.init_containers - if override.node_selector: - api_override["nodeSelector"] = override.node_selector - if override.service_account_name: - api_override["serviceAccountName"] = override.service_account_name - if override.tolerations: - api_override["tolerations"] = override.tolerations - - spec["podSpecOverrides"].append(api_override) + if override.metadata: + api_override["metadata"] = override.metadata + + if override.spec: + spec_dict = {} + + if override.spec.service_account_name: + spec_dict["serviceAccountName"] = override.spec.service_account_name + if override.spec.node_selector: + spec_dict["nodeSelector"] = override.spec.node_selector + if override.spec.affinity: + spec_dict["affinity"] = override.spec.affinity + if override.spec.tolerations: + spec_dict["tolerations"] = override.spec.tolerations + if override.spec.volumes: + spec_dict["volumes"] = override.spec.volumes + if override.spec.scheduling_gates: + spec_dict["schedulingGates"] = override.spec.scheduling_gates + if override.spec.image_pull_secrets: + spec_dict["imagePullSecrets"] = override.spec.image_pull_secrets + + # Handle container overrides + if override.spec.init_containers: + spec_dict["initContainers"] = [] + for container in override.spec.init_containers: + container_dict = {"name": container.name} + if container.env: + container_dict["env"] = container.env + if container.volume_mounts: + container_dict["volumeMounts"] = container.volume_mounts + spec_dict["initContainers"].append(container_dict) + + if override.spec.containers: + spec_dict["containers"] = [] + for container in override.spec.containers: + container_dict = {"name": container.name} + if container.env: + container_dict["env"] = container.env + if container.volume_mounts: + container_dict["volumeMounts"] = container.volume_mounts + spec_dict["containers"].append(container_dict) + + if spec_dict: + api_override["spec"] = spec_dict + + spec["podTemplateOverrides"].append(api_override) @dataclass @@ -105,7 +254,11 @@ class WithTrainerImage(KubernetesCompatible): image: str - def __call__(self, job_spec: dict, trainer=None) -> None: + def __call__( + self, + job_spec: dict[str, Any], + trainer: Optional[Union["CustomTrainer", "BuiltinTrainer"]] = None, + ) -> None: """Apply trainer image override to the job specification. Args: @@ -123,7 +276,11 @@ class WithTrainerCommand(KubernetesCompatible): command: list[str] - def __call__(self, job_spec: dict, trainer=None) -> None: + def __call__( + self, + job_spec: dict[str, Any], + trainer: Optional[Union["CustomTrainer", "BuiltinTrainer"]] = None, + ) -> None: """Apply trainer command override to the job specification. Args: @@ -134,7 +291,9 @@ def __call__(self, job_spec: dict, trainer=None) -> None: ValueError: If there's a conflict with the trainer configuration """ # Validate conflicts with trainer - if trainer and hasattr(trainer, "func") and trainer.func is not None: + from kubeflow.trainer.types.types import CustomTrainer + + if isinstance(trainer, CustomTrainer) and trainer.func is not None: raise ValueError( "Cannot specify WithTrainerCommand when CustomTrainer.func is provided. " "The func generates its own command. Use container-only training " @@ -152,7 +311,11 @@ class WithTrainerArgs(KubernetesCompatible): args: list[str] - def __call__(self, job_spec: dict, trainer=None) -> None: + def __call__( + self, + job_spec: dict[str, Any], + trainer: Optional[Union["CustomTrainer", "BuiltinTrainer"]] = None, + ) -> None: """Apply trainer args override to the job specification. Args: @@ -163,7 +326,9 @@ def __call__(self, job_spec: dict, trainer=None) -> None: ValueError: If there's a conflict with the trainer configuration """ # Validate conflicts with trainer - if trainer and hasattr(trainer, "func") and trainer.func is not None: + from kubeflow.trainer.types.types import CustomTrainer + + if isinstance(trainer, CustomTrainer) and trainer.func is not None: raise ValueError( "Cannot specify WithTrainerArgs when CustomTrainer.func is provided. " "The func generates its own arguments. Use container-only training " diff --git a/kubeflow/trainer/backends/kubernetes/options_test.py b/kubeflow/trainer/backends/kubernetes/options_test.py index 8eb2fc826..6936c4cca 100644 --- a/kubeflow/trainer/backends/kubernetes/options_test.py +++ b/kubeflow/trainer/backends/kubernetes/options_test.py @@ -16,11 +16,11 @@ from kubeflow.trainer.backends.base import KubernetesCompatible from kubeflow.trainer.backends.kubernetes.options import ( - PodSpecOverride, + PodTemplateOverride, WithAnnotations, WithLabels, WithName, - WithPodSpecOverrides, + WithPodTemplateOverrides, WithTrainerArgs, WithTrainerCommand, WithTrainerImage, @@ -45,10 +45,10 @@ def test_with_name_option_type(self): option = WithName("test-job") assert isinstance(option, KubernetesCompatible) - def test_with_pod_spec_overrides_option_type(self): - """Test WithPodSpecOverrides inherits from KubernetesCompatible.""" - overrides = [PodSpecOverride(target_jobs=["node"])] - option = WithPodSpecOverrides(overrides) + def test_with_pod_template_overrides_option_type(self): + """Test WithPodTemplateOverrides inherits from KubernetesCompatible.""" + overrides = [PodTemplateOverride(target_jobs=["node"])] + option = WithPodTemplateOverrides(overrides) assert isinstance(option, KubernetesCompatible) def test_trainer_options_types(self): @@ -198,14 +198,14 @@ def test_kubernetes_capabilities_support(self): labels_option = WithLabels({"app": "test"}) annotations_option = WithAnnotations({"desc": "test"}) name_option = WithName("test-job") - pod_spec_option = WithPodSpecOverrides([PodSpecOverride(target_jobs=["node"])]) + pod_template_option = WithPodTemplateOverrides([PodTemplateOverride(target_jobs=["node"])]) image_option = WithTrainerImage("custom:latest") command_option = WithTrainerCommand(["python", "train.py"]) args_option = WithTrainerArgs(["--epochs", "10"]) assert isinstance(labels_option, KubernetesCompatible) assert isinstance(annotations_option, KubernetesCompatible) assert isinstance(name_option, KubernetesCompatible) - assert isinstance(pod_spec_option, KubernetesCompatible) + assert isinstance(pod_template_option, KubernetesCompatible) assert isinstance(image_option, KubernetesCompatible) assert isinstance(command_option, KubernetesCompatible) assert isinstance(args_option, KubernetesCompatible) @@ -216,7 +216,7 @@ def test_kubernetes_option_compatibility(self): labels_opt = WithLabels({"app": "test"}) annotations_opt = WithAnnotations({"desc": "test"}) name_opt = WithName("test-job") - pod_opt = WithPodSpecOverrides([PodSpecOverride(target_jobs=["node"])]) + pod_opt = WithPodTemplateOverrides([PodTemplateOverride(target_jobs=["node"])]) image_opt = WithTrainerImage("custom:latest") command_opt = WithTrainerCommand(["python", "train.py"]) args_opt = WithTrainerArgs(["--epochs", "10"]) @@ -230,33 +230,110 @@ def test_kubernetes_option_compatibility(self): assert isinstance(args_opt, KubernetesCompatible) -class TestPodSpecOverride: - """Test PodSpecOverride dataclass.""" +class TestContainerOverride: + """Test ContainerOverride dataclass.""" - def test_pod_spec_override_creation(self): - """Test PodSpecOverride creation with various fields.""" - override = PodSpecOverride( - target_jobs=["node", "worker"], - volumes=[{"name": "data", "emptyDir": {}}], - containers=[{"name": "node", "volumeMounts": [{"name": "data", "mountPath": "/data"}]}], - node_selector={"gpu": "true"}, + def test_container_override_full(self): + """Test ContainerOverride with all fields.""" + from kubeflow.trainer.backends.kubernetes.options import ContainerOverride + + override = ContainerOverride( + name="node", + env=[{"name": "MY_VAR", "value": "my_value"}], + volume_mounts=[{"name": "data", "mountPath": "/data"}], + ) + + assert override.name == "node" + assert override.env == [{"name": "MY_VAR", "value": "my_value"}] + assert override.volume_mounts == [{"name": "data", "mountPath": "/data"}] + + def test_container_override_minimal(self): + """Test ContainerOverride with minimal fields.""" + from kubeflow.trainer.backends.kubernetes.options import ContainerOverride + + override = ContainerOverride(name="node") + + assert override.name == "node" + assert override.env is None + assert override.volume_mounts is None + + +class TestPodTemplateSpecOverride: + """Test PodTemplateSpecOverride dataclass.""" + + def test_pod_template_spec_override_full(self): + """Test PodTemplateSpecOverride with all fields.""" + from kubeflow.trainer.backends.kubernetes.options import ( + ContainerOverride, + PodTemplateSpecOverride, + ) + + override = PodTemplateSpecOverride( service_account_name="training-sa", - tolerations=[{"key": "gpu", "operator": "Exists", "effect": "NoSchedule"}], + node_selector={"gpu": "true"}, + affinity={"nodeAffinity": {}}, + tolerations=[{"key": "gpu", "operator": "Exists"}], + volumes=[{"name": "data", "emptyDir": {}}], + init_containers=[ContainerOverride(name="init")], + containers=[ContainerOverride(name="node")], + scheduling_gates=[{"name": "wait"}], + image_pull_secrets=[{"name": "my-secret"}], ) - assert override.target_jobs == ["node", "worker"] - assert override.volumes == [{"name": "data", "emptyDir": {}}] - assert override.node_selector == {"gpu": "true"} assert override.service_account_name == "training-sa" + assert override.node_selector == {"gpu": "true"} + assert override.affinity == {"nodeAffinity": {}} + assert len(override.containers) == 1 + assert override.containers[0].name == "node" - def test_pod_spec_override_minimal(self): - """Test PodSpecOverride with minimal required fields.""" - override = PodSpecOverride(target_jobs=["node"]) + def test_pod_template_spec_override_minimal(self): + """Test PodTemplateSpecOverride with no fields.""" + from kubeflow.trainer.backends.kubernetes.options import PodTemplateSpecOverride + + override = PodTemplateSpecOverride() - assert override.target_jobs == ["node"] - assert override.volumes is None - assert override.containers is None - assert override.init_containers is None - assert override.node_selector is None assert override.service_account_name is None + assert override.node_selector is None + assert override.affinity is None assert override.tolerations is None + assert override.volumes is None + assert override.init_containers is None + assert override.containers is None + assert override.scheduling_gates is None + assert override.image_pull_secrets is None + + +class TestPodTemplateOverride: + """Test PodTemplateOverride dataclass.""" + + def test_pod_template_override_full(self): + """Test PodTemplateOverride with all fields.""" + from kubeflow.trainer.backends.kubernetes.options import ( + ContainerOverride, + PodTemplateOverride, + PodTemplateSpecOverride, + ) + + override = PodTemplateOverride( + target_jobs=["node", "launcher"], + metadata={"labels": {"custom": "label"}}, + spec=PodTemplateSpecOverride( + node_selector={"gpu": "true"}, + containers=[ContainerOverride(name="node")], + ), + ) + + assert override.target_jobs == ["node", "launcher"] + assert override.metadata == {"labels": {"custom": "label"}} + assert override.spec is not None + assert override.spec.node_selector == {"gpu": "true"} + + def test_pod_template_override_minimal(self): + """Test PodTemplateOverride with minimal fields.""" + from kubeflow.trainer.backends.kubernetes.options import PodTemplateOverride + + override = PodTemplateOverride(target_jobs=["node"]) + + assert override.target_jobs == ["node"] + assert override.metadata is None + assert override.spec is None diff --git a/kubeflow/trainer/constants/constants.py b/kubeflow/trainer/constants/constants.py index 015498f9a..1f3fd5c21 100644 --- a/kubeflow/trainer/constants/constants.py +++ b/kubeflow/trainer/constants/constants.py @@ -171,3 +171,8 @@ # The Instruct Datasets class in torchtune TORCH_TUNE_INSTRUCT_DATASET = "torchtune.datasets.instruct_dataset" + +# The length of the UUID suffix for auto-generated TrainJob names. +# Total name length = 1 (random letter) + 11 (UUID hex) = 12 characters +# This keeps the name well under Kubernetes' 63 character limit for resource names. +TRAINJOB_NAME_UUID_LENGTH = 11 From fd89e68eabbdec08b0e6730bd0ddfa9ef98e27f5 Mon Sep 17 00:00:00 2001 From: Abhijeet Dhumal Date: Thu, 23 Oct 2025 17:47:03 +0530 Subject: [PATCH 6/7] refactor: consolidate training options into centralized package Signed-off-by: Abhijeet Dhumal --- kubeflow/trainer/__init__.py | 48 +-- kubeflow/trainer/api/trainer_client_test.py | 26 +- kubeflow/trainer/backends/__init__.py | 6 +- kubeflow/trainer/backends/base.py | 16 +- .../backends/kubernetes/backend_test.py | 345 +++++++++--------- .../backends/kubernetes/options_test.py | 339 ----------------- .../backends/localprocess/backend_test.py | 48 ++- .../trainer/backends/localprocess/options.py | 41 --- kubeflow/trainer/options/__init__.py | 53 +++ kubeflow/trainer/options/common.py | 134 +++++++ .../options.py => options/kubernetes.py} | 144 +------- kubeflow/trainer/options/kubernetes_test.py | 248 +++++++++++++ kubeflow/trainer/options/localprocess.py | 63 ++++ .../localprocess_test.py} | 64 ++-- 14 files changed, 795 insertions(+), 780 deletions(-) delete mode 100644 kubeflow/trainer/backends/kubernetes/options_test.py delete mode 100644 kubeflow/trainer/backends/localprocess/options.py create mode 100644 kubeflow/trainer/options/__init__.py create mode 100644 kubeflow/trainer/options/common.py rename kubeflow/trainer/{backends/kubernetes/options.py => options/kubernetes.py} (55%) create mode 100644 kubeflow/trainer/options/kubernetes_test.py create mode 100644 kubeflow/trainer/options/localprocess.py rename kubeflow/trainer/{backends/localprocess/options_test.py => options/localprocess_test.py} (58%) diff --git a/kubeflow/trainer/__init__.py b/kubeflow/trainer/__init__.py index b65538013..cd001db8f 100644 --- a/kubeflow/trainer/__init__.py +++ b/kubeflow/trainer/__init__.py @@ -16,20 +16,6 @@ # Import the Kubeflow Trainer client. from kubeflow.trainer.api.trainer_client import TrainerClient # noqa: F401 -# Import common training options (defaults to Kubernetes backend) -from kubeflow.trainer.backends.kubernetes.options import ( - ContainerOverride, - PodTemplateOverride, - PodTemplateSpecOverride, - WithAnnotations, - WithLabels, - WithName, - WithPodTemplateOverrides, - WithTrainerArgs, - WithTrainerCommand, - WithTrainerImage, -) - # import backends and its associated configs from kubeflow.trainer.backends.kubernetes.types import KubernetesBackendConfig from kubeflow.trainer.backends.localprocess.types import LocalProcessBackendConfig @@ -37,6 +23,20 @@ # Import the Kubeflow Trainer constants. from kubeflow.trainer.constants.constants import DATASET_PATH, MODEL_PATH # noqa: F401 +# Import training options +from kubeflow.trainer.options import ( + Annotations, + ContainerOverride, + Labels, + Name, + PodTemplateOverride, + PodTemplateOverrides, + PodTemplateSpecOverride, + TrainerArgs, + TrainerCommand, + TrainerImage, +) + # Import the Kubeflow Trainer types. from kubeflow.trainer.types.types import ( BuiltinTrainer, @@ -57,6 +57,7 @@ ) __all__ = [ + "Annotations", "BuiltinTrainer", "ContainerOverride", "CustomTrainer", @@ -67,24 +68,23 @@ "HuggingFaceDatasetInitializer", "HuggingFaceModelInitializer", "Initializer", + "KubernetesBackendConfig", + "Labels", + "LocalProcessBackendConfig", "LoraConfig", "Loss", "MODEL_PATH", + "Name", "PodTemplateOverride", + "PodTemplateOverrides", "PodTemplateSpecOverride", "Runtime", + "RuntimeTrainer", "TorchTuneConfig", "TorchTuneInstructDataset", - "RuntimeTrainer", + "TrainerArgs", "TrainerClient", + "TrainerCommand", + "TrainerImage", "TrainerType", - "LocalProcessBackendConfig", - "KubernetesBackendConfig", - "WithAnnotations", - "WithLabels", - "WithName", - "WithPodTemplateOverrides", - "WithTrainerArgs", - "WithTrainerCommand", - "WithTrainerImage", ] diff --git a/kubeflow/trainer/api/trainer_client_test.py b/kubeflow/trainer/api/trainer_client_test.py index 423be3ca0..04040bd07 100644 --- a/kubeflow/trainer/api/trainer_client_test.py +++ b/kubeflow/trainer/api/trainer_client_test.py @@ -21,9 +21,9 @@ import pytest from kubeflow.trainer.api.trainer_client import TrainerClient -from kubeflow.trainer.backends.kubernetes.options import WithAnnotations, WithLabels from kubeflow.trainer.backends.kubernetes.types import KubernetesBackendConfig from kubeflow.trainer.backends.localprocess.types import LocalProcessBackendConfig +from kubeflow.trainer.options import Annotations, Labels from kubeflow.trainer.types import types @@ -39,14 +39,14 @@ def simple_func(): return "test" trainer = types.CustomTrainer(func=simple_func) - options = [WithLabels({"app": "test"})] + options = [Labels({"app": "test"})] with pytest.raises(ValueError) as exc_info: client.train(trainer=trainer, options=options) error_msg = str(exc_info.value) assert "The following options are not compatible with this backend" in error_msg - assert "WithLabels" in error_msg + assert "Labels" in error_msg @patch("kubernetes.config.load_kube_config") @patch("kubernetes.client.CustomObjectsApi") @@ -65,7 +65,7 @@ def simple_func(): return "test" trainer = types.CustomTrainer(func=simple_func) - options = [WithLabels({"app": "test"}), WithAnnotations({"desc": "test"})] + options = [Labels({"app": "test"}), Annotations({"desc": "test"})] with pytest.raises((ValueError, RuntimeError)) as exc_info: client.train(trainer=trainer, options=options) @@ -123,15 +123,15 @@ def simple_func(): return "test" trainer = types.CustomTrainer(func=simple_func) - options = [WithLabels({"app": "test"}), WithAnnotations({"desc": "test"})] + options = [Labels({"app": "test"}), Annotations({"desc": "test"})] with pytest.raises(ValueError) as exc_info: client.train(trainer=trainer, options=options) error_msg = str(exc_info.value) assert "The following options are not compatible with this backend" in error_msg - assert "WithLabels" in error_msg - assert "WithAnnotations" in error_msg + assert "Labels" in error_msg + assert "Annotations" in error_msg assert "The following options are not compatible with this backend" in error_msg def test_error_message_does_not_contain_runtime_help_for_option_errors(self): @@ -143,7 +143,7 @@ def simple_func(): return "test" trainer = types.CustomTrainer(func=simple_func) - options = [WithLabels({"app": "test"})] + options = [Labels({"app": "test"})] with pytest.raises(ValueError) as exc_info: client.train(trainer=trainer, options=options) @@ -235,7 +235,7 @@ def simple_func(): return "test" trainer = types.CustomTrainer(func=simple_func) - options = [WithLabels({"app": "test"})] + options = [Labels({"app": "test"})] with pytest.raises(ValueError) as exc_info: client.train(trainer=trainer, options=options) @@ -253,8 +253,8 @@ def simple_func(): trainer = types.CustomTrainer(func=simple_func) options = [ - WithLabels({"app": "test"}), - WithAnnotations({"desc": "test"}), + Labels({"app": "test"}), + Annotations({"desc": "test"}), ] with pytest.raises(ValueError) as exc_info: @@ -262,8 +262,8 @@ def simple_func(): error_msg = str(exc_info.value) assert "The following options are not compatible with this backend" in error_msg - assert "WithLabels" in error_msg - assert "WithAnnotations" in error_msg + assert "Labels" in error_msg + assert "Annotations" in error_msg def test_none_options_handling(self): """Test that None options are handled correctly.""" diff --git a/kubeflow/trainer/backends/__init__.py b/kubeflow/trainer/backends/__init__.py index 241cd02bb..7c1cd8811 100644 --- a/kubeflow/trainer/backends/__init__.py +++ b/kubeflow/trainer/backends/__init__.py @@ -1,15 +1,15 @@ from kubeflow.trainer.backends.base import ( + CompatibleOption, ExecutionBackend, KubernetesCompatible, LocalProcessCompatible, - UniversalCompatible, Option, - CompatibleOption, + UniversalCompatible, ) __all__ = [ "ExecutionBackend", - "KubernetesCompatible", + "KubernetesCompatible", "LocalProcessCompatible", "UniversalCompatible", "Option", diff --git a/kubeflow/trainer/backends/base.py b/kubeflow/trainer/backends/base.py index ea91e1c70..2f86c5286 100644 --- a/kubeflow/trainer/backends/base.py +++ b/kubeflow/trainer/backends/base.py @@ -14,7 +14,7 @@ import abc from collections.abc import Iterator -from typing import Optional, Union, Protocol +from typing import Optional, Protocol, Union from kubeflow.trainer.constants import constants from kubeflow.trainer.types import types @@ -41,14 +41,14 @@ class UniversalCompatible(KubernetesCompatible, LocalProcessCompatible): class Option(Protocol): """Protocol defining the contract for training options.""" - + def __call__( - self, - job_spec: dict, - trainer: Optional[Union["types.BuiltinTrainer", "types.CustomTrainer"]] = None + self, + job_spec: dict, + trainer: Optional[Union["types.BuiltinTrainer", "types.CustomTrainer"]] = None, ) -> None: """Apply the option to the job specification. - + Args: job_spec: The job specification dictionary to modify trainer: Optional trainer context for validation @@ -68,9 +68,7 @@ def _validate_options_with_mixin( return incompatible = [ - f"{opt.__class__.__name__}" - for opt in options - if not isinstance(opt, mixin_class) + f"{opt.__class__.__name__}" for opt in options if not isinstance(opt, mixin_class) ] if incompatible: diff --git a/kubeflow/trainer/backends/kubernetes/backend_test.py b/kubeflow/trainer/backends/kubernetes/backend_test.py index 1d44261ee..a9c9e71ec 100644 --- a/kubeflow/trainer/backends/kubernetes/backend_test.py +++ b/kubeflow/trainer/backends/kubernetes/backend_test.py @@ -32,19 +32,19 @@ import pytest from kubeflow.trainer.backends.kubernetes.backend import KubernetesBackend -from kubeflow.trainer.backends.kubernetes.options import ( +from kubeflow.trainer.backends.kubernetes.types import KubernetesBackendConfig +from kubeflow.trainer.constants import constants +from kubeflow.trainer.options import ( + Annotations, ContainerOverride, + Labels, PodTemplateOverride, + PodTemplateOverrides, PodTemplateSpecOverride, - WithAnnotations, - WithLabels, - WithPodTemplateOverrides, - WithTrainerArgs, - WithTrainerCommand, - WithTrainerImage, + TrainerArgs, + TrainerCommand, + TrainerImage, ) -from kubeflow.trainer.backends.kubernetes.types import KubernetesBackendConfig -from kubeflow.trainer.constants import constants from kubeflow.trainer.test.common import ( DEFAULT_NAMESPACE, FAILED, @@ -663,7 +663,9 @@ def get_train_job_data_type( expected_output=get_train_job( runtime_name=TORCH_TUNE_RUNTIME, train_job_name=TRAIN_JOB_WITH_BUILT_IN_TRAINER, - train_job_trainer=get_builtin_trainer(), + train_job_trainer=get_builtin_trainer( + args=["batch_size=2", "epochs=2", "loss=Loss.CEWithChunkedOutputLoss"], + ), ), ), TestCase( @@ -826,8 +828,8 @@ def get_custom_trainer_with_overrides( expected_status=SUCCESS, config={ "options": [ - WithLabels({"app": "test"}), - WithAnnotations({"desc": "test"}), + Labels({"app": "test"}), + Annotations({"desc": "test"}), ] }, expected_output=None, # No exception expected @@ -860,20 +862,20 @@ def test_backend_validation(kubernetes_backend, test_case): def test_backend_capabilities(kubernetes_backend): """Test KubernetesBackend validation with correct mixin.""" - from kubeflow.trainer.backends.kubernetes.options import WithLabels - from kubeflow.trainer.backends.localprocess.options import WithWorkingDirectory + from kubeflow.trainer.options import Labels + from kubeflow.trainer.options.localprocess import WorkingDirectory # Test that KubernetesCompatible options pass validation - kubernetes_options = [WithLabels({"app": "test"})] + kubernetes_options = [Labels({"app": "test"})] kubernetes_backend.validate_options(kubernetes_options) # Should not raise # Test that LocalProcessCompatible options fail validation - local_options = [WithWorkingDirectory("/tmp")] + local_options = [WorkingDirectory("/tmp")] try: kubernetes_backend.validate_options(local_options) raise AssertionError("Expected ValueError for incompatible options") except ValueError as e: - assert "WithWorkingDirectory" in str(e) + assert "WorkingDirectory" in str(e) # Test cases for new trainer container options @@ -884,7 +886,7 @@ def test_backend_capabilities(kubernetes_backend): config={ "trainer": types.CustomTrainer(num_nodes=2), "options": [ - WithTrainerImage("custom/pytorch:latest"), + TrainerImage("custom/pytorch:latest"), ], }, expected_output=get_train_job( @@ -902,7 +904,7 @@ def test_backend_capabilities(kubernetes_backend): config={ "trainer": types.CustomTrainer(num_nodes=2), "options": [ - WithTrainerCommand(["python", "train.py"]), + TrainerCommand(["python", "train.py"]), ], }, expected_output=get_train_job( @@ -917,7 +919,7 @@ def test_backend_capabilities(kubernetes_backend): config={ "trainer": types.CustomTrainer(num_nodes=2), "options": [ - WithTrainerArgs(["--epochs", "10", "--lr", "0.001"]), + TrainerArgs(["--epochs", "10", "--lr", "0.001"]), ], }, expected_output=get_train_job( @@ -935,9 +937,9 @@ def test_backend_capabilities(kubernetes_backend): config={ "trainer": types.CustomTrainer(num_nodes=2), "options": [ - WithTrainerImage("custom/pytorch:2.0"), - WithTrainerCommand(["python", "-m", "torch.distributed.run"]), - WithTrainerArgs(["train.py", "--epochs", "5"]), + TrainerImage("custom/pytorch:2.0"), + TrainerCommand(["python", "-m", "torch.distributed.run"]), + TrainerArgs(["train.py", "--epochs", "5"]), ], }, expected_output=get_train_job( @@ -960,9 +962,9 @@ def test_backend_capabilities(kubernetes_backend): resources_per_node={"cpu": "2", "memory": "4Gi"}, ), "options": [ - WithTrainerImage("python:3.11"), - WithTrainerCommand(["python", "-c"]), - WithTrainerArgs(["print('Container-only training!')"]), + TrainerImage("python:3.11"), + TrainerCommand(["python", "-c"]), + TrainerArgs(["print('Container-only training!')"]), ], }, expected_output=get_train_job( @@ -1096,7 +1098,6 @@ def test_get_runtime_packages(kubernetes_backend, test_case): assert type(e) is test_case.expected_error - @pytest.mark.parametrize( "test_case", [ @@ -1149,7 +1150,9 @@ def test_get_runtime_packages(kubernetes_backend, test_case): ), ), ), - "runtime": TORCH_TUNE_RUNTIME, + "runtime": create_runtime_type( + name=TORCH_TUNE_RUNTIME, trainer_type=types.TrainerType.BUILTIN_TRAINER + ), }, expected_output=get_train_job( runtime_name=TORCH_TUNE_RUNTIME, @@ -1247,10 +1250,10 @@ def test_get_runtime_packages(kubernetes_backend, test_case): ), # Test cases using the new Options pattern TestCase( - name="valid flow with WithLabels option", + name="valid flow with Labels option", expected_status=SUCCESS, config={ - "options": [WithLabels({"team": "ml-platform", "project": "training"})], + "options": [Labels({"team": "ml-platform", "project": "training"})], }, expected_output=get_train_job( runtime_name=TORCH_RUNTIME, @@ -1264,8 +1267,8 @@ def test_get_runtime_packages(kubernetes_backend, test_case): expected_status=SUCCESS, config={ "options": [ - WithLabels({"team": "ml-platform"}), - WithAnnotations({"created-by": "sdk"}), + Labels({"team": "ml-platform"}), + Annotations({"created-by": "sdk"}), ], }, expected_output=get_train_job( @@ -1280,11 +1283,32 @@ def test_get_runtime_packages(kubernetes_backend, test_case): ) def test_train_validation(kubernetes_backend, test_case): """Test KubernetesBackend.train validation with various scenarios.""" - try: - kubernetes_backend.namespace = test_case.config.get("namespace", DEFAULT_NAMESPACE) + kubernetes_backend.namespace = test_case.config.get("namespace", DEFAULT_NAMESPACE) - if test_case.expected_status == SUCCESS: - train_job_name = kubernetes_backend.train( + if test_case.expected_status == SUCCESS: + train_job_name = kubernetes_backend.train( + trainer=test_case.config.get( + "trainer", types.CustomTrainer(func=lambda: print("Hello World")) + ), + runtime=test_case.config.get("runtime", create_runtime_type(name=TORCH_RUNTIME)), + options=test_case.config.get("options", []), + ) + + # Set the expected output's name to the actual job name + expected_output = test_case.expected_output + expected_output.metadata.name = train_job_name + + # Verify the mock was called with the expected output + kubernetes_backend.custom_api.create_namespaced_custom_object.assert_called_with( + constants.GROUP, + constants.VERSION, + DEFAULT_NAMESPACE, + constants.TRAINJOB_PLURAL, + expected_output.to_dict(), + ) + else: + with pytest.raises(test_case.expected_error): + kubernetes_backend.train( trainer=test_case.config.get( "trainer", types.CustomTrainer(func=lambda: print("Hello World")) ), @@ -1292,32 +1316,6 @@ def test_train_validation(kubernetes_backend, test_case): options=test_case.config.get("options", []), ) - # Set the expected output's name to the actual job name - expected_output = test_case.expected_output - expected_output.metadata.name = train_job_name - - # Verify the mock was called with the expected output - kubernetes_backend.custom_api.create_namespaced_custom_object.assert_called_with( - constants.GROUP, - constants.VERSION, - DEFAULT_NAMESPACE, - constants.TRAINJOB_PLURAL, - expected_output.to_dict(), - ) - else: - with pytest.raises(test_case.expected_error): - kubernetes_backend.train( - trainer=test_case.config.get( - "trainer", types.CustomTrainer(func=lambda: print("Hello World")) - ), - runtime=test_case.config.get( - "runtime", create_runtime_type(name=TORCH_RUNTIME) - ), - options=test_case.config.get("options", []), - ) - except Exception as e: - raise - TRAIN_TEST_CASES = [ TestCase( @@ -1325,8 +1323,8 @@ def test_train_validation(kubernetes_backend, test_case): expected_status=SUCCESS, config={ "options": [ - WithLabels({"team": "ml-platform"}), - WithAnnotations({"created-by": "sdk"}), + Labels({"team": "ml-platform"}), + Annotations({"created-by": "sdk"}), ], }, expected_output=get_train_job( @@ -1338,118 +1336,135 @@ def test_train_validation(kubernetes_backend, test_case): ), ] -TRAIN_TEST_CASES.extend([ - TestCase( - name="train with pod template overrides - basic", - expected_status=SUCCESS, - config={ - "options": [ - WithPodTemplateOverrides([ - PodTemplateOverride( - target_jobs=["node"], - spec=PodTemplateSpecOverride( - volumes=[ +TRAIN_TEST_CASES.extend( + [ + TestCase( + name="train with pod template overrides - basic", + expected_status=SUCCESS, + config={ + "options": [ + PodTemplateOverrides( + pod_template_overrides=[ + PodTemplateOverride( + target_jobs=["node"], + spec=PodTemplateSpecOverride( + volumes=[ + { + "name": "data", + "persistentVolumeClaim": {"claimName": "my-pvc"}, + } + ], + containers=[ + ContainerOverride( + name="node", + volume_mounts=[{"name": "data", "mountPath": "/data"}], + ) + ], + node_selector={"gpu": "true"}, + tolerations=[ + {"key": "gpu", "operator": "Exists", "effect": "NoSchedule"} + ], + ), + ) + ] + ) + ], + }, + expected_output=get_train_job( + runtime_name=TORCH_RUNTIME, + train_job_name=BASIC_TRAIN_JOB_NAME, + pod_template_overrides=[ + { + "targetJobs": [{"name": "node"}], + "spec": { + "volumes": [ {"name": "data", "persistentVolumeClaim": {"claimName": "my-pvc"}} ], - containers=[ - ContainerOverride( - name="node", - volume_mounts=[{"name": "data", "mountPath": "/data"}], - ) + "containers": [ + { + "name": "node", + "volumeMounts": [{"name": "data", "mountPath": "/data"}], + } ], - node_selector={"gpu": "true"}, - tolerations=[ + "nodeSelector": {"gpu": "true"}, + "tolerations": [ {"key": "gpu", "operator": "Exists", "effect": "NoSchedule"} ], - ) - ) - ]) - ], - }, - expected_output=get_train_job( - runtime_name=TORCH_RUNTIME, - train_job_name=BASIC_TRAIN_JOB_NAME, - pod_template_overrides=[ - { - "targetJobs": [{"name": "node"}], - "spec": { - "volumes": [ - {"name": "data", "persistentVolumeClaim": {"claimName": "my-pvc"}} - ], - "containers": [ - { - "name": "node", - "volumeMounts": [{"name": "data", "mountPath": "/data"}], - } - ], - "nodeSelector": {"gpu": "true"}, - "tolerations": [ - {"key": "gpu", "operator": "Exists", "effect": "NoSchedule"} - ], + }, } - } - ], + ], + ), ), - ), - TestCase( - name="train with pod template overrides - with metadata and affinity", - expected_status=SUCCESS, - config={ - "options": [ - WithPodTemplateOverrides([ - PodTemplateOverride( - target_jobs=["node"], - metadata={"labels": {"custom-label": "custom-value"}}, - spec=PodTemplateSpecOverride( - service_account_name="training-sa", - affinity={ + TestCase( + name="train with pod template overrides - with metadata and affinity", + expected_status=SUCCESS, + config={ + "options": [ + PodTemplateOverrides( + pod_template_overrides=[ + PodTemplateOverride( + target_jobs=["node"], + metadata={"labels": {"custom-label": "custom-value"}}, + spec=PodTemplateSpecOverride( + service_account_name="training-sa", + affinity={ + "nodeAffinity": { + "requiredDuringSchedulingIgnoredDuringExecution": { + "nodeSelectorTerms": [ + { + "matchExpressions": [ + { + "key": "gpu-type", + "operator": "In", + "values": ["nvidia-a100"], + } + ] + } + ] + } + } + }, + scheduling_gates=[{"name": "wait-for-resources"}], + ), + ) + ] + ) + ], + }, + expected_output=get_train_job( + runtime_name=TORCH_RUNTIME, + train_job_name=BASIC_TRAIN_JOB_NAME, + pod_template_overrides=[ + { + "targetJobs": [{"name": "node"}], + "metadata": {"labels": {"custom-label": "custom-value"}}, + "spec": { + "serviceAccountName": "training-sa", + "affinity": { "nodeAffinity": { "requiredDuringSchedulingIgnoredDuringExecution": { - "nodeSelectorTerms": [{ - "matchExpressions": [{ - "key": "gpu-type", - "operator": "In", - "values": ["nvidia-a100"] - }] - }] + "nodeSelectorTerms": [ + { + "matchExpressions": [ + { + "key": "gpu-type", + "operator": "In", + "values": ["nvidia-a100"], + } + ] + } + ] } } }, - scheduling_gates=[{"name": "wait-for-resources"}], - ) - ) - ]) - ], - }, - expected_output=get_train_job( - runtime_name=TORCH_RUNTIME, - train_job_name=BASIC_TRAIN_JOB_NAME, - pod_template_overrides=[ - { - "targetJobs": [{"name": "node"}], - "metadata": {"labels": {"custom-label": "custom-value"}}, - "spec": { - "serviceAccountName": "training-sa", - "affinity": { - "nodeAffinity": { - "requiredDuringSchedulingIgnoredDuringExecution": { - "nodeSelectorTerms": [{ - "matchExpressions": [{ - "key": "gpu-type", - "operator": "In", - "values": ["nvidia-a100"] - }] - }] - } - } + "schedulingGates": [{"name": "wait-for-resources"}], }, - "schedulingGates": [{"name": "wait-for-resources"}], } - } - ], + ], + ), ), - ), -]) + ] +) @pytest.mark.parametrize("test_case", TRAIN_TEST_CASES) @@ -1461,14 +1476,14 @@ def test_train(kubernetes_backend, test_case): options = test_case.config.get("options", []) if test_case.config.get("labels"): - from kubeflow.trainer.backends.kubernetes.options import WithLabels + from kubeflow.trainer.options import Labels - options.append(WithLabels(test_case.config["labels"])) + options.append(Labels(test_case.config["labels"])) if test_case.config.get("annotations"): - from kubeflow.trainer.backends.kubernetes.options import WithAnnotations + from kubeflow.trainer.options import Annotations - options.append(WithAnnotations(test_case.config["annotations"])) + options.append(Annotations(test_case.config["annotations"])) train_job_name = kubernetes_backend.train( runtime=runtime, diff --git a/kubeflow/trainer/backends/kubernetes/options_test.py b/kubeflow/trainer/backends/kubernetes/options_test.py deleted file mode 100644 index 6936c4cca..000000000 --- a/kubeflow/trainer/backends/kubernetes/options_test.py +++ /dev/null @@ -1,339 +0,0 @@ -# Copyright 2025 The Kubeflow Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Unit tests for Kubernetes backend options.""" - -from kubeflow.trainer.backends.base import KubernetesCompatible -from kubeflow.trainer.backends.kubernetes.options import ( - PodTemplateOverride, - WithAnnotations, - WithLabels, - WithName, - WithPodTemplateOverrides, - WithTrainerArgs, - WithTrainerCommand, - WithTrainerImage, -) - - -class TestKubernetesOptionTypes: - """Test Kubernetes option types.""" - - def test_with_labels_option_type(self): - """Test WithLabels inherits from KubernetesCompatible.""" - option = WithLabels({"app": "test", "version": "v1"}) - assert isinstance(option, KubernetesCompatible) - - def test_with_annotations_option_type(self): - """Test WithAnnotations inherits from KubernetesCompatible.""" - option = WithAnnotations({"description": "test job"}) - assert isinstance(option, KubernetesCompatible) - - def test_with_name_option_type(self): - """Test WithName inherits from KubernetesCompatible.""" - option = WithName("test-job") - assert isinstance(option, KubernetesCompatible) - - def test_with_pod_template_overrides_option_type(self): - """Test WithPodTemplateOverrides inherits from KubernetesCompatible.""" - overrides = [PodTemplateOverride(target_jobs=["node"])] - option = WithPodTemplateOverrides(overrides) - assert isinstance(option, KubernetesCompatible) - - def test_trainer_options_types(self): - """Test trainer options inherit from KubernetesCompatible.""" - image_option = WithTrainerImage("custom:latest") - command_option = WithTrainerCommand(["python", "train.py"]) - args_option = WithTrainerArgs(["--epochs", "10"]) - assert isinstance(image_option, KubernetesCompatible) - assert isinstance(command_option, KubernetesCompatible) - assert isinstance(args_option, KubernetesCompatible) - - -class TestKubernetesOptionApplication: - """Test Kubernetes option application behavior.""" - - def test_labels_application(self): - """Test WithLabels applies correctly to job spec.""" - option = WithLabels({"app": "test", "version": "v1"}) - - job_spec = {} - option(job_spec) - - expected = {"metadata": {"labels": {"app": "test", "version": "v1"}}} - assert job_spec == expected - - def test_annotations_application(self): - """Test WithAnnotations applies correctly to job spec.""" - option = WithAnnotations({"description": "test job", "owner": "team"}) - - job_spec = {} - option(job_spec) - - expected = {"metadata": {"annotations": {"description": "test job", "owner": "team"}}} - assert job_spec == expected - - def test_name_application(self): - """Test WithName applies correctly to job spec.""" - option = WithName("my-training-job") - - job_spec = {} - option(job_spec) - - expected = {"metadata": {"name": "my-training-job"}} - assert job_spec == expected - - def test_trainer_image_application(self): - """Test WithTrainerImage applies correctly to job spec.""" - option = WithTrainerImage("custom/pytorch:latest") - - job_spec = {} - option(job_spec) - - expected = {"spec": {"trainer": {"image": "custom/pytorch:latest"}}} - assert job_spec == expected - - def test_trainer_command_application(self): - """Test WithTrainerCommand applies correctly to job spec.""" - option = WithTrainerCommand(["python", "train.py"]) - - job_spec = {} - option(job_spec) - - expected = {"spec": {"trainer": {"command": ["python", "train.py"]}}} - assert job_spec == expected - - def test_trainer_args_application(self): - """Test WithTrainerArgs applies correctly to job spec.""" - option = WithTrainerArgs(["--epochs", "10", "--lr", "0.001"]) - - job_spec = {} - option(job_spec) - - expected = {"spec": {"trainer": {"args": ["--epochs", "10", "--lr", "0.001"]}}} - assert job_spec == expected - - def test_trainer_command_validation_with_func(self): - """Test WithTrainerCommand validates conflicts with CustomTrainer.func.""" - from kubeflow.trainer.types.types import CustomTrainer - - option = WithTrainerCommand(["python", "custom_train.py"]) - trainer_with_func = CustomTrainer(func=lambda: print("training")) - - job_spec = {} - - # Should raise ValueError when trainer has func - try: - option(job_spec, trainer_with_func) - raise AssertionError("Expected ValueError for func conflict") - except ValueError as e: - assert "Cannot specify WithTrainerCommand when CustomTrainer.func is provided" in str(e) - - def test_trainer_command_validation_without_func(self): - """Test WithTrainerCommand works with container-only training.""" - from kubeflow.trainer.types.types import CustomTrainer - - option = WithTrainerCommand(["python", "custom_train.py"]) - trainer_without_func = CustomTrainer(func=None) # Container-only training - - job_spec = {} - option(job_spec, trainer_without_func) # Should not raise - - expected = {"spec": {"trainer": {"command": ["python", "custom_train.py"]}}} - assert job_spec == expected - - def test_trainer_args_validation_with_func(self): - """Test WithTrainerArgs validates conflicts with CustomTrainer.func.""" - from kubeflow.trainer.types.types import CustomTrainer - - option = WithTrainerArgs(["--epochs", "10"]) - trainer_with_func = CustomTrainer(func=lambda: print("training")) - - job_spec = {} - - # Should raise ValueError when trainer has func - try: - option(job_spec, trainer_with_func) - raise AssertionError("Expected ValueError for func conflict") - except ValueError as e: - assert "Cannot specify WithTrainerArgs when CustomTrainer.func is provided" in str(e) - - def test_multiple_options_override_behavior(self): - """Test multiple options with override semantics.""" - job_spec = {} - - # Apply first set of labels - WithLabels({"app": "trainer", "env": "dev"})(job_spec) - # Apply second set of labels (should override) - WithLabels({"app": "ml-trainer", "version": "v1.0"})(job_spec) - # Apply annotations - WithAnnotations({"description": "test"})(job_spec) - - expected = { - "metadata": { - "labels": {"app": "ml-trainer", "version": "v1.0"}, # Override behavior - "annotations": {"description": "test"}, - } - } - assert job_spec == expected - - -class TestKubernetesCapabilities: - """Test Kubernetes backend capabilities.""" - - def test_kubernetes_capabilities_support(self): - """Test Kubernetes backend mixin compatibility.""" - # Test that all Kubernetes options inherit from KubernetesCompatible - labels_option = WithLabels({"app": "test"}) - annotations_option = WithAnnotations({"desc": "test"}) - name_option = WithName("test-job") - pod_template_option = WithPodTemplateOverrides([PodTemplateOverride(target_jobs=["node"])]) - image_option = WithTrainerImage("custom:latest") - command_option = WithTrainerCommand(["python", "train.py"]) - args_option = WithTrainerArgs(["--epochs", "10"]) - assert isinstance(labels_option, KubernetesCompatible) - assert isinstance(annotations_option, KubernetesCompatible) - assert isinstance(name_option, KubernetesCompatible) - assert isinstance(pod_template_option, KubernetesCompatible) - assert isinstance(image_option, KubernetesCompatible) - assert isinstance(command_option, KubernetesCompatible) - assert isinstance(args_option, KubernetesCompatible) - - def test_kubernetes_option_compatibility(self): - """Test Kubernetes option compatibility checking with mixin approach.""" - # Test that we can identify Kubernetes-compatible options - labels_opt = WithLabels({"app": "test"}) - annotations_opt = WithAnnotations({"desc": "test"}) - name_opt = WithName("test-job") - pod_opt = WithPodTemplateOverrides([PodTemplateOverride(target_jobs=["node"])]) - image_opt = WithTrainerImage("custom:latest") - command_opt = WithTrainerCommand(["python", "train.py"]) - args_opt = WithTrainerArgs(["--epochs", "10"]) - # All options should be KubernetesCompatible - assert isinstance(labels_opt, KubernetesCompatible) - assert isinstance(annotations_opt, KubernetesCompatible) - assert isinstance(name_opt, KubernetesCompatible) - assert isinstance(pod_opt, KubernetesCompatible) - assert isinstance(image_opt, KubernetesCompatible) - assert isinstance(command_opt, KubernetesCompatible) - assert isinstance(args_opt, KubernetesCompatible) - - -class TestContainerOverride: - """Test ContainerOverride dataclass.""" - - def test_container_override_full(self): - """Test ContainerOverride with all fields.""" - from kubeflow.trainer.backends.kubernetes.options import ContainerOverride - - override = ContainerOverride( - name="node", - env=[{"name": "MY_VAR", "value": "my_value"}], - volume_mounts=[{"name": "data", "mountPath": "/data"}], - ) - - assert override.name == "node" - assert override.env == [{"name": "MY_VAR", "value": "my_value"}] - assert override.volume_mounts == [{"name": "data", "mountPath": "/data"}] - - def test_container_override_minimal(self): - """Test ContainerOverride with minimal fields.""" - from kubeflow.trainer.backends.kubernetes.options import ContainerOverride - - override = ContainerOverride(name="node") - - assert override.name == "node" - assert override.env is None - assert override.volume_mounts is None - - -class TestPodTemplateSpecOverride: - """Test PodTemplateSpecOverride dataclass.""" - - def test_pod_template_spec_override_full(self): - """Test PodTemplateSpecOverride with all fields.""" - from kubeflow.trainer.backends.kubernetes.options import ( - ContainerOverride, - PodTemplateSpecOverride, - ) - - override = PodTemplateSpecOverride( - service_account_name="training-sa", - node_selector={"gpu": "true"}, - affinity={"nodeAffinity": {}}, - tolerations=[{"key": "gpu", "operator": "Exists"}], - volumes=[{"name": "data", "emptyDir": {}}], - init_containers=[ContainerOverride(name="init")], - containers=[ContainerOverride(name="node")], - scheduling_gates=[{"name": "wait"}], - image_pull_secrets=[{"name": "my-secret"}], - ) - - assert override.service_account_name == "training-sa" - assert override.node_selector == {"gpu": "true"} - assert override.affinity == {"nodeAffinity": {}} - assert len(override.containers) == 1 - assert override.containers[0].name == "node" - - def test_pod_template_spec_override_minimal(self): - """Test PodTemplateSpecOverride with no fields.""" - from kubeflow.trainer.backends.kubernetes.options import PodTemplateSpecOverride - - override = PodTemplateSpecOverride() - - assert override.service_account_name is None - assert override.node_selector is None - assert override.affinity is None - assert override.tolerations is None - assert override.volumes is None - assert override.init_containers is None - assert override.containers is None - assert override.scheduling_gates is None - assert override.image_pull_secrets is None - - -class TestPodTemplateOverride: - """Test PodTemplateOverride dataclass.""" - - def test_pod_template_override_full(self): - """Test PodTemplateOverride with all fields.""" - from kubeflow.trainer.backends.kubernetes.options import ( - ContainerOverride, - PodTemplateOverride, - PodTemplateSpecOverride, - ) - - override = PodTemplateOverride( - target_jobs=["node", "launcher"], - metadata={"labels": {"custom": "label"}}, - spec=PodTemplateSpecOverride( - node_selector={"gpu": "true"}, - containers=[ContainerOverride(name="node")], - ), - ) - - assert override.target_jobs == ["node", "launcher"] - assert override.metadata == {"labels": {"custom": "label"}} - assert override.spec is not None - assert override.spec.node_selector == {"gpu": "true"} - - def test_pod_template_override_minimal(self): - """Test PodTemplateOverride with minimal fields.""" - from kubeflow.trainer.backends.kubernetes.options import PodTemplateOverride - - override = PodTemplateOverride(target_jobs=["node"]) - - assert override.target_jobs == ["node"] - assert override.metadata is None - assert override.spec is None diff --git a/kubeflow/trainer/backends/localprocess/backend_test.py b/kubeflow/trainer/backends/localprocess/backend_test.py index 1e4ce3e72..3cc4d21bd 100644 --- a/kubeflow/trainer/backends/localprocess/backend_test.py +++ b/kubeflow/trainer/backends/localprocess/backend_test.py @@ -18,13 +18,9 @@ import pytest -from kubeflow.trainer.backends.kubernetes.options import ( - WithAnnotations, - WithLabels, - WithPodSpecOverrides, -) from kubeflow.trainer.backends.localprocess.backend import LocalProcessBackend from kubeflow.trainer.backends.localprocess.types import LocalProcessBackendConfig +from kubeflow.trainer.options import Annotations, Labels, PodTemplateOverrides from kubeflow.trainer.test.common import FAILED, SUCCESS, TestCase # Test cases for LocalProcess backend validation @@ -44,7 +40,7 @@ TestCase( name="reject_incompatible_labels", expected_status=FAILED, - config={"options": [WithLabels({"app": "test"})]}, + config={"options": [Labels({"app": "test"})]}, expected_error=ValueError, ), TestCase( @@ -52,9 +48,9 @@ expected_status=FAILED, config={ "options": [ - WithLabels({"app": "test"}), - WithAnnotations({"desc": "test"}), - WithPodSpecOverrides([{}]), + Labels({"app": "test"}), + Annotations({"desc": "test"}), + PodTemplateOverrides(pod_template_overrides=[]), ] }, expected_error=ValueError, @@ -82,32 +78,32 @@ def test_local_backend_validation(local_backend, test_case): if test_case.name == "reject_incompatible_labels": error_msg = str(e) assert "The following options are not compatible with this backend" in error_msg - assert "WithLabels" in error_msg + assert "Labels" in error_msg elif test_case.name == "reject_multiple_incompatible_options": error_msg = str(e) assert "The following options are not compatible with this backend" in error_msg - assert "WithLabels" in error_msg - assert "WithAnnotations" in error_msg - assert "WithPodSpecOverrides" in error_msg + assert "Labels" in error_msg + assert "Annotations" in error_msg + assert "PodTemplateOverrides" in error_msg print("test execution complete") def test_local_backend_capabilities(local_backend): """Test LocalProcessBackend validation with correct mixin.""" - from kubeflow.trainer.backends.localprocess.options import WithWorkingDirectory - from kubeflow.trainer.backends.kubernetes.options import WithLabels - + from kubeflow.trainer.options import Labels + from kubeflow.trainer.options.localprocess import WorkingDirectory + # Test that LocalProcessCompatible options pass validation - local_options = [WithWorkingDirectory("/tmp")] + local_options = [WorkingDirectory("/tmp")] local_backend.validate_options(local_options) # Should not raise - + # Test that KubernetesCompatible options fail validation - kubernetes_options = [WithLabels({"app": "test"})] + kubernetes_options = [Labels({"app": "test"})] try: local_backend.validate_options(kubernetes_options) - assert False, "Expected ValueError for incompatible options" + raise AssertionError("Expected ValueError for incompatible options") except ValueError as e: - assert "WithLabels" in str(e) + assert "Labels" in str(e) class TestLocalBackendValidationFlow: @@ -115,7 +111,7 @@ class TestLocalBackendValidationFlow: def test_validation_happens_before_processing(self, local_backend): """Test that validation happens before any processing.""" - incompatible_options = [WithLabels({"app": "test"})] + incompatible_options = [Labels({"app": "test"})] with pytest.raises(ValueError) as exc_info: local_backend.validate_options(incompatible_options) @@ -125,13 +121,13 @@ def test_validation_happens_before_processing(self, local_backend): def test_validation_early_exit_behavior(self, local_backend): """Test that validation reports all incompatible options.""" mixed_options = [ - WithLabels({"app": "test"}), - WithAnnotations({"desc": "test"}), + Labels({"app": "test"}), + Annotations({"desc": "test"}), ] with pytest.raises(ValueError) as exc_info: local_backend.validate_options(mixed_options) error_msg = str(exc_info.value) - assert "WithLabels" in error_msg - assert "WithAnnotations" in error_msg + assert "Labels" in error_msg + assert "Annotations" in error_msg diff --git a/kubeflow/trainer/backends/localprocess/options.py b/kubeflow/trainer/backends/localprocess/options.py deleted file mode 100644 index 9d022053c..000000000 --- a/kubeflow/trainer/backends/localprocess/options.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright 2025 The Kubeflow Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""LocalProcess-specific training options for the Kubeflow Trainer SDK.""" - -from dataclasses import dataclass - -from kubeflow.trainer.backends.base import LocalProcessCompatible - - -@dataclass -class WithProcessTimeout(LocalProcessCompatible): - """Set a timeout for the local training process.""" - - timeout_seconds: int - - def __call__(self, config: dict) -> None: - """Apply timeout to local process configuration.""" - config["timeout_seconds"] = self.timeout_seconds - - -@dataclass -class WithWorkingDirectory(LocalProcessCompatible): - """Set the working directory for the local training process.""" - - working_dir: str - - def __call__(self, config: dict) -> None: - """Apply working directory to local process configuration.""" - config["working_dir"] = self.working_dir diff --git a/kubeflow/trainer/options/__init__.py b/kubeflow/trainer/options/__init__.py new file mode 100644 index 000000000..2197b28de --- /dev/null +++ b/kubeflow/trainer/options/__init__.py @@ -0,0 +1,53 @@ +# Copyright 2025 The Kubeflow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Training options for the Kubeflow Trainer SDK. + +This package provides backend-specific and common training options +that can be used to customize job execution. +""" + +from kubeflow.trainer.options.common import ( + ContainerOverride, + PodTemplateOverride, + PodTemplateSpecOverride, +) +from kubeflow.trainer.options.kubernetes import ( + Annotations, + Labels, + Name, + PodTemplateOverrides, + TrainerArgs, + TrainerCommand, + TrainerImage, +) +from kubeflow.trainer.options.localprocess import ProcessTimeout, WorkingDirectory + +__all__ = [ + # Common options + "ContainerOverride", + "PodTemplateOverride", + "PodTemplateSpecOverride", + # Kubernetes options + "Annotations", + "Labels", + "Name", + "PodTemplateOverrides", + "TrainerArgs", + "TrainerCommand", + "TrainerImage", + # LocalProcess options + "ProcessTimeout", + "WorkingDirectory", +] diff --git a/kubeflow/trainer/options/common.py b/kubeflow/trainer/options/common.py new file mode 100644 index 000000000..b2eb077b7 --- /dev/null +++ b/kubeflow/trainer/options/common.py @@ -0,0 +1,134 @@ +# Copyright 2025 The Kubeflow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Common options and helper classes used across multiple backends.""" + +from dataclasses import dataclass +from typing import Optional + +__all__ = [ + "ContainerOverride", + "PodTemplateOverride", + "PodTemplateSpecOverride", +] + + +@dataclass +class ContainerOverride: + """Configuration for overriding a specific container in a pod. + + Args: + name: Name of the container to override (must exist in TrainingRuntime). + env: Environment variables to add/merge with the container. + Each dict should have 'name' and 'value' or 'valueFrom' keys. + volume_mounts: Volume mounts to add/merge with the container. + Each dict should have 'name' and 'mountPath' keys at minimum. + """ + + name: str + env: Optional[list[dict]] = None + volume_mounts: Optional[list[dict]] = None + + def __post_init__(self): + """Validate the container override configuration.""" + # Validate container name + if not self.name or not self.name.strip(): + raise ValueError("Container name must be a non-empty string") + + if self.env is not None: + if not isinstance(self.env, list): + raise ValueError("env must be a list of dictionaries") + for env_var in self.env: + if not isinstance(env_var, dict): + raise ValueError("Each env entry must be a dictionary") + if "name" not in env_var: + raise ValueError("Each env entry must have a 'name' key") + if not env_var.get("name"): + raise ValueError("env 'name' must be a non-empty string") + if "value" not in env_var and "valueFrom" not in env_var: + raise ValueError("Each env entry must have either 'value' or 'valueFrom' key") + # Validate valueFrom structure if present + if "valueFrom" in env_var: + value_from = env_var["valueFrom"] + if not isinstance(value_from, dict): + raise ValueError("env 'valueFrom' must be a dictionary") + # valueFrom must have one of these keys + valid_keys = {"configMapKeyRef", "secretKeyRef", "fieldRef", "resourceFieldRef"} + if not any(key in value_from for key in valid_keys): + raise ValueError( + f"env 'valueFrom' must contain one of: {', '.join(valid_keys)}" + ) + + if self.volume_mounts is not None: + if not isinstance(self.volume_mounts, list): + raise ValueError("volume_mounts must be a list of dictionaries") + for mount in self.volume_mounts: + if not isinstance(mount, dict): + raise ValueError("Each volume_mounts entry must be a dictionary") + if "name" not in mount: + raise ValueError("Each volume_mounts entry must have a 'name' key") + if not mount.get("name"): + raise ValueError("volume_mounts 'name' must be a non-empty string") + if "mountPath" not in mount: + raise ValueError("Each volume_mounts entry must have a 'mountPath' key") + mount_path = mount.get("mountPath") + if not mount_path or not isinstance(mount_path, str): + raise ValueError("volume_mounts 'mountPath' must be a non-empty string") + if not mount_path.startswith("/"): + raise ValueError( + f"volume_mounts 'mountPath' must be an absolute path " + f"(start with /): {mount_path}" + ) + + +@dataclass +class PodTemplateSpecOverride: + """Configuration for overriding pod template specifications. + + Args: + service_account_name: Service account to use for the pods. + node_selector: Node selector to place pods on specific nodes. + affinity: Affinity rules for pod scheduling. + tolerations: Tolerations for pod scheduling. + volumes: Volumes to add/merge with the pod. + init_containers: Init containers to add/merge with the pod. + containers: Containers to add/merge with the pod. + scheduling_gates: Scheduling gates for the pods. + image_pull_secrets: Image pull secrets for the pods. + """ + + service_account_name: Optional[str] = None + node_selector: Optional[dict[str, str]] = None + affinity: Optional[dict] = None + tolerations: Optional[list[dict]] = None + volumes: Optional[list[dict]] = None + init_containers: Optional[list[ContainerOverride]] = None + containers: Optional[list[ContainerOverride]] = None + scheduling_gates: Optional[list[dict]] = None + image_pull_secrets: Optional[list[dict]] = None + + +@dataclass +class PodTemplateOverride: + """Configuration for overriding pod templates for specific job types. + + Args: + target_jobs: List of job names to apply the overrides to (e.g., ["node", "launcher"]). + metadata: Metadata overrides for the pod template (labels, annotations). + spec: Spec overrides for the pod template. + """ + + target_jobs: list[str] + metadata: Optional[dict] = None + spec: Optional[PodTemplateSpecOverride] = None diff --git a/kubeflow/trainer/backends/kubernetes/options.py b/kubeflow/trainer/options/kubernetes.py similarity index 55% rename from kubeflow/trainer/backends/kubernetes/options.py rename to kubeflow/trainer/options/kubernetes.py index 96cf7e76b..a216cbaad 100644 --- a/kubeflow/trainer/backends/kubernetes/options.py +++ b/kubeflow/trainer/options/kubernetes.py @@ -18,13 +18,14 @@ from typing import TYPE_CHECKING, Any, Optional, Union from kubeflow.trainer.backends.base import KubernetesCompatible +from kubeflow.trainer.options.common import PodTemplateOverride if TYPE_CHECKING: from kubeflow.trainer.types.types import BuiltinTrainer, CustomTrainer @dataclass -class WithLabels(KubernetesCompatible): +class Labels(KubernetesCompatible): """Add labels to the TrainJob resource metadata (.metadata.labels).""" labels: dict[str, str] @@ -40,7 +41,7 @@ def __call__( @dataclass -class WithAnnotations(KubernetesCompatible): +class Annotations(KubernetesCompatible): """Add annotations to the TrainJob resource metadata (.metadata.annotations).""" annotations: dict[str, str] @@ -56,7 +57,7 @@ def __call__( @dataclass -class WithName(KubernetesCompatible): +class Name(KubernetesCompatible): """Set a custom name for the TrainJob resource (.metadata.name).""" name: str @@ -72,121 +73,10 @@ def __call__( @dataclass -class ContainerOverride: - """Configuration for overriding a specific container in a pod. - - Args: - name: Name of the container to override (must exist in TrainingRuntime). - env: Environment variables to add/merge with the container. - Each dict should have 'name' and 'value' or 'valueFrom' keys. - volume_mounts: Volume mounts to add/merge with the container. - Each dict should have 'name' and 'mountPath' keys at minimum. - """ - - name: str - env: Optional[list[dict]] = None - volume_mounts: Optional[list[dict]] = None - - def __post_init__(self): - """Validate the container override configuration.""" - # Validate container name - if not self.name or not self.name.strip(): - raise ValueError("Container name must be a non-empty string") - - if self.env is not None: - if not isinstance(self.env, list): - raise ValueError("env must be a list of dictionaries") - for env_var in self.env: - if not isinstance(env_var, dict): - raise ValueError("Each env entry must be a dictionary") - if "name" not in env_var: - raise ValueError("Each env entry must have a 'name' key") - if not env_var.get("name"): - raise ValueError("env 'name' must be a non-empty string") - if "value" not in env_var and "valueFrom" not in env_var: - raise ValueError( - "Each env entry must have either 'value' or 'valueFrom' key" - ) - # Validate valueFrom structure if present - if "valueFrom" in env_var: - value_from = env_var["valueFrom"] - if not isinstance(value_from, dict): - raise ValueError("env 'valueFrom' must be a dictionary") - # valueFrom must have one of these keys - valid_keys = {"configMapKeyRef", "secretKeyRef", "fieldRef", "resourceFieldRef"} - if not any(key in value_from for key in valid_keys): - raise ValueError( - f"env 'valueFrom' must contain one of: {', '.join(valid_keys)}" - ) - - if self.volume_mounts is not None: - if not isinstance(self.volume_mounts, list): - raise ValueError("volume_mounts must be a list of dictionaries") - for mount in self.volume_mounts: - if not isinstance(mount, dict): - raise ValueError("Each volume_mounts entry must be a dictionary") - if "name" not in mount: - raise ValueError("Each volume_mounts entry must have a 'name' key") - if not mount.get("name"): - raise ValueError("volume_mounts 'name' must be a non-empty string") - if "mountPath" not in mount: - raise ValueError("Each volume_mounts entry must have a 'mountPath' key") - mount_path = mount.get("mountPath") - if not mount_path or not isinstance(mount_path, str): - raise ValueError("volume_mounts 'mountPath' must be a non-empty string") - if not mount_path.startswith("/"): - raise ValueError( - f"volume_mounts 'mountPath' must be an absolute path (start with /): {mount_path}" - ) - - -@dataclass -class PodTemplateSpecOverride: - """Configuration for overriding pod template specifications. - - Args: - service_account_name: Service account to use for the pods. - node_selector: Node selector to place pods on specific nodes. - affinity: Affinity rules for pod scheduling. - tolerations: Tolerations for pod scheduling. - volumes: Volumes to add/merge with the pod. - init_containers: Init containers to add/merge with the pod. - containers: Containers to add/merge with the pod. - scheduling_gates: Scheduling gates for the pods. - image_pull_secrets: Image pull secrets for the pods. - """ - - service_account_name: Optional[str] = None - node_selector: Optional[dict[str, str]] = None - affinity: Optional[dict] = None - tolerations: Optional[list[dict]] = None - volumes: Optional[list[dict]] = None - init_containers: Optional[list[ContainerOverride]] = None - containers: Optional[list[ContainerOverride]] = None - scheduling_gates: Optional[list[dict]] = None - image_pull_secrets: Optional[list[dict]] = None - - -@dataclass -class PodTemplateOverride: - """Configuration for overriding pod templates for specific job types. - - Args: - target_jobs: List of job names to apply the overrides to (e.g., ["node", "launcher"]). - metadata: Metadata overrides for the pod template (labels, annotations). - spec: Spec overrides for the pod template. - """ - - target_jobs: list[str] - metadata: Optional[dict] = None - spec: Optional[PodTemplateSpecOverride] = None - - -@dataclass -class WithPodTemplateOverrides(KubernetesCompatible): +class PodTemplateOverrides(KubernetesCompatible): """Add pod template overrides to the TrainJob (.spec.podTemplateOverrides).""" - overrides: list[PodTemplateOverride] + pod_template_overrides: list[PodTemplateOverride] def __call__( self, @@ -195,9 +85,9 @@ def __call__( ) -> None: """Apply pod template overrides to the job specification.""" spec = job_spec.setdefault("spec", {}) - spec["podTemplateOverrides"] = [] + pod_overrides = spec.setdefault("podTemplateOverrides", []) - for override in self.overrides: + for override in self.pod_template_overrides: api_override = {"targetJobs": [{"name": job} for job in override.target_jobs]} if override.metadata: @@ -245,11 +135,11 @@ def __call__( if spec_dict: api_override["spec"] = spec_dict - spec["podTemplateOverrides"].append(api_override) + pod_overrides.append(api_override) @dataclass -class WithTrainerImage(KubernetesCompatible): +class TrainerImage(KubernetesCompatible): """Override the trainer container image (.spec.trainer.image).""" image: str @@ -271,7 +161,7 @@ def __call__( @dataclass -class WithTrainerCommand(KubernetesCompatible): +class TrainerCommand(KubernetesCompatible): """Override the trainer container command (.spec.trainer.command).""" command: list[str] @@ -290,14 +180,13 @@ def __call__( Raises: ValueError: If there's a conflict with the trainer configuration """ - # Validate conflicts with trainer from kubeflow.trainer.types.types import CustomTrainer if isinstance(trainer, CustomTrainer) and trainer.func is not None: raise ValueError( - "Cannot specify WithTrainerCommand when CustomTrainer.func is provided. " + "Cannot specify TrainerCommand when CustomTrainer.func is provided. " "The func generates its own command. Use container-only training " - "(CustomTrainer without func) or remove WithTrainerCommand." + "(CustomTrainer without func) or remove TrainerCommand." ) spec = job_spec.setdefault("spec", {}) @@ -306,7 +195,7 @@ def __call__( @dataclass -class WithTrainerArgs(KubernetesCompatible): +class TrainerArgs(KubernetesCompatible): """Override the trainer container arguments (.spec.trainer.args).""" args: list[str] @@ -325,14 +214,13 @@ def __call__( Raises: ValueError: If there's a conflict with the trainer configuration """ - # Validate conflicts with trainer from kubeflow.trainer.types.types import CustomTrainer if isinstance(trainer, CustomTrainer) and trainer.func is not None: raise ValueError( - "Cannot specify WithTrainerArgs when CustomTrainer.func is provided. " + "Cannot specify TrainerArgs when CustomTrainer.func is provided. " "The func generates its own arguments. Use container-only training " - "(CustomTrainer without func) or remove WithTrainerArgs." + "(CustomTrainer without func) or remove TrainerArgs." ) spec = job_spec.setdefault("spec", {}) diff --git a/kubeflow/trainer/options/kubernetes_test.py b/kubeflow/trainer/options/kubernetes_test.py new file mode 100644 index 000000000..c6669eba9 --- /dev/null +++ b/kubeflow/trainer/options/kubernetes_test.py @@ -0,0 +1,248 @@ +# Copyright 2025 The Kubeflow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for Kubernetes options.""" + +from kubeflow.trainer.backends.base import KubernetesCompatible +from kubeflow.trainer.options import ( + Annotations, + Labels, + Name, + PodTemplateOverride, + PodTemplateOverrides, + TrainerArgs, + TrainerCommand, + TrainerImage, +) + + +class TestKubernetesOptionTypes: + """Test Kubernetes option types.""" + + def test_labels_option_type(self): + """Test Labels inherits from KubernetesCompatible.""" + option = Labels({"app": "test", "version": "v1"}) + assert isinstance(option, KubernetesCompatible) + + def test_annotations_option_type(self): + """Test Annotations inherits from KubernetesCompatible.""" + option = Annotations({"description": "test job"}) + assert isinstance(option, KubernetesCompatible) + + def test_name_option_type(self): + """Test Name inherits from KubernetesCompatible.""" + option = Name("test-job") + assert isinstance(option, KubernetesCompatible) + + def test_pod_template_overrides_option_type(self): + """Test PodTemplateOverrides inherits from KubernetesCompatible.""" + overrides = [PodTemplateOverride(target_jobs=["node"])] + option = PodTemplateOverrides(pod_template_overrides=overrides) + assert isinstance(option, KubernetesCompatible) + + def test_trainer_image_option_type(self): + """Test TrainerImage inherits from KubernetesCompatible.""" + option = TrainerImage(image="custom:latest") + assert isinstance(option, KubernetesCompatible) + + def test_trainer_command_option_type(self): + """Test TrainerCommand inherits from KubernetesCompatible.""" + option = TrainerCommand(command=["python", "train.py"]) + assert isinstance(option, KubernetesCompatible) + + def test_trainer_args_option_type(self): + """Test TrainerArgs inherits from KubernetesCompatible.""" + option = TrainerArgs(args=["--epochs", "10"]) + assert isinstance(option, KubernetesCompatible) + + +class TestKubernetesOptionApplication: + """Test Kubernetes option application behavior.""" + + def test_labels_application(self): + """Test Labels applies correctly to job spec.""" + option = Labels({"app": "test", "version": "v1"}) + + job_spec = {} + option(job_spec) + + expected = {"metadata": {"labels": {"app": "test", "version": "v1"}}} + assert job_spec == expected + + def test_annotations_application(self): + """Test Annotations applies correctly to job spec.""" + option = Annotations({"description": "test job"}) + + job_spec = {} + option(job_spec) + + expected = {"metadata": {"annotations": {"description": "test job"}}} + assert job_spec == expected + + def test_name_application(self): + """Test Name applies correctly to job spec.""" + option = Name("custom-job-name") + + job_spec = {} + option(job_spec) + + expected = {"metadata": {"name": "custom-job-name"}} + assert job_spec == expected + + def test_trainer_image_application(self): + """Test TrainerImage applies correctly to job spec.""" + option = TrainerImage(image="custom:latest") + + job_spec = {} + option(job_spec) + + expected = {"spec": {"trainer": {"image": "custom:latest"}}} + assert job_spec == expected + + def test_trainer_command_application(self): + """Test TrainerCommand applies correctly to job spec.""" + option = TrainerCommand(command=["python", "train.py"]) + + job_spec = {} + option(job_spec) + + expected = {"spec": {"trainer": {"command": ["python", "train.py"]}}} + assert job_spec == expected + + def test_trainer_args_application(self): + """Test TrainerArgs applies correctly to job spec.""" + option = TrainerArgs(args=["--epochs", "10"]) + + job_spec = {} + option(job_spec) + + expected = {"spec": {"trainer": {"args": ["--epochs", "10"]}}} + assert job_spec == expected + + +class TestTrainerOptionValidation: + """Test validation of trainer-specific options.""" + + def test_trainer_command_validates_against_custom_trainer_with_func(self): + """Test TrainerCommand validation when CustomTrainer.func is provided.""" + from kubeflow.trainer.types.types import CustomTrainer + + def dummy_func(): + pass + + trainer = CustomTrainer(func=dummy_func) + option = TrainerCommand(command=["python", "train.py"]) + + job_spec = {} + + try: + option(job_spec, trainer) + raise AssertionError("Expected ValueError") + except ValueError as e: + assert "Cannot specify TrainerCommand when CustomTrainer.func is provided" in str(e) + + def test_trainer_args_validates_against_custom_trainer_with_func(self): + """Test TrainerArgs validation when CustomTrainer.func is provided.""" + from kubeflow.trainer.types.types import CustomTrainer + + def dummy_func(): + pass + + trainer = CustomTrainer(func=dummy_func) + option = TrainerArgs(args=["--epochs", "10"]) + + job_spec = {} + + try: + option(job_spec, trainer) + raise AssertionError("Expected ValueError") + except ValueError as e: + assert "Cannot specify TrainerArgs when CustomTrainer.func is provided" in str(e) + + def test_trainer_command_allows_container_only_training(self): + """Test TrainerCommand works with container-only CustomTrainer (no func).""" + from kubeflow.trainer.types.types import CustomTrainer + + trainer = CustomTrainer(func=None) + option = TrainerCommand(command=["python", "train.py"]) + + job_spec = {} + option(job_spec, trainer) + + assert job_spec["spec"]["trainer"]["command"] == ["python", "train.py"] + + def test_trainer_args_allows_container_only_training(self): + """Test TrainerArgs works with container-only CustomTrainer (no func).""" + from kubeflow.trainer.types.types import CustomTrainer + + trainer = CustomTrainer(func=None) + option = TrainerArgs(args=["--epochs", "10"]) + + job_spec = {} + option(job_spec, trainer) + + assert job_spec["spec"]["trainer"]["args"] == ["--epochs", "10"] + + +class TestContainerOverride: + """Test ContainerOverride validation.""" + + def test_container_override_validates_name(self): + """Test ContainerOverride validates container name.""" + import pytest + + from kubeflow.trainer.options import ContainerOverride + + with pytest.raises(ValueError) as exc_info: + ContainerOverride(name="") + assert "Container name must be a non-empty string" in str(exc_info.value) + + def test_container_override_validates_env(self): + """Test ContainerOverride validates env structure.""" + import pytest + + from kubeflow.trainer.options import ContainerOverride + + with pytest.raises(ValueError) as exc_info: + ContainerOverride(name="trainer", env=[{"invalid": "structure"}]) + assert "Each env entry must have a 'name' key" in str(exc_info.value) + + def test_container_override_validates_volume_mounts(self): + """Test ContainerOverride validates volume mount structure.""" + import pytest + + from kubeflow.trainer.options import ContainerOverride + + with pytest.raises(ValueError) as exc_info: + ContainerOverride(name="trainer", volume_mounts=[{"name": "vol"}]) + assert "Each volume_mounts entry must have a 'mountPath' key" in str(exc_info.value) + + +class TestPodTemplateOverrides: + """Test PodTemplateOverrides functionality.""" + + def test_pod_template_overrides_basic(self): + """Test basic PodTemplateOverrides application.""" + from kubeflow.trainer.options import PodTemplateOverride + + override = PodTemplateOverride(target_jobs=["node"]) + option = PodTemplateOverrides(pod_template_overrides=[override]) + + job_spec = {} + option(job_spec) + + assert "spec" in job_spec + assert "podTemplateOverrides" in job_spec["spec"] + assert len(job_spec["spec"]["podTemplateOverrides"]) == 1 + assert job_spec["spec"]["podTemplateOverrides"][0]["targetJobs"] == [{"name": "node"}] diff --git a/kubeflow/trainer/options/localprocess.py b/kubeflow/trainer/options/localprocess.py new file mode 100644 index 000000000..53b7e9075 --- /dev/null +++ b/kubeflow/trainer/options/localprocess.py @@ -0,0 +1,63 @@ +# Copyright 2025 The Kubeflow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""LocalProcess-specific training options for the Kubeflow Trainer SDK.""" + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional, Union + +from kubeflow.trainer.backends.base import LocalProcessCompatible + +if TYPE_CHECKING: + from kubeflow.trainer.types.types import BuiltinTrainer, CustomTrainer + + +@dataclass +class ProcessTimeout(LocalProcessCompatible): + """Set a timeout for the local training process.""" + + timeout_seconds: int + + def __call__( + self, + job_spec: dict, + trainer: Optional[Union["BuiltinTrainer", "CustomTrainer"]] = None, + ) -> None: + """Apply timeout to local process configuration. + + Args: + job_spec: The job specification dictionary to modify + trainer: Optional trainer context (unused for local process) + """ + job_spec["timeout_seconds"] = self.timeout_seconds + + +@dataclass +class WorkingDirectory(LocalProcessCompatible): + """Set the working directory for the local training process.""" + + working_dir: str + + def __call__( + self, + job_spec: dict, + trainer: Optional[Union["BuiltinTrainer", "CustomTrainer"]] = None, + ) -> None: + """Apply working directory to local process configuration. + + Args: + job_spec: The job specification dictionary to modify + trainer: Optional trainer context (unused for local process) + """ + job_spec["working_dir"] = self.working_dir diff --git a/kubeflow/trainer/backends/localprocess/options_test.py b/kubeflow/trainer/options/localprocess_test.py similarity index 58% rename from kubeflow/trainer/backends/localprocess/options_test.py rename to kubeflow/trainer/options/localprocess_test.py index e4f2b55f1..05d0ea21a 100644 --- a/kubeflow/trainer/backends/localprocess/options_test.py +++ b/kubeflow/trainer/options/localprocess_test.py @@ -12,23 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Unit tests for LocalProcess backend options.""" +"""Unit tests for LocalProcess options.""" from kubeflow.trainer.backends.base import LocalProcessCompatible -from kubeflow.trainer.backends.localprocess.options import WithProcessTimeout, WithWorkingDirectory +from kubeflow.trainer.options.localprocess import ProcessTimeout, WorkingDirectory class TestLocalProcessOptionTypes: """Test LocalProcess option types.""" - def test_with_process_timeout_type(self): - """Test WithProcessTimeout inherits from LocalProcessCompatible.""" - option = WithProcessTimeout(timeout_seconds=300) + def test_process_timeout_type(self): + """Test ProcessTimeout inherits from LocalProcessCompatible.""" + option = ProcessTimeout(timeout_seconds=300) assert isinstance(option, LocalProcessCompatible) - def test_with_working_directory_type(self): - """Test WithWorkingDirectory inherits from LocalProcessCompatible.""" - option = WithWorkingDirectory(working_dir="/tmp/training") + def test_working_directory_type(self): + """Test WorkingDirectory inherits from LocalProcessCompatible.""" + option = WorkingDirectory(working_dir="/tmp/training") assert isinstance(option, LocalProcessCompatible) @@ -36,34 +36,34 @@ class TestLocalProcessOptionApplication: """Test LocalProcess option application behavior.""" def test_process_timeout_application(self): - """Test WithProcessTimeout applies correctly to local config.""" - option = WithProcessTimeout(timeout_seconds=600) + """Test ProcessTimeout applies correctly to local config.""" + option = ProcessTimeout(timeout_seconds=600) - config = {} - option(config) + job_spec = {} + option(job_spec) expected = {"timeout_seconds": 600} - assert config == expected + assert job_spec == expected def test_working_directory_application(self): - """Test WithWorkingDirectory applies correctly to local config.""" - option = WithWorkingDirectory(working_dir="/home/user/training") + """Test WorkingDirectory applies correctly to local config.""" + option = WorkingDirectory(working_dir="/home/user/training") - config = {} - option(config) + job_spec = {} + option(job_spec) expected = {"working_dir": "/home/user/training"} - assert config == expected + assert job_spec == expected def test_multiple_local_options(self): """Test multiple LocalProcess options together.""" - config = {} + job_spec = {} - WithProcessTimeout(timeout_seconds=300)(config) - WithWorkingDirectory(working_dir="/tmp/work")(config) + ProcessTimeout(timeout_seconds=300)(job_spec) + WorkingDirectory(working_dir="/tmp/work")(job_spec) expected = {"timeout_seconds": 300, "working_dir": "/tmp/work"} - assert config == expected + assert job_spec == expected class TestLocalProcessCapabilities: @@ -72,8 +72,8 @@ class TestLocalProcessCapabilities: def test_local_process_capabilities_minimal(self): """Test LocalProcess backend mixin compatibility.""" # Test that LocalProcess options inherit from LocalProcessCompatible - timeout_option = WithProcessTimeout(timeout_seconds=300) - working_dir_option = WithWorkingDirectory(working_dir="/tmp/training") + timeout_option = ProcessTimeout(timeout_seconds=300) + working_dir_option = WorkingDirectory(working_dir="/tmp/training") assert isinstance(timeout_option, LocalProcessCompatible) assert isinstance(working_dir_option, LocalProcessCompatible) @@ -81,8 +81,8 @@ def test_local_process_capabilities_minimal(self): def test_local_process_empty_options_compatibility(self): """Test LocalProcess compatibility with empty options.""" # Empty options list should always be compatible - timeout_option = WithProcessTimeout(timeout_seconds=300) - working_dir_option = WithWorkingDirectory(working_dir="/tmp/training") + timeout_option = ProcessTimeout(timeout_seconds=300) + working_dir_option = WorkingDirectory(working_dir="/tmp/training") # Both should be LocalProcessCompatible assert isinstance(timeout_option, LocalProcessCompatible) @@ -93,17 +93,17 @@ class TestLocalProcessOptionCreation: """Test LocalProcess option creation and validation.""" def test_process_timeout_creation(self): - """Test WithProcessTimeout creation with various values.""" - option = WithProcessTimeout(timeout_seconds=300) + """Test ProcessTimeout creation with various values.""" + option = ProcessTimeout(timeout_seconds=300) assert option.timeout_seconds == 300 - option = WithProcessTimeout(timeout_seconds=3600) + option = ProcessTimeout(timeout_seconds=3600) assert option.timeout_seconds == 3600 def test_working_directory_creation(self): - """Test WithWorkingDirectory creation with various paths.""" - option = WithWorkingDirectory(working_dir="/home/user/training") + """Test WorkingDirectory creation with various paths.""" + option = WorkingDirectory(working_dir="/home/user/training") assert option.working_dir == "/home/user/training" - option = WithWorkingDirectory(working_dir="./training") + option = WorkingDirectory(working_dir="./training") assert option.working_dir == "./training" From 5a7c563ba3dd22364fcad821b3511f5677c857c5 Mon Sep 17 00:00:00 2001 From: Abhijeet Dhumal Date: Thu, 23 Oct 2025 19:04:40 +0530 Subject: [PATCH 7/7] feat: add spec-level labels and annotations options for derivative JobSet/Jobs labels Signed-off-by: Abhijeet Dhumal --- kubeflow/trainer/api/trainer_client.py | 3 +- kubeflow/trainer/backends/__init__.py | 2 +- kubeflow/trainer/backends/base.py | 21 +--------- .../trainer/backends/kubernetes/backend.py | 12 ++++-- .../backends/kubernetes/backend_test.py | 42 +++++++++++++++++++ .../trainer/backends/localprocess/backend.py | 3 +- kubeflow/trainer/options/__init__.py | 4 ++ kubeflow/trainer/options/kubernetes.py | 40 ++++++++++++++++++ kubeflow/trainer/types/__init__.py | 2 + kubeflow/trainer/types/types.py | 19 ++++++++- 10 files changed, 120 insertions(+), 28 deletions(-) diff --git a/kubeflow/trainer/api/trainer_client.py b/kubeflow/trainer/api/trainer_client.py index d8cb8ba1b..8fb45c9f9 100644 --- a/kubeflow/trainer/api/trainer_client.py +++ b/kubeflow/trainer/api/trainer_client.py @@ -16,7 +16,6 @@ import logging from typing import Optional, Union -from kubeflow.trainer.backends.base import Option from kubeflow.trainer.backends.kubernetes.backend import KubernetesBackend from kubeflow.trainer.backends.kubernetes.types import KubernetesBackendConfig from kubeflow.trainer.backends.localprocess.backend import ( @@ -24,7 +23,7 @@ LocalProcessBackendConfig, ) from kubeflow.trainer.constants import constants -from kubeflow.trainer.types import types +from kubeflow.trainer.types import Option, types logger = logging.getLogger(__name__) diff --git a/kubeflow/trainer/backends/__init__.py b/kubeflow/trainer/backends/__init__.py index 7c1cd8811..86637cc10 100644 --- a/kubeflow/trainer/backends/__init__.py +++ b/kubeflow/trainer/backends/__init__.py @@ -3,9 +3,9 @@ ExecutionBackend, KubernetesCompatible, LocalProcessCompatible, - Option, UniversalCompatible, ) +from kubeflow.trainer.types import Option __all__ = [ "ExecutionBackend", diff --git a/kubeflow/trainer/backends/base.py b/kubeflow/trainer/backends/base.py index 2f86c5286..fe804edf6 100644 --- a/kubeflow/trainer/backends/base.py +++ b/kubeflow/trainer/backends/base.py @@ -14,10 +14,10 @@ import abc from collections.abc import Iterator -from typing import Optional, Protocol, Union +from typing import Optional, Union from kubeflow.trainer.constants import constants -from kubeflow.trainer.types import types +from kubeflow.trainer.types import Option, types # Backend compatibility mixins @@ -39,23 +39,6 @@ class UniversalCompatible(KubernetesCompatible, LocalProcessCompatible): pass -class Option(Protocol): - """Protocol defining the contract for training options.""" - - def __call__( - self, - job_spec: dict, - trainer: Optional[Union["types.BuiltinTrainer", "types.CustomTrainer"]] = None, - ) -> None: - """Apply the option to the job specification. - - Args: - job_spec: The job specification dictionary to modify - trainer: Optional trainer context for validation - """ - ... - - CompatibleOption = Union[KubernetesCompatible, LocalProcessCompatible] diff --git a/kubeflow/trainer/backends/kubernetes/backend.py b/kubeflow/trainer/backends/kubernetes/backend.py index 5e0941ac3..c70c0a02d 100644 --- a/kubeflow/trainer/backends/kubernetes/backend.py +++ b/kubeflow/trainer/backends/kubernetes/backend.py @@ -26,10 +26,10 @@ from kubeflow_trainer_api import models from kubernetes import client, config, watch -from kubeflow.trainer.backends.base import ExecutionBackend, KubernetesCompatible, Option +from kubeflow.trainer.backends.base import ExecutionBackend, KubernetesCompatible from kubeflow.trainer.backends.kubernetes import types as k8s_types from kubeflow.trainer.constants import constants -from kubeflow.trainer.types import types +from kubeflow.trainer.types import Option, types from kubeflow.trainer.utils import utils logger = logging.getLogger(__name__) @@ -196,6 +196,8 @@ def train( labels = None annotations = None name = None + spec_labels = None + spec_annotations = None trainer_overrides = {} pod_template_overrides = None @@ -208,8 +210,10 @@ def train( annotations = metadata_section.get("annotations") name = metadata_section.get("name") - # Extract trainer-specific overrides and pod template overrides + # Extract spec-level labels/annotations and other spec configurations spec_section = job_spec.get("spec", {}) + spec_labels = spec_section.get("labels") + spec_annotations = spec_section.get("annotations") trainer_spec = spec_section.get("trainer", {}) if trainer_spec: trainer_overrides = trainer_spec @@ -270,6 +274,8 @@ def train( if isinstance(initializer, types.Initializer) else None ), + labels=spec_labels, + annotations=spec_annotations, pod_template_overrides=pod_template_overrides, ), ) diff --git a/kubeflow/trainer/backends/kubernetes/backend_test.py b/kubeflow/trainer/backends/kubernetes/backend_test.py index a9c9e71ec..0cf7b2327 100644 --- a/kubeflow/trainer/backends/kubernetes/backend_test.py +++ b/kubeflow/trainer/backends/kubernetes/backend_test.py @@ -41,6 +41,8 @@ PodTemplateOverride, PodTemplateOverrides, PodTemplateSpecOverride, + SpecAnnotations, + SpecLabels, TrainerArgs, TrainerCommand, TrainerImage, @@ -287,6 +289,8 @@ def get_train_job( train_job_trainer: Optional[models.TrainerV1alpha1Trainer] = None, labels: Optional[dict[str, str]] = None, annotations: Optional[dict[str, str]] = None, + spec_labels: Optional[dict[str, str]] = None, + spec_annotations: Optional[dict[str, str]] = None, pod_template_overrides: Optional[list] = None, ) -> models.TrainerV1alpha1TrainJob: """Create a mock TrainJob object with optional trainer configurations.""" @@ -299,6 +303,8 @@ def get_train_job( spec=models.TrainerV1alpha1TrainJobSpec( runtimeRef=models.TrainerV1alpha1RuntimeRef(name=runtime_name), trainer=train_job_trainer, + labels=spec_labels, + annotations=spec_annotations, pod_template_overrides=pod_template_overrides, ), ) @@ -1334,6 +1340,42 @@ def test_train_validation(kubernetes_backend, test_case): annotations={"created-by": "sdk"}, ), ), + TestCase( + name="train with spec labels and annotations", + expected_status=SUCCESS, + config={ + "options": [ + SpecLabels({"app": "training", "version": "v1.0"}), + SpecAnnotations({"prometheus.io/scrape": "true"}), + ], + }, + expected_output=get_train_job( + runtime_name=TORCH_RUNTIME, + train_job_name=BASIC_TRAIN_JOB_NAME, + spec_labels={"app": "training", "version": "v1.0"}, + spec_annotations={"prometheus.io/scrape": "true"}, + ), + ), + TestCase( + name="train with both metadata and spec labels/annotations", + expected_status=SUCCESS, + config={ + "options": [ + Labels({"owner": "ml-team"}), + Annotations({"description": "Fine-tuning job"}), + SpecLabels({"app": "training", "version": "v1.0"}), + SpecAnnotations({"prometheus.io/scrape": "true"}), + ], + }, + expected_output=get_train_job( + runtime_name=TORCH_RUNTIME, + train_job_name=BASIC_TRAIN_JOB_NAME, + labels={"owner": "ml-team"}, + annotations={"description": "Fine-tuning job"}, + spec_labels={"app": "training", "version": "v1.0"}, + spec_annotations={"prometheus.io/scrape": "true"}, + ), + ), ] TRAIN_TEST_CASES.extend( diff --git a/kubeflow/trainer/backends/localprocess/backend.py b/kubeflow/trainer/backends/localprocess/backend.py index e65920565..50ae6245c 100644 --- a/kubeflow/trainer/backends/localprocess/backend.py +++ b/kubeflow/trainer/backends/localprocess/backend.py @@ -23,7 +23,6 @@ from kubeflow.trainer.backends.base import ( ExecutionBackend, LocalProcessCompatible, - Option, ) from kubeflow.trainer.backends.localprocess import utils as local_utils from kubeflow.trainer.backends.localprocess.constants import local_runtimes @@ -34,7 +33,7 @@ LocalProcessBackendConfig, ) from kubeflow.trainer.constants import constants -from kubeflow.trainer.types import types +from kubeflow.trainer.types import Option, types logger = logging.getLogger(__name__) diff --git a/kubeflow/trainer/options/__init__.py b/kubeflow/trainer/options/__init__.py index 2197b28de..48a4f46ed 100644 --- a/kubeflow/trainer/options/__init__.py +++ b/kubeflow/trainer/options/__init__.py @@ -28,6 +28,8 @@ Labels, Name, PodTemplateOverrides, + SpecAnnotations, + SpecLabels, TrainerArgs, TrainerCommand, TrainerImage, @@ -44,6 +46,8 @@ "Labels", "Name", "PodTemplateOverrides", + "SpecAnnotations", + "SpecLabels", "TrainerArgs", "TrainerCommand", "TrainerImage", diff --git a/kubeflow/trainer/options/kubernetes.py b/kubeflow/trainer/options/kubernetes.py index a216cbaad..117631574 100644 --- a/kubeflow/trainer/options/kubernetes.py +++ b/kubeflow/trainer/options/kubernetes.py @@ -56,6 +56,46 @@ def __call__( metadata["annotations"] = self.annotations +@dataclass +class SpecLabels(KubernetesCompatible): + """Add labels to derivative JobSet and Jobs (.spec.labels). + + These labels will be merged with the TrainingRuntime values and applied to + the JobSet and Jobs created by the TrainJob. + """ + + labels: dict[str, str] + + def __call__( + self, + job_spec: dict[str, Any], + trainer: Optional[Union["CustomTrainer", "BuiltinTrainer"]] = None, + ) -> None: + """Apply spec-level labels to the job specification.""" + spec = job_spec.setdefault("spec", {}) + spec["labels"] = self.labels + + +@dataclass +class SpecAnnotations(KubernetesCompatible): + """Add annotations to derivative JobSet and Jobs (.spec.annotations). + + These annotations will be merged with the TrainingRuntime values and applied to + the JobSet and Jobs created by the TrainJob. + """ + + annotations: dict[str, str] + + def __call__( + self, + job_spec: dict[str, Any], + trainer: Optional[Union["CustomTrainer", "BuiltinTrainer"]] = None, + ) -> None: + """Apply spec-level annotations to the job specification.""" + spec = job_spec.setdefault("spec", {}) + spec["annotations"] = self.annotations + + @dataclass class Name(KubernetesCompatible): """Set a custom name for the TrainJob resource (.metadata.name).""" diff --git a/kubeflow/trainer/types/__init__.py b/kubeflow/trainer/types/__init__.py index bdb8d563b..1e2c24715 100644 --- a/kubeflow/trainer/types/__init__.py +++ b/kubeflow/trainer/types/__init__.py @@ -25,6 +25,7 @@ HuggingFaceModelInitializer, Initializer, Loss, + Option, Runtime, RuntimeTrainer, Step, @@ -43,6 +44,7 @@ "HuggingFaceModelInitializer", "Initializer", "Loss", + "Option", "Runtime", "RuntimeTrainer", "Step", diff --git a/kubeflow/trainer/types/types.py b/kubeflow/trainer/types/types.py index 72c5f433a..47cb35822 100644 --- a/kubeflow/trainer/types/types.py +++ b/kubeflow/trainer/types/types.py @@ -16,11 +16,28 @@ from dataclasses import dataclass, field from datetime import datetime from enum import Enum -from typing import Callable, Optional, Union +from typing import Callable, Optional, Protocol, Union from kubeflow.trainer.constants import constants +class Option(Protocol): + """Protocol defining the contract for training options.""" + + def __call__( + self, + job_spec: dict, + trainer: Optional[Union["BuiltinTrainer", "CustomTrainer"]] = None, + ) -> None: + """Apply the option to the job specification. + + Args: + job_spec: The job specification dictionary to modify + trainer: Optional trainer context for validation + """ + ... + + # Configuration for the Custom Trainer. @dataclass class CustomTrainer: