From 35da3851cab6390d241b195a21f04878f8e1e41a Mon Sep 17 00:00:00 2001 From: Andrey Velichkevich Date: Mon, 28 Jul 2025 16:58:53 +0100 Subject: [PATCH 1/9] feat(trainer): Add wait_for_job_status() API Signed-off-by: Andrey Velichkevich --- python/kubeflow/trainer/__init__.py | 25 +++- python/kubeflow/trainer/api/trainer_client.py | 131 ++++++++++++++---- .../kubeflow/trainer/constants/constants.py | 11 +- python/kubeflow/trainer/types/types.py | 18 +-- python/kubeflow/trainer/utils/utils.py | 7 +- 5 files changed, 148 insertions(+), 44 deletions(-) diff --git a/python/kubeflow/trainer/__init__.py b/python/kubeflow/trainer/__init__.py index 827f4933..b87715f7 100644 --- a/python/kubeflow/trainer/__init__.py +++ b/python/kubeflow/trainer/__init__.py @@ -19,8 +19,10 @@ # Import the Kubeflow Trainer client. from kubeflow.trainer.api.trainer_client import TrainerClient # noqa: F401 + # Import the Kubeflow Trainer constants. from kubeflow.trainer.constants.constants import DATASET_PATH, MODEL_PATH # noqa: F401 + # Import the Kubeflow Trainer types. from kubeflow.trainer.types.types import ( BuiltinTrainer, @@ -35,13 +37,26 @@ Runtime, TorchTuneConfig, TorchTuneInstructDataset, - Trainer, + RuntimeTrainer, TrainerType, ) __all__ = [ - "BuiltinTrainer", "CustomTrainer", "DataFormat", "DATASET_PATH", "DataType", "Framework", - "HuggingFaceDatasetInitializer", "HuggingFaceModelInitializer", "Initializer", "Loss", - "MODEL_PATH", "Runtime", "TorchTuneConfig", "TorchTuneInstructDataset", "Trainer", - "TrainerClient", "TrainerType" + "BuiltinTrainer", + "CustomTrainer", + "DataFormat", + "DATASET_PATH", + "DataType", + "Framework", + "HuggingFaceDatasetInitializer", + "HuggingFaceModelInitializer", + "Initializer", + "Loss", + "MODEL_PATH", + "Runtime", + "TorchTuneConfig", + "TorchTuneInstructDataset", + "RuntimeTrainer", + "TrainerClient", + "TrainerType", ] diff --git a/python/kubeflow/trainer/api/trainer_client.py b/python/kubeflow/trainer/api/trainer_client.py index 7f6e33fc..0ded85a5 100644 --- a/python/kubeflow/trainer/api/trainer_client.py +++ b/python/kubeflow/trainer/api/trainer_client.py @@ -18,7 +18,8 @@ import random import string import uuid -from typing import Dict, List, Optional, Union +import time +from typing import Dict, List, Optional, Union, Set from kubeflow.trainer.constants import constants from kubeflow.trainer.types import types @@ -433,6 +434,61 @@ def get_job_logs( return logs_dict + def wait_for_job_status( + self, + name: str, + status: Set = {constants.TRAINJOB_COMPLETE}, + timeout: int = 600, + polling_interval=5, + ): + """Wait for TrainJob to reach the desired status + + Args: + name: Name of the TrainJob. + status: Set of expected statuses. It must be subset of Running, Complete, and Failed + statuses. + timeout: How many seconds to wait until TrainJob reaches one of the expected conditions. + polling_interval: The polling interval in seconds to check TrainJob status. + + Returns: + TrainJob: The training job that reaches desired status. + + Raises: + + ValueError: The input values are incorrect. + RuntimeError: Failed to get TrainJob or TrainJob reaches unexpected Failed status. + TimeoutError: Timeout to wait for TrainJob status. + """ + + job_statuses = { + constants.TRAINJOB_RUNNING, + constants.TRAINJOB_COMPLETE, + constants.TRAINJOB_FAILED, + } + if not status.issubset(job_statuses): + raise ValueError( + f"Expected status {status} must be a subset of {job_statuses}" + ) + for _ in range(round(timeout / polling_interval)): + trainjob = self.get_job(name) + + # Raise an error if TrainJob is Failed and it is not the expected status. + if ( + constants.TRAINJOB_FAILED not in status + and trainjob.status == constants.TRAINJOB_FAILED + ): + raise RuntimeError(f"TrainJob {name} is Failed") + + # Return the TrainJob if it reaches the expected status. + if trainjob.status in status: + return trainjob + + time.sleep(polling_interval) + + raise TimeoutError( + f"Timeout waiting for TrainJob {name} to reach {status} Status" + ) + def delete_job(self, name: str): """Delete the TrainJob. @@ -485,7 +541,6 @@ def __get_runtime_from_crd( trainer=utils.get_runtime_trainer( runtime_crd.spec.template.spec.replicated_jobs, runtime_crd.spec.ml_policy, - runtime_crd.metadata, ), ) @@ -506,26 +561,22 @@ def __get_trainjob_from_crd( name = trainjob_crd.metadata.name namespace = trainjob_crd.metadata.namespace + runtime = self.get_runtime(trainjob_crd.spec.runtime_ref.name) + # Construct the TrainJob from the CRD. trainjob = types.TrainJob( name=name, creation_timestamp=trainjob_crd.metadata.creation_timestamp, - runtime=self.get_runtime(trainjob_crd.spec.runtime_ref.name), + runtime=runtime, steps=[], + # Number of nodes is taken from TrainJob or TrainingRuntime + num_nodes=( + trainjob_crd.spec.trainer.num_nodes + if trainjob_crd.spec.trainer and trainjob_crd.spec.trainer.num_nodes + else runtime.trainer.num_nodes + ), ) - # Add the TrainJob status. - # TODO (andreyvelich): Discuss how we should show TrainJob status to SDK users. - # The TrainJob exists at that stage so its status can safely default to Created - trainjob.status = constants.TRAINJOB_CREATED - # Then it can be read from the TrainJob conditions if any - if trainjob_crd.status and trainjob_crd.status.conditions: - for c in trainjob_crd.status.conditions: - if c.type == "Complete" and c.status == "True": - trainjob.status = "Succeeded" - elif c.type == "Failed" and c.status == "True": - trainjob.status = "Failed" - # Select Pods created by the appropriate JobSet. It checks the following ReplicatedJob.name: # dataset-initializer, model-initializer, launcher, node. label_selector = "{}={},{} in ({}, {}, {}, {})".format( @@ -567,26 +618,28 @@ def __get_trainjob_from_crd( constants.DATASET_INITIALIZER, constants.MODEL_INITIALIZER, }: - step = utils.get_trainjob_initializer_step( - pod.metadata.name, - pod.spec, - pod.status, + trainjob.steps.append( + utils.get_trainjob_initializer_step( + pod.metadata.name, + pod.spec, + pod.status, + ) ) # Get the Node step. elif pod.metadata.labels[constants.JOBSET_RJOB_NAME_LABEL] in { constants.LAUNCHER, constants.NODE, }: - step = utils.get_trainjob_node_step( - pod.metadata.name, - pod.spec, - pod.status, - trainjob.runtime, - pod.metadata.labels[constants.JOBSET_RJOB_NAME_LABEL], - int(pod.metadata.labels[constants.JOB_INDEX_LABEL]), + trainjob.steps.append( + utils.get_trainjob_node_step( + pod.metadata.name, + pod.spec, + pod.status, + trainjob.runtime, + pod.metadata.labels[constants.JOBSET_RJOB_NAME_LABEL], + int(pod.metadata.labels[constants.JOB_INDEX_LABEL]), + ) ) - - trainjob.steps.append(step) except multiprocessing.TimeoutError: raise TimeoutError( f"Timeout to list {constants.TRAINJOB_KIND}'s steps: {namespace}/{name}" @@ -596,4 +649,26 @@ def __get_trainjob_from_crd( f"Failed to list {constants.TRAINJOB_KIND}'s steps: {namespace}/{name}" ) + # Add the TrainJob status. + # The TrainJob exists at that stage so its status can safely default to Created. + trainjob.status = constants.TRAINJOB_CREATED + # Otherwise, we read the TrainJob status from its conditions. + if trainjob_crd.status and trainjob_crd.status.conditions: + for c in trainjob_crd.status.conditions: + if c.type == constants.TRAINJOB_COMPLETE and c.status == "True": + trainjob.status = c.type + elif c.type == constants.TRAINJOB_FAILED and c.status == "True": + trainjob.status = c.type + else: + # The TrainJob running status is defined when all training node (e.g. Pods) are running. + num_running_nodes = sum( + 1 + for step in trainjob.steps + if step.name.startswith(constants.NODE) + and step.status == constants.TRAINJOB_RUNNING + ) + + if trainjob.num_nodes == num_running_nodes: + trainjob.status = constants.TRAINJOB_RUNNING + return trainjob diff --git a/python/kubeflow/trainer/constants/constants.py b/python/kubeflow/trainer/constants/constants.py index d21da279..293ce55a 100644 --- a/python/kubeflow/trainer/constants/constants.py +++ b/python/kubeflow/trainer/constants/constants.py @@ -37,9 +37,18 @@ # The plural for the TrainJob. TRAINJOB_PLURAL = "trainjobs" -# The default status for the TrainJob. +# The default status for the TrainJob once users create it. TRAINJOB_CREATED = "Created" +# The running status of the TrainJob, defined when all training node (e.g. Pods) are running. +TRAINJOB_RUNNING = "Running" + +# The complete status of the TrainJob, defined when TrainJob CR has complete condition. +TRAINJOB_COMPLETE = "Complete" + +# The failed status of the TrainJob, defined when TrainJob CR has failed condition. +TRAINJOB_FAILED = "Failed" + # The label key to identify the relationship between TrainJob and Pod template in the runtime. # For example, what PodTemplate must be overridden by TrainJob's .spec.trainer APIs. TRAINJOB_ANCESTOR_LABEL = "trainer.kubeflow.org/trainjob-ancestor-step" diff --git a/python/kubeflow/trainer/types/types.py b/python/kubeflow/trainer/types/types.py index 0d9c5ed9..2220b97c 100644 --- a/python/kubeflow/trainer/types/types.py +++ b/python/kubeflow/trainer/types/types.py @@ -163,9 +163,10 @@ class Framework(Enum): # Representation for the Trainer of the runtime. @dataclass -class Trainer: +class RuntimeTrainer: trainer_type: TrainerType framework: Framework + num_nodes: int = 1 # The default value is set in the APIs. entrypoint: Optional[List[str]] = None accelerator_count: Union[str, float, int] = constants.UNKNOWN @@ -174,7 +175,7 @@ class Trainer: @dataclass class Runtime: name: str - trainer: Optional[Trainer] = None + trainer: RuntimeTrainer pretrained_model: Optional[str] = None @@ -196,6 +197,7 @@ class TrainJob: creation_timestamp: datetime runtime: Runtime steps: List[Step] + num_nodes: int status: Optional[str] = constants.UNKNOWN @@ -232,14 +234,14 @@ class Initializer: # The dict where key is the container image and value its representation. # Each Trainer representation defines trainer parameters (e.g. type, framework, entrypoint). # TODO (andreyvelich): We should allow user to overrides the default image names. -ALL_TRAINERS: Dict[str, Trainer] = { +ALL_TRAINERS: Dict[str, RuntimeTrainer] = { # Custom Trainers. - "pytorch/pytorch": Trainer( + "pytorch/pytorch": RuntimeTrainer( trainer_type=TrainerType.CUSTOM_TRAINER, framework=Framework.TORCH, entrypoint=[constants.TORCH_ENTRYPOINT], ), - "ghcr.io/kubeflow/trainer/mlx-runtime": Trainer( + "ghcr.io/kubeflow/trainer/mlx-runtime": RuntimeTrainer( trainer_type=TrainerType.CUSTOM_TRAINER, framework=Framework.MLX, entrypoint=[ @@ -250,7 +252,7 @@ class Initializer: "-c", ], ), - "ghcr.io/kubeflow/trainer/deepspeed-runtime": Trainer( + "ghcr.io/kubeflow/trainer/deepspeed-runtime": RuntimeTrainer( trainer_type=TrainerType.CUSTOM_TRAINER, framework=Framework.DEEPSPEED, entrypoint=[ @@ -262,7 +264,7 @@ class Initializer: ], ), # Builtin Trainers. - "ghcr.io/kubeflow/trainer/torchtune-trainer": Trainer( + "ghcr.io/kubeflow/trainer/torchtune-trainer": RuntimeTrainer( trainer_type=TrainerType.BUILTIN_TRAINER, framework=Framework.TORCHTUNE, entrypoint=constants.DEFAULT_TORCHTUNE_COMMAND, @@ -270,7 +272,7 @@ class Initializer: } # The default trainer configuration when runtime detection fails -DEFAULT_TRAINER = Trainer( +DEFAULT_TRAINER = RuntimeTrainer( trainer_type=TrainerType.CUSTOM_TRAINER, framework=Framework.TORCH, entrypoint=[constants.TORCH_ENTRYPOINT], diff --git a/python/kubeflow/trainer/utils/utils.py b/python/kubeflow/trainer/utils/utils.py index 85783732..8837c5fc 100644 --- a/python/kubeflow/trainer/utils/utils.py +++ b/python/kubeflow/trainer/utils/utils.py @@ -109,8 +109,7 @@ def get_runtime_trainer_container( def get_runtime_trainer( replicated_jobs: List[models.JobsetV1alpha2ReplicatedJob], ml_policy: models.TrainerV1alpha1MLPolicy, - runtime_metadata: models.IoK8sApimachineryPkgApisMetaV1ObjectMeta, -) -> types.Trainer: +) -> types.RuntimeTrainer: """ Get the runtime trainer object. """ @@ -140,6 +139,10 @@ def get_runtime_trainer( if isinstance(trainer.accelerator_count, (int, float)) and ml_policy.num_nodes: trainer.accelerator_count *= ml_policy.num_nodes + # Add number of training nodes. + if ml_policy.num_nodes: + trainer.num_nodes = ml_policy.num_nodes + return trainer From 76fd9a2c272f163087e8d58cf974162b5d27a420 Mon Sep 17 00:00:00 2001 From: Andrey Velichkevich Date: Tue, 29 Jul 2025 02:45:39 +0100 Subject: [PATCH 2/9] Fix unit tests Signed-off-by: Andrey Velichkevich --- .../trainer/api/trainer_client_test.py | 126 ++++++++---------- 1 file changed, 53 insertions(+), 73 deletions(-) diff --git a/python/kubeflow/trainer/api/trainer_client_test.py b/python/kubeflow/trainer/api/trainer_client_test.py index 9c3dbb39..b8e19b45 100644 --- a/python/kubeflow/trainer/api/trainer_client_test.py +++ b/python/kubeflow/trainer/api/trainer_client_test.py @@ -26,6 +26,7 @@ from dataclasses import asdict, dataclass, field from typing import Any, Dict, Optional, Type from unittest.mock import Mock, patch +from pydantic import BaseModel import pytest from kubeflow.trainer import TrainerClient @@ -216,13 +217,12 @@ def get_resource_requirements() -> models.IoK8sApiCoreV1ResourceRequirements: ) -def add_custom_trainer_to_job( - train_job: models.TrainerV1alpha1TrainJob, -) -> models.TrainerV1alpha1TrainJob: +def get_custom_trainer() -> models.TrainerV1alpha1Trainer: """ - Add a custom trainer configuration to the train job. + Get the custom trainer for the TrainJob. """ - trainer_crd = models.TrainerV1alpha1Trainer( + + return models.TrainerV1alpha1Trainer( command=["bash", "-c"], args=[ '\nif ! [ -x "$(command -v pip)" ]; then\n python -m ensurepip ' @@ -234,33 +234,24 @@ def add_custom_trainer_to_job( "{'learning_rate': 0.001, 'batch_size': 32})\n\nEOM\nprintf \"%s\" " '"$SCRIPT" > "trainer_client_test.py"\ntorchrun "trainer_client_test.py"' ], - num_nodes=2, + numNodes=2, ) - train_job.spec.trainer = trainer_crd - return train_job - -def add_built_in_trainer_to_job( - train_job: models.TrainerV1alpha1TrainJob, -) -> models.TrainerV1alpha1TrainJob: +def get_builtin_trainer() -> models.TrainerV1alpha1Trainer: """ - Add a built-in trainer configuration to the train job. + Get the builtin trainer for the TrainJob. """ - trainer_crd = models.TrainerV1alpha1Trainer( + return models.TrainerV1alpha1Trainer( args=["batch_size=2", "epochs=2", "loss=Loss.CEWithChunkedOutputLoss"], command=["tune", "run"], numNodes=2, ) - train_job.spec.trainer = trainer_crd - return train_job def get_train_job( train_job_name: str = BASIC_TRAIN_JOB_NAME, - runtime_name: str = TORCH_DISTRIBUTED, - add_built_in_trainer: bool = False, - add_custom_trainer: bool = False, + train_job_trainer: Optional[models.TrainerV1alpha1Trainer] = None, ) -> models.TrainerV1alpha1TrainJob: """ Create a mock TrainJob object with optional trainer configurations. @@ -270,44 +261,42 @@ def get_train_job( kind=constants.TRAINJOB_KIND, metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta(name=train_job_name), spec=models.TrainerV1alpha1TrainJobSpec( - runtimeRef=models.TrainerV1alpha1RuntimeRef(name=runtime_name), + runtimeRef=models.TrainerV1alpha1RuntimeRef(name=TORCH_DISTRIBUTED), + trainer=train_job_trainer, ), ) - if add_built_in_trainer: - train_job = add_built_in_trainer_to_job(train_job) - if add_custom_trainer: - train_job = add_custom_trainer_to_job(train_job) - return train_job def get_cluster_custom_object_response(*args, **kwargs): """Return a mocked ClusterTrainingRuntime object.""" + mock_thread = Mock() if args[3] == TIMEOUT: raise multiprocessing.TimeoutError() if args[3] == RUNTIME: raise RuntimeError() if args[2] == constants.CLUSTER_TRAINING_RUNTIME_PLURAL: - result = create_cluster_training_runtime() + mock_thread.get.return_value = normalize_model( + create_cluster_training_runtime(), + models.TrainerV1alpha1ClusterTrainingRuntime, + ) - result = normalize_model(result, models.TrainerV1alpha1ClusterTrainingRuntime) - mock_thread = Mock() - mock_thread.get.return_value = result return mock_thread def get_namespaced_custom_object_response(*args, **kwargs): """Return a mocked TrainJob object.""" + mock_thread = Mock() if args[2] == TIMEOUT or args[4] == TIMEOUT: raise multiprocessing.TimeoutError() if args[2] == RUNTIME or args[4] == RUNTIME: raise RuntimeError() if args[3] == TRAIN_JOBS: # TODO: review this. - job = add_status(create_train_job(train_job_name=args[4])) + mock_thread.get.return_value = add_status( + create_train_job(train_job_name=args[4]) + ) - mock_thread = Mock() - mock_thread.get.return_value = job return mock_thread @@ -336,37 +325,26 @@ def add_status( def list_namespaced_custom_object_response(*args, **kwargs): """Return a list of mocked TrainJob objects.""" mock_thread = Mock() - if args[2] == TIMEOUT: raise multiprocessing.TimeoutError() if args[2] == RUNTIME: raise RuntimeError() - if args[2] == BASIC_TRAIN_JOB_NAME: - return_value = { - "items": [create_train_job(train_job_name=BASIC_TRAIN_JOB_NAME)] - } - if args[3] == TRAIN_JOBS: - return_value = build_train_job_list() - - mock_thread.get.return_value = return_value - return mock_thread - - -def build_train_job_list() -> models.TrainerV1alpha1TrainJobList: - """Build a mock TrainJobList object with multiple TrainJob items.""" - train_job_list = models.TrainerV1alpha1TrainJobList( - apiVersion=constants.API_VERSION, - kind=constants.TRAINJOB_PLURAL, - items=[ + if args[3] == constants.TRAINJOB_PLURAL: + items = [ add_status(create_train_job(train_job_name="basic-job-1")), add_status(create_train_job(train_job_name="basic-job-2")), - ], - ) - return normalize_model(train_job_list, models.TrainerV1alpha1TrainJobList) + ] + mock_thread.get.return_value = normalize_model( + models.TrainerV1alpha1TrainJobList(items=items), + models.TrainerV1alpha1TrainJobList, + ) + + return mock_thread def list_cluster_custom_object(*args, **kwargs): """Return a generic mocked response for cluster object listing.""" + mock_thread = Mock() if args[2] == TIMEOUT: raise multiprocessing.TimeoutError() if args[2] == RUNTIME: @@ -376,14 +354,11 @@ def list_cluster_custom_object(*args, **kwargs): create_cluster_training_runtime(name="runtime-1"), create_cluster_training_runtime(name="runtime-2"), ] + mock_thread.get.return_value = normalize_model( + models.TrainerV1alpha1ClusterTrainingRuntimeList(items=items), + models.TrainerV1alpha1ClusterTrainingRuntimeList, + ) - runtime_list_obj = models.TrainerV1alpha1ClusterTrainingRuntimeList(items=items) - runtimes = normalize_model( - runtime_list_obj, models.TrainerV1alpha1ClusterTrainingRuntimeList - ) - - mock_thread = Mock() - mock_thread.get.return_value = runtimes return mock_thread @@ -394,7 +369,7 @@ def mock_read_namespaced_pod_log(*args, **kwargs): return "test log content" -def normalize_model(model_obj, model_class): +def normalize_model(model_obj, model_class) -> BaseModel: # Simulate real api behavior # Converts model to raw dictionary, like a real API response # Parses dict and ensures correct model instantiation and type validation @@ -412,8 +387,8 @@ def create_train_job( runtime: str = PYTORCH, image: str = "pytorch/pytorch:latest", initializer: Optional[types.Initializer] = None, - command: list = None, - args: list = None, + command: Optional[list] = None, + args: Optional[list] = None, ) -> models.TrainerV1alpha1TrainJob: """Create a mock TrainJob object.""" return models.TrainerV1alpha1TrainJob( @@ -454,9 +429,7 @@ def create_cluster_training_runtime( spec=models.TrainerV1alpha1TrainingRuntimeSpec( mlPolicy=models.TrainerV1alpha1MLPolicy( torch=models.TrainerV1alpha1TorchMLPolicySource( - num_proc_per_node=models.IoK8sApimachineryPkgUtilIntstrIntOrString( - 2 - ) + numProcPerNode=models.IoK8sApimachineryPkgUtilIntstrIntOrString(2) ), numNodes=2, ), @@ -466,7 +439,7 @@ def create_cluster_training_runtime( namespace=namespace, ), spec=models.JobsetV1alpha2JobSetSpec( - replicated_jobs=[get_replicated_job()] + replicatedJobs=[get_replicated_job()] ), ), ), @@ -474,7 +447,7 @@ def create_cluster_training_runtime( return runtime -def get_replicated_job() -> models.TrainerV1alpha1ClusterTrainingRuntime: +def get_replicated_job() -> models.JobsetV1alpha2ReplicatedJob: return models.JobsetV1alpha2ReplicatedJob( name="node", replicas=1, @@ -507,11 +480,12 @@ def create_runtime_type( return types.Runtime( name=name, pretrained_model=None, - trainer=types.Trainer( + trainer=types.RuntimeTrainer( trainer_type=types.TrainerType.CUSTOM_TRAINER, framework=types.Framework.TORCH, entrypoint=[constants.TORCH_ENTRYPOINT], accelerator_count=4, + num_nodes=2, ), ) @@ -533,11 +507,12 @@ def get_train_job_data_type( runtime=types.Runtime( name=TORCH_DISTRIBUTED, pretrained_model=None, - trainer=types.Trainer( + trainer=types.RuntimeTrainer( trainer_type=types.TrainerType.CUSTOM_TRAINER, framework=types.Framework.TORCH, entrypoint=["torchrun"], accelerator_count=4, + num_nodes=2, ), ), steps=[ @@ -563,7 +538,8 @@ def get_train_job_data_type( device_count="1", ), ], - status="Succeeded", + num_nodes=2, + status="Complete", ) @@ -669,7 +645,7 @@ def test_list_runtimes(training_client, test_case): }, expected_output=get_train_job( train_job_name=TRAIN_JOB_WITH_BUILT_IN_TRAINER, - add_built_in_trainer=True, + train_job_trainer=get_builtin_trainer(), ), ), TestCase( @@ -685,7 +661,8 @@ def test_list_runtimes(training_client, test_case): ) }, expected_output=get_train_job( - train_job_name=TRAIN_JOB_WITH_CUSTOM_TRAINER, add_custom_trainer=True + train_job_name=TRAIN_JOB_WITH_CUSTOM_TRAINER, + train_job_trainer=get_custom_trainer(), ), ), TestCase( @@ -811,6 +788,9 @@ def test_list_jobs(training_client, test_case): assert test_case.expected_status == SUCCESS assert isinstance(jobs, list) assert len(jobs) == 2 + assert [asdict(j) for j in jobs] == [ + asdict(r) for r in test_case.expected_output + ] except Exception as e: assert type(e) is test_case.expected_error From dc1b011f81ddc6da10992016e5becb705788a5d2 Mon Sep 17 00:00:00 2001 From: Andrey Velichkevich Date: Tue, 29 Jul 2025 03:36:06 +0100 Subject: [PATCH 3/9] Add tests for wait_for_job_status Signed-off-by: Andrey Velichkevich --- .../trainer/api/trainer_client_test.py | 142 ++++++++++++++---- 1 file changed, 115 insertions(+), 27 deletions(-) diff --git a/python/kubeflow/trainer/api/trainer_client_test.py b/python/kubeflow/trainer/api/trainer_client_test.py index b8e19b45..da46bdfd 100644 --- a/python/kubeflow/trainer/api/trainer_client_test.py +++ b/python/kubeflow/trainer/api/trainer_client_test.py @@ -26,7 +26,6 @@ from dataclasses import asdict, dataclass, field from typing import Any, Dict, Optional, Type from unittest.mock import Mock, patch -from pydantic import BaseModel import pytest from kubeflow.trainer import TrainerClient @@ -369,7 +368,7 @@ def mock_read_namespaced_pod_log(*args, **kwargs): return "test log content" -def normalize_model(model_obj, model_class) -> BaseModel: +def normalize_model(model_obj, model_class): # Simulate real api behavior # Converts model to raw dictionary, like a real API response # Parses dict and ensures correct model instantiation and type validation @@ -804,40 +803,117 @@ def test_list_jobs(training_client, test_case): name="valid flow with all defaults", expected_status=SUCCESS, config={"name": BASIC_TRAIN_JOB_NAME}, - expected_output=None, + expected_output={ + "node-0": "test log content", + }, ), TestCase( - name="timeout error when deleting job", + name="runtime error when getting logs", expected_status=FAILED, - config={"namespace": TIMEOUT}, + config={"name": RUNTIME}, + expected_error=RuntimeError, + ), + ], +) +def test_get_job_logs(training_client, test_case): + """Test TrainerClient.get_job_logs with basic success path.""" + print("Executing test:", test_case.name) + try: + logs = training_client.get_job_logs(test_case.config.get("name")) + assert test_case.expected_status == SUCCESS + assert logs == test_case.expected_output + + except Exception as e: + assert type(e) is test_case.expected_error + print("test execution complete") + + +@pytest.mark.parametrize( + "test_case", + [ + TestCase( + name="wait for complete status (default)", + expected_status=SUCCESS, + config={"name": BASIC_TRAIN_JOB_NAME}, + expected_output=get_train_job_data_type( + train_job_name=BASIC_TRAIN_JOB_NAME + ), + ), + TestCase( + name="wait for multiple statuses", + expected_status=SUCCESS, + config={ + "name": BASIC_TRAIN_JOB_NAME, + "status": {constants.TRAINJOB_RUNNING, constants.TRAINJOB_COMPLETE}, + }, + expected_output=get_train_job_data_type( + train_job_name=BASIC_TRAIN_JOB_NAME + ), + ), + TestCase( + name="timeout error when waiting for job", + expected_status=FAILED, + config={ + "name": TIMEOUT, + "timeout": 1, + "polling_interval": 0.5, + }, expected_error=TimeoutError, ), TestCase( - name="runtime error when deleting job", + name="runtime error when waiting for job", expected_status=FAILED, - config={"namespace": RUNTIME}, + config={"name": RUNTIME}, + expected_error=RuntimeError, + ), + TestCase( + name="invalid status set error", + expected_status=FAILED, + config={ + "name": BASIC_TRAIN_JOB_NAME, + "status": {"InvalidStatus"}, + }, + expected_error=ValueError, + ), + TestCase( + name="job failed when not expected", + expected_status=FAILED, + config={ + "name": "failed-job", + "status": {constants.TRAINJOB_RUNNING}, + }, expected_error=RuntimeError, ), ], ) -def test_delete_job(training_client, test_case): - """Test TrainerClient.delete_job with basic success path.""" +def test_wait_for_job_status(training_client, test_case): + """Test TrainerClient.wait_for_job_status with various scenarios.""" print("Executing test:", test_case.name) + + original_get_job = training_client.get_job + + # TrainJob has unexpected failed status. + def mock_get_job(name): + job = original_get_job(name) + if test_case.config.get("name") == "failed-job": + job.status = constants.TRAINJOB_FAILED + return job + + training_client.get_job = mock_get_job + try: - training_client.namespace = test_case.config.get("namespace", DEFAULT_NAMESPACE) - training_client.delete_job(test_case.config.get("name")) - assert test_case.expected_status == SUCCESS + job = training_client.wait_for_job_status(**test_case.config) - training_client.custom_api.delete_namespaced_custom_object.assert_called_with( - constants.GROUP, - constants.VERSION, - test_case.config.get("namespace", DEFAULT_NAMESPACE), - constants.TRAINJOB_PLURAL, - name=test_case.config.get("name"), + assert test_case.expected_status == SUCCESS + assert isinstance(job, types.TrainJob) + # Job status should be in the expected set. + assert job.status in test_case.config.get( + "status", {constants.TRAINJOB_COMPLETE} ) except Exception as e: assert type(e) is test_case.expected_error + print("test execution complete") @@ -848,25 +924,37 @@ def test_delete_job(training_client, test_case): name="valid flow with all defaults", expected_status=SUCCESS, config={"name": BASIC_TRAIN_JOB_NAME}, - expected_output={ - "node-0": "test log content", - }, + expected_output=None, ), TestCase( - name="runtime error when getting logs", + name="timeout error when deleting job", expected_status=FAILED, - config={"name": RUNTIME}, + config={"namespace": TIMEOUT}, + expected_error=TimeoutError, + ), + TestCase( + name="runtime error when deleting job", + expected_status=FAILED, + config={"namespace": RUNTIME}, expected_error=RuntimeError, ), ], ) -def test_get_job_logs(training_client, test_case): - """Test TrainerClient.get_job_logs with basic success path.""" +def test_delete_job(training_client, test_case): + """Test TrainerClient.delete_job with basic success path.""" print("Executing test:", test_case.name) try: - logs = training_client.get_job_logs(test_case.config.get("name")) + training_client.namespace = test_case.config.get("namespace", DEFAULT_NAMESPACE) + training_client.delete_job(test_case.config.get("name")) assert test_case.expected_status == SUCCESS - assert logs == test_case.expected_output + + training_client.custom_api.delete_namespaced_custom_object.assert_called_with( + constants.GROUP, + constants.VERSION, + test_case.config.get("namespace", DEFAULT_NAMESPACE), + constants.TRAINJOB_PLURAL, + name=test_case.config.get("name"), + ) except Exception as e: assert type(e) is test_case.expected_error From 476b1b2551f42c8b8589f1ace67ac4b17572ccbf Mon Sep 17 00:00:00 2001 From: Andrey Velichkevich Date: Tue, 29 Jul 2025 03:41:09 +0100 Subject: [PATCH 4/9] Update test case for timeout Signed-off-by: Andrey Velichkevich --- .../trainer/api/trainer_client_test.py | 27 ++++++++----------- 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/python/kubeflow/trainer/api/trainer_client_test.py b/python/kubeflow/trainer/api/trainer_client_test.py index da46bdfd..e05c209a 100644 --- a/python/kubeflow/trainer/api/trainer_client_test.py +++ b/python/kubeflow/trainer/api/trainer_client_test.py @@ -850,22 +850,6 @@ def test_get_job_logs(training_client, test_case): train_job_name=BASIC_TRAIN_JOB_NAME ), ), - TestCase( - name="timeout error when waiting for job", - expected_status=FAILED, - config={ - "name": TIMEOUT, - "timeout": 1, - "polling_interval": 0.5, - }, - expected_error=TimeoutError, - ), - TestCase( - name="runtime error when waiting for job", - expected_status=FAILED, - config={"name": RUNTIME}, - expected_error=RuntimeError, - ), TestCase( name="invalid status set error", expected_status=FAILED, @@ -884,6 +868,17 @@ def test_get_job_logs(training_client, test_case): }, expected_error=RuntimeError, ), + TestCase( + name="timeout error to wait for failed status", + expected_status=FAILED, + config={ + "name": BASIC_TRAIN_JOB_NAME, + "status": {constants.TRAINJOB_FAILED}, + "timeout": 1, + "polling_interval": 0.5, + }, + expected_error=TimeoutError, + ), ], ) def test_wait_for_job_status(training_client, test_case): From ddd265551174ac2c7ea789a48f06d48ed80f117c Mon Sep 17 00:00:00 2001 From: Andrey Velichkevich Date: Thu, 31 Jul 2025 00:08:38 +0100 Subject: [PATCH 5/9] Update python/kubeflow/trainer/api/trainer_client.py Co-authored-by: Anya Kramar Signed-off-by: Andrey Velichkevich --- python/kubeflow/trainer/api/trainer_client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/kubeflow/trainer/api/trainer_client.py b/python/kubeflow/trainer/api/trainer_client.py index 0ded85a5..0ed59574 100644 --- a/python/kubeflow/trainer/api/trainer_client.py +++ b/python/kubeflow/trainer/api/trainer_client.py @@ -437,9 +437,9 @@ def get_job_logs( def wait_for_job_status( self, name: str, - status: Set = {constants.TRAINJOB_COMPLETE}, + status: Set[str] = {constants.TRAINJOB_COMPLETE}, timeout: int = 600, - polling_interval=5, + polling_interval: int = 5, ): """Wait for TrainJob to reach the desired status From 675f101807e6e3856aea89e7817628871cbbf48c Mon Sep 17 00:00:00 2001 From: Andrey Velichkevich Date: Thu, 31 Jul 2025 00:28:47 +0100 Subject: [PATCH 6/9] Add Created status after TrainJob is created Signed-off-by: Andrey Velichkevich --- python/kubeflow/trainer/api/trainer_client.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/kubeflow/trainer/api/trainer_client.py b/python/kubeflow/trainer/api/trainer_client.py index 0ed59574..226898a4 100644 --- a/python/kubeflow/trainer/api/trainer_client.py +++ b/python/kubeflow/trainer/api/trainer_client.py @@ -575,6 +575,7 @@ def __get_trainjob_from_crd( if trainjob_crd.spec.trainer and trainjob_crd.spec.trainer.num_nodes else runtime.trainer.num_nodes ), + status=constants.TRAINJOB_CREATED, # The default TrainJob status. ) # Select Pods created by the appropriate JobSet. It checks the following ReplicatedJob.name: @@ -649,10 +650,7 @@ def __get_trainjob_from_crd( f"Failed to list {constants.TRAINJOB_KIND}'s steps: {namespace}/{name}" ) - # Add the TrainJob status. - # The TrainJob exists at that stage so its status can safely default to Created. - trainjob.status = constants.TRAINJOB_CREATED - # Otherwise, we read the TrainJob status from its conditions. + # Update the TrainJob status from its conditions. if trainjob_crd.status and trainjob_crd.status.conditions: for c in trainjob_crd.status.conditions: if c.type == constants.TRAINJOB_COMPLETE and c.status == "True": From 65b8a37305cdf8d7d0109504dea5b574dfd2ed66 Mon Sep 17 00:00:00 2001 From: Andrey Velichkevich Date: Thu, 31 Jul 2025 00:36:26 +0100 Subject: [PATCH 7/9] Wait for Created status Signed-off-by: Andrey Velichkevich --- python/kubeflow/trainer/api/trainer_client.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/kubeflow/trainer/api/trainer_client.py b/python/kubeflow/trainer/api/trainer_client.py index 226898a4..b9dcd554 100644 --- a/python/kubeflow/trainer/api/trainer_client.py +++ b/python/kubeflow/trainer/api/trainer_client.py @@ -440,13 +440,13 @@ def wait_for_job_status( status: Set[str] = {constants.TRAINJOB_COMPLETE}, timeout: int = 600, polling_interval: int = 5, - ): + ) -> types.TrainJob: """Wait for TrainJob to reach the desired status Args: name: Name of the TrainJob. - status: Set of expected statuses. It must be subset of Running, Complete, and Failed - statuses. + status: Set of expected statuses. It must be subset of Created, Running, Complete, and + Failed statuses. timeout: How many seconds to wait until TrainJob reaches one of the expected conditions. polling_interval: The polling interval in seconds to check TrainJob status. @@ -454,13 +454,13 @@ def wait_for_job_status( TrainJob: The training job that reaches desired status. Raises: - ValueError: The input values are incorrect. RuntimeError: Failed to get TrainJob or TrainJob reaches unexpected Failed status. TimeoutError: Timeout to wait for TrainJob status. """ job_statuses = { + constants.TRAINJOB_CREATED, constants.TRAINJOB_RUNNING, constants.TRAINJOB_COMPLETE, constants.TRAINJOB_FAILED, From a4f0bcd0001b5a0e74f55ab8b7b6166fdf7234e2 Mon Sep 17 00:00:00 2001 From: Andrey Velichkevich Date: Thu, 31 Jul 2025 14:24:52 +0100 Subject: [PATCH 8/9] Use watch API Signed-off-by: Andrey Velichkevich --- python/kubeflow/trainer/api/trainer_client.py | 58 ++++++++++--------- .../trainer/api/trainer_client_test.py | 32 +++++++++- .../kubeflow/trainer/constants/constants.py | 11 ++++ 3 files changed, 72 insertions(+), 29 deletions(-) diff --git a/python/kubeflow/trainer/api/trainer_client.py b/python/kubeflow/trainer/api/trainer_client.py index b9dcd554..9368b62f 100644 --- a/python/kubeflow/trainer/api/trainer_client.py +++ b/python/kubeflow/trainer/api/trainer_client.py @@ -439,7 +439,6 @@ def wait_for_job_status( name: str, status: Set[str] = {constants.TRAINJOB_COMPLETE}, timeout: int = 600, - polling_interval: int = 5, ) -> types.TrainJob: """Wait for TrainJob to reach the desired status @@ -448,10 +447,9 @@ def wait_for_job_status( status: Set of expected statuses. It must be subset of Created, Running, Complete, and Failed statuses. timeout: How many seconds to wait until TrainJob reaches one of the expected conditions. - polling_interval: The polling interval in seconds to check TrainJob status. Returns: - TrainJob: The training job that reaches desired status. + TrainJob: The training job that reaches the desired status. Raises: ValueError: The input values are incorrect. @@ -469,24 +467,40 @@ def wait_for_job_status( raise ValueError( f"Expected status {status} must be a subset of {job_statuses}" ) - for _ in range(round(timeout / polling_interval)): - trainjob = self.get_job(name) - # Raise an error if TrainJob is Failed and it is not the expected status. - if ( - constants.TRAINJOB_FAILED not in status - and trainjob.status == constants.TRAINJOB_FAILED + # Use Kubernetes watch API to monitor the TrainJob's Pods. + w = watch.Watch() + try: + for event in w.stream( + self.core_api.list_namespaced_pod, + self.namespace, + label_selector=constants.POD_LABEL_SELECTOR.format(trainjob_name=name), + timeout_seconds=timeout, ): - raise RuntimeError(f"TrainJob {name} is Failed") + # Check the status after event is generated for the TrainJob's Pods. + trainjob = self.get_job(name) + logger.debug(f"TrainJob {name}, status {trainjob.status}") - # Return the TrainJob if it reaches the expected status. - if trainjob.status in status: - return trainjob + # Raise an error if TrainJob is Failed and it is not the expected status. + if ( + constants.TRAINJOB_FAILED not in status + and trainjob.status == constants.TRAINJOB_FAILED + ): + raise RuntimeError(f"TrainJob {name} is Failed") - time.sleep(polling_interval) + # Return the TrainJob if it reaches the expected status. + if trainjob.status in status: + return trainjob + + except TimeoutError: + raise TimeoutError(f"Timeout to get the TrainJob {name}") + except Exception: + raise RuntimeError(f"Failed to watch Pods for TrainJob {name}") + finally: + w.stop() raise TimeoutError( - f"Timeout waiting for TrainJob {name} to reach {status} Status" + f"Timeout waiting for TrainJob {name} to reach status: {status} status" ) def delete_job(self, name: str): @@ -578,23 +592,11 @@ def __get_trainjob_from_crd( status=constants.TRAINJOB_CREATED, # The default TrainJob status. ) - # Select Pods created by the appropriate JobSet. It checks the following ReplicatedJob.name: - # dataset-initializer, model-initializer, launcher, node. - label_selector = "{}={},{} in ({}, {}, {}, {})".format( - constants.JOBSET_NAME_LABEL, - name, - constants.JOBSET_RJOB_NAME_LABEL, - constants.DATASET_INITIALIZER, - constants.MODEL_INITIALIZER, - constants.LAUNCHER, - constants.NODE, - ) - # Add the TrainJob components, e.g. trainer nodes and initializer. try: response = self.core_api.list_namespaced_pod( namespace, - label_selector=label_selector, + label_selector=constants.POD_LABEL_SELECTOR.format(trainjob_name=name), async_req=True, ).get(constants.DEFAULT_TIMEOUT) diff --git a/python/kubeflow/trainer/api/trainer_client_test.py b/python/kubeflow/trainer/api/trainer_client_test.py index e05c209a..410dbfaf 100644 --- a/python/kubeflow/trainer/api/trainer_client_test.py +++ b/python/kubeflow/trainer/api/trainer_client_test.py @@ -102,6 +102,11 @@ def training_client(request): list_namespaced_pod=Mock(side_effect=list_namespaced_pod_response), read_namespaced_pod_log=Mock(side_effect=mock_read_namespaced_pod_log), ), + ), patch( + "kubernetes.watch.Watch", + return_value=Mock( + stream=Mock(side_effect=mock_watch), + ), ): yield TrainerClient() @@ -368,6 +373,32 @@ def mock_read_namespaced_pod_log(*args, **kwargs): return "test log content" +def mock_watch(*args, **kwargs): + """Simulate watch event""" + if kwargs.get("timeout_seconds") == 1: + raise TimeoutError("Watch timeout") + + events = [ + { + "type": "MODIFIED", + "object": { + "metadata": { + "name": f"{BASIC_TRAIN_JOB_NAME}-node-0", + "labels": { + constants.JOBSET_NAME_LABEL: BASIC_TRAIN_JOB_NAME, + constants.JOBSET_RJOB_NAME_LABEL: constants.NODE, + constants.JOB_INDEX_LABEL: "0", + }, + }, + "spec": {"containers": [{"name": constants.NODE}]}, + "status": {"phase": "Running"}, + }, + } + ] + + return iter(events) + + def normalize_model(model_obj, model_class): # Simulate real api behavior # Converts model to raw dictionary, like a real API response @@ -875,7 +906,6 @@ def test_get_job_logs(training_client, test_case): "name": BASIC_TRAIN_JOB_NAME, "status": {constants.TRAINJOB_FAILED}, "timeout": 1, - "polling_interval": 0.5, }, expected_error=TimeoutError, ), diff --git a/python/kubeflow/trainer/constants/constants.py b/python/kubeflow/trainer/constants/constants.py index 293ce55a..f4900ab1 100644 --- a/python/kubeflow/trainer/constants/constants.py +++ b/python/kubeflow/trainer/constants/constants.py @@ -106,6 +106,17 @@ # but one or more of the containers has not been made ready to run. POD_PENDING = "Pending" +# The label selector for Pods created by the TrainJob. +# It checks the following rJob.name: dataset-initializer, model-initializer, launcher, node. +POD_LABEL_SELECTOR = ("{}={{trainjob_name}},{} in ({}, {}, {}, {})").format( + JOBSET_NAME_LABEL, + JOBSET_RJOB_NAME_LABEL, + DATASET_INITIALIZER, + MODEL_INITIALIZER, + LAUNCHER, + NODE, +) + # The default PIP index URL to download Python packages. DEFAULT_PIP_INDEX_URL = os.getenv("DEFAULT_PIP_INDEX_URL", "https://pypi.org/simple") From 81e43fdd655c39ba6b16fc6e6a385ffbeb2753ac Mon Sep 17 00:00:00 2001 From: Andrey Velichkevich Date: Thu, 31 Jul 2025 14:27:14 +0100 Subject: [PATCH 9/9] Remove time import Signed-off-by: Andrey Velichkevich --- python/kubeflow/trainer/api/trainer_client.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/kubeflow/trainer/api/trainer_client.py b/python/kubeflow/trainer/api/trainer_client.py index 9368b62f..6db05833 100644 --- a/python/kubeflow/trainer/api/trainer_client.py +++ b/python/kubeflow/trainer/api/trainer_client.py @@ -18,7 +18,6 @@ import random import string import uuid -import time from typing import Dict, List, Optional, Union, Set from kubeflow.trainer.constants import constants