Skip to content
30 changes: 27 additions & 3 deletions kubeflow/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,20 @@
# Import the Kubeflow Trainer constants.
from kubeflow.trainer.constants.constants import DATASET_PATH, MODEL_PATH # noqa: F401

# Import training options
from kubeflow.trainer.options import (
Annotations,
ContainerOverride,
Labels,
Name,
PodTemplateOverride,
PodTemplateOverrides,
PodTemplateSpecOverride,
TrainerArgs,
TrainerCommand,
TrainerImage,
)

# Import the Kubeflow Trainer types.
from kubeflow.trainer.types.types import (
BuiltinTrainer,
Expand All @@ -43,7 +57,9 @@
)

__all__ = [
"Annotations",
"BuiltinTrainer",
"ContainerOverride",
"CustomTrainer",
"DataCacheInitializer",
"DataFormat",
Expand All @@ -52,15 +68,23 @@
"HuggingFaceDatasetInitializer",
"HuggingFaceModelInitializer",
"Initializer",
"KubernetesBackendConfig",
"Labels",
"LocalProcessBackendConfig",
"LoraConfig",
"Loss",
"MODEL_PATH",
"Name",
"PodTemplateOverride",
"PodTemplateOverrides",
"PodTemplateSpecOverride",
"Runtime",
"RuntimeTrainer",
"TorchTuneConfig",
"TorchTuneInstructDataset",
"RuntimeTrainer",
"TrainerArgs",
"TrainerClient",
"TrainerCommand",
"TrainerImage",
"TrainerType",
"LocalProcessBackendConfig",
"KubernetesBackendConfig",
]
16 changes: 14 additions & 2 deletions kubeflow/trainer/api/trainer_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
LocalProcessBackendConfig,
)
from kubeflow.trainer.constants import constants
from kubeflow.trainer.types import types
from kubeflow.trainer.types import Option, types

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -96,6 +96,7 @@ def train(
runtime: Optional[types.Runtime] = None,
initializer: Optional[types.Initializer] = None,
trainer: Optional[Union[types.CustomTrainer, types.BuiltinTrainer]] = None,
options: Optional[list[Option]] = None,
) -> str:
"""Create a TrainJob. You can configure the TrainJob using one of these trainers:

Expand All @@ -110,6 +111,8 @@ def train(
initializer: Optional configuration for the dataset and model initializers.
trainer: Optional configuration for a CustomTrainer or BuiltinTrainer. If not specified,
the TrainJob will use the runtime's default values.
options: Optional list of configuration options to apply to the TrainJob. Use
WithLabels and WithAnnotations for basic metadata configuration.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
WithLabels and WithAnnotations for basic metadata configuration.
Labels and Annotations for basic metadata configuration.


Returns:
The unique name of the TrainJob that has been generated.
Expand All @@ -119,7 +122,16 @@ def train(
TimeoutError: Timeout to create TrainJobs.
RuntimeError: Failed to create TrainJobs.
"""
return self.backend.train(runtime=runtime, initializer=initializer, trainer=trainer)
# Validate options compatibility with backend
if options:
self.backend.validate_options(options)

return self.backend.train(
runtime=runtime,
initializer=initializer,
trainer=trainer,
options=options,
)

def list_jobs(self, runtime: Optional[types.Runtime] = None) -> list[types.TrainJob]:
"""List of the created TrainJobs. If a runtime is specified, only TrainJobs associated with
Expand Down
282 changes: 282 additions & 0 deletions kubeflow/trainer/api/trainer_client_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
# Copyright 2025 The Kubeflow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Unit tests for TrainerClient option handling and error messages.
"""

from unittest.mock import Mock, patch

import pytest

from kubeflow.trainer.api.trainer_client import TrainerClient
from kubeflow.trainer.backends.kubernetes.types import KubernetesBackendConfig
from kubeflow.trainer.backends.localprocess.types import LocalProcessBackendConfig
from kubeflow.trainer.options import Annotations, Labels
from kubeflow.trainer.types import types


