Skip to content
Merged
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
3524abc
Implement TrainerClient Backends & Local Process
szaher Jun 21, 2025
0f4e504
Merge branch 'main' of github.com:kubeflow/sdk into training-backends
szaher Jun 21, 2025
908af68
Implement Job Cancellation
szaher Jun 25, 2025
71e83ae
Merge branch 'main' into training-backends
szaher Jul 3, 2025
3d578c7
update local job to add resouce limitation in k8s style
szaher Jul 9, 2025
bed8f70
Update python/kubeflow/trainer/api/trainer_client.py
szaher Jul 9, 2025
3781c23
Merge with latest changes from main
szaher Aug 11, 2025
28db17f
Fix linting issues
szaher Aug 12, 2025
7977cc4
fix unit tests
szaher Aug 12, 2025
da0ce2f
add support wait_for_job_status
szaher Aug 12, 2025
ca564d6
Update data types
szaher Aug 19, 2025
d9af6f2
fix merge conflict
szaher Aug 19, 2025
2383c52
Merge branch 'main' into training-backends
szaher Aug 19, 2025
46961ba
fix unit tests
szaher Aug 20, 2025
e226167
remove TypeAlias
szaher Aug 20, 2025
2ef70db
Replace TRAINER_BACKEND_REGISTRY with TRAINER_BACKEND
szaher Aug 20, 2025
822a262
Update kubeflow/trainer/api/trainer_client.py
szaher Aug 21, 2025
f00280a
Update kubeflow/trainer/api/trainer_client.py
szaher Aug 21, 2025
e0c714f
Restructure training backends into separate dirs
szaher Aug 22, 2025
1dbc3e9
Update kubeflow/trainer/api/trainer_client.py
szaher Aug 22, 2025
460aae2
Merge branch 'main' into training-backends
szaher Sep 6, 2025
46a5fd7
add get_runtime_packages as not supported by local-exec
szaher Sep 7, 2025
ea3e9cf
move backends and its configs to kubeflow.trainer
szaher Sep 7, 2025
2976f8a
fix typo in delete_job
szaher Sep 8, 2025
e4a57b3
Move local_runtimes to constants
szaher Sep 8, 2025
8d6b1e7
use google style docstring for LocalJob
szaher Sep 8, 2025
c3719b5
remove debug opt from LocalProcessConfig
szaher Sep 9, 2025
64cdcba
only use imports from kubeflow.trainer for backends
szaher Sep 9, 2025
511b22b
upload local-exec to use only one step
szaher Sep 13, 2025
74d60a4
optimize loops when getting runtime
szaher Sep 13, 2025
9d9a14c
add LocalRuntimeTrainer
szaher Sep 14, 2025
60d96d0
rename cleanup config item to cleanup_venv
szaher Sep 14, 2025
8e9190e
convert local runtime to runtime
szaher Sep 14, 2025
ac0be0c
convert runtimes before returning
szaher Sep 14, 2025
4b1db8c
fix get_job_logs to align with parent interface
szaher Sep 14, 2025
4fe7baa
rename get_runtime_trainer func
szaher Sep 14, 2025
9775f3f
rename get_training_job_command to get_local_train_job_script
szaher Sep 14, 2025
42c7769
Ignore failures in Coveralls action
andreyvelich Sep 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions kubeflow/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@
TrainerType,
)

# import backends and its associated configs
from kubeflow.trainer.backends.kubernetes.types import KubernetesBackendConfig
from kubeflow.trainer.backends.localprocess.types import LocalProcessBackendConfig


__all__ = [
"BuiltinTrainer",
"CustomTrainer",
Expand All @@ -55,4 +60,6 @@
"RuntimeTrainer",
"TrainerClient",
"TrainerType",
"LocalProcessBackendConfig",
"KubernetesBackendConfig",
]
8 changes: 7 additions & 1 deletion kubeflow/trainer/api/trainer_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from kubeflow.trainer.types import types
from kubeflow.trainer.backends.kubernetes.backend import KubernetesBackend
from kubeflow.trainer.backends.kubernetes.types import KubernetesBackendConfig
from kubeflow.trainer.backends.localprocess.backend import LocalProcessBackend
from kubeflow.trainer.backends.localprocess.backend import LocalProcessBackendConfig


