diff --git a/python/kubeflow/trainer/__init__.py b/python/kubeflow/trainer/__init__.py index b87715f7..bc780d1b 100644 --- a/python/kubeflow/trainer/__init__.py +++ b/python/kubeflow/trainer/__init__.py @@ -29,7 +29,6 @@ CustomTrainer, DataFormat, DataType, - Framework, HuggingFaceDatasetInitializer, HuggingFaceModelInitializer, Initializer, @@ -47,7 +46,6 @@ "DataFormat", "DATASET_PATH", "DataType", - "Framework", "HuggingFaceDatasetInitializer", "HuggingFaceModelInitializer", "Initializer", diff --git a/python/kubeflow/trainer/api/trainer_client.py b/python/kubeflow/trainer/api/trainer_client.py index 6db05833..d3bad21c 100644 --- a/python/kubeflow/trainer/api/trainer_client.py +++ b/python/kubeflow/trainer/api/trainer_client.py @@ -105,6 +105,16 @@ def list_runtimes(self) -> List[types.Runtime]: return result for runtime in runtime_list.items: + if not ( + runtime.metadata + and runtime.metadata.labels + and constants.RUNTIME_FRAMEWORK_LABEL in runtime.metadata.labels + ): + logger.warning( + f"Runtime {runtime.metadata.name} must have " # type: ignore + f"{constants.RUNTIME_FRAMEWORK_LABEL} label." + ) + continue result.append(self.__get_runtime_from_crd(runtime)) except multiprocessing.TimeoutError: @@ -151,7 +161,7 @@ def get_runtime(self, name: str) -> types.Runtime: def train( self, - runtime: types.Runtime = types.DEFAULT_RUNTIME, + runtime: Optional[types.Runtime] = None, initializer: Optional[types.Initializer] = None, trainer: Optional[Union[types.CustomTrainer, types.BuiltinTrainer]] = None, ) -> str: @@ -164,7 +174,8 @@ def train( the post-training logic, requiring only parameter adjustments, e.g. `BuiltinTrainer`. Args: - runtime (`types.Runtime`): Reference to one of existing Runtimes. + runtime (`types.Runtime`): Reference to one of existing Runtimes. By default the + torch-distributed Runtime is used. initializer (`Optional[types.Initializer]`): Configuration for the dataset and model initializers. trainer (`Optional[types.CustomTrainer, types.BuiltinTrainer]`): @@ -179,6 +190,9 @@ def train( RuntimeError: Failed to create TrainJobs. """ + if runtime is None: + runtime = self.get_runtime(constants.TORCH_RUNTIME) + # Generate unique name for the TrainJob. # TODO (andreyvelich): Discuss this TrainJob name generation. train_job_name = random.choice(string.ascii_lowercase) + uuid.uuid4().hex[:11] @@ -189,14 +203,22 @@ def train( if trainer: # If users choose to use a custom training function. if isinstance(trainer, types.CustomTrainer): + if runtime.trainer.trainer_type != types.TrainerType.CUSTOM_TRAINER: + raise ValueError( + f"CustomTrainer can't be used with {runtime} runtime" + ) trainer_crd = utils.get_trainer_crd_from_custom_trainer( - trainer, runtime + runtime, trainer ) # If users choose to use a builtin trainer for post-training. elif isinstance(trainer, types.BuiltinTrainer): + if runtime.trainer.trainer_type != types.TrainerType.BUILTIN_TRAINER: + raise ValueError( + f"BuiltinTrainer can't be used with {runtime} runtime" + ) trainer_crd = utils.get_trainer_crd_from_builtin_trainer( - trainer, initializer + runtime, trainer, initializer ) else: @@ -549,9 +571,19 @@ def __get_runtime_from_crd( ): raise Exception(f"ClusterTrainingRuntime CRD is invalid: {runtime_crd}") + if not ( + runtime_crd.metadata.labels + and constants.RUNTIME_FRAMEWORK_LABEL in runtime_crd.metadata.labels + ): + raise Exception( + f"Runtime {runtime_crd.metadata.name} must have " + f"{constants.RUNTIME_FRAMEWORK_LABEL} label" + ) + return types.Runtime( name=runtime_crd.metadata.name, trainer=utils.get_runtime_trainer( + runtime_crd.metadata.labels[constants.RUNTIME_FRAMEWORK_LABEL], runtime_crd.spec.template.spec.replicated_jobs, runtime_crd.spec.ml_policy, ), diff --git a/python/kubeflow/trainer/api/trainer_client_test.py b/python/kubeflow/trainer/api/trainer_client_test.py index da1609cd..1b16b61c 100644 --- a/python/kubeflow/trainer/api/trainer_client_test.py +++ b/python/kubeflow/trainer/api/trainer_client_test.py @@ -50,26 +50,18 @@ class TestCase: # -------------------------- TIMEOUT = "timeout" RUNTIME = "runtime" -INVALID_RUNTIME = "invalid_runtime" SUCCESS = "success" FAILED = "Failed" -CREATED = "Created" -RUNNING = "Running" -RESTARTING = "Restarting" -NO_PODS = "no_pods" -SUCCEEDED = "Succeeded" -INVALID = "invalid" DEFAULT_NAMESPACE = "default" -PYTORCH = "pytorch" -MOCK_POD_OBJ = "mock_pod_obj" +# In all tests runtime name is equal to the framework name. +TORCH_RUNTIME = "torch" +TORCH_TUNE_RUNTIME = "torchtune" FAIL_LOGS = "fail_logs" -TORCH_DISTRIBUTED = "torch-distributed" LIST_RUNTIMES = "list_runtimes" BASIC_TRAIN_JOB_NAME = "basic-job" TRAIN_JOBS = "trainjobs" TRAIN_JOB_WITH_BUILT_IN_TRAINER = "train-job-with-built-in-trainer" TRAIN_JOB_WITH_CUSTOM_TRAINER = "train-job-with-custom-trainer" -TRAIN_JOB_WITH_CUSTOM_TRAINER_ENV = "train-job-with-custom-trainer-env" # -------------------------- @@ -78,7 +70,7 @@ class TestCase: @pytest.fixture -def training_client(request): +def trainer_client(request): """Provide a TrainerClient with mocked Kubernetes APIs.""" with patch("kubernetes.config.load_kube_config", return_value=None), patch( "kubernetes.client.CustomObjectsApi", @@ -223,15 +215,16 @@ def get_resource_requirements() -> models.IoK8sApiCoreV1ResourceRequirements: def get_custom_trainer( - env: Optional[list[models.IoK8sApiCoreV1EnvVar]] = None, + env: Optional[list[models.IoK8sApiCoreV1EnvVar]] = None, ) -> models.TrainerV1alpha1Trainer: """ Get the custom trainer for the TrainJob. """ return models.TrainerV1alpha1Trainer( - command=["bash", "-c"], - args=[ + command=[ + "bash", + "-c", '\nif ! [ -x "$(command -v pip)" ]; then\n python -m ensurepip ' "|| python -m ensurepip --user || apt-get install python-pip" "\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1 python -m pip install --quiet" @@ -239,12 +232,13 @@ def get_custom_trainer( "torch numpy \n\nread -r -d '' SCRIPT << EOM\n\nfunc=lambda: " 'print("Hello World"),\n\n(' "{'learning_rate': 0.001, 'batch_size': 32})\n\nEOM\nprintf \"%s\" " - '"$SCRIPT" > "trainer_client_test.py"\ntorchrun "trainer_client_test.py"' + '"$SCRIPT" > "trainer_client_test.py"\ntorchrun "trainer_client_test.py"', ], numNodes=2, env=env, ) + def get_builtin_trainer() -> models.TrainerV1alpha1Trainer: """ Get the builtin trainer for the TrainJob. @@ -257,6 +251,7 @@ def get_builtin_trainer() -> models.TrainerV1alpha1Trainer: def get_train_job( + runtime_name: str, train_job_name: str = BASIC_TRAIN_JOB_NAME, train_job_trainer: Optional[models.TrainerV1alpha1Trainer] = None, ) -> models.TrainerV1alpha1TrainJob: @@ -268,7 +263,7 @@ def get_train_job( kind=constants.TRAINJOB_KIND, metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta(name=train_job_name), spec=models.TrainerV1alpha1TrainJobSpec( - runtimeRef=models.TrainerV1alpha1RuntimeRef(name=TORCH_DISTRIBUTED), + runtimeRef=models.TrainerV1alpha1RuntimeRef(name=runtime_name), trainer=train_job_trainer, ), ) @@ -285,7 +280,7 @@ def get_cluster_custom_object_response(*args, **kwargs): raise RuntimeError() if args[2] == constants.CLUSTER_TRAINING_RUNTIME_PLURAL: mock_thread.get.return_value = normalize_model( - create_cluster_training_runtime(), + create_cluster_training_runtime(name=args[3]), models.TrainerV1alpha1ClusterTrainingRuntime, ) @@ -417,7 +412,6 @@ def normalize_model(model_obj, model_class): def create_train_job( train_job_name: str = random.choice(string.ascii_lowercase) + uuid.uuid4().hex[:11], namespace: str = "default", - runtime: str = PYTORCH, image: str = "pytorch/pytorch:latest", initializer: Optional[types.Initializer] = None, command: Optional[list] = None, @@ -433,7 +427,7 @@ def create_train_job( creationTimestamp=datetime.datetime(2025, 6, 1, 10, 30, 0), ), spec=models.TrainerV1alpha1TrainJobSpec( - runtimeRef=models.TrainerV1alpha1RuntimeRef(name=runtime), + runtimeRef=models.TrainerV1alpha1RuntimeRef(name=TORCH_RUNTIME), trainer=None, initializer=( models.TrainerV1alpha1Initializer( @@ -448,16 +442,18 @@ def create_train_job( def create_cluster_training_runtime( + name: str, namespace: str = "default", - name: str = TORCH_DISTRIBUTED, ) -> models.TrainerV1alpha1ClusterTrainingRuntime: """Create a mock ClusterTrainingRuntime object.""" - runtime = models.TrainerV1alpha1ClusterTrainingRuntime( + + return models.TrainerV1alpha1ClusterTrainingRuntime( apiVersion=constants.API_VERSION, kind="ClusterTrainingRuntime", metadata=models.IoK8sApimachineryPkgApisMetaV1ObjectMeta( name=name, namespace=namespace, + labels={constants.RUNTIME_FRAMEWORK_LABEL: name}, ), spec=models.TrainerV1alpha1TrainingRuntimeSpec( mlPolicy=models.TrainerV1alpha1MLPolicy( @@ -477,7 +473,6 @@ def create_cluster_training_runtime( ), ), ) - return runtime def get_replicated_job() -> models.JobsetV1alpha2ReplicatedJob: @@ -507,46 +502,43 @@ def get_container() -> models.IoK8sApiCoreV1Container: def create_runtime_type( - name: str = TORCH_DISTRIBUTED, + name: str, ) -> types.Runtime: """Create a mock Runtime object for testing.""" + trainer = types.RuntimeTrainer( + trainer_type=types.TrainerType.CUSTOM_TRAINER, + framework=name, + num_nodes=2, + accelerator_count=4, + ) + trainer.set_command(constants.TORCH_COMMAND) return types.Runtime( name=name, pretrained_model=None, - trainer=types.RuntimeTrainer( - trainer_type=types.TrainerType.CUSTOM_TRAINER, - framework=types.Framework.TORCH, - entrypoint=[constants.TORCH_ENTRYPOINT], - accelerator_count=4, - num_nodes=2, - ), + trainer=trainer, ) def get_train_job_data_type( - train_job_name: str = BASIC_TRAIN_JOB_NAME, + runtime_name: str, + train_job_name: str, ) -> types.TrainJob: - """Create a mock TrainJob object with the expected structure for testing. - - Args: - train_job_name: Name of the training job + """Create a mock TrainJob object with the expected structure for testing.""" - Returns: - A TrainJob object with predefined structure for testing - """ + trainer = types.RuntimeTrainer( + trainer_type=types.TrainerType.CUSTOM_TRAINER, + framework=runtime_name, + accelerator_count=4, + num_nodes=2, + ) + trainer.set_command(constants.TORCH_COMMAND) return types.TrainJob( name=train_job_name, creation_timestamp=datetime.datetime(2025, 6, 1, 10, 30, 0), runtime=types.Runtime( - name=TORCH_DISTRIBUTED, + name=runtime_name, pretrained_model=None, - trainer=types.RuntimeTrainer( - trainer_type=types.TrainerType.CUSTOM_TRAINER, - framework=types.Framework.TORCH, - entrypoint=["torchrun"], - accelerator_count=4, - num_nodes=2, - ), + trainer=trainer, ), steps=[ types.Step( @@ -587,8 +579,8 @@ def get_train_job_data_type( TestCase( name="valid flow with all defaults", expected_status=SUCCESS, - config={}, - expected_output=create_runtime_type(), + config={"name": TORCH_RUNTIME}, + expected_output=create_runtime_type(name=TORCH_RUNTIME), ), TestCase( name="timeout error when getting runtime", @@ -604,13 +596,11 @@ def get_train_job_data_type( ), ], ) -def test_get_runtime(training_client, test_case): +def test_get_runtime(trainer_client, test_case): """Test TrainerClient.get_runtime with basic success path.""" print("Executing test:", test_case.name) try: - runtime = training_client.get_runtime( - test_case.config.get("name", TORCH_DISTRIBUTED) - ) + runtime = trainer_client.get_runtime(**test_case.config) assert test_case.expected_status == SUCCESS assert isinstance(runtime, types.Runtime) @@ -635,12 +625,12 @@ def test_get_runtime(training_client, test_case): ), ], ) -def test_list_runtimes(training_client, test_case): +def test_list_runtimes(trainer_client, test_case): """Test TrainerClient.list_runtimes with basic success path.""" print("Executing test:", test_case.name) try: - training_client.namespace = test_case.config.get("namespace", DEFAULT_NAMESPACE) - runtimes = training_client.list_runtimes() + trainer_client.namespace = test_case.config.get("namespace", DEFAULT_NAMESPACE) + runtimes = trainer_client.list_runtimes() assert test_case.expected_status == SUCCESS assert isinstance(runtimes, list) @@ -661,7 +651,10 @@ def test_list_runtimes(training_client, test_case): name="valid flow with all defaults", expected_status=SUCCESS, config={}, - expected_output=get_train_job(train_job_name=BASIC_TRAIN_JOB_NAME), + expected_output=get_train_job( + runtime_name=TORCH_RUNTIME, + train_job_name=BASIC_TRAIN_JOB_NAME, + ), ), TestCase( name="valid flow with built in trainer", @@ -674,9 +667,11 @@ def test_list_runtimes(training_client, test_case): epochs=2, loss=types.Loss.CEWithChunkedOutputLoss, ) - ) + ), + "runtime": TORCH_TUNE_RUNTIME, }, expected_output=get_train_job( + runtime_name=TORCH_TUNE_RUNTIME, train_job_name=TRAIN_JOB_WITH_BUILT_IN_TRAINER, train_job_trainer=get_builtin_trainer(), ), @@ -694,6 +689,7 @@ def test_list_runtimes(training_client, test_case): ) }, expected_output=get_train_job( + runtime_name=TORCH_RUNTIME, train_job_name=TRAIN_JOB_WITH_CUSTOM_TRAINER, train_job_trainer=get_custom_trainer(), ), @@ -715,11 +711,16 @@ def test_list_runtimes(training_client, test_case): ) }, expected_output=get_train_job( - train_job_name=TRAIN_JOB_WITH_CUSTOM_TRAINER_ENV, + runtime_name=TORCH_RUNTIME, + train_job_name=TRAIN_JOB_WITH_CUSTOM_TRAINER, train_job_trainer=get_custom_trainer( - env = [ - models.IoK8sApiCoreV1EnvVar(name="TEST_ENV", value="test_value"), - models.IoK8sApiCoreV1EnvVar(name="ANOTHER_ENV", value="another_value"), + env=[ + models.IoK8sApiCoreV1EnvVar( + name="TEST_ENV", value="test_value" + ), + models.IoK8sApiCoreV1EnvVar( + name="ANOTHER_ENV", value="another_value" + ), ], ), ), @@ -740,17 +741,32 @@ def test_list_runtimes(training_client, test_case): }, expected_error=RuntimeError, ), + TestCase( + name="value error when runtime doesn't support CustomTrainer", + expected_status=FAILED, + config={ + "trainer": types.CustomTrainer( + func=lambda: print("Hello World"), + num_nodes=2, + ), + "runtime": TORCH_TUNE_RUNTIME, + }, + expected_error=ValueError, + ), ], ) -def test_train(training_client, test_case): +def test_train(trainer_client, test_case): """Test TrainerClient.train with basic success path.""" print("Executing test:", test_case.name) try: - training_client.namespace = test_case.config.get("namespace", DEFAULT_NAMESPACE) - test_case.config.pop( - "namespace", None - ) # None is the default value if key doesn't exist - train_job_name = training_client.train(**test_case.config) + trainer_client.namespace = test_case.config.get("namespace", DEFAULT_NAMESPACE) + runtime = trainer_client.get_runtime( + test_case.config.get("runtime", TORCH_RUNTIME) + ) + + train_job_name = trainer_client.train( + runtime=runtime, trainer=test_case.config.get("trainer", None) + ) assert test_case.expected_status == SUCCESS @@ -759,7 +775,7 @@ def test_train(training_client, test_case): expected_output = test_case.expected_output expected_output.metadata.name = train_job_name - training_client.custom_api.create_namespaced_custom_object.assert_called_with( + trainer_client.custom_api.create_namespaced_custom_object.assert_called_with( constants.GROUP, constants.VERSION, DEFAULT_NAMESPACE, @@ -780,7 +796,8 @@ def test_train(training_client, test_case): expected_status=SUCCESS, config={"name": BASIC_TRAIN_JOB_NAME}, expected_output=get_train_job_data_type( - train_job_name=BASIC_TRAIN_JOB_NAME + runtime_name=TORCH_RUNTIME, + train_job_name=BASIC_TRAIN_JOB_NAME, ), ), TestCase( @@ -797,11 +814,11 @@ def test_train(training_client, test_case): ), ], ) -def test_get_job(training_client, test_case): +def test_get_job(trainer_client, test_case): """Test TrainerClient.get_job with basic success path.""" print("Executing test:", test_case.name) try: - job = training_client.get_job(**test_case.config) + job = trainer_client.get_job(**test_case.config) assert test_case.expected_status == SUCCESS assert asdict(job) == asdict(test_case.expected_output) @@ -819,8 +836,14 @@ def test_get_job(training_client, test_case): expected_status=SUCCESS, config={}, expected_output=[ - get_train_job_data_type(train_job_name="basic-job-1"), - get_train_job_data_type(train_job_name="basic-job-2"), + get_train_job_data_type( + runtime_name=TORCH_RUNTIME, + train_job_name="basic-job-1", + ), + get_train_job_data_type( + runtime_name=TORCH_RUNTIME, + train_job_name="basic-job-2", + ), ], ), TestCase( @@ -837,12 +860,12 @@ def test_get_job(training_client, test_case): ), ], ) -def test_list_jobs(training_client, test_case): +def test_list_jobs(trainer_client, test_case): """Test TrainerClient.list_jobs with basic success path.""" print("Executing test:", test_case.name) try: - training_client.namespace = test_case.config.get("namespace", DEFAULT_NAMESPACE) - jobs = training_client.list_jobs() + trainer_client.namespace = test_case.config.get("namespace", DEFAULT_NAMESPACE) + jobs = trainer_client.list_jobs() assert test_case.expected_status == SUCCESS assert isinstance(jobs, list) @@ -875,11 +898,11 @@ def test_list_jobs(training_client, test_case): ), ], ) -def test_get_job_logs(training_client, test_case): +def test_get_job_logs(trainer_client, test_case): """Test TrainerClient.get_job_logs with basic success path.""" print("Executing test:", test_case.name) try: - logs = training_client.get_job_logs(test_case.config.get("name")) + logs = trainer_client.get_job_logs(test_case.config.get("name")) assert test_case.expected_status == SUCCESS assert logs == test_case.expected_output @@ -896,7 +919,8 @@ def test_get_job_logs(training_client, test_case): expected_status=SUCCESS, config={"name": BASIC_TRAIN_JOB_NAME}, expected_output=get_train_job_data_type( - train_job_name=BASIC_TRAIN_JOB_NAME + runtime_name=TORCH_RUNTIME, + train_job_name=BASIC_TRAIN_JOB_NAME, ), ), TestCase( @@ -907,7 +931,8 @@ def test_get_job_logs(training_client, test_case): "status": {constants.TRAINJOB_RUNNING, constants.TRAINJOB_COMPLETE}, }, expected_output=get_train_job_data_type( - train_job_name=BASIC_TRAIN_JOB_NAME + runtime_name=TORCH_RUNTIME, + train_job_name=BASIC_TRAIN_JOB_NAME, ), ), TestCase( @@ -940,11 +965,11 @@ def test_get_job_logs(training_client, test_case): ), ], ) -def test_wait_for_job_status(training_client, test_case): +def test_wait_for_job_status(trainer_client, test_case): """Test TrainerClient.wait_for_job_status with various scenarios.""" print("Executing test:", test_case.name) - original_get_job = training_client.get_job + original_get_job = trainer_client.get_job # TrainJob has unexpected failed status. def mock_get_job(name): @@ -953,10 +978,10 @@ def mock_get_job(name): job.status = constants.TRAINJOB_FAILED return job - training_client.get_job = mock_get_job + trainer_client.get_job = mock_get_job try: - job = training_client.wait_for_job_status(**test_case.config) + job = trainer_client.wait_for_job_status(**test_case.config) assert test_case.expected_status == SUCCESS assert isinstance(job, types.TrainJob) @@ -994,15 +1019,15 @@ def mock_get_job(name): ), ], ) -def test_delete_job(training_client, test_case): +def test_delete_job(trainer_client, test_case): """Test TrainerClient.delete_job with basic success path.""" print("Executing test:", test_case.name) try: - training_client.namespace = test_case.config.get("namespace", DEFAULT_NAMESPACE) - training_client.delete_job(test_case.config.get("name")) + trainer_client.namespace = test_case.config.get("namespace", DEFAULT_NAMESPACE) + trainer_client.delete_job(test_case.config.get("name")) assert test_case.expected_status == SUCCESS - training_client.custom_api.delete_namespaced_custom_object.assert_called_with( + trainer_client.custom_api.delete_namespaced_custom_object.assert_called_with( constants.GROUP, constants.VERSION, test_case.config.get("namespace", DEFAULT_NAMESPACE), diff --git a/python/kubeflow/trainer/constants/constants.py b/python/kubeflow/trainer/constants/constants.py index f4900ab1..f22d8f34 100644 --- a/python/kubeflow/trainer/constants/constants.py +++ b/python/kubeflow/trainer/constants/constants.py @@ -13,6 +13,7 @@ # limitations under the License. import os +import textwrap # How long to wait in seconds for requests to the Kubernetes API Server. DEFAULT_TIMEOUT = 120 @@ -53,6 +54,9 @@ # For example, what PodTemplate must be overridden by TrainJob's .spec.trainer APIs. TRAINJOB_ANCESTOR_LABEL = "trainer.kubeflow.org/trainjob-ancestor-step" +# The label key to identify ML framework that runtime uses (e.g. torch, deepspeed, torchtune, etc.) +RUNTIME_FRAMEWORK_LABEL = "trainer.kubeflow.org/framework" + # The name of the ReplicatedJob and container of the dataset initializer. # Also, it represents the `trainjob-ancestor-step` label value for the dataset initializer step. DATASET_INITIALIZER = "dataset-initializer" @@ -120,27 +124,49 @@ # The default PIP index URL to download Python packages. DEFAULT_PIP_INDEX_URL = os.getenv("DEFAULT_PIP_INDEX_URL", "https://pypi.org/simple") -# The default command for the Custom Trainer. -DEFAULT_CUSTOM_COMMAND = ["bash", "-c"] - -# The default command for the TorchTune Trainer. -DEFAULT_TORCHTUNE_COMMAND = ["tune", "run"] - -# The default entrypoint for torchrun. -TORCH_ENTRYPOINT = "torchrun" +# The exec script to embed training function into container command. +# __ENTRYPOINT__ depends on the MLPolicy, func_code and func_file is substituted in the `train` API. +EXEC_FUNC_SCRIPT = textwrap.dedent( + """ + read -r -d '' SCRIPT << EOM\n + {func_code} + EOM + printf "%s" \"$SCRIPT\" > \"{func_file}\" + __ENTRYPOINT__ \"{func_file}\"""" +) -# The Torch env name for the number of procs per node (e.g. number of GPUs per Pod). -TORCH_ENV_NUM_PROC_PER_NODE = "PET_NPROC_PER_NODE" +# The default command for the PlainML CustomTrainer. +DEFAULT_COMMAND = ( + "bash", + "-c", + EXEC_FUNC_SCRIPT.replace("__ENTRYPOINT__", "python"), +) # The default home directory for the MPI user. DEFAULT_MPI_USER_HOME = os.getenv("DEFAULT_MPI_USER_HOME", "/home/mpiuser") -# The default location for MPI hostfile. -# TODO (andreyvelich): We should get this info from Runtime CRD. -MPI_HOSTFILE = "/etc/mpi/hostfile" +# The default command for the OpenMPI CustomTrainer. +MPI_COMMAND = ( + "mpirun", + "--hostfile", + "/etc/mpi/hostfile", + *DEFAULT_COMMAND, +) + +# The default name for the Torch runtime. +TORCH_RUNTIME = "torch-distributed" + +# The default container command for the Torch CustomTrainer +TORCH_COMMAND = ( + "bash", + "-c", + EXEC_FUNC_SCRIPT.replace("__ENTRYPOINT__", "torchrun"), +) +# The Torch env name for the number of procs per node (e.g. number of GPUs per Pod). +TORCH_ENV_NUM_PROC_PER_NODE = "PET_NPROC_PER_NODE" -# The default entrypoint for mpirun. -MPI_ENTRYPOINT = "mpirun" +# The default command for the TorchTune BuiltinTrainer. +TORCH_TUNE_COMMAND = ("tune", "run") # The Instruct Datasets class in torchtune -TORCHTUNE_INSTRUCT_DATASET = "torchtune.datasets.instruct_dataset" +TORCH_TUNE_INSTRUCT_DATASET = "torchtune.datasets.instruct_dataset" diff --git a/python/kubeflow/trainer/types/types.py b/python/kubeflow/trainer/types/types.py index 39ed8793..bd493a89 100644 --- a/python/kubeflow/trainer/types/types.py +++ b/python/kubeflow/trainer/types/types.py @@ -13,10 +13,10 @@ # limitations under the License. -from dataclasses import dataclass +from dataclasses import dataclass, field from datetime import datetime from enum import Enum -from typing import Callable, Dict, List, Optional, Union +from typing import Callable, Dict, Optional, Union from kubeflow.trainer.constants import constants @@ -40,7 +40,7 @@ class CustomTrainer: func: Callable func_args: Optional[Dict] = None - packages_to_install: Optional[List[str]] = None + packages_to_install: Optional[list[str]] = None pip_index_url: str = constants.DEFAULT_PIP_INDEX_URL num_nodes: Optional[int] = None resources_per_node: Optional[Dict] = None @@ -151,26 +151,32 @@ class BuiltinTrainer: config: TorchTuneConfig +# Change it to list: BUILTIN_CONFIGS, once we support more Builtin Trainer configs. +TORCH_TUNE = ( + BuiltinTrainer.__annotations__["config"].__name__.lower().replace("config", "") +) + + class TrainerType(Enum): CUSTOM_TRAINER = CustomTrainer.__name__ BUILTIN_TRAINER = BuiltinTrainer.__name__ -class Framework(Enum): - TORCH = "torch" - DEEPSPEED = "deepspeed" - MLX = "mlx" - TORCHTUNE = "torchtune" - - # Representation for the Trainer of the runtime. @dataclass class RuntimeTrainer: trainer_type: TrainerType - framework: Framework + framework: str num_nodes: int = 1 # The default value is set in the APIs. - entrypoint: Optional[List[str]] = None accelerator_count: Union[str, float, int] = constants.UNKNOWN + __command: tuple[str, ...] = field(init=False, repr=False) + + @property + def command(self) -> tuple[str, ...]: + return self.__command + + def set_command(self, command: tuple[str, ...]): + self.__command = command # Representation for the Training Runtime. @@ -198,7 +204,7 @@ class TrainJob: name: str creation_timestamp: datetime runtime: Runtime - steps: List[Step] + steps: list[Step] num_nodes: int status: Optional[str] = constants.UNKNOWN @@ -231,57 +237,3 @@ class Initializer: dataset: Optional[HuggingFaceDatasetInitializer] = None model: Optional[HuggingFaceModelInitializer] = None - - -# The dict where key is the container image and value its representation. -# Each Trainer representation defines trainer parameters (e.g. type, framework, entrypoint). -# TODO (andreyvelich): We should allow user to overrides the default image names. -ALL_TRAINERS: Dict[str, RuntimeTrainer] = { - # Custom Trainers. - "pytorch/pytorch": RuntimeTrainer( - trainer_type=TrainerType.CUSTOM_TRAINER, - framework=Framework.TORCH, - entrypoint=[constants.TORCH_ENTRYPOINT], - ), - "ghcr.io/kubeflow/trainer/mlx-runtime": RuntimeTrainer( - trainer_type=TrainerType.CUSTOM_TRAINER, - framework=Framework.MLX, - entrypoint=[ - constants.MPI_ENTRYPOINT, - "--hostfile", - constants.MPI_HOSTFILE, - "bash", - "-c", - ], - ), - "ghcr.io/kubeflow/trainer/deepspeed-runtime": RuntimeTrainer( - trainer_type=TrainerType.CUSTOM_TRAINER, - framework=Framework.DEEPSPEED, - entrypoint=[ - constants.MPI_ENTRYPOINT, - "--hostfile", - constants.MPI_HOSTFILE, - "bash", - "-c", - ], - ), - # Builtin Trainers. - "ghcr.io/kubeflow/trainer/torchtune-trainer": RuntimeTrainer( - trainer_type=TrainerType.BUILTIN_TRAINER, - framework=Framework.TORCHTUNE, - entrypoint=constants.DEFAULT_TORCHTUNE_COMMAND, - ), -} - -# The default trainer configuration when runtime detection fails -DEFAULT_TRAINER = RuntimeTrainer( - trainer_type=TrainerType.CUSTOM_TRAINER, - framework=Framework.TORCH, - entrypoint=[constants.TORCH_ENTRYPOINT], -) - -# The default runtime configuration for the train() API -DEFAULT_RUNTIME = Runtime( - name="torch-distributed", - trainer=DEFAULT_TRAINER, -) diff --git a/python/kubeflow/trainer/utils/utils.py b/python/kubeflow/trainer/utils/utils.py index 98eb777d..91573e65 100644 --- a/python/kubeflow/trainer/utils/utils.py +++ b/python/kubeflow/trainer/utils/utils.py @@ -17,7 +17,7 @@ import queue import textwrap import threading -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, Optional from urllib.parse import urlparse from kubeflow.trainer.constants import constants @@ -49,7 +49,7 @@ def get_default_target_namespace(context: Optional[str] = None) -> str: def get_container_devices( resources: Optional[models.IoK8sApiCoreV1ResourceRequirements], -) -> Optional[Tuple[str, str]]: +) -> Optional[tuple[str, str]]: """ Get the device type and device count for the given container. """ @@ -82,7 +82,7 @@ def get_container_devices( def get_runtime_trainer_container( - replicated_jobs: List[models.JobsetV1alpha2ReplicatedJob], + replicated_jobs: list[models.JobsetV1alpha2ReplicatedJob], ) -> Optional[models.IoK8sApiCoreV1Container]: """ Get the runtime node container from the given replicated jobs. @@ -107,11 +107,12 @@ def get_runtime_trainer_container( def get_runtime_trainer( - replicated_jobs: List[models.JobsetV1alpha2ReplicatedJob], + framework: str, + replicated_jobs: list[models.JobsetV1alpha2ReplicatedJob], ml_policy: models.TrainerV1alpha1MLPolicy, ) -> types.RuntimeTrainer: """ - Get the runtime trainer object. + Get the RuntimeTrainer object. """ trainer_container = get_runtime_trainer_container(replicated_jobs) @@ -119,9 +120,14 @@ def get_runtime_trainer( if not (trainer_container and trainer_container.image): raise Exception(f"Runtime doesn't have trainer container {replicated_jobs}") - # Extract image name from the container image to get appropriate Trainer. - image_name = trainer_container.image.split(":")[0] - trainer = types.ALL_TRAINERS.get(image_name, types.DEFAULT_TRAINER) + trainer = types.RuntimeTrainer( + trainer_type=( + types.TrainerType.BUILTIN_TRAINER + if framework == types.TORCH_TUNE + else types.TrainerType.CUSTOM_TRAINER + ), + framework=framework, + ) # Get the container devices. if devices := get_container_devices(trainer_container.resources): @@ -143,6 +149,16 @@ def get_runtime_trainer( if ml_policy.num_nodes: trainer.num_nodes = ml_policy.num_nodes + # Set the Trainer entrypoint. + if framework == types.TORCH_TUNE: + trainer.set_command(constants.TORCH_TUNE_COMMAND) + elif ml_policy.torch: + trainer.set_command(constants.TORCH_COMMAND) + elif ml_policy.mpi: + trainer.set_command(constants.MPI_COMMAND) + else: + trainer.set_command(constants.DEFAULT_COMMAND) + return trainer @@ -199,8 +215,7 @@ def get_trainjob_node_step( # For the MPI use-cases, the launcher container is always node-0 # Thus, we should increase the index for other nodes. if ( - trainjob_runtime.trainer.entrypoint - and trainjob_runtime.trainer.entrypoint[0] == constants.MPI_ENTRYPOINT + trainjob_runtime.trainer.command[0] == "mpirun" and replicated_job_name != constants.LAUNCHER ): # TODO (andreyvelich): We should also override the device_count @@ -240,17 +255,47 @@ def get_resources_per_node( return resources -def get_entrypoint_using_train_func( +def get_script_for_python_packages( + packages_to_install: list[str], + pip_index_url: str, + is_mpi: bool, +) -> str: + """ + Get init script to install Python packages from the given pip index URL. + """ + # packages_str = " ".join([str(package) for package in packages_to_install]) + packages_str = " ".join(packages_to_install) + + script_for_python_packages = textwrap.dedent( + """ + if ! [ -x "$(command -v pip)" ]; then + python -m ensurepip || python -m ensurepip --user || apt-get install python-pip + fi + + PIP_DISABLE_PIP_VERSION_CHECK=1 python -m pip install --quiet \ + --no-warn-script-location --index-url {} {} {} + """.format( + pip_index_url, + packages_str, + # For the OpenMPI, the packages must be installed for the mpiuser. + "--user" if is_mpi else "", + ) + ) + + return script_for_python_packages + + +def get_command_using_train_func( runtime: types.Runtime, train_func: Callable, train_func_parameters: Optional[Dict[str, Any]], pip_index_url: str, - packages_to_install: Optional[List[str]] = None, -) -> Tuple[List[str], List[str]]: + packages_to_install: Optional[list[str]] = None, +) -> list[str]: """ - Get the Trainer command and args from the given training function and parameters. + Get the Trainer container command from the given training function and parameters. """ - # Check if the runtime has a trainer. + # Check if the runtime has a Trainer. if not runtime.trainer: raise ValueError(f"Runtime must have a trainer: {runtime}") @@ -280,112 +325,37 @@ def get_entrypoint_using_train_func( else: func_code = f"{func_code}\n{train_func.__name__}({train_func_parameters})\n" - # Prepare the template to execute script. - # Currently, we override the file where the training function is defined. - # That allows to execute the training script with the entrypoint. - if runtime.trainer.entrypoint is None: - raise Exception(f"Runtime trainer must have an entrypoint: {runtime.trainer}") - - # We don't allow to override python entrypoint for `mpirun` - if runtime.trainer.entrypoint[0] == constants.MPI_ENTRYPOINT: - container_command = runtime.trainer.entrypoint - python_entrypoint = "python" - # The default file location is: /home/mpiuser/.py + is_mpi = runtime.trainer.command[0] == "mpirun" + # The default file location for OpenMPI is: /home/mpiuser/.py + if is_mpi: func_file = os.path.join(constants.DEFAULT_MPI_USER_HOME, func_file) - else: - container_command = constants.DEFAULT_CUSTOM_COMMAND - python_entrypoint = " ".join(runtime.trainer.entrypoint) - - exec_script = textwrap.dedent( - """ - read -r -d '' SCRIPT << EOM\n - {func_code} - EOM - printf "%s" \"$SCRIPT\" > \"{func_file}\" - {python_entrypoint} \"{func_file}\"""" - ) - - # Add function code to the execute script. - exec_script = exec_script.format( - func_code=func_code, - func_file=func_file, - python_entrypoint=python_entrypoint, - ) # Install Python packages if that is required. - if packages_to_install is not None: - exec_script = ( - get_script_for_python_packages( - packages_to_install, pip_index_url, runtime.trainer.entrypoint - ) - + exec_script - ) - - # Return container command and args to execute training function. - return container_command, [exec_script] - - -def get_args_using_torchtune_config( - fine_tuning_config: types.TorchTuneConfig, - initializer: Optional[types.Initializer] = None, -) -> Tuple[List[str], List[str]]: - """ - Get the Trainer args from the TorchTuneConfig. - """ - args = [] - - # Override the dtype if it is provided. - if fine_tuning_config.dtype: - if not isinstance(fine_tuning_config.dtype, types.DataType): - raise ValueError(f"Invalid dtype: {fine_tuning_config.dtype}.") - - args.append(f"dtype={fine_tuning_config.dtype}") - - # Override the batch size if it is provided. - if fine_tuning_config.batch_size: - args.append(f"batch_size={fine_tuning_config.batch_size}") - - # Override the epochs if it is provided. - if fine_tuning_config.epochs: - args.append(f"epochs={fine_tuning_config.epochs}") - - # Override the loss if it is provided. - if fine_tuning_config.loss: - args.append(f"loss={fine_tuning_config.loss}") - - # Override the data dir or data files if it is provided. - if isinstance(initializer, types.Initializer) and isinstance( - initializer.dataset, types.HuggingFaceDatasetInitializer - ): - storage_uri = ( - "hf://" + initializer.dataset.storage_uri - if not initializer.dataset.storage_uri.startswith("hf://") - else initializer.dataset.storage_uri + install_packages = "" + if packages_to_install: + install_packages = get_script_for_python_packages( + packages_to_install, + pip_index_url, + is_mpi, ) - storage_uri_parsed = urlparse(storage_uri) - parts = storage_uri_parsed.path.strip("/").split("/") - relative_path = "/".join(parts[1:]) if len(parts) > 1 else "." - if relative_path != "." and "." in relative_path: - args.append( - f"dataset.data_files={os.path.join(constants.DATASET_PATH, relative_path)}" - ) + # Add function code to the Trainer command. + command = [] + for c in runtime.trainer.command: + if "{func_file}" in c: + exec_script = c.format(func_code=func_code, func_file=func_file) + if install_packages: + exec_script = install_packages + exec_script + command.append(exec_script) else: - args.append( - f"dataset.data_dir={os.path.join(constants.DATASET_PATH, relative_path)}" - ) - - if fine_tuning_config.dataset_preprocess_config: - args += get_args_in_dataset_preprocess_config( - fine_tuning_config.dataset_preprocess_config - ) + command.append(c) - return constants.DEFAULT_TORCHTUNE_COMMAND, args + return command def get_trainer_crd_from_custom_trainer( - trainer: types.CustomTrainer, runtime: types.Runtime, + trainer: types.CustomTrainer, ) -> models.TrainerV1alpha1Trainer: """ Get the Trainer CRD from the custom trainer. @@ -402,10 +372,9 @@ def get_trainer_crd_from_custom_trainer( trainer.resources_per_node ) - # Add command and args to the Trainer. - trainer_crd.command = constants.DEFAULT_CUSTOM_COMMAND + # Add command to the Trainer. # TODO: Support train function parameters. - trainer_crd.command, trainer_crd.args = get_entrypoint_using_train_func( + trainer_crd.command = get_command_using_train_func( runtime, trainer.func, trainer.func_args, @@ -424,6 +393,7 @@ def get_trainer_crd_from_custom_trainer( def get_trainer_crd_from_builtin_trainer( + runtime: types.Runtime, trainer: types.BuiltinTrainer, initializer: Optional[types.Initializer] = None, ) -> models.TrainerV1alpha1Trainer: @@ -445,43 +415,119 @@ def get_trainer_crd_from_builtin_trainer( trainer.config.resources_per_node ) + trainer_crd.command = list(runtime.trainer.command) # Parse args in the TorchTuneConfig to the Trainer, preparing for the mutation of # the torchtune config in the runtime plugin. # Ref:https://github.com/kubeflow/trainer/tree/master/docs/proposals/2401-llm-trainer-v2 - trainer_crd.command, trainer_crd.args = get_args_using_torchtune_config( - trainer.config, initializer - ) + trainer_crd.args = get_args_using_torchtune_config(trainer.config, initializer) return trainer_crd -def get_script_for_python_packages( - packages_to_install: List[str], - pip_index_url: str, - runtime_entrypoint: List[str], -) -> str: +def get_args_using_torchtune_config( + fine_tuning_config: types.TorchTuneConfig, + initializer: Optional[types.Initializer] = None, +) -> list[str]: """ - Get init script to install Python packages from the given pip index URL. + Get the Trainer args from the TorchTuneConfig. """ - packages_str = " ".join([str(package) for package in packages_to_install]) + args = [] - script_for_python_packages = textwrap.dedent( - """ - if ! [ -x "$(command -v pip)" ]; then - python -m ensurepip || python -m ensurepip --user || apt-get install python-pip - fi + # Override the dtype if it is provided. + if fine_tuning_config.dtype: + if not isinstance(fine_tuning_config.dtype, types.DataType): + raise ValueError(f"Invalid dtype: {fine_tuning_config.dtype}.") - PIP_DISABLE_PIP_VERSION_CHECK=1 python -m pip install --quiet \ - --no-warn-script-location --index-url {} {} {} - """.format( - pip_index_url, - packages_str, - # For the OpenMPI, the packages must be installed for the mpiuser. - "--user" if runtime_entrypoint[0] == constants.MPI_ENTRYPOINT else "", + args.append(f"dtype={fine_tuning_config.dtype}") + + # Override the batch size if it is provided. + if fine_tuning_config.batch_size: + args.append(f"batch_size={fine_tuning_config.batch_size}") + + # Override the epochs if it is provided. + if fine_tuning_config.epochs: + args.append(f"epochs={fine_tuning_config.epochs}") + + # Override the loss if it is provided. + if fine_tuning_config.loss: + args.append(f"loss={fine_tuning_config.loss}") + + # Override the data dir or data files if it is provided. + if isinstance(initializer, types.Initializer) and isinstance( + initializer.dataset, types.HuggingFaceDatasetInitializer + ): + storage_uri = ( + "hf://" + initializer.dataset.storage_uri + if not initializer.dataset.storage_uri.startswith("hf://") + else initializer.dataset.storage_uri ) - ) + storage_uri_parsed = urlparse(storage_uri) + parts = storage_uri_parsed.path.strip("/").split("/") + relative_path = "/".join(parts[1:]) if len(parts) > 1 else "." - return script_for_python_packages + if relative_path != "." and "." in relative_path: + args.append( + f"dataset.data_files={os.path.join(constants.DATASET_PATH, relative_path)}" + ) + else: + args.append( + f"dataset.data_dir={os.path.join(constants.DATASET_PATH, relative_path)}" + ) + + if fine_tuning_config.dataset_preprocess_config: + args += get_args_in_dataset_preprocess_config( + fine_tuning_config.dataset_preprocess_config + ) + + return args + + +def get_args_in_dataset_preprocess_config( + dataset_preprocess_config: types.TorchTuneInstructDataset, +) -> list[str]: + """ + Get the args from the given dataset preprocess config. + """ + args = [] + + if not isinstance(dataset_preprocess_config, types.TorchTuneInstructDataset): + raise ValueError( + f"Invalid dataset preprocess config type: {type(dataset_preprocess_config)}." + ) + + # Override the dataset type field in the torchtune config. + args.append(f"dataset={constants.TORCH_TUNE_INSTRUCT_DATASET}") + + # Override the dataset source field if it is provided. + if dataset_preprocess_config.source: + if not isinstance(dataset_preprocess_config.source, types.DataFormat): + raise ValueError( + f"Invalid data format: {dataset_preprocess_config.source.value}." + ) + + args.append(f"dataset.source={dataset_preprocess_config.source.value}") + + # Override the split field if it is provided. + if dataset_preprocess_config.split: + args.append(f"dataset.split={dataset_preprocess_config.split}") + + # Override the train_on_input field if it is provided. + if dataset_preprocess_config.train_on_input: + args.append( + f"dataset.train_on_input={dataset_preprocess_config.train_on_input}" + ) + + # Override the new_system_prompt field if it is provided. + if dataset_preprocess_config.new_system_prompt: + args.append( + f"dataset.new_system_prompt={dataset_preprocess_config.new_system_prompt}" + ) + + # Override the column_map field if it is provided. + if dataset_preprocess_config.column_map: + args.append(f"dataset.column_map={dataset_preprocess_config.column_map}") + + return args def get_dataset_initializer( @@ -556,58 +602,10 @@ def wrap_log_stream(q: queue.Queue, log_stream: Any): return -def get_log_queue_pool(log_streams: List[Any]) -> List[queue.Queue]: +def get_log_queue_pool(log_streams: list[Any]) -> list[queue.Queue]: pool = [] for log_stream in log_streams: q = queue.Queue(maxsize=100) pool.append(q) threading.Thread(target=wrap_log_stream, args=(q, log_stream)).start() return pool - - -def get_args_in_dataset_preprocess_config( - dataset_preprocess_config: types.TorchTuneInstructDataset, -) -> List[str]: - """ - Get the args from the given dataset preprocess config. - """ - args = [] - - if not isinstance(dataset_preprocess_config, types.TorchTuneInstructDataset): - raise ValueError( - f"Invalid dataset preprocess config type: {type(dataset_preprocess_config)}." - ) - - # Override the dataset type field in the torchtune config. - args.append(f"dataset={constants.TORCHTUNE_INSTRUCT_DATASET}") - - # Override the dataset source field if it is provided. - if dataset_preprocess_config.source: - if not isinstance(dataset_preprocess_config.source, types.DataFormat): - raise ValueError( - f"Invalid data format: {dataset_preprocess_config.source.value}." - ) - - args.append(f"dataset.source={dataset_preprocess_config.source.value}") - - # Override the split field if it is provided. - if dataset_preprocess_config.split: - args.append(f"dataset.split={dataset_preprocess_config.split}") - - # Override the train_on_input field if it is provided. - if dataset_preprocess_config.train_on_input: - args.append( - f"dataset.train_on_input={dataset_preprocess_config.train_on_input}" - ) - - # Override the new_system_prompt field if it is provided. - if dataset_preprocess_config.new_system_prompt: - args.append( - f"dataset.new_system_prompt={dataset_preprocess_config.new_system_prompt}" - ) - - # Override the column_map field if it is provided. - if dataset_preprocess_config.column_map: - args.append(f"dataset.column_map={dataset_preprocess_config.column_map}") - - return args