class TestTrainerClientOptionValidation:
"""Test TrainerClient option validation integration."""

def test_trainer_client_passes_options_to_backend(self):
"""Test that TrainerClient passes options to backend correctly."""
config = LocalProcessBackendConfig()
client = TrainerClient(backend_config=config)

def simple_func():
return "test"

trainer = types.CustomTrainer(func=simple_func)
options = [Labels({"app": "test"})]

with pytest.raises(ValueError) as exc_info:
client.train(trainer=trainer, options=options)

error_msg = str(exc_info.value)
assert "The following options are not compatible with this backend" in error_msg
assert "Labels" in error_msg

@patch("kubernetes.config.load_kube_config")
@patch("kubernetes.client.CustomObjectsApi")
@patch("kubernetes.client.CoreV1Api")
def test_trainer_client_with_kubernetes_backend(
self, mock_core_api, mock_custom_api, mock_load_config
):
"""Test TrainerClient with KubernetesBackend and compatible options."""
mock_custom_api.return_value = Mock()
mock_core_api.return_value = Mock()

config = KubernetesBackendConfig()
client = TrainerClient(backend_config=config)

def simple_func():
return "test"

trainer = types.CustomTrainer(func=simple_func)
options = [Labels({"app": "test"}), Annotations({"desc": "test"})]

with pytest.raises((ValueError, RuntimeError)) as exc_info:
client.train(trainer=trainer, options=options)

error_msg = str(exc_info.value)
# Should either fail with runtime requirement or K8s connection error
assert (
"Runtime is required" in error_msg
or "Failed to get clustertrainingruntimes" in error_msg
)

def test_trainer_client_empty_options(self):
"""Test TrainerClient with empty options."""
config = LocalProcessBackendConfig()
client = TrainerClient(backend_config=config)

def simple_func():
return "test"

trainer = types.CustomTrainer(func=simple_func)

with pytest.raises(ValueError) as exc_info:
client.train(trainer=trainer, options=[])

error_msg = str(exc_info.value)
assert "Runtime must be provided for LocalProcessBackend" in error_msg


class TestTrainerClientErrorHandling:
"""Test TrainerClient error handling improvements."""

def test_missing_runtime_error_message(self):
"""Test improved error message for missing runtime."""
config = LocalProcessBackendConfig()
client = TrainerClient(backend_config=config)

def simple_func():
return "test"

trainer = types.CustomTrainer(func=simple_func)

with pytest.raises(ValueError) as exc_info:
client.train(trainer=trainer)

error_msg = str(exc_info.value)
# The error message should contain the runtime requirement
assert "Runtime must be provided for LocalProcessBackend" in error_msg

def test_option_validation_error_propagation(self):
"""Test that option validation errors are properly propagated."""
config = LocalProcessBackendConfig()
client = TrainerClient(backend_config=config)

def simple_func():
return "test"

trainer = types.CustomTrainer(func=simple_func)
options = [Labels({"app": "test"}), Annotations({"desc": "test"})]

with pytest.raises(ValueError) as exc_info:
client.train(trainer=trainer, options=options)

error_msg = str(exc_info.value)
assert "The following options are not compatible with this backend" in error_msg
assert "Labels" in error_msg
assert "Annotations" in error_msg
assert "The following options are not compatible with this backend" in error_msg

def test_error_message_does_not_contain_runtime_help_for_option_errors(self):
"""Test that option validation errors don't get runtime help text."""
config = LocalProcessBackendConfig()
client = TrainerClient(backend_config=config)

def simple_func():
return "test"

trainer = types.CustomTrainer(func=simple_func)
options = [Labels({"app": "test"})]

with pytest.raises(ValueError) as exc_info:
client.train(trainer=trainer, options=options)

error_msg = str(exc_info.value)
assert "The following options are not compatible with this backend" in error_msg
assert "Example usage:" not in error_msg

