Skip to content

Commit a005199

Browse files
committed
feat(sdk): support volume mount in tune API
Signed-off-by: truc0 <22969604+truc0@users.noreply.github.com>
1 parent c18035e commit a005199

File tree

1 file changed

+154
-2
lines changed

1 file changed

+154
-2
lines changed

sdk/python/v1beta1/kubeflow/katib/api/katib_client.py

Lines changed: 154 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import logging
1717
import multiprocessing
1818
import time
19-
from typing import Any, Callable, Dict, List, Optional, Union
19+
from typing import Any, Callable, Dict, List, Optional, TypedDict, Union
2020

2121
import grpc
2222
import kubeflow.katib.katib_api_pb2 as katib_api_pb2
@@ -42,6 +42,14 @@
4242

4343
logger = logging.getLogger(__name__)
4444

45+
TuneStoragePerTrialType = TypedDict(
46+
"TuneStoragePerTrial",
47+
{
48+
"volume": Union[client.V1Volume, Dict[str, Any]],
49+
"mount_path": Union[str, client.V1VolumeMount],
50+
},
51+
)
52+
4553

4654
class KatibClient(object):
4755
def __init__(
@@ -198,6 +206,7 @@ def tune(
198206
env_per_trial: Optional[
199207
Union[Dict[str, str], List[Union[client.V1EnvVar, client.V1EnvFromSource]]]
200208
] = None,
209+
storage_per_trial: Optional[List[TuneStoragePerTrialType]] = None,
201210
algorithm_name: str = "random",
202211
algorithm_settings: Union[
203212
dict, List[models.V1beta1AlgorithmSetting], None
@@ -288,6 +297,21 @@ class name in this argument.
288297
https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1EnvVar.md)
289298
or a kubernetes.client.models.V1EnvFromSource (documented here:
290299
https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1EnvFromSource.md)
300+
storage_per_trial: List of storage configurations for each trial container.
301+
Each element in the list should be a dictionary with two keys:
302+
- volume: Either a kubernetes.client.V1Volume object or a dictionary
303+
containing volume configuration with required fields:
304+
- name: Name of the volume
305+
- type: One of "pvc", "secret", "config_map", or "empty_dir"
306+
Additional fields based on volume type:
307+
- For pvc: claim_name, read_only (optional)
308+
- For secret: secret_name, items (optional), default_mode (optional),
309+
optional (optional)
310+
- For config_map: config_map_name, items (optional), default_mode
311+
(optional), optional (optional)
312+
- For empty_dir: medium (optional), size_limit (optional)
313+
- mount_path: Either a kubernetes.client.V1VolumeMount object or a string
314+
specifying the path where the volume should be mounted in the container
291315
algorithm_name: Search algorithm for the HyperParameter tuning.
292316
algorithm_settings: Settings for the search algorithm given.
293317
For available fields, check this doc:
@@ -503,6 +527,134 @@ class name in this argument.
503527
container_spec.env = env if env else None
504528
container_spec.env_from = env_from if env_from else None
505529

530+
volumes: List[client.V1Volume] = []
531+
volume_mounts: List[client.V1VolumeMount] = []
532+
if storage_per_trial:
533+
if isinstance(storage_per_trial, dict):
534+
storage_per_trial = [storage_per_trial]
535+
for storage in storage_per_trial:
536+
print(f"storage: {storage}")
537+
volume = None
538+
if isinstance(storage["volume"], client.V1Volume):
539+
volume = storage["volume"]
540+
elif isinstance(storage["volume"], dict):
541+
volume_name = storage["volume"].get("name")
542+
volume_type = storage["volume"].get("type")
543+
544+
if not volume_name:
545+
raise ValueError(
546+
"storage_per_trial['volume'] does not have a 'name' key"
547+
)
548+
if not volume_type:
549+
raise ValueError(
550+
"storage_per_trial['volume'] does not have a 'type' key"
551+
)
552+
553+
if volume_type == "pvc":
554+
volume_claim_name = storage["volume"].get("claim_name")
555+
if not volume_claim_name:
556+
raise ValueError(
557+
"storage_per_trial['volume'] should have a "
558+
"'claim_name' key for type pvc"
559+
)
560+
volume = client.V1Volume(
561+
name=volume_name,
562+
persistent_volume_claim=client.V1PersistentVolumeClaimVolumeSource(
563+
claim_name=volume_claim_name,
564+
read_only=storage["volume"].get("read_only", False),
565+
),
566+
)
567+
elif volume_type == "secret":
568+
volume = client.V1Volume(
569+
name=volume_name,
570+
secret=client.V1SecretVolumeSource(
571+
secret_name=storage["volume"].get("secret_name"),
572+
items=storage["volume"].get("items", None),
573+
default_mode=storage["volume"].get(
574+
"default_mode", None
575+
),
576+
optional=storage["volume"].get("optional", False),
577+
),
578+
)
579+
elif volume_type == "config_map":
580+
volume = client.V1Volume(
581+
name=volume_name,
582+
config_map=client.V1ConfigMapVolumeSource(
583+
name=storage["volume"].get("config_map_name"),
584+
items=storage["volume"].get("items", []),
585+
default_mode=storage["volume"].get(
586+
"default_mode", None
587+
),
588+
optional=storage["volume"].get("optional", False),
589+
),
590+
)
591+
elif volume_type == "empty_dir":
592+
volume = client.V1Volume(
593+
name=volume_name,
594+
empty_dir=client.V1EmptyDirVolumeSource(
595+
medium=storage["volume"].get("medium", None),
596+
size_limit=storage["volume"].get(
597+
"size_limit", None
598+
),
599+
),
600+
)
601+
else:
602+
raise ValueError(
603+
"storage_per_trial['volume'] must be a client.V1Volume or a dict"
604+
)
605+
606+
else:
607+
raise ValueError(
608+
"storage_per_trial['volume'] must be a client.V1Volume or a dict"
609+
)
610+
611+
volumes.append(volume)
612+
613+
if isinstance(storage["mount_path"], client.V1VolumeMount):
614+
volume_mounts.append(storage["mount_path"])
615+
elif isinstance(storage["mount_path"], str):
616+
volume_mounts.append(
617+
client.V1VolumeMount(
618+
name=volume_name, mount_path=storage["mount_path"]
619+
)
620+
)
621+
else:
622+
raise ValueError(
623+
"storage_per_trial['mount_path'] must be a "
624+
"client.V1VolumeMount or a str"
625+
)
626+
627+
# Create Trial specification.
628+
trial_spec = client.V1Job(
629+
api_version="batch/v1",
630+
kind="Job",
631+
spec=client.V1JobSpec(
632+
template=client.V1PodTemplateSpec(
633+
metadata=models.V1ObjectMeta(
634+
annotations={"sidecar.istio.io/inject": "false"}
635+
),
636+
spec=client.V1PodSpec(
637+
restart_policy="Never",
638+
containers=[
639+
client.V1Container(
640+
name=constants.DEFAULT_PRIMARY_CONTAINER_NAME,
641+
image=base_image,
642+
command=["bash", "-c"],
643+
args=[exec_script],
644+
env=env if env else None,
645+
env_from=env_from if env_from else None,
646+
resources=resources_per_trial,
647+
volume_mounts=(
648+
volume_mounts if volume_mounts else None
649+
),
650+
)
651+
],
652+
volumes=volumes if volumes else None,
653+
),
654+
)
655+
),
656+
)
657+
506658
# Trial uses PyTorchJob for distributed training if TrainerResources is set.
507659
if isinstance(resources_per_trial, TrainerResources):
508660
trial_template = utils.get_trial_template_with_pytorchjob(
@@ -584,7 +736,7 @@ class name in this argument.
584736
f"It must also start and end with an alphanumeric character."
585737
)
586738
elif hasattr(e, "status") and e.status == 409:
587-
print(f"PVC '{name}' already exists in namespace " f"{namespace}.")
739+
print(f"PVC '{name}' already exists in namespace {namespace}.")
588740
else:
589741
raise RuntimeError(f"failed to create PVC. Error: {e}")
590742

0 commit comments

Comments
 (0)