Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ uv-venv:
.PHONY: test-python
test-python: uv-venv
@uv sync
@uv run coverage run --source=kubeflow.trainer.backends.kubernetes.backend,kubeflow.trainer.utils.utils -m pytest ./kubeflow/trainer/backends/kubernetes/backend_test.py
@uv run coverage run --source=kubeflow.trainer.backends.kubernetes.backend,kubeflow.trainer.utils.utils -m pytest ./kubeflow/trainer/backends/kubernetes/backend_test.py ./kubeflow/trainer/utils/utils_test.py
@uv run coverage report -m kubeflow/trainer/backends/kubernetes/backend.py kubeflow/trainer/utils/utils.py
ifeq ($(report),xml)
@uv run coverage xml
Expand Down
37 changes: 13 additions & 24 deletions kubeflow/trainer/backends/kubernetes/backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
import random
import string
import uuid
from dataclasses import asdict, dataclass, field
from typing import Any, Optional, Type
from dataclasses import asdict
from typing import Optional
from unittest.mock import Mock, patch

import pytest
Expand All @@ -34,28 +34,17 @@
from kubeflow.trainer.constants import constants
from kubeflow.trainer.types import types
from kubeflow.trainer.utils import utils
from kubeflow.trainer.backends.kubernetes.types import KubernetesBackendConfig
from kubeflow.trainer.backends.kubernetes.backend import KubernetesBackend
from kubeflow.trainer.backends.kubernetes.types import KubernetesBackendConfig
from kubeflow.trainer.test.common import TestCase
from kubeflow.trainer.test.common import (
SUCCESS,
FAILED,
DEFAULT_NAMESPACE,
TIMEOUT,
RUNTIME,
)


@dataclass
class TestCase:
name: str
expected_status: str
config: dict[str, Any] = field(default_factory=dict)
expected_output: Optional[Any] = None
expected_error: Optional[Type[Exception]] = None
__test__ = False


# --------------------------
# Constants for test scenarios
# --------------------------
TIMEOUT = "timeout"
RUNTIME = "runtime"
SUCCESS = "success"
FAILED = "Failed"
DEFAULT_NAMESPACE = "default"
# In all tests runtime name is equal to the framework name.
TORCH_RUNTIME = "torch"
TORCH_TUNE_RUNTIME = "torchtune"
Expand Down Expand Up @@ -238,9 +227,9 @@ def get_custom_trainer(
'\nif ! [ -x "$(command -v pip)" ]; then\n python -m ensurepip '
"|| python -m ensurepip --user || apt-get install python-pip"
"\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1 python -m pip install --quiet"
f" --no-warn-script-location {pip_command} {packages_command}"
f" --no-warn-script-location {pip_command} {packages_command}"
"\n\nread -r -d '' SCRIPT << EOM\n\nfunc=lambda: "
'print("Hello World"),\n\n<lambda>('
'print("Hello World"),\n\n<lambda>(**'
"{'learning_rate': 0.001, 'batch_size': 32})\n\nEOM\nprintf \"%s\" "
'"$SCRIPT" > "backend_test.py"\ntorchrun "backend_test.py"',
],
Expand Down
23 changes: 23 additions & 0 deletions kubeflow/trainer/test/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Shared test utilities and types for Kubeflow Trainer tests.

from dataclasses import dataclass, field
from typing import Any, Optional, Type


# Common status constants
SUCCESS = "success"
FAILED = "Failed"
DEFAULT_NAMESPACE = "default"
TIMEOUT = "timeout"
RUNTIME = "runtime"


@dataclass
class TestCase:
name: str
expected_status: str = SUCCESS
config: dict[str, Any] = field(default_factory=dict)
expected_output: Optional[Any] = None
expected_error: Optional[Type[Exception]] = None
# Prevent pytest from collecting this dataclass as a test
__test__ = False
8 changes: 4 additions & 4 deletions kubeflow/trainer/types/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ class CustomTrainer:

Args:
func (`Callable`): The function that encapsulates the entire model training process.
func_args (`Optional[Dict]`): The arguments to pass to the function.
packages_to_install (`Optional[List[str]]`):
func_args (`Optional[dict]`): The arguments to pass to the function.
packages_to_install (`Optional[list[str]]`):
A list of Python packages to install before running the function.
pip_index_urls (`list[str]`): The PyPI URLs from which to install
Python packages. The first URL will be the index-url, and remaining ones
are extra-index-urls.
num_nodes (`Optional[int]`): The number of nodes to use for training.
resources_per_node (`Optional[Dict]`): The computing resources to allocate per node.
env (`Optional[Dict[str, str]]`): The environment variables to set in the training nodes.
resources_per_node (`Optional[dict]`): The computing resources to allocate per node.
env (`Optional[dict[str, str]]`): The environment variables to set in the training nodes.
"""

func: Callable
Expand Down
26 changes: 17 additions & 9 deletions kubeflow/trainer/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
import inspect
import os
import textwrap
from typing import Any, Callable, Optional
from urllib.parse import urlparse

from typing import Callable, Optional, Any
from urllib.parse import urlparse
from kubeflow.trainer.constants import constants
from kubeflow.trainer.types import types
from kubeflow_trainer_api import models
Expand Down Expand Up @@ -268,15 +268,19 @@ def get_script_for_python_packages(
if is_mpi:
options.append("--user")

script_for_python_packages = textwrap.dedent(
header_script = textwrap.dedent(
"""
if ! [ -x "$(command -v pip)" ]; then
python -m ensurepip || python -m ensurepip --user || apt-get install python-pip
fi

PIP_DISABLE_PIP_VERSION_CHECK=1 python -m pip install --quiet \
--no-warn-script-location {} {}
""".format(
"""
)

script_for_python_packages = (
header_script
+ "PIP_DISABLE_PIP_VERSION_CHECK=1 python -m pip install --quiet "
+ "--no-warn-script-location {} {}\n".format(
" ".join(options),
packages_str,
)
Expand Down Expand Up @@ -318,12 +322,16 @@ def get_command_using_train_func(
# Wrap function code to execute it from the file. For example:
# TODO (andreyvelich): Find a better way to run users' scripts.
# def train(parameters):
# print('Start Training...')
# print('Start Training...')
# train({'lr': 0.01})
if train_func_parameters is None:
func_code = f"{func_code}\n{train_func.__name__}()\n"
func_call = f"{train_func.__name__}()"
else:
func_code = f"{func_code}\n{train_func.__name__}({train_func_parameters})\n"
# Always unpack kwargs for training function calls.
func_call = f"{train_func.__name__}(**{train_func_parameters})"

# Combine everything into the final code string.
func_code = f"{func_code}\n{func_call}\n"

is_mpi = runtime.trainer.command[0] == "mpirun"
# The default file location for OpenMPI is: /home/mpiuser/<FILE_NAME>.py
Expand Down
150 changes: 139 additions & 11 deletions kubeflow/trainer/utils/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from typing import Any, Dict

import pytest

from kubeflow.trainer.utils import utils
from kubeflow.trainer.constants import constants
from kubeflow.trainer.types import types
from kubeflow.trainer.test.common import TestCase, SUCCESS, FAILED


@dataclass
class TestCase:
name: str
config: Dict[str, Any]
expected_output: str
__test__ = False

def _build_runtime() -> types.Runtime:
runtime_trainer = types.RuntimeTrainer(
trainer_type=types.TrainerType.CUSTOM_TRAINER,
framework="torch",
device="cpu",
device_count="1",
)
runtime_trainer.set_command(constants.DEFAULT_COMMAND)
return types.Runtime(name="test-runtime", trainer=runtime_trainer)

@pytest.mark.parametrize(
"test_case",
Expand Down Expand Up @@ -124,3 +124,131 @@ def test_get_script_for_python_packages(test_case):
)

assert test_case.expected_output == script

@pytest.mark.parametrize(
"test_case",
[
TestCase(
name="with args dict always unpacks kwargs",
expected_status=SUCCESS,
config={
"func": (lambda: print("Hello World")),
"func_args": {"batch_size": 128, "learning_rate": 0.001, "epochs": 20},
"runtime": _build_runtime(),
},
expected_output=[
'bash',
'-c',
(
"\nread -r -d '' SCRIPT << EOM\n\n"
'"func": (lambda: print("Hello World")),\n\n'
"<lambda>(**{'batch_size': 128, 'learning_rate': 0.001, 'epochs': 20})\n\n"
'EOM\n'
'printf "%s" "$SCRIPT" > "utils_test.py"\n'
'python "utils_test.py"'
),
]
),
TestCase(
name="without args calls function with no params",
expected_status=SUCCESS,
config={
"func": (lambda: print("Hello World")),
"func_args": None,
"runtime": _build_runtime(),
},
expected_output=[
'bash',
'-c',
(
"\nread -r -d '' SCRIPT << EOM\n\n"
'"func": (lambda: print("Hello World")),\n\n'
'<lambda>()\n\n'
'EOM\n'
'printf "%s" "$SCRIPT" > "utils_test.py"\n'
'python "utils_test.py"'
),
],
),
TestCase(
name="raises when runtime has no trainer",
expected_status=FAILED,
config={
"func": (lambda: print("Hello World")),
"func_args": None,
"runtime": types.Runtime(name="no-trainer", trainer=None),
},
expected_error=ValueError,
),
TestCase(
name="raises when train_func is not callable",
expected_status=FAILED,
config={
"func": "not callable",
"func_args": None,
"runtime": _build_runtime(),
},
expected_error=ValueError,
),
TestCase(
name="single dict param also unpacks kwargs",
expected_status=SUCCESS,
config={
"func": (lambda: print("Hello World")),
"func_args": {"a": 1, "b": 2},
"runtime": _build_runtime(),
},
expected_output=[
'bash',
'-c',
(
"\nread -r -d '' SCRIPT << EOM\n\n"
'"func": (lambda: print("Hello World")),\n\n'
"<lambda>(**{'a': 1, 'b': 2})\n\n"
'EOM\n'
'printf "%s" "$SCRIPT" > "utils_test.py"\n'
'python "utils_test.py"'
),
],
),
TestCase(
name="multi-param function uses kwargs-unpacking",
expected_status=SUCCESS,
config={
"func": (lambda **kwargs: "ok"),
"func_args": {"a": 3, "b": "hi", "c": 0.2},
"runtime": _build_runtime(),
},
expected_output=[
"bash",
"-c",
(
"\nread -r -d '' SCRIPT << EOM\n\n"
'"func": (lambda **kwargs: "ok"),\n\n'
"<lambda>(**{'a': 3, 'b': 'hi', 'c': 0.2})\n\n"
'EOM\n'
'printf "%s" "$SCRIPT" > "utils_test.py"\n'
'python "utils_test.py"'
),
],
),
],
)
def test_get_command_using_train_func(test_case: TestCase):
print("Executing test:", test_case.name)

try:
command = utils.get_command_using_train_func(
runtime=test_case.config["runtime"],
train_func=test_case.config.get("func"),
train_func_parameters=test_case.config.get("func_args"),
pip_index_urls=constants.DEFAULT_PIP_INDEX_URLS,
packages_to_install=[],
)

assert test_case.expected_status == SUCCESS
assert command == test_case.expected_output

except Exception as e:
assert type(e) is test_case.expected_error
print("test execution complete")
Loading