@patch("kubernetes.config.load_kube_config")
@patch("kubernetes.client.CustomObjectsApi")
@patch("kubernetes.client.CoreV1Api")
def test_kubernetes_backend_error_handling(
self, mock_core_api, mock_custom_api, mock_load_config
):
"""Test error handling with KubernetesBackend."""
mock_custom_api.return_value = Mock()
mock_core_api.return_value = Mock()

config = KubernetesBackendConfig()
client = TrainerClient(backend_config=config)

def simple_func():
return "test"

trainer = types.CustomTrainer(func=simple_func)

with pytest.raises((ValueError, RuntimeError)) as exc_info:
client.train(trainer=trainer)

error_msg = str(exc_info.value)
# Should either fail with runtime requirement or K8s connection error
assert (
"Runtime is required" in error_msg
or "Failed to get clustertrainingruntimes" in error_msg
)


class TestTrainerClientBackendSelection:
"""Test TrainerClient backend selection and configuration."""

@patch("kubernetes.config.load_kube_config")
@patch("kubernetes.client.CustomObjectsApi")
@patch("kubernetes.client.CoreV1Api")
def test_default_backend_is_kubernetes(self, mock_core_api, mock_custom_api, mock_load_config):
"""Test that default backend is Kubernetes."""
mock_custom_api.return_value = Mock()
mock_core_api.return_value = Mock()

client = TrainerClient()

from kubeflow.trainer.backends.kubernetes.backend import KubernetesBackend

assert isinstance(client.backend, KubernetesBackend)

def test_local_process_backend_selection(self):
"""Test LocalProcess backend selection."""
config = LocalProcessBackendConfig()
client = TrainerClient(backend_config=config)

from kubeflow.trainer.backends.localprocess.backend import LocalProcessBackend

assert isinstance(client.backend, LocalProcessBackend)

@patch("kubernetes.config.load_kube_config")
@patch("kubernetes.client.CustomObjectsApi")
@patch("kubernetes.client.CoreV1Api")
def test_kubernetes_backend_selection(self, mock_core_api, mock_custom_api, mock_load_config):
"""Test Kubernetes backend selection."""
mock_custom_api.return_value = Mock()
mock_core_api.return_value = Mock()

config = KubernetesBackendConfig()
client = TrainerClient(backend_config=config)

from kubeflow.trainer.backends.kubernetes.backend import KubernetesBackend

assert isinstance(client.backend, KubernetesBackend)


class TestTrainerClientOptionFlow:
"""Test the complete option flow through TrainerClient."""

def test_option_validation_happens_early(self):
"""Test that option validation happens before other validations."""
config = LocalProcessBackendConfig()
client = TrainerClient(backend_config=config)

def simple_func():
return "test"

trainer = types.CustomTrainer(func=simple_func)
options = [Labels({"app": "test"})]

with pytest.raises(ValueError) as exc_info:
client.train(trainer=trainer, options=options)

error_msg = str(exc_info.value)
assert "The following options are not compatible with this backend" in error_msg

def test_multiple_option_types_validation(self):
"""Test validation with multiple different option types."""
config = LocalProcessBackendConfig()
client = TrainerClient(backend_config=config)

def simple_func():
return "test"

trainer = types.CustomTrainer(func=simple_func)
options = [
Labels({"app": "test"}),
Annotations({"desc": "test"}),
]

with pytest.raises(ValueError) as exc_info:
client.train(trainer=trainer, options=options)

error_msg = str(exc_info.value)
assert "The following options are not compatible with this backend" in error_msg
assert "Labels" in error_msg
assert "Annotations" in error_msg

def test_none_options_handling(self):
"""Test that None options are handled correctly."""
config = LocalProcessBackendConfig()
client = TrainerClient(backend_config=config)

def simple_func():
return "test"

trainer = types.CustomTrainer(func=simple_func)

with pytest.raises(ValueError) as exc_info:
client.train(trainer=trainer, options=None)

error_msg = str(exc_info.value)
assert "Runtime must be provided for LocalProcessBackend" in error_msg
Loading