Skip to content

Commit 222020d

Browse files
committed
feat(sdk): support volume mount in tune API
Signed-off-by: truc0 <22969604+truc0@users.noreply.github.com>
1 parent 4d2a230 commit 222020d

File tree

1 file changed

+125
-2
lines changed

1 file changed

+125
-2
lines changed

sdk/python/v1beta1/kubeflow/katib/api/katib_client.py

Lines changed: 125 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import logging
1717
import multiprocessing
1818
import time
19-
from typing import Any, Callable, Dict, List, Optional, Union
19+
from typing import Any, Callable, Dict, List, Optional, Union, TypedDict
2020

2121
import grpc
2222
import kubeflow.katib.katib_api_pb2 as katib_api_pb2
@@ -30,6 +30,14 @@
3030

3131
logger = logging.getLogger(__name__)
3232

33+
TuneStoragePerTrialType = TypedDict(
34+
"TuneStoragePerTrial",
35+
{
36+
"volume": Union[client.V1Volume, Dict[str, Any]],
37+
"mount_path": Union[str, client.V1VolumeMount],
38+
},
39+
)
40+
3341

3442
class KatibClient(object):
3543
def __init__(
@@ -186,6 +194,7 @@ def tune(
186194
env_per_trial: Optional[
187195
Union[Dict[str, str], List[Union[client.V1EnvVar, client.V1EnvFromSource]]]
188196
] = None,
197+
storage_per_trial: Optional[List[TuneStoragePerTrialType]] = None,
189198
algorithm_name: str = "random",
190199
algorithm_settings: Union[
191200
dict, List[models.V1beta1AlgorithmSetting], None
@@ -276,6 +285,21 @@ class name in this argument.
276285
https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1EnvVar.md)
277286
or a kubernetes.client.models.V1EnvFromSource (documented here:
278287
https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1EnvFromSource.md)
288+
storage_per_trial: List of storage configurations for each trial container.
289+
Each element in the list should be a dictionary with two keys:
290+
- volume: Either a kubernetes.client.V1Volume object or a dictionary
291+
containing volume configuration with required fields:
292+
- name: Name of the volume
293+
- type: One of "pvc", "secret", "config_map", or "empty_dir"
294+
Additional fields based on volume type:
295+
- For pvc: claim_name, read_only (optional)
296+
- For secret: secret_name, items (optional), default_mode (optional),
297+
optional (optional)
298+
- For config_map: config_map_name, items (optional), default_mode
299+
(optional), optional (optional)
300+
- For empty_dir: medium (optional), size_limit (optional)
301+
- mount_path: Either a kubernetes.client.V1VolumeMount object or a string
302+
specifying the path where the volume should be mounted in the container
279303
algorithm_name: Search algorithm for the HyperParameter tuning.
280304
algorithm_settings: Settings for the search algorithm given.
281305
For available fields, check this doc:
@@ -468,6 +492,101 @@ class name in this argument.
468492
f"Incorrect value for env_per_trial: {env_per_trial}"
469493
)
470494

495+
volumes: List[client.V1Volume] = []
496+
volume_mounts: List[client.V1VolumeMount] = []
497+
if storage_per_trial:
498+
if isinstance(storage_per_trial, dict):
499+
storage_per_trial = [storage_per_trial]
500+
for storage in storage_per_trial:
501+
print(f"storage: {storage}")
502+
volume = None
503+
if isinstance(storage["volume"], client.V1Volume):
504+
volume = storage["volume"]
505+
elif isinstance(storage["volume"], dict):
506+
volume_name = storage["volume"].get("name")
507+
volume_type = storage["volume"].get("type")
508+
509+
if not volume_name:
510+
raise ValueError(
511+
"storage_per_trial['volume'] does not have a 'name' key"
512+
)
513+
if not volume_type:
514+
raise ValueError(
515+
"storage_per_trial['volume'] does not have a 'type' key"
516+
)
517+
518+
if volume_type == "pvc":
519+
volume_claim_name = storage["volume"].get("claim_name")
520+
if not volume_claim_name:
521+
raise ValueError(
522+
"storage_per_trial['volume'] should have a 'claim_name' key for type pvc"
523+
)
524+
volume = client.V1Volume(
525+
name=volume_name,
526+
persistent_volume_claim=client.V1PersistentVolumeClaimVolumeSource(
527+
claim_name=volume_claim_name,
528+
read_only=storage["volume"].get("read_only", False),
529+
),
530+
)
531+
elif volume_type == "secret":
532+
volume = client.V1Volume(
533+
name=volume_name,
534+
secret=client.V1SecretVolumeSource(
535+
secret_name=storage["volume"].get("secret_name"),
536+
items=storage["volume"].get("items", None),
537+
default_mode=storage["volume"].get(
538+
"default_mode", None
539+
),
540+
optional=storage["volume"].get("optional", False),
541+
),
542+
)
543+
elif volume_type == "config_map":
544+
volume = client.V1Volume(
545+
name=volume_name,
546+
config_map=client.V1ConfigMapVolumeSource(
547+
name=storage["volume"].get("config_map_name"),
548+
items=storage["volume"].get("items", []),
549+
default_mode=storage["volume"].get(
550+
"default_mode", None
551+
),
552+
optional=storage["volume"].get("optional", False),
553+
),
554+
)
555+
elif volume_type == "empty_dir":
556+
volume = client.V1Volume(
557+
name=volume_name,
558+
empty_dir=client.V1EmptyDirVolumeSource(
559+
medium=storage["volume"].get("medium", None),
560+
size_limit=storage["volume"].get(
561+
"size_limit", None
562+
),
563+
),
564+
)
565+
else:
566+
raise ValueError(
567+
"storage_per_trial['volume'] must be a client.V1Volume or a dict"
568+
)
569+
570+
else:
571+
raise ValueError(
572+
"storage_per_trial['volume'] must be a client.V1Volume or a dict"
573+
)
574+
575+
volumes.append(volume)
576+
577+
if isinstance(storage["mount_path"], client.V1VolumeMount):
578+
volume_mounts.append(storage["mount_path"])
579+
elif isinstance(storage["mount_path"], str):
580+
volume_mounts.append(
581+
client.V1VolumeMount(
582+
name=volume_name, mount_path=storage["mount_path"]
583+
)
584+
)
585+
else:
586+
raise ValueError(
587+
"storage_per_trial['mount_path'] must be a client.V1VolumeMount or a str"
588+
)
589+
471590
# Create Trial specification.
472591
trial_spec = client.V1Job(
473592
api_version="batch/v1",
@@ -488,8 +607,12 @@ class name in this argument.
488607
env=env if env else None,
489608
env_from=env_from if env_from else None,
490609
resources=resources_per_trial,
610+
volume_mounts=(
611+
volume_mounts if volume_mounts else None
612+
),
491613
)
492614
],
615+
volumes=volumes if volumes else None,
493616
),
494617
)
495618
),
@@ -576,7 +699,7 @@ class name in this argument.
576699
f"It must also start and end with an alphanumeric character."
577700
)
578701
elif hasattr(e, "status") and e.status == 409:
579-
print(f"PVC '{name}' already exists in namespace " f"{namespace}.")
702+
print(f"PVC '{name}' already exists in namespace {namespace}.")
580703
else:
581704
raise RuntimeError(f"failed to create PVC. Error: {e}")
582705

0 commit comments

Comments
 (0)