Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 58 additions & 35 deletions kubeflow/trainer/api/trainer_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import logging
from typing import Dict, List, Optional, Set, Union
from typing import Optional, Union

from kubeflow.trainer.constants import constants
from kubeflow.trainer.types import types
Expand All @@ -35,42 +35,45 @@ def __init__(
backend_config: Backend configuration. Either KubernetesBackendConfig or
LocalProcessBackendConfig, or None to use the backend's
default config class. Defaults to KubernetesBackendConfig.

Raises:
ValueError: Invalid backend configuration.

"""
# initialize training backend
if isinstance(backend_config, KubernetesBackendConfig):
self.backend = KubernetesBackend(backend_config)
else:
raise ValueError("Invalid backend config '{}'".format(backend_config))

def list_runtimes(self) -> types.Runtime:
"""List of the available Runtimes.
def list_runtimes(self) -> list[types.Runtime]:
"""List of the available runtimes.

Returns:
List[Runtime]: List of available training runtimes.
If no runtimes exist, an empty list is returned.
A list of available training runtimes. If no runtimes exist, an empty list is returned.

Raises:
TimeoutError: Timeout to list Runtimes.
RuntimeError: Failed to list Runtimes.
TimeoutError: Timeout to list runtimes.
RuntimeError: Failed to list runtimes.
"""
return self.backend.list_runtimes()

def get_runtime(self, name: str) -> types.Runtime:
"""Get the Runtime object
"""Get the runtime object
Args:
name: Name of the runtime.
Returns:
types.TrainingRuntime: Runtime object.

Returns:
A runtime object.
"""
return self.backend.get_runtime(name=name)

def get_runtime_packages(self, runtime: types.Runtime):
"""
Print the installed Python packages for the given Runtime. If Runtime has GPUs it also
"""Print the installed Python packages for the given runtime. If a runtime has GPUs it also
prints available GPUs on the single training node.

Args:
runtime: Reference to one of existing Runtimes.
runtime: Reference to one of existing runtimes.

Raises:
ValueError: Input arguments are invalid.
Expand All @@ -81,33 +84,40 @@ def get_runtime_packages(self, runtime: types.Runtime):

def train(
self,
runtime: types.Runtime = None,
runtime: Optional[types.Runtime] = None,
initializer: Optional[types.Initializer] = None,
trainer: Optional[Union[types.CustomTrainer, types.BuiltinTrainer]] = None,
) -> str:
"""
Create the TrainJob. You can configure these types of training task:
- Custom Training Task: Training with a self-contained function that encapsulates
the entire model training process, e.g. `CustomTrainer`.
- Config-driven Task with Existing Trainer: Training with a trainer that already includes
the post-training logic, requiring only parameter adjustments, e.g. `BuiltinTrainer`.
"""Create a TrainJob. You can configure the TrainJob using one of these trainers:

- CustomTrainer: Runs training with a user-defined function that fully encapsulates the
training process.
- BuiltinTrainer: Uses a predefined trainer with built-in post-training logic, requiring
only parameter configuration.

Args:
runtime: Reference to one of existing Runtimes.
initializer:
Configuration for the dataset and model initializers.
trainer:
Configuration for Custom Training Task or Config-driven Task with Builtin Trainer.
runtime: Optional reference to one of the existing runtimes. Defaults to the
torch-distributed runtime if not provided.
initializer: Optional configuration for the dataset and model initializers.
trainer: Optional configuration for a CustomTrainer or BuiltinTrainer. If not specified,
the TrainJob will use the runtime's default values.

Returns:
str: The unique name of the TrainJob that has been generated.
The unique name of the TrainJob that has been generated.

Raises:
ValueError: Input arguments are invalid.
TimeoutError: Timeout to create TrainJobs.
RuntimeError: Failed to create TrainJobs.
"""
return self.backend.train(runtime=runtime, initializer=initializer, trainer=trainer)

def list_jobs(self, runtime: Optional[types.Runtime] = None) -> List[types.TrainJob]:
"""List of all TrainJobs.
def list_jobs(self, runtime: Optional[types.Runtime] = None) -> list[types.TrainJob]:
"""List of the created TrainJobs. If a runtime is specified, only TrainJobs associated with
that runtime are returned.

Args:
runtime: Reference to one of the existing runtimes.

Returns:
List: List of created TrainJobs.
Expand All @@ -120,7 +130,19 @@ def list_jobs(self, runtime: Optional[types.Runtime] = None) -> List[types.Train
return self.backend.list_jobs(runtime=runtime)

def get_job(self, name: str) -> types.TrainJob:
"""Get the TrainJob object"""
"""Get the TrainJob object

Args:
name: Name of the TrainJob.

Returns:
A TrainJob object.

Raises:
TimeoutError: Timeout to get a TrainJob.
RuntimeError: Failed to get a TrainJob.
"""

return self.backend.get_job(name=name)

def get_job_logs(
Expand All @@ -129,28 +151,29 @@ def get_job_logs(
follow: Optional[bool] = False,
step: str = constants.NODE,
node_rank: int = 0,
) -> Dict[str, str]:
) -> dict[str, str]:
"""Get the logs from TrainJob"""
return self.backend.get_job_logs(name=name, follow=follow, step=step, node_rank=node_rank)

def wait_for_job_status(
self,
name: str,
status: Set[str] = {constants.TRAINJOB_COMPLETE},
status: set[str] = {constants.TRAINJOB_COMPLETE},
timeout: int = 600,
polling_interval: int = 2,
) -> types.TrainJob:
"""Wait for TrainJob to reach the desired status
"""Wait for a TrainJob to reach a desired status.

Args:
name: Name of the TrainJob.
status: Set of expected statuses. It must be subset of Created, Running, Complete, and
status: Expected statuses. Must be a subset of Created, Running, Complete, and
Failed statuses.
timeout: How many seconds to wait until TrainJob reaches one of the expected conditions.
timeout: Maximum number of seconds to wait for the TrainJob to reach one of the
expected statuses.
polling_interval: The polling interval in seconds to check TrainJob status.

Returns:
TrainJob: The training job that reaches the desired status.
A TrainJob object that reaches the desired status.

Raises:
ValueError: The input values are incorrect.
Expand Down
10 changes: 5 additions & 5 deletions kubeflow/trainer/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@

import abc

from typing import Dict, List, Optional, Set, Union
from typing import Optional, Union
from kubeflow.trainer.constants import constants
from kubeflow.trainer.types import types


class ExecutionBackend(abc.ABC):
def list_runtimes(self) -> List[types.Runtime]:
def list_runtimes(self) -> list[types.Runtime]:
raise NotImplementedError()

def get_runtime(self, name: str) -> types.Runtime:
Expand All @@ -37,7 +37,7 @@ def train(
) -> str:
raise NotImplementedError()

def list_jobs(self, runtime: Optional[types.Runtime] = None) -> List[types.TrainJob]:
def list_jobs(self, runtime: Optional[types.Runtime] = None) -> list[types.TrainJob]:
raise NotImplementedError()

def get_job(self, name: str) -> types.TrainJob:
Expand All @@ -49,13 +49,13 @@ def get_job_logs(
follow: Optional[bool] = False,
step: str = constants.NODE,
node_rank: int = 0,
) -> Dict[str, str]:
) -> dict[str, str]:
raise NotImplementedError()

def wait_for_job_status(
self,
name: str,
status: Set[str] = {constants.TRAINJOB_COMPLETE},
status: set[str] = {constants.TRAINJOB_COMPLETE},
timeout: int = 600,
polling_interval: int = 2,
) -> types.TrainJob:
Expand Down
Loading