Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
3524abc
Implement TrainerClient Backends & Local Process
szaher Jun 21, 2025
0f4e504
Merge branch 'main' of github.com:kubeflow/sdk into training-backends
szaher Jun 21, 2025
908af68
Implement Job Cancellation
szaher Jun 25, 2025
71e83ae
Merge branch 'main' into training-backends
szaher Jul 3, 2025
3d578c7
update local job to add resouce limitation in k8s style
szaher Jul 9, 2025
bed8f70
Update python/kubeflow/trainer/api/trainer_client.py
szaher Jul 9, 2025
3781c23
Merge with latest changes from main
szaher Aug 11, 2025
28db17f
Fix linting issues
szaher Aug 12, 2025
7977cc4
fix unit tests
szaher Aug 12, 2025
da0ce2f
add support wait_for_job_status
szaher Aug 12, 2025
ca564d6
Update data types
szaher Aug 19, 2025
d9af6f2
fix merge conflict
szaher Aug 19, 2025
2383c52
Merge branch 'main' into training-backends
szaher Aug 19, 2025
46961ba
fix unit tests
szaher Aug 20, 2025
e226167
remove TypeAlias
szaher Aug 20, 2025
2ef70db
Replace TRAINER_BACKEND_REGISTRY with TRAINER_BACKEND
szaher Aug 20, 2025
822a262
Update kubeflow/trainer/api/trainer_client.py
szaher Aug 21, 2025
f00280a
Update kubeflow/trainer/api/trainer_client.py
szaher Aug 21, 2025
e0c714f
Restructure training backends into separate dirs
szaher Aug 22, 2025
1dbc3e9
Update kubeflow/trainer/api/trainer_client.py
szaher Aug 22, 2025
460aae2
Merge branch 'main' into training-backends
szaher Sep 6, 2025
46a5fd7
add get_runtime_packages as not supported by local-exec
szaher Sep 7, 2025
ea3e9cf
move backends and its configs to kubeflow.trainer
szaher Sep 7, 2025
2976f8a
fix typo in delete_job
szaher Sep 8, 2025
e4a57b3
Move local_runtimes to constants
szaher Sep 8, 2025
8d6b1e7
use google style docstring for LocalJob
szaher Sep 8, 2025
c3719b5
remove debug opt from LocalProcessConfig
szaher Sep 9, 2025
64cdcba
only use imports from kubeflow.trainer for backends
szaher Sep 9, 2025
511b22b
upload local-exec to use only one step
szaher Sep 13, 2025
74d60a4
optimize loops when getting runtime
szaher Sep 13, 2025
9d9a14c
add LocalRuntimeTrainer
szaher Sep 14, 2025
60d96d0
rename cleanup config item to cleanup_venv
szaher Sep 14, 2025
8e9190e
convert local runtime to runtime
szaher Sep 14, 2025
ac0be0c
convert runtimes before returning
szaher Sep 14, 2025
4b1db8c
fix get_job_logs to align with parent interface
szaher Sep 14, 2025
4fe7baa
rename get_runtime_trainer func
szaher Sep 14, 2025
9775f3f
rename get_training_job_command to get_local_train_job_script
szaher Sep 14, 2025
42c7769
Ignore failures in Coveralls action
andreyvelich Sep 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
771 changes: 86 additions & 685 deletions kubeflow/trainer/api/trainer_client.py

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions kubeflow/trainer/api/trainer_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@
from unittest.mock import Mock, patch

import pytest
from kubeflow.trainer import TrainerClient
from kubeflow.trainer.backends import K8SBackend
from kubeflow.trainer.constants import constants
from kubeflow.trainer.types import types
from kubeflow.trainer.types.backends import K8SBackendConfig
from kubeflow.trainer.utils import utils
from kubeflow_trainer_api import models

Expand Down Expand Up @@ -100,7 +101,7 @@ def trainer_client(request):
read_namespaced_pod_log=Mock(side_effect=mock_read_namespaced_pod_log),
),
):
yield TrainerClient()
yield K8SBackend(K8SBackendConfig())


# --------------------------
Expand Down
28 changes: 28 additions & 0 deletions kubeflow/trainer/backends/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright 2025 The Kubeflow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from kubeflow.trainer.backends.k8s import K8SBackend
from kubeflow.trainer.backends.local_process import LocalProcessBackend
from kubeflow.trainer.types.backends import K8SBackendConfig, LocalProcessBackendConfig

TRAINER_BACKEND_REGISTRY = {
"kubernetes": {
"backend_cls": K8SBackend,
"config_cls": K8SBackendConfig,
},
"local": {
"backend_cls": LocalProcessBackend,
"config_cls": LocalProcessBackendConfig,
}
}
88 changes: 88 additions & 0 deletions kubeflow/trainer/backends/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright 2025 The Kubeflow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import abc

from typing import Dict, List, Optional, Set
from kubeflow.trainer.constants import constants
from kubeflow.trainer.types import types


class TrainingBackend(abc.ABC):

@abc.abstractmethod
def list_runtimes(self) -> List[types.Runtime]:
raise NotImplementedError()

@abc.abstractmethod
def get_runtime(self, name: str) -> Optional[types.Runtime]:
raise NotImplementedError()

@abc.abstractmethod
def train(self,
runtime: types.Runtime,
initializer: Optional[types.Initializer] = None,
trainer: Optional[types.RuntimeTrainer] = None,
) -> str:
raise NotImplementedError()

@abc.abstractmethod
def list_jobs(
self, runtime: Optional[types.Runtime] = None
) -> List[types.TrainJob]:
raise NotImplementedError()

@abc.abstractmethod
def get_job(self, name: str) -> Optional[types.TrainJob]:
raise NotImplementedError()

@abc.abstractmethod
def get_job_logs(self,
name: str,
follow: Optional[bool] = False,
step: str = constants.NODE,
node_rank: int = 0,
) -> Dict[str, str]:
raise NotImplementedError()

@abc.abstractmethod
def delete_job(self, name: str) -> None:
raise NotImplementedError()

@abc.abstractmethod
def wait_for_job_status(
self,
name: str,
status: Set[str] = {constants.TRAINJOB_COMPLETE},
timeout: int = 600,
polling_interval: int = 2,
) -> types.TrainJob:
"""Wait for TrainJob to reach the desired status

Args:
name: Name of the TrainJob.
status: Set of expected statuses. It must be subset of Created, Running, Complete, and
Failed statuses.
timeout: How many seconds to wait until TrainJob reaches one of the expected conditions.
polling_interval: The polling interval in seconds to check TrainJob status.

Returns:
TrainJob: The training job that reaches the desired status.

Raises:
ValueError: The input values are incorrect.
RuntimeError: Failed to get TrainJob or TrainJob reaches unexpected Failed status.
TimeoutError: Timeout to wait for TrainJob status.
"""
raise NotImplementedError()
Loading