diff --git a/python/kubeflow/trainer/api/trainer_client_test.py b/python/kubeflow/trainer/api/trainer_client_test.py index 410dbfaf..da1609cd 100644 --- a/python/kubeflow/trainer/api/trainer_client_test.py +++ b/python/kubeflow/trainer/api/trainer_client_test.py @@ -69,6 +69,7 @@ class TestCase: TRAIN_JOBS = "trainjobs" TRAIN_JOB_WITH_BUILT_IN_TRAINER = "train-job-with-built-in-trainer" TRAIN_JOB_WITH_CUSTOM_TRAINER = "train-job-with-custom-trainer" +TRAIN_JOB_WITH_CUSTOM_TRAINER_ENV = "train-job-with-custom-trainer-env" # -------------------------- @@ -221,7 +222,9 @@ def get_resource_requirements() -> models.IoK8sApiCoreV1ResourceRequirements: ) -def get_custom_trainer() -> models.TrainerV1alpha1Trainer: +def get_custom_trainer( + env: Optional[list[models.IoK8sApiCoreV1EnvVar]] = None, +) -> models.TrainerV1alpha1Trainer: """ Get the custom trainer for the TrainJob. """ @@ -239,9 +242,9 @@ def get_custom_trainer() -> models.TrainerV1alpha1Trainer: '"$SCRIPT" > "trainer_client_test.py"\ntorchrun "trainer_client_test.py"' ], numNodes=2, + env=env, ) - def get_builtin_trainer() -> models.TrainerV1alpha1Trainer: """ Get the builtin trainer for the TrainJob. @@ -695,6 +698,32 @@ def test_list_runtimes(training_client, test_case): train_job_trainer=get_custom_trainer(), ), ), + TestCase( + name="valid flow with custom trainer and env vars", + expected_status=SUCCESS, + config={ + "trainer": types.CustomTrainer( + func=lambda: print("Hello World"), + func_args={"learning_rate": 0.001, "batch_size": 32}, + packages_to_install=["torch", "numpy"], + pip_index_url=constants.DEFAULT_PIP_INDEX_URL, + num_nodes=2, + env={ + "TEST_ENV": "test_value", + "ANOTHER_ENV": "another_value", + }, + ) + }, + expected_output=get_train_job( + train_job_name=TRAIN_JOB_WITH_CUSTOM_TRAINER_ENV, + train_job_trainer=get_custom_trainer( + env = [ + models.IoK8sApiCoreV1EnvVar(name="TEST_ENV", value="test_value"), + models.IoK8sApiCoreV1EnvVar(name="ANOTHER_ENV", value="another_value"), + ], + ), + ), + ), TestCase( name="timeout error when deleting job", expected_status=FAILED, diff --git a/python/kubeflow/trainer/types/types.py b/python/kubeflow/trainer/types/types.py index 2220b97c..39ed8793 100644 --- a/python/kubeflow/trainer/types/types.py +++ b/python/kubeflow/trainer/types/types.py @@ -35,6 +35,7 @@ class CustomTrainer: pip_index_url (`Optional[str]`): The PyPI URL from which to install Python packages. num_nodes (`Optional[int]`): The number of nodes to use for training. resources_per_node (`Optional[Dict]`): The computing resources to allocate per node. + env (`Optional[Dict[str, str]]`): The environment variables to set in the training nodes. """ func: Callable @@ -43,6 +44,7 @@ class CustomTrainer: pip_index_url: str = constants.DEFAULT_PIP_INDEX_URL num_nodes: Optional[int] = None resources_per_node: Optional[Dict] = None + env: Optional[Dict[str, str]] = None # TODO(Electronic-Waste): Add more loss functions. diff --git a/python/kubeflow/trainer/utils/utils.py b/python/kubeflow/trainer/utils/utils.py index 8837c5fc..98eb777d 100644 --- a/python/kubeflow/trainer/utils/utils.py +++ b/python/kubeflow/trainer/utils/utils.py @@ -413,6 +413,13 @@ def get_trainer_crd_from_custom_trainer( trainer.packages_to_install, ) + # Add environment variables to the Trainer. + if trainer.env: + trainer_crd.env = [ + models.IoK8sApiCoreV1EnvVar(name=key, value=value) + for key, value in trainer.env.items() + ] + return trainer_crd