Skip to content

Commit 898d047

Browse files
committed
feat(sdk): support volume mount in tune API
Signed-off-by: truc0 <22969604+truc0@users.noreply.github.com>
1 parent 4d2a230 commit 898d047

File tree

1 file changed

+97
-2
lines changed

1 file changed

+97
-2
lines changed

sdk/python/v1beta1/kubeflow/katib/api/katib_client.py

Lines changed: 97 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import logging
1717
import multiprocessing
1818
import time
19-
from typing import Any, Callable, Dict, List, Optional, Union
19+
from typing import Any, Callable, Dict, List, Optional, Union, TypedDict
2020

2121
import grpc
2222
import kubeflow.katib.katib_api_pb2 as katib_api_pb2
@@ -30,6 +30,10 @@
3030

3131
logger = logging.getLogger(__name__)
3232

33+
TuneStoragePerTrialType = TypedDict(
34+
"TuneStoragePerTrial",
35+
{"volume": Union[client.V1Volume, Dict[str, Any]], "mount_path": Union[str, client.V1VolumeMount]},
36+
)
3337

3438
class KatibClient(object):
3539
def __init__(
@@ -186,6 +190,7 @@ def tune(
186190
env_per_trial: Optional[
187191
Union[Dict[str, str], List[Union[client.V1EnvVar, client.V1EnvFromSource]]]
188192
] = None,
193+
storage_per_trial: Optional[List[TuneStoragePerTrialType]] = None,
189194
algorithm_name: str = "random",
190195
algorithm_settings: Union[
191196
dict, List[models.V1beta1AlgorithmSetting], None
@@ -276,6 +281,21 @@ class name in this argument.
276281
https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1EnvVar.md)
277282
or a kubernetes.client.models.V1EnvFromSource (documented here:
278283
https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1EnvFromSource.md)
284+
storage_per_trial: List of storage configurations for each trial container.
285+
Each element in the list should be a dictionary with two keys:
286+
- volume: Either a kubernetes.client.V1Volume object or a dictionary
287+
containing volume configuration with required fields:
288+
- name: Name of the volume
289+
- type: One of "pvc", "secret", "config_map", or "empty_dir"
290+
Additional fields based on volume type:
291+
- For pvc: claim_name, read_only (optional)
292+
- For secret: secret_name, items (optional), default_mode (optional),
293+
optional (optional)
294+
- For config_map: config_map_name, items (optional), default_mode
295+
(optional), optional (optional)
296+
- For empty_dir: medium (optional), size_limit (optional)
297+
- mount_path: Either a kubernetes.client.V1VolumeMount object or a string
298+
specifying the path where the volume should be mounted in the container
279299
algorithm_name: Search algorithm for the HyperParameter tuning.
280300
algorithm_settings: Settings for the search algorithm given.
281301
For available fields, check this doc:
@@ -468,6 +488,79 @@ class name in this argument.
468488
f"Incorrect value for env_per_trial: {env_per_trial}"
469489
)
470490

491+
volumes: List[client.V1Volume] = []
492+
volume_mounts: List[client.V1VolumeMount] = []
493+
if storage_per_trial:
494+
if isinstance(storage_per_trial, dict):
495+
storage_per_trial = [storage_per_trial]
496+
for storage in storage_per_trial:
497+
print(f"storage: {storage}")
498+
volume = None
499+
if isinstance(storage["volume"], client.V1Volume):
500+
volume = storage["volume"]
501+
elif isinstance(storage["volume"], dict):
502+
volume_name = storage["volume"].get("name")
503+
volume_type = storage["volume"].get("type")
504+
505+
if not volume_name:
506+
raise ValueError("storage_per_trial['volume'] does not have a 'name' key")
507+
if not volume_type:
508+
raise ValueError("storage_per_trial['volume'] does not have a 'type' key")
509+
510+
if volume_type == "pvc":
511+
volume_claim_name = storage["volume"].get("claim_name")
512+
if not volume_claim_name:
513+
raise ValueError("storage_per_trial['volume'] should have a 'claim_name' key for type pvc")
514+
volume = client.V1Volume(
515+
name=volume_name,
516+
persistent_volume_claim=client.V1PersistentVolumeClaimVolumeSource(
517+
claim_name=volume_claim_name,
518+
read_only=storage["volume"].get("read_only", False),
519+
)
520+
)
521+
elif volume_type == "secret":
522+
volume = client.V1Volume(
523+
name=volume_name,
524+
secret=client.V1SecretVolumeSource(
525+
secret_name=storage["volume"].get("secret_name"),
526+
items=storage["volume"].get("items", None),
527+
default_mode=storage["volume"].get("default_mode", None),
528+
optional=storage["volume"].get("optional", False),
529+
)
530+
)
531+
elif volume_type == "config_map":
532+
volume = client.V1Volume(
533+
name=volume_name,
534+
config_map=client.V1ConfigMapVolumeSource(
535+
name=storage["volume"].get("config_map_name"),
536+
items=storage["volume"].get("items", []),
537+
default_mode=storage["volume"].get("default_mode", None),
538+
optional=storage["volume"].get("optional", False),
539+
)
540+
)
541+
elif volume_type == "empty_dir":
542+
volume = client.V1Volume(
543+
name=volume_name,
544+
empty_dir=client.V1EmptyDirVolumeSource(
545+
medium=storage["volume"].get("medium", None),
546+
size_limit=storage["volume"].get("size_limit", None),
547+
)
548+
)
549+
else:
550+
raise ValueError("storage_per_trial['volume'] must be a client.V1Volume or a dict")
551+
552+
else:
553+
raise ValueError("storage_per_trial['volume'] must be a client.V1Volume or a dict")
554+
555+
volumes.append(volume)
556+
557+
if isinstance(storage["mount_path"], client.V1VolumeMount):
558+
volume_mounts.append(storage["mount_path"])
559+
elif isinstance(storage["mount_path"], str):
560+
volume_mounts.append(client.V1VolumeMount(name=volume_name, mount_path=storage["mount_path"]))
561+
else:
562+
raise ValueError("storage_per_trial['mount_path'] must be a client.V1VolumeMount or a str")
563+
471564
# Create Trial specification.
472565
trial_spec = client.V1Job(
473566
api_version="batch/v1",
@@ -488,8 +581,10 @@ class name in this argument.
488581
env=env if env else None,
489582
env_from=env_from if env_from else None,
490583
resources=resources_per_trial,
584+
volume_mounts=volume_mounts if volume_mounts else None,
491585
)
492586
],
587+
volumes=volumes if volumes else None,
493588
),
494589
)
495590
),
@@ -576,7 +671,7 @@ class name in this argument.
576671
f"It must also start and end with an alphanumeric character."
577672
)
578673
elif hasattr(e, "status") and e.status == 409:
579-
print(f"PVC '{name}' already exists in namespace " f"{namespace}.")
674+
print(f"PVC '{name}' already exists in namespace {namespace}.")
580675
else:
581676
raise RuntimeError(f"failed to create PVC. Error: {e}")
582677

0 commit comments

Comments
 (0)