Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions python/kubeflow/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
# Import the Kubeflow Trainer client.
from kubeflow.trainer.api.trainer_client import TrainerClient

# Import the Kubeflow Local Trainer client.
from kubeflow.trainer.api.local_trainer_client import LocalTrainerClient

# Import the Kubeflow Trainer constants.
from kubeflow.trainer.constants.constants import DATASET_PATH, MODEL_PATH

Expand Down
62 changes: 62 additions & 0 deletions python/kubeflow/trainer/api/abstract_trainer_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright 2025 The Kubeflow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Dict, List, Optional

from kubeflow.trainer.constants import constants
from kubeflow.trainer.types import types


class AbstractTrainerClient(ABC):
@abstractmethod
def delete_job(self, name: str):
pass

@abstractmethod
def get_job(self, name: str) -> types.TrainJob:
pass

@abstractmethod
def get_job_logs(
self,
name: str,
follow: Optional[bool] = False,
step: str = constants.NODE,
node_rank: int = 0,
) -> Dict[str, str]:
pass

@abstractmethod
def get_runtime(self, name: str) -> types.Runtime:
pass

@abstractmethod
def list_jobs(
self, runtime: Optional[types.Runtime] = None
) -> List[types.TrainJob]:
pass

@abstractmethod
def list_runtimes(self) -> List[types.Runtime]:
pass

@abstractmethod
def train(
self,
runtime: types.Runtime = types.DEFAULT_RUNTIME,
initializer: Optional[types.Initializer] = None,
trainer: Optional[types.CustomTrainer] = None,
) -> str:
pass
244 changes: 244 additions & 0 deletions python/kubeflow/trainer/api/local_trainer_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
# Copyright 2025 The Kubeflow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from importlib import resources
from pathlib import Path
from typing import Dict, List, Optional

import yaml
from kubeflow.trainer import models
from kubeflow.trainer.api.abstract_trainer_client import AbstractTrainerClient
from kubeflow.trainer.constants import constants
from kubeflow.trainer.job_runners import DockerJobRunner, JobRunner
from kubeflow.trainer.types import types
from kubeflow.trainer.utils import utils


class LocalTrainerClient(AbstractTrainerClient):
"""LocalTrainerClient exposes functionality for running training jobs locally.

A Kubernetes cluster is not required.
It exposes the same interface as the TrainerClient.

Args:
local_runtimes_path: The path to the directory containing runtime YAML files.
Defaults to the runtimes included with the package.
job_runner: The job runner to use for local training.
Options include the DockerJobRunner and PodmanJobRunner.
Defaults to the Docker job runner.
"""

def __init__(
self,
local_runtimes_path: Optional[Path] = None,
job_runner: Optional[JobRunner] = None,
):
print(
"Warning: LocalTrainerClient is an alpha feature for Kubeflow Trainer. "
"Some features may be unstable or unimplemented."
)

if local_runtimes_path is None:
self.local_runtimes_path = (
resources.files(constants.PACKAGE_NAME) / constants.LOCAL_RUNTIMES_PATH
)
else:
self.local_runtimes_path = local_runtimes_path

if job_runner is None:
self.job_runner = DockerJobRunner()
else:
self.job_runner = job_runner

def list_runtimes(self) -> List[types.Runtime]:
"""Lists all runtimes.

Returns:
A list of runtime objects.
"""
runtimes = []
for cr in self.__list_runtime_crs():
runtimes.append(utils.get_runtime_from_crd(cr))
return runtimes

def get_runtime(self, name: str) -> types.Runtime:
"""Get a specific runtime by name.

Args:
name: The name of the runtime.

Returns:
A runtime object.

Raises:
RuntimeError: if the specified runtime cannot be found.
"""
for r in self.list_runtimes():
if r.name == name:
return r
raise RuntimeError(f"No runtime found with name '{name}'")

def train(
self,
runtime: types.Runtime = types.DEFAULT_RUNTIME,
initializer: Optional[types.Initializer] = None,
trainer: Optional[types.CustomTrainer] = None,
) -> str:
"""Starts a training job.

Args:
runtime: Config for the train job's runtime.
trainer: Config for the function that encapsulates the model training process.
initializer: Config for dataset and model initialization.

Returns:
The generated name of the training job.

Raises:
RuntimeError: if the specified runtime cannot be found,
or the runtime container cannot be found,
or the runtime container image is not specified.
"""
runtime_cr = self.__get_runtime_cr(runtime.name)
if runtime_cr is None:
raise RuntimeError(f"No runtime found with name '{runtime.name}'")

runtime_container = utils.get_runtime_trainer_container(
runtime_cr.spec.template.spec.replicated_jobs
)
if runtime_container is None:
raise RuntimeError("No runtime container found")

image = runtime_container.image
if image is None:
raise RuntimeError("No runtime container image specified")

if trainer and trainer.func:
entrypoint, command = utils.get_entrypoint_using_train_func(
runtime,
trainer.func,
trainer.func_args,
trainer.pip_index_url,
trainer.packages_to_install,
)
else:
entrypoint = runtime_container.command
command = runtime_container.args

if trainer and trainer.num_nodes:
num_nodes = trainer.num_nodes
else:
num_nodes = 1

train_job_name = self.job_runner.create_job(
image=image,
entrypoint=entrypoint,
command=command,
num_nodes=num_nodes,
framework=runtime.trainer.framework,
runtime_name=runtime.name,
)
return train_job_name

def list_jobs(
self, runtime: Optional[types.Runtime] = None
) -> List[types.TrainJob]:
"""Lists all training jobs.

Args:
runtime: If provided, only return jobs that use the given runtime.

Returns:
A list of training jobs.
"""
runtime_name = runtime.name if runtime else None
container_jobs = self.job_runner.list_jobs(runtime_name)

train_jobs = []
for container_job in container_jobs:
train_jobs.append(self.__container_job_to_train_job(container_job))
return train_jobs

def get_job(self, name: str) -> types.TrainJob:
"""Get a specific training job by name.

Args:
name: The name of the training job to get.

Returns:
A training job.
"""
container_job = self.job_runner.get_job(name)
return self.__container_job_to_train_job(container_job)

def get_job_logs(
self,
name: str,
follow: Optional[bool] = False,
step: str = constants.NODE,
node_rank: int = 0,
) -> Dict[str, str]:
"""Gets logs for the specified training job
Args:
name (str): The name of the training job
follow (bool): If true, follows job logs and prints them to standard out (default False)
step (int): The training job step to target (default "node")
node_rank (int): The node rank to retrieve logs from (default 0)

Returns:
Dict[str, str]: The logs of the training job, where the key is the
step and node rank, and the value is the logs for that node.
"""
return self.job_runner.get_job_logs(
job_name=name, follow=follow, step=step, node_rank=node_rank
)

def delete_job(self, name: str):
"""Deletes a specific training job.

Args:
name: The name of the training job to delete.
"""
self.job_runner.delete_job(job_name=name)

def __list_runtime_crs(self) -> List[models.TrainerV1alpha1ClusterTrainingRuntime]:
runtime_crs = []
for filename in self.local_runtimes_path.iterdir():
with open(filename, "r") as f:
cr_str = f.read()
cr_dict = yaml.safe_load(cr_str)
cr = models.TrainerV1alpha1ClusterTrainingRuntime.from_dict(cr_dict)
if cr is not None:
runtime_crs.append(cr)
return runtime_crs

def __get_runtime_cr(
self,
name: str,
) -> Optional[models.TrainerV1alpha1ClusterTrainingRuntime]:
for cr in self.__list_runtime_crs():
if cr.metadata.name == name:
return cr
return None

def __container_job_to_train_job(
self, container_job: types.ContainerJob
) -> types.TrainJob:
return types.TrainJob(
name=container_job.name,
creation_timestamp=container_job.creation_timestamp,
steps=[container.to_step() for container in container_job.containers],
runtime=self.get_runtime(container_job.runtime_name),
status=container_job.status,
)
36 changes: 5 additions & 31 deletions python/kubeflow/trainer/api/trainer_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,10 @@
import logging
import multiprocessing
import queue
import random
import string
import uuid
from typing import Dict, List, Optional

import kubeflow.trainer.models as models
from kubeflow.trainer.api.abstract_trainer_client import AbstractTrainerClient
from kubeflow.trainer.constants import constants
from kubeflow.trainer.types import types
from kubeflow.trainer.utils import utils
Expand All @@ -29,7 +27,7 @@
logger = logging.getLogger(__name__)


class TrainerClient:
class TrainerClient(AbstractTrainerClient):
def __init__(
self,
config_file: Optional[str] = None,
Expand Down Expand Up @@ -105,7 +103,7 @@ def list_runtimes(self) -> List[types.Runtime]:
return result

for runtime in runtime_list.items:
result.append(self.__get_runtime_from_crd(runtime))
result.append(utils.get_runtime_from_crd(runtime))

except multiprocessing.TimeoutError:
raise TimeoutError(
Expand Down Expand Up @@ -147,7 +145,7 @@ def get_runtime(self, name: str) -> types.Runtime:
f"{self.namespace}/{name}"
)

return self.__get_runtime_from_crd(runtime) # type: ignore
return utils.get_runtime_from_crd(runtime) # type: ignore

def train(
self,
Expand Down Expand Up @@ -179,7 +177,7 @@ def train(

# Generate unique name for the TrainJob.
# TODO (andreyvelich): Discuss this TrainJob name generation.
train_job_name = random.choice(string.ascii_lowercase) + uuid.uuid4().hex[:11]
train_job_name = utils.generate_train_job_name()

# Build the Trainer.
trainer_crd = models.TrainerV1alpha1Trainer()
Expand Down Expand Up @@ -463,30 +461,6 @@ def delete_job(self, name: str):
f"{constants.TRAINJOB_KIND} {self.namespace}/{name} has been deleted"
)

def __get_runtime_from_crd(
self,
runtime_crd: models.TrainerV1alpha1ClusterTrainingRuntime,
) -> types.Runtime:

if not (
runtime_crd.metadata
and runtime_crd.metadata.name
and runtime_crd.spec
and runtime_crd.spec.ml_policy
and runtime_crd.spec.template.spec
and runtime_crd.spec.template.spec.replicated_jobs
):
raise Exception(f"ClusterTrainingRuntime CRD is invalid: {runtime_crd}")

return types.Runtime(
name=runtime_crd.metadata.name,
trainer=utils.get_runtime_trainer(
runtime_crd.spec.template.spec.replicated_jobs,
runtime_crd.spec.ml_policy,
runtime_crd.metadata,
),
)

def __get_trainjob_from_crd(
self,
trainjob_crd: models.TrainerV1alpha1TrainJob,
Expand Down
Loading