logger = logging.getLogger(__name__)
Expand All @@ -27,7 +29,9 @@
class TrainerClient:
def __init__(
self,
backend_config: KubernetesBackendConfig = KubernetesBackendConfig(),
backend_config: Union[
KubernetesBackendConfig, LocalProcessBackendConfig
] = KubernetesBackendConfig(),
):
"""Initialize a Kubeflow Trainer client.
Expand All @@ -43,6 +47,8 @@ def __init__(
# initialize training backend
if isinstance(backend_config, KubernetesBackendConfig):
self.backend = KubernetesBackend(backend_config)
elif isinstance(backend_config, LocalProcessBackendConfig):
self.backend = LocalProcessBackend(backend_config)
else:
raise ValueError("Invalid backend config '{}'".format(backend_config))

Expand Down
Empty file.
257 changes: 257 additions & 0 deletions kubeflow/trainer/backends/localprocess/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
# Copyright 2025 The Kubeflow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import string
import tempfile
import uuid
import random
from datetime import datetime
from typing import List, Optional, Set, Union, Iterator

from kubeflow.trainer.constants import constants
from kubeflow.trainer.types import types
from kubeflow.trainer.backends.base import ExecutionBackend
from kubeflow.trainer.backends.localprocess.types import (
LocalProcessBackendConfig,
LocalBackendJobs,
LocalBackendStep,
)
from kubeflow.trainer.backends.localprocess.constants import local_runtimes
from kubeflow.trainer.backends.localprocess.job import LocalJob
from kubeflow.trainer.backends.localprocess import utils as local_utils

logger = logging.getLogger(__name__)


class LocalProcessBackend(ExecutionBackend):
def __init__(
self,
cfg: LocalProcessBackendConfig,
):
# list of running subprocesses
self.__local_jobs: List[LocalBackendJobs] = []
self.cfg = cfg

def list_runtimes(self) -> List[types.Runtime]:
return [self.__convert_local_runtime_to_runtime(local_runtime=rt) for rt in local_runtimes]

def get_runtime(self, name: str) -> types.Runtime:
runtime = next(
(
self.__convert_local_runtime_to_runtime(rt)
for rt in local_runtimes
if rt.name == name
),
None,
)
if not runtime:
raise ValueError(f"Runtime '{name}' not found.")

return runtime

def get_runtime_packages(self, runtime: types.Runtime):
runtime = next((rt for rt in local_runtimes if rt.name == runtime.name), None)
if not runtime:
raise ValueError(f"Runtime '{runtime.name}' not found.")

return runtime.trainer.packages

def train(
self,
runtime: Optional[types.Runtime] = None,
initializer: Optional[types.Initializer] = None,
trainer: Optional[Union[types.CustomTrainer, types.BuiltinTrainer]] = None,
) -> str:
# set train job name
train_job_name = random.choice(string.ascii_lowercase) + uuid.uuid4().hex[:11]
# localprocess backend only supports CustomTrainer
if not isinstance(trainer, types.CustomTrainer):
raise ValueError("CustomTrainer must be set with LocalProcessBackend")

# create temp dir
venv_dir = tempfile.mkdtemp(prefix=train_job_name)
logger.debug("operating in {}".format(venv_dir))

runtime.trainer = local_utils.get_local_runtime_trainer(
runtime_name=runtime.name,
venv_dir=venv_dir,
framework=runtime.trainer.framework,
)

# build training job command
training_command = local_utils.get_training_job_command(
trainer=trainer,
runtime=runtime,
train_job_name=train_job_name,
venv_dir=venv_dir,
cleanup_venv=self.cfg.cleanup_venv,
)

# set the command in the runtime trainer
runtime.trainer.set_command(training_command)

# create subprocess object
train_job = LocalJob(
name="{}-train".format(train_job_name),
command=training_command,
execution_dir=venv_dir,
env=trainer.env,
dependencies=[],
)

self.__register_job(
train_job_name=train_job_name,
step_name="train",
job=train_job,
runtime=runtime,
)
# start the job.
train_job.start()

return train_job_name

def list_jobs(self, runtime: Optional[types.Runtime] = None) -> List[types.TrainJob]:
result = []

for _job in self.__local_jobs:
if runtime and _job.runtime.name != runtime.name:
continue
result.append(
types.TrainJob(
name=_job.name,
creation_timestamp=_job.created,
runtime=runtime,
num_nodes=1,
steps=[
types.Step(name=s.step_name, pod_name=s.step_name, status=s.job.status)
for s in _job.steps
],
)
)
return result

def get_job(self, name: str) -> Optional[types.TrainJob]:
_job = next((j for j in self.__local_jobs if j.name == name), None)
if _job is None:
raise ValueError("No TrainJob with name '%s'" % name)

# check and set the correct job status to match `TrainerClient` supported statuses
status = self.__get_job_status(_job)

return types.TrainJob(
name=_job.name,
creation_timestamp=_job.created,
steps=[
types.Step(name=_step.step_name, pod_name=_step.step_name, status=_step.job.status)
for _step in _job.steps
],
runtime=_job.runtime,
num_nodes=1,
status=status,
)

def get_job_logs(
self,
name: str,
step: str = constants.NODE + "-0",
follow: Optional[bool] = False,
) -> Iterator[str]:
_job = [j for j in self.__local_jobs if j.name == name]
if not _job:
raise ValueError("No TrainJob with name '%s'" % name)

want_all_steps = step == constants.NODE + "-0"

for _step in _job[0].steps:
if not want_all_steps and _step.step_name != step:
continue
# Flatten the generator and pass through flags so it behaves as expected
# (adjust args if stream_logs has different signature)
yield from _step.job.logs(follow=follow)

def wait_for_job_status(
self,
name: str,
status: Set[str] = {constants.TRAINJOB_COMPLETE},
timeout: int = 600,
polling_interval: int = 2,
) -> types.TrainJob:
# find first match or fallback
_job = next((_job for _job in self.__local_jobs if _job.name == name), None)

if _job is None:
raise ValueError("No TrainJob with name '%s'" % name)
# find a better implementation for this
for _step in _job.steps:
if _step.job.status in [constants.TRAINJOB_RUNNING, constants.TRAINJOB_CREATED]:
_step.job.join(timeout=timeout)
return self.get_job(name)

def delete_job(self, name: str):
# find job first.
_job = next((j for j in self.__local_jobs if j.name == name), None)
if _job is None:
raise ValueError("No TrainJob with name '%s'" % name)

# cancel all nested step jobs in target job
_ = [step.job.cancel() for step in _job.steps]
# remove the job from the list of jobs
self.__local_jobs.remove(_job)

def __get_job_status(self, job: LocalBackendJobs) -> str:
statuses = [_step.job.status for _step in job.steps]
# if status is running or failed will take precedence over completed
if constants.TRAINJOB_FAILED in statuses:
status = constants.TRAINJOB_FAILED
elif constants.TRAINJOB_RUNNING in statuses:
status = constants.TRAINJOB_RUNNING
elif constants.TRAINJOB_CREATED in statuses:
status = constants.TRAINJOB_CREATED
else:
status = constants.TRAINJOB_CREATED

return status

def __register_job(
self,
train_job_name: str,
step_name: str,
job: LocalJob,
runtime: types.Runtime = None,
):
_job = [j for j in self.__local_jobs if j.name == train_job_name]
if not _job:
_job = LocalBackendJobs(name=train_job_name, runtime=runtime, created=datetime.now())
self.__local_jobs.append(_job)
else:
_job = _job[0]
_step = [s for s in _job.steps if s.step_name == step_name]
if not _step:
_step = LocalBackendStep(step_name=step_name, job=job)
_job.steps.append(_step)
else:
logger.warning("Step '{}' already registered.".format(step_name))

def __convert_local_runtime_to_runtime(self, local_runtime) -> types.Runtime:
return types.Runtime(
name=local_runtime.name,
trainer=types.RuntimeTrainer(
trainer_type=local_runtime.trainer.trainer_type,
framework=local_runtime.trainer.framework,
num_nodes=local_runtime.trainer.num_nodes,
device_count=local_runtime.trainer.device_count,
device=local_runtime.trainer.device,
),
pretrained_model=local_runtime.pretrained_model,
)
84 changes: 84 additions & 0 deletions kubeflow/trainer/backends/localprocess/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright 2025 The Kubeflow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import textwrap
import re
from kubeflow.trainer.types import types as base_types
from kubeflow.trainer.constants import constants
from kubeflow.trainer.backends.localprocess import types

TORCH_FRAMEWORK_TYPE = "torch"

local_runtimes = [
base_types.Runtime(
name=constants.TORCH_RUNTIME,
trainer=types.LocalRuntimeTrainer(
trainer_type=base_types.TrainerType.CUSTOM_TRAINER,
framework=TORCH_FRAMEWORK_TYPE,
num_nodes=1,
device_count=constants.UNKNOWN,
device=constants.UNKNOWN,
packages=["torch"],
),
)
]


# Create venv script


# The exec script to embed training function into container command.
DEPENDENCIES_SCRIPT = textwrap.dedent(
"""
PIP_DISABLE_PIP_VERSION_CHECK=1 pip install $QUIET \
--no-warn-script-location $PIP_INDEX $PACKAGE_STR
"""
)

# activate virtualenv, then run the entrypoint from the virtualenv bin
LOCAL_EXEC_ENTRYPOINT = textwrap.dedent(
"""
$ENTRYPOINT "$FUNC_FILE" "$PARAMETERS"
"""
)

TORCH_COMMAND = "torchrun"

# default command, will run from within the virtualenv
DEFAULT_COMMAND = "python"

# remove virtualenv after training is completed.
LOCAL_EXEC_JOB_CLEANUP_SCRIPT = textwrap.dedent(
"""
rm -rf $PYENV_LOCATION
"""
)


LOCAL_EXEC_JOB_TEMPLATE = textwrap.dedent(
"""
set -e
$OS_PYTHON_BIN -m venv --without-pip $PYENV_LOCATION
echo "Operating inside $PYENV_LOCATION"
source $PYENV_LOCATION/bin/activate
$PYENV_LOCATION/bin/python -m ensurepip --upgrade --default-pip
$DEPENDENCIES_SCRIPT
$ENTRYPOINT
$CLEANUP_SCRIPT
"""
)

LOCAL_EXEC_FILENAME = "train_{}.py"

PYTHON_PACKAGE_NAME_RE = re.compile(r"^\s*([A-Za-z0-9][A-Za-z0-9._-]*)")
Loading
Loading