Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@ repos:
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.4
# Pin to the latest release as of today
rev: v0.12.10
hooks:
- id: ruff
exclude: |
(?x)^(
kubeflow/trainer/__init__.py|
kubeflow/trainer/api/__init__.py|
kubeflow/trainer/models/.*|
)$
# Lint + auto-fix (must run before format)
- id: ruff-check
args: [ --fix ]
# Format after fixes
- id: ruff-format
5 changes: 3 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,13 @@ uv: ## Install UV

.PHONY: ruff
ruff: ## Install Ruff
@uvx ruff --help &> /dev/null || uv tool install ruff
@uv run ruff --help &> /dev/null || uv tool install ruff

.PHONY: verify
verify: install-dev ## install all required tools
@uv lock --check
@uvx ruff check --show-fixes
@uv run ruff check --show-fixes --output-format=github .
@uv run ruff format --check kubeflow

.PHONY: uv-venv
uv-venv:
Expand Down
13 changes: 5 additions & 8 deletions kubeflow/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
# limitations under the License.


from __future__ import absolute_import

# Import the Kubeflow Trainer client.
from kubeflow.trainer.api.trainer_client import TrainerClient # noqa: F401

# import backends and its associated configs
from kubeflow.trainer.backends.kubernetes.types import KubernetesBackendConfig
from kubeflow.trainer.backends.localprocess.types import LocalProcessBackendConfig

# Import the Kubeflow Trainer constants.
from kubeflow.trainer.constants.constants import DATASET_PATH, MODEL_PATH # noqa: F401

Expand All @@ -32,17 +34,12 @@
Initializer,
Loss,
Runtime,
RuntimeTrainer,
TorchTuneConfig,
TorchTuneInstructDataset,
RuntimeTrainer,
TrainerType,
)

# import backends and its associated configs
from kubeflow.trainer.backends.kubernetes.types import KubernetesBackendConfig
from kubeflow.trainer.backends.localprocess.types import LocalProcessBackendConfig


__all__ = [
"BuiltinTrainer",
"CustomTrainer",
Expand Down
1 change: 0 additions & 1 deletion kubeflow/trainer/api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# ruff: noqa

# import apis into api package

23 changes: 13 additions & 10 deletions kubeflow/trainer/api/trainer_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from collections.abc import Iterator
import logging
from typing import Optional, Union, Iterator
from typing import Optional, Union

from kubeflow.trainer.constants import constants
from kubeflow.trainer.types import types
from kubeflow.trainer.backends.kubernetes.backend import KubernetesBackend
from kubeflow.trainer.backends.kubernetes.types import KubernetesBackendConfig
from kubeflow.trainer.backends.localprocess.backend import LocalProcessBackend
from kubeflow.trainer.backends.localprocess.backend import LocalProcessBackendConfig

from kubeflow.trainer.backends.localprocess.backend import (
LocalProcessBackend,
LocalProcessBackendConfig,
)
from kubeflow.trainer.constants import constants
from kubeflow.trainer.types import types

logger = logging.getLogger(__name__)


class TrainerClient:
def __init__(
self,
backend_config: Union[
KubernetesBackendConfig, LocalProcessBackendConfig
] = KubernetesBackendConfig(),
backend_config: Union[KubernetesBackendConfig, LocalProcessBackendConfig] = None,
):
"""Initialize a Kubeflow Trainer client.

Expand All @@ -45,12 +45,15 @@ def __init__(

"""
# initialize training backend
if not backend_config:
backend_config = KubernetesBackendConfig()

if isinstance(backend_config, KubernetesBackendConfig):
self.backend = KubernetesBackend(backend_config)
elif isinstance(backend_config, LocalProcessBackendConfig):
self.backend = LocalProcessBackend(backend_config)
else:
raise ValueError("Invalid backend config '{}'".format(backend_config))
raise ValueError(f"Invalid backend config '{backend_config}'")

def list_runtimes(self) -> list[types.Runtime]:
"""List of the available runtimes.
Expand Down
12 changes: 11 additions & 1 deletion kubeflow/trainer/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,27 @@
# limitations under the License.

import abc
from collections.abc import Iterator
from typing import Optional, Union

from typing import Optional, Union, Iterator
from kubeflow.trainer.constants import constants
from kubeflow.trainer.types import types


class ExecutionBackend(abc.ABC):
@abc.abstractmethod
def list_runtimes(self) -> list[types.Runtime]:
raise NotImplementedError()

@abc.abstractmethod
def get_runtime(self, name: str) -> types.Runtime:
raise NotImplementedError()

@abc.abstractmethod
def get_runtime_packages(self, runtime: types.Runtime):
raise NotImplementedError()

@abc.abstractmethod
def train(
self,
runtime: Optional[types.Runtime] = None,
Expand All @@ -37,12 +42,15 @@ def train(
) -> str:
raise NotImplementedError()

@abc.abstractmethod
def list_jobs(self, runtime: Optional[types.Runtime] = None) -> list[types.TrainJob]:
raise NotImplementedError()

@abc.abstractmethod
def get_job(self, name: str) -> types.TrainJob:
raise NotImplementedError()

@abc.abstractmethod
def get_job_logs(
self,
name: str,
Expand All @@ -51,6 +59,7 @@ def get_job_logs(
) -> Iterator[str]:
raise NotImplementedError()

@abc.abstractmethod
def wait_for_job_status(
self,
name: str,
Expand All @@ -60,5 +69,6 @@ def wait_for_job_status(
) -> types.TrainJob:
raise NotImplementedError()

@abc.abstractmethod
def delete_job(self, name: str):
raise NotImplementedError()
29 changes: 16 additions & 13 deletions kubeflow/trainer/backends/kubernetes/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from collections.abc import Iterator
import copy
import logging
import multiprocessing
import random
import re
import string
import time
from typing import Optional, Union
import uuid
from typing import Optional, Union, Iterator
import re

from kubeflow.trainer.constants import constants
from kubeflow.trainer.types import types
from kubeflow.trainer.utils import utils
from kubeflow_trainer_api import models
from kubernetes import client, config, watch

from kubeflow.trainer.backends.base import ExecutionBackend
from kubeflow.trainer.backends.kubernetes import types as k8s_types
from kubeflow.trainer.constants import constants
from kubeflow.trainer.types import types
from kubeflow.trainer.utils import utils

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -141,8 +143,8 @@ def get_runtime_packages(self, runtime: types.Runtime):
runtime_copy.trainer.set_command(tuple(mpi_command))

def print_packages():
import subprocess
import shutil
import subprocess
import sys

# Print Python version.
Expand Down Expand Up @@ -353,17 +355,15 @@ def get_job_logs(
)

# Stream logs incrementally.
for logline in log_stream:
yield logline # type:ignore
yield from log_stream
else:
logs = self.core_api.read_namespaced_pod_log(
name=pod_name,
namespace=self.namespace,
container=container_name,
)

for line in logs.splitlines():
yield line
yield from logs.splitlines()

except Exception as e:
raise RuntimeError(
Expand Down Expand Up @@ -554,9 +554,12 @@ def __get_trainjob_from_crd(
# Update the TrainJob status from its conditions.
if trainjob_crd.status and trainjob_crd.status.conditions:
for c in trainjob_crd.status.conditions:
if c.type == constants.TRAINJOB_COMPLETE and c.status == "True":
trainjob.status = c.type
elif c.type == constants.TRAINJOB_FAILED and c.status == "True":
if (
c.type == constants.TRAINJOB_COMPLETE
and c.status == "True"
or c.type == constants.TRAINJOB_FAILED
and c.status == "True"
):
trainjob.status = c.type
else:
# The TrainJob running status is defined when all training node (e.g. Pods) are
Expand Down
21 changes: 10 additions & 11 deletions kubeflow/trainer/backends/kubernetes/backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,31 +19,31 @@
It tests KubernetesBackend's behavior across job listing, resource creation etc
"""

from dataclasses import asdict
import datetime
import multiprocessing
import random
import string
import uuid
from dataclasses import asdict
from typing import Optional
from unittest.mock import Mock, patch
import uuid

import pytest
from kubeflow_trainer_api import models
import pytest

from kubeflow.trainer.constants import constants
from kubeflow.trainer.types import types
from kubeflow.trainer.utils import utils
from kubeflow.trainer.backends.kubernetes.backend import KubernetesBackend
from kubeflow.trainer.backends.kubernetes.types import KubernetesBackendConfig
from kubeflow.trainer.test.common import TestCase
from kubeflow.trainer.constants import constants
from kubeflow.trainer.test.common import (
SUCCESS,
FAILED,
DEFAULT_NAMESPACE,
TIMEOUT,
FAILED,
RUNTIME,
SUCCESS,
TIMEOUT,
TestCase,
)
from kubeflow.trainer.types import types
from kubeflow.trainer.utils import utils

# In all tests runtime name is equal to the framework name.
TORCH_RUNTIME = "torch"
Expand Down Expand Up @@ -788,7 +788,6 @@ def test_get_runtime_packages(kubernetes_backend, test_case):
},
expected_error=ValueError,
),

],
)
def test_train(kubernetes_backend, test_case):
Expand Down
1 change: 1 addition & 0 deletions kubeflow/trainer/backends/kubernetes/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from typing import Optional

from kubernetes import client
from pydantic import BaseModel

Expand Down
Loading