diff --git a/.github/workflows/test-python.yaml b/.github/workflows/test-python.yaml index 850bb314..9023f5ff 100644 --- a/.github/workflows/test-python.yaml +++ b/.github/workflows/test-python.yaml @@ -15,7 +15,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ['3.9', '3.11'] + python-version: ["3.9", "3.11"] name: Test (Python ${{ matrix.python-version }}) @@ -36,6 +36,7 @@ jobs: - name: Upload coverage to Coveralls uses: coverallsapp/github-action@v2 + continue-on-error: true with: github-token: ${{ secrets.GITHUB_TOKEN }} parallel: true @@ -48,6 +49,7 @@ jobs: steps: - name: Close parallel build uses: coverallsapp/github-action@v2 + continue-on-error: true with: github-token: ${{ secrets.GITHUB_TOKEN }} parallel-finished: true diff --git a/kubeflow/trainer/__init__.py b/kubeflow/trainer/__init__.py index dce03bbb..346ef507 100644 --- a/kubeflow/trainer/__init__.py +++ b/kubeflow/trainer/__init__.py @@ -38,6 +38,11 @@ TrainerType, ) +# import backends and its associated configs +from kubeflow.trainer.backends.kubernetes.types import KubernetesBackendConfig +from kubeflow.trainer.backends.localprocess.types import LocalProcessBackendConfig + + __all__ = [ "BuiltinTrainer", "CustomTrainer", @@ -55,4 +60,6 @@ "RuntimeTrainer", "TrainerClient", "TrainerType", + "LocalProcessBackendConfig", + "KubernetesBackendConfig", ] diff --git a/kubeflow/trainer/api/trainer_client.py b/kubeflow/trainer/api/trainer_client.py index c57ff711..a61c8084 100644 --- a/kubeflow/trainer/api/trainer_client.py +++ b/kubeflow/trainer/api/trainer_client.py @@ -19,6 +19,8 @@ from kubeflow.trainer.types import types from kubeflow.trainer.backends.kubernetes.backend import KubernetesBackend from kubeflow.trainer.backends.kubernetes.types import KubernetesBackendConfig +from kubeflow.trainer.backends.localprocess.backend import LocalProcessBackend +from kubeflow.trainer.backends.localprocess.backend import LocalProcessBackendConfig logger = logging.getLogger(__name__) @@ -27,7 +29,9 @@ class TrainerClient: def __init__( self, - backend_config: KubernetesBackendConfig = KubernetesBackendConfig(), + backend_config: Union[ + KubernetesBackendConfig, LocalProcessBackendConfig + ] = KubernetesBackendConfig(), ): """Initialize a Kubeflow Trainer client. @@ -43,6 +47,8 @@ def __init__( # initialize training backend if isinstance(backend_config, KubernetesBackendConfig): self.backend = KubernetesBackend(backend_config) + elif isinstance(backend_config, LocalProcessBackendConfig): + self.backend = LocalProcessBackend(backend_config) else: raise ValueError("Invalid backend config '{}'".format(backend_config)) diff --git a/kubeflow/trainer/backends/localprocess/__init__.py b/kubeflow/trainer/backends/localprocess/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/kubeflow/trainer/backends/localprocess/backend.py b/kubeflow/trainer/backends/localprocess/backend.py new file mode 100644 index 00000000..4fbb2d46 --- /dev/null +++ b/kubeflow/trainer/backends/localprocess/backend.py @@ -0,0 +1,257 @@ +# Copyright 2025 The Kubeflow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import string +import tempfile +import uuid +import random +from datetime import datetime +from typing import List, Optional, Set, Union, Iterator + +from kubeflow.trainer.constants import constants +from kubeflow.trainer.types import types +from kubeflow.trainer.backends.base import ExecutionBackend +from kubeflow.trainer.backends.localprocess.types import ( + LocalProcessBackendConfig, + LocalBackendJobs, + LocalBackendStep, +) +from kubeflow.trainer.backends.localprocess.constants import local_runtimes +from kubeflow.trainer.backends.localprocess.job import LocalJob +from kubeflow.trainer.backends.localprocess import utils as local_utils + +logger = logging.getLogger(__name__) + + +class LocalProcessBackend(ExecutionBackend): + def __init__( + self, + cfg: LocalProcessBackendConfig, + ): + # list of running subprocesses + self.__local_jobs: List[LocalBackendJobs] = [] + self.cfg = cfg + + def list_runtimes(self) -> List[types.Runtime]: + return [self.__convert_local_runtime_to_runtime(local_runtime=rt) for rt in local_runtimes] + + def get_runtime(self, name: str) -> types.Runtime: + runtime = next( + ( + self.__convert_local_runtime_to_runtime(rt) + for rt in local_runtimes + if rt.name == name + ), + None, + ) + if not runtime: + raise ValueError(f"Runtime '{name}' not found.") + + return runtime + + def get_runtime_packages(self, runtime: types.Runtime): + runtime = next((rt for rt in local_runtimes if rt.name == runtime.name), None) + if not runtime: + raise ValueError(f"Runtime '{runtime.name}' not found.") + + return runtime.trainer.packages + + def train( + self, + runtime: Optional[types.Runtime] = None, + initializer: Optional[types.Initializer] = None, + trainer: Optional[Union[types.CustomTrainer, types.BuiltinTrainer]] = None, + ) -> str: + # set train job name + train_job_name = random.choice(string.ascii_lowercase) + uuid.uuid4().hex[:11] + # localprocess backend only supports CustomTrainer + if not isinstance(trainer, types.CustomTrainer): + raise ValueError("CustomTrainer must be set with LocalProcessBackend") + + # create temp dir + venv_dir = tempfile.mkdtemp(prefix=train_job_name) + logger.debug("operating in {}".format(venv_dir)) + + runtime.trainer = local_utils.get_local_runtime_trainer( + runtime_name=runtime.name, + venv_dir=venv_dir, + framework=runtime.trainer.framework, + ) + + # build training job command + training_command = local_utils.get_local_train_job_script( + trainer=trainer, + runtime=runtime, + train_job_name=train_job_name, + venv_dir=venv_dir, + cleanup_venv=self.cfg.cleanup_venv, + ) + + # set the command in the runtime trainer + runtime.trainer.set_command(training_command) + + # create subprocess object + train_job = LocalJob( + name="{}-train".format(train_job_name), + command=training_command, + execution_dir=venv_dir, + env=trainer.env, + dependencies=[], + ) + + self.__register_job( + train_job_name=train_job_name, + step_name="train", + job=train_job, + runtime=runtime, + ) + # start the job. + train_job.start() + + return train_job_name + + def list_jobs(self, runtime: Optional[types.Runtime] = None) -> List[types.TrainJob]: + result = [] + + for _job in self.__local_jobs: + if runtime and _job.runtime.name != runtime.name: + continue + result.append( + types.TrainJob( + name=_job.name, + creation_timestamp=_job.created, + runtime=runtime, + num_nodes=1, + steps=[ + types.Step(name=s.step_name, pod_name=s.step_name, status=s.job.status) + for s in _job.steps + ], + ) + ) + return result + + def get_job(self, name: str) -> Optional[types.TrainJob]: + _job = next((j for j in self.__local_jobs if j.name == name), None) + if _job is None: + raise ValueError("No TrainJob with name '%s'" % name) + + # check and set the correct job status to match `TrainerClient` supported statuses + status = self.__get_job_status(_job) + + return types.TrainJob( + name=_job.name, + creation_timestamp=_job.created, + steps=[ + types.Step(name=_step.step_name, pod_name=_step.step_name, status=_step.job.status) + for _step in _job.steps + ], + runtime=_job.runtime, + num_nodes=1, + status=status, + ) + + def get_job_logs( + self, + name: str, + step: str = constants.NODE + "-0", + follow: Optional[bool] = False, + ) -> Iterator[str]: + _job = [j for j in self.__local_jobs if j.name == name] + if not _job: + raise ValueError("No TrainJob with name '%s'" % name) + + want_all_steps = step == constants.NODE + "-0" + + for _step in _job[0].steps: + if not want_all_steps and _step.step_name != step: + continue + # Flatten the generator and pass through flags so it behaves as expected + # (adjust args if stream_logs has different signature) + yield from _step.job.logs(follow=follow) + + def wait_for_job_status( + self, + name: str, + status: Set[str] = {constants.TRAINJOB_COMPLETE}, + timeout: int = 600, + polling_interval: int = 2, + ) -> types.TrainJob: + # find first match or fallback + _job = next((_job for _job in self.__local_jobs if _job.name == name), None) + + if _job is None: + raise ValueError("No TrainJob with name '%s'" % name) + # find a better implementation for this + for _step in _job.steps: + if _step.job.status in [constants.TRAINJOB_RUNNING, constants.TRAINJOB_CREATED]: + _step.job.join(timeout=timeout) + return self.get_job(name) + + def delete_job(self, name: str): + # find job first. + _job = next((j for j in self.__local_jobs if j.name == name), None) + if _job is None: + raise ValueError("No TrainJob with name '%s'" % name) + + # cancel all nested step jobs in target job + _ = [step.job.cancel() for step in _job.steps] + # remove the job from the list of jobs + self.__local_jobs.remove(_job) + + def __get_job_status(self, job: LocalBackendJobs) -> str: + statuses = [_step.job.status for _step in job.steps] + # if status is running or failed will take precedence over completed + if constants.TRAINJOB_FAILED in statuses: + status = constants.TRAINJOB_FAILED + elif constants.TRAINJOB_RUNNING in statuses: + status = constants.TRAINJOB_RUNNING + elif constants.TRAINJOB_CREATED in statuses: + status = constants.TRAINJOB_CREATED + else: + status = constants.TRAINJOB_CREATED + + return status + + def __register_job( + self, + train_job_name: str, + step_name: str, + job: LocalJob, + runtime: types.Runtime = None, + ): + _job = [j for j in self.__local_jobs if j.name == train_job_name] + if not _job: + _job = LocalBackendJobs(name=train_job_name, runtime=runtime, created=datetime.now()) + self.__local_jobs.append(_job) + else: + _job = _job[0] + _step = [s for s in _job.steps if s.step_name == step_name] + if not _step: + _step = LocalBackendStep(step_name=step_name, job=job) + _job.steps.append(_step) + else: + logger.warning("Step '{}' already registered.".format(step_name)) + + def __convert_local_runtime_to_runtime(self, local_runtime) -> types.Runtime: + return types.Runtime( + name=local_runtime.name, + trainer=types.RuntimeTrainer( + trainer_type=local_runtime.trainer.trainer_type, + framework=local_runtime.trainer.framework, + num_nodes=local_runtime.trainer.num_nodes, + device_count=local_runtime.trainer.device_count, + device=local_runtime.trainer.device, + ), + pretrained_model=local_runtime.pretrained_model, + ) diff --git a/kubeflow/trainer/backends/localprocess/constants.py b/kubeflow/trainer/backends/localprocess/constants.py new file mode 100644 index 00000000..7ac0868f --- /dev/null +++ b/kubeflow/trainer/backends/localprocess/constants.py @@ -0,0 +1,84 @@ +# Copyright 2025 The Kubeflow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap +import re +from kubeflow.trainer.types import types as base_types +from kubeflow.trainer.constants import constants +from kubeflow.trainer.backends.localprocess import types + +TORCH_FRAMEWORK_TYPE = "torch" + +local_runtimes = [ + base_types.Runtime( + name=constants.TORCH_RUNTIME, + trainer=types.LocalRuntimeTrainer( + trainer_type=base_types.TrainerType.CUSTOM_TRAINER, + framework=TORCH_FRAMEWORK_TYPE, + num_nodes=1, + device_count=constants.UNKNOWN, + device=constants.UNKNOWN, + packages=["torch"], + ), + ) +] + + +# Create venv script + + +# The exec script to embed training function into container command. +DEPENDENCIES_SCRIPT = textwrap.dedent( + """ + PIP_DISABLE_PIP_VERSION_CHECK=1 pip install $QUIET \ + --no-warn-script-location $PIP_INDEX $PACKAGE_STR + """ +) + +# activate virtualenv, then run the entrypoint from the virtualenv bin +LOCAL_EXEC_ENTRYPOINT = textwrap.dedent( + """ + $ENTRYPOINT "$FUNC_FILE" "$PARAMETERS" + """ +) + +TORCH_COMMAND = "torchrun" + +# default command, will run from within the virtualenv +DEFAULT_COMMAND = "python" + +# remove virtualenv after training is completed. +LOCAL_EXEC_JOB_CLEANUP_SCRIPT = textwrap.dedent( + """ + rm -rf $PYENV_LOCATION + """ +) + + +LOCAL_EXEC_JOB_TEMPLATE = textwrap.dedent( + """ + set -e + $OS_PYTHON_BIN -m venv --without-pip $PYENV_LOCATION + echo "Operating inside $PYENV_LOCATION" + source $PYENV_LOCATION/bin/activate + $PYENV_LOCATION/bin/python -m ensurepip --upgrade --default-pip + $DEPENDENCIES_SCRIPT + $ENTRYPOINT + $CLEANUP_SCRIPT + """ +) + +LOCAL_EXEC_FILENAME = "train_{}.py" + +PYTHON_PACKAGE_NAME_RE = re.compile(r"^\s*([A-Za-z0-9][A-Za-z0-9._-]*)") diff --git a/kubeflow/trainer/backends/localprocess/job.py b/kubeflow/trainer/backends/localprocess/job.py new file mode 100644 index 00000000..3ef8c58e --- /dev/null +++ b/kubeflow/trainer/backends/localprocess/job.py @@ -0,0 +1,184 @@ +# Copyright 2025 The Kubeflow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import threading +import subprocess +import logging +from datetime import datetime +from typing import List, Union, Dict, Tuple + +from kubeflow.trainer.constants import constants + +logger = logging.getLogger(__name__) + + +class LocalJob(threading.Thread): + def __init__( + self, + name, + command: Union[List, Tuple[str], str], + execution_dir: str = None, + env: Dict[str, str] = None, + dependencies: List = None, + ): + """Creates a LocalJob. + + Creates a local subprocess with threading to allow users to create background jobs. + + Args: + name (str): The name of the job. + command (str): The command to run. + execution_dir (str): The execution directory. + env (Dict[str, str], optional): Environment variables. Defaults to None. + dependencies (List[str], optional): List of dependencies. Defaults to None. + """ + super().__init__() + self.name = name + self.command = command + self._stdout = "" + self._returncode = None + self._success = False + self._status = constants.TRAINJOB_CREATED + self._lock = threading.Lock() + self._process = None + self._output_updated = threading.Event() + self._cancel_requested = threading.Event() + self._start_time = None + self._end_time = None + self.env = env or {} + self.dependencies = dependencies or [] + self.execution_dir = execution_dir or os.getcwd() + + def run(self): + for dep in self.dependencies: + dep.join() + if not dep.success: + with self._lock: + self._stdout = f"Dependency {dep.name} failed. Skipping" + return + + current_dir = os.getcwd() + try: + self._start_time = datetime.now() + _c = " ".join(self.command) + logger.debug(f"[{self.name}] Started at {self._start_time} with command: \n {_c}") + + # change working directory to venv before executing script + os.chdir(self.execution_dir) + + self._process = subprocess.Popen( + self.command, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + encoding="utf-8", + bufsize=1, + env=self.env, + ) + # set job status + self._status = constants.TRAINJOB_RUNNING + + while True: + if self._cancel_requested.is_set(): + self._process.terminate() + self._stdout += "[JobCancelled]\n" + self._status = constants.TRAINJOB_FAILED + self._success = False + return + + # Read output line by line (for streaming) + output_line = self._process.stdout.readline() + with self._lock: + if output_line: + self._stdout += output_line + self._output_updated.set() + + if not output_line and self._process.poll() is not None: + break + + self._process.stdout.close() + self._returncode = self._process.wait() + self._end_time = datetime.now() + self._success = self._process.returncode == 0 + msg = ( + f"[{self.name}] Completed with code {self._returncode}" + f" in {self._end_time - self._start_time} seconds." + ) + # set status based on success or failure + self._status = ( + constants.TRAINJOB_COMPLETE if self._success else (constants.TRAINJOB_FAILED) + ) + self._stdout += msg + logger.debug("Job output: ", self._stdout) + + except Exception as e: + with self._lock: + self._stdout += f"Exception: {e}\n" + self._success = False + self._status = constants.TRAINJOB_FAILED + finally: + os.chdir(current_dir) + + @property + def stdout(self): + with self._lock: + return self._stdout + + @property + def success(self): + return self._success + + @property + def status(self): + return self._status + + def cancel(self): + self._cancel_requested.set() + + @property + def returncode(self): + return self._returncode + + def logs(self, follow=False) -> List[str]: + if not follow: + return self._stdout.splitlines() + + try: + for chunk in self.stream_logs(): + print(chunk, end="", flush=True) # stream to console live + except StopIteration: + pass + + return self._stdout.splitlines() + + def stream_logs(self): + """Generator that yields new output lines as they come in.""" + last_index = 0 + while self.is_alive() or last_index < len(self._stdout): + self._output_updated.wait(timeout=1) + with self._lock: + data = self._stdout + new_data = data[last_index:] + last_index = len(data) + self._output_updated.clear() + if new_data: + yield new_data + + @property + def creation_time(self): + return self._start_time + + @property + def completion_time(self): + return self._end_time diff --git a/kubeflow/trainer/backends/localprocess/types.py b/kubeflow/trainer/backends/localprocess/types.py new file mode 100644 index 00000000..d2ee0145 --- /dev/null +++ b/kubeflow/trainer/backends/localprocess/types.py @@ -0,0 +1,51 @@ +# Copyright 2025 The Kubeflow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing +from dataclasses import dataclass, field +from datetime import datetime +from typing import List, Optional + +from pydantic import BaseModel + +from kubeflow.trainer.backends.localprocess.job import LocalJob +from kubeflow.trainer.types import types + + +class LocalProcessBackendConfig(BaseModel): + cleanup_venv: bool = True + + +@dataclass +class LocalRuntimeTrainer(types.RuntimeTrainer): + packages: List[str] = field(default_factory=list) + + +class LocalBackendStep(BaseModel): + step_name: str + job: LocalJob + + class Config: + arbitrary_types_allowed = True + + +class LocalBackendJobs(BaseModel): + steps: Optional[List[LocalBackendStep]] = [] + runtime: Optional[types.Runtime] = None + name: str + created: typing.Optional[datetime] = None + completed: typing.Optional[datetime] = None + + class Config: + arbitrary_types_allowed = True diff --git a/kubeflow/trainer/backends/localprocess/utils.py b/kubeflow/trainer/backends/localprocess/utils.py new file mode 100644 index 00000000..c86e6f81 --- /dev/null +++ b/kubeflow/trainer/backends/localprocess/utils.py @@ -0,0 +1,298 @@ +import inspect +import os +import shutil +import re +import textwrap +from pathlib import Path +from string import Template +from typing import List, Callable, Optional, Dict, Any, Tuple, Set + +from kubeflow.trainer.backends.localprocess import constants as local_exec_constants +from kubeflow.trainer.constants import constants +from kubeflow.trainer.types import types +from kubeflow.trainer.backends.localprocess.types import LocalRuntimeTrainer + + +def _extract_name(requirement: str) -> str: + """ + Extract the base distribution name from a requirement string without external deps. + + Supports common PEP 508 patterns: + - 'package' + - 'package[extra1,extra2]' + - 'package==1.2.3', 'package>=1.0', 'package~=1.4', etc. + - 'package @ https://...' + - markers after ';' are irrelevant for name extraction. + + Returns the *raw* (un-normalized) name as it appears. + Raises ValueError if a name cannot be parsed. + """ + if requirement is None: + raise ValueError("Requirement string cannot be None") + s = requirement.strip() + if not s: + raise ValueError("Empty requirement string") + + m = local_exec_constants.PYTHON_PACKAGE_NAME_RE.match(s) + if not m: + raise ValueError(f"Could not parse package name from requirement: {requirement!r}") + return m.group(1) + + +def _canonicalize_name(name: str) -> str: + """ + PEP 503-style normalization: case-insensitive, and collapse runs of -, _, . into '-'. + """ + return re.sub(r"[-_.]+", "-", name).lower() + + +def get_install_packages( + runtime_packages: List[str], + trainer_packages: Optional[List[str]] = None, +) -> List[str]: + """ + Merge two requirement lists into a single list of strings. + + Rules implemented: + 1) If a package appears in trainer_packages, it overwrites the one in runtime_packages. + We keep the *trainer string verbatim* (specifier, markers, extras, spacing). + 2) Case-insensitive matching of package names (PEP 503-style normalization). + 3) Output is a list of strings. + 4) If trainer_packages contains the same dependency multiple times (case-insensitive), + raise ValueError. + 5) If runtime_packages contains duplicates, the last one among *runtime* wins there + (no error), but any trainer entry still overwrites it. Runtime packages shouldn't + have any duplicates. + 6) Ordering: keep runtime-only packages in their original order (emitting only their + last occurrence), then append all trainer packages in their original order. + """ + if not trainer_packages: + return runtime_packages + + # --- Parse + normalize runtime --- + runtime_parsed: List[Tuple[str, str]] = [] # (orig, canonical_name) + last_runtime_index_by_name: Dict[str, int] = {} + + for i, orig in enumerate(runtime_packages): + raw_name = _extract_name(orig) + canon = _canonicalize_name(raw_name) + runtime_parsed.append((orig, canon)) + last_runtime_index_by_name[canon] = i # last occurrence index wins among runtime + + # --- Parse + validate trainer (detect duplicates) --- + trainer_parsed: List[Tuple[str, str]] = [] + seen_trainer: Set[str] = set() + for orig in trainer_packages: + raw_name = _extract_name(orig) + canon = _canonicalize_name(raw_name) + if canon in seen_trainer: + raise ValueError( + f"Duplicate dependency in trainer_packages: '{raw_name}' (canonical: '{canon}')" + ) + seen_trainer.add(canon) + trainer_parsed.append((orig, canon)) + + trainer_names: Set[str] = {canon for _, canon in trainer_parsed} + + # --- Build merged list respecting order semantics --- + merged: List[str] = [] + + # 1) Runtime-only packages (only emit the last occurrence for each name) + emitted_runtime_names: Set[str] = set() + for idx, (orig, canon) in enumerate(runtime_parsed): + if canon in trainer_names: + continue # overwritten by trainer + if last_runtime_index_by_name[canon] == idx and canon not in emitted_runtime_names: + merged.append(orig) + emitted_runtime_names.add(canon) + + # 2) Trainer packages (overwrite and preserve trainer's exact strings, original order) + for orig, _ in trainer_parsed: + merged.append(orig) + + return merged + + +def get_local_runtime_trainer( + runtime_name: str, + venv_dir: str, + framework: str, +) -> LocalRuntimeTrainer: + """ + Get the LocalRuntimeTrainer object. + """ + local_runtime = next( + (rt for rt in local_exec_constants.local_runtimes if rt.name == runtime_name), None + ) + if not local_runtime: + raise ValueError(f"Runtime {runtime_name} not found") + + trainer = LocalRuntimeTrainer( + trainer_type=types.TrainerType.CUSTOM_TRAINER, + framework=framework, + packages=local_runtime.trainer.packages, + ) + + # set command to run from venv + venv_bin_dir = str(Path(venv_dir) / "bin") + default_cmd = [str(Path(venv_bin_dir) / local_exec_constants.DEFAULT_COMMAND)] + # Set the Trainer entrypoint. + if framework == local_exec_constants.TORCH_FRAMEWORK_TYPE: + _c = [os.path.join(venv_bin_dir, local_exec_constants.TORCH_COMMAND)] + trainer.set_command(tuple(_c)) + else: + trainer.set_command(tuple(default_cmd)) + + return trainer + + +def get_dependencies_command( + runtime_packages: List[str], + pip_index_urls: str, + trainer_packages: List[str], + quiet: bool = True, +) -> str: + # resolve runtime dependencies and trainer dependencies. + packages = get_install_packages( + runtime_packages=runtime_packages, + trainer_packages=trainer_packages, + ) + + options = [f"--index-url {pip_index_urls[0]}"] + options.extend(f"--extra-index-url {extra_index_url}" for extra_index_url in pip_index_urls[1:]) + + """ + PIP_DISABLE_PIP_VERSION_CHECK=1 pip install $QUIET $AS_USER \ + --no-warn-script-location $PIP_INDEX $PACKAGE_STR + """ + mapping = { + "QUIET": "--quiet" if quiet else "", + "PIP_INDEX": " ".join(options), + "PACKAGE_STR": '"{}"'.format('" "'.join(packages)), # quote deps + } + t = Template(local_exec_constants.DEPENDENCIES_SCRIPT) + result = t.substitute(**mapping) + return result + + +def get_command_using_train_func( + runtime: types.Runtime, + train_func: Callable, + train_func_parameters: Optional[Dict[str, Any]], + venv_dir: str, + train_job_name: str, +) -> str: + """ + Get the Trainer container command from the given training function and parameters. + """ + # Check if the runtime has a Trainer. + if not runtime.trainer: + raise ValueError(f"Runtime must have a trainer: {runtime}") + + # Check if training function is callable. + if not callable(train_func): + raise ValueError( + f"Training function must be callable, got function type: {type(train_func)}" + ) + + # Extract the function implementation. + func_code = inspect.getsource(train_func) + + # Extract the file name where the function is defined and move it the venv directory. + func_file = Path(venv_dir) / local_exec_constants.LOCAL_EXEC_FILENAME.format(train_job_name) + + # Function might be defined in some indented scope (e.g. in another function). + # We need to dedent the function code. + func_code = textwrap.dedent(func_code) + + # Wrap function code to execute it from the file. For example: + # TODO (andreyvelich): Find a better way to run users' scripts. + # def train(parameters): + # print('Start Training...') + # train({'lr': 0.01}) + if train_func_parameters is None: + func_code = f"{func_code}\n{train_func.__name__}()\n" + else: + func_code = f"{func_code}\n{train_func.__name__}({train_func_parameters})\n" + + with open(func_file, "w") as f: + f.write(func_code) + f.close() + + t = Template(local_exec_constants.LOCAL_EXEC_ENTRYPOINT) + mapping = { + "PARAMETERS": "", ## Torch Parameters if any + "PYENV_LOCATION": venv_dir, + "ENTRYPOINT": " ".join(runtime.trainer.command), + "FUNC_FILE": func_file, + } + entrypoint = t.safe_substitute(**mapping) + + return entrypoint + + +def get_cleanup_venv_script(venv_dir: str, cleanup_venv: bool = True) -> str: + script = "\n" + if not cleanup_venv: + return script + + t = Template(local_exec_constants.LOCAL_EXEC_JOB_CLEANUP_SCRIPT) + mapping = { + "PYENV_LOCATION": venv_dir, + } + return t.substitute(**mapping) + + +def get_local_train_job_script( + train_job_name: str, + venv_dir: str, + trainer: types.CustomTrainer, + runtime: types.Runtime, + cleanup_venv: bool = True, +) -> tuple: + # use local-exec train job template + t = Template(local_exec_constants.LOCAL_EXEC_JOB_TEMPLATE) + # find os python binary to create venv + python_bin = shutil.which("python") + if not python_bin: + python_bin = shutil.which("python3") + if not python_bin: + raise ValueError("No python executable found") + + # workout if dependencies needs to be installed + if isinstance(runtime.trainer, LocalRuntimeTrainer): + runtime_trainer: LocalRuntimeTrainer = runtime.trainer + else: + raise ValueError("Invalid Runtime Trainer type: {type(runtime.trainer)}") + dependency_script = "\n" + if trainer.packages_to_install: + dependency_script = get_dependencies_command( + pip_index_urls=trainer.pip_index_urls + if trainer.pip_index_urls + else constants.DEFAULT_PIP_INDEX_URLS, + runtime_packages=runtime_trainer.packages, + trainer_packages=trainer.packages_to_install, + quiet=False, + ) + + entrypoint = get_command_using_train_func( + venv_dir=venv_dir, + runtime=runtime, + train_func=trainer.func, + train_func_parameters=trainer.func_args, + train_job_name=train_job_name, + ) + + cleanup_script = get_cleanup_venv_script(cleanup_venv=cleanup_venv, venv_dir=venv_dir) + + mapping = { + "OS_PYTHON_BIN": python_bin, + "PYENV_LOCATION": venv_dir, + "DEPENDENCIES_SCRIPT": dependency_script, + "ENTRYPOINT": entrypoint, + "CLEANUP_SCRIPT": cleanup_script, + } + + command = t.safe_substitute(**mapping) + + return "bash", "-c", command