diff --git a/sdk/python/kubeflow/training/api/training_client.py b/sdk/python/kubeflow/training/api/training_client.py index 901a9e9028..34db0d9942 100644 --- a/sdk/python/kubeflow/training/api/training_client.py +++ b/sdk/python/kubeflow/training/api/training_client.py @@ -354,6 +354,8 @@ def create_job( env_vars: Optional[ Union[Dict[str, str], List[Union[models.V1EnvVar, models.V1EnvVar]]] ] = None, + volumes: Optional[List[models.V1Volume]] = None, + volume_mounts: Optional[List[models.V1VolumeMount]] = None, ): """Create the Training Job. Job can be created using one of the following options: @@ -418,6 +420,8 @@ def create_job( https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1EnvVar.md) or a kubernetes.client.models.V1EnvFromSource (documented here: https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1EnvFromSource.md) + volumes: Volume(s) to be attached to the replicas. + volume_mounts: VolumeMount(s) specifying where to mount the volume(s) into the replicas. Raises: ValueError: Invalid input parameters. @@ -448,6 +452,12 @@ def create_job( f"Job kind must be one of these: {constants.JOB_PARAMETERS.keys()}" ) + if len(volumes or []) != len(volume_mounts or []): + raise ValueError( + "Volumes and VolumeMounts must be the same length: " + f"{len(volumes or [])} vs. {len(volume_mounts or [])}" + ) + # If Training function or base image is set, configure Job template. if job is None and (train_func is not None or base_image is not None): # Job name must be set to configure Job template. @@ -496,11 +506,13 @@ def create_job( args=args, resources=resources_per_worker, env_vars=env_vars, + volume_mounts=volume_mounts, ) # Get Pod template spec using the above container. pod_template_spec = utils.get_pod_template_spec( containers=[container_spec], + volumes=volumes, ) # Configure template for different Jobs. diff --git a/sdk/python/kubeflow/training/api/training_client_test.py b/sdk/python/kubeflow/training/api/training_client_test.py index bc5366f078..97060805dd 100644 --- a/sdk/python/kubeflow/training/api/training_client_test.py +++ b/sdk/python/kubeflow/training/api/training_client_test.py @@ -22,6 +22,8 @@ V1ObjectMeta, V1PodSpec, V1PodTemplateSpec, + V1Volume, + V1VolumeMount, ) TEST_NAME = "test" @@ -142,6 +144,8 @@ def create_job( args=None, num_workers=2, env_vars=None, + volumes=None, + volume_mounts=None, ): # Handle env_vars as either a dict or a list if env_vars: @@ -158,6 +162,7 @@ def create_job( command=command, args=args, env=env_vars, + volume_mounts=volume_mounts, ) master = KubeflowOrgV1ReplicaSpec( @@ -166,7 +171,10 @@ def create_job( metadata=V1ObjectMeta( annotations={constants.ISTIO_SIDECAR_INJECTION: "false"} ), - spec=V1PodSpec(containers=[container]), + spec=V1PodSpec( + containers=[container], + volumes=volumes, + ), ), ) @@ -180,7 +188,10 @@ def create_job( metadata=V1ObjectMeta( annotations={constants.ISTIO_SIDECAR_INJECTION: "false"} ), - spec=V1PodSpec(containers=[container]), + spec=V1PodSpec( + containers=[container], + volumes=volumes, + ), ), ) @@ -530,6 +541,35 @@ def __init__(self): env_vars=[V1EnvVar(name="ENV_VAR", value="env_value")], num_workers=2 ), ), + ( + "create job with a volume and a volume mount", + { + "name": TEST_NAME, + "namespace": TEST_NAME, + "base_image": TEST_IMAGE, + "num_workers": 1, + "volumes": [V1Volume(name="vol")], + "volume_mounts": [V1VolumeMount(name="vol", mount_path="/mnt")], + }, + SUCCESS, + create_job( + num_workers=1, + volumes=[V1Volume(name="vol")], + volume_mounts=[V1VolumeMount(name="vol", mount_path="/mnt")], + ), + ), + ( + "invalid number of volume mount", + { + "name": TEST_NAME, + "namespace": TEST_NAME, + "base_image": TEST_IMAGE, + "num_workers": 1, + "volumes": [V1Volume(name="vol")], + }, + ValueError, + None, + ), ] test_data_get_job_pods = [