Skip to content

Commit ec38767

Browse files
committed
feat(sdk): support volume mount in tune API
Signed-off-by: truc0 <22969604+truc0@users.noreply.github.com>
1 parent 4d2a230 commit ec38767

File tree

1 file changed

+127
-2
lines changed

1 file changed

+127
-2
lines changed

sdk/python/v1beta1/kubeflow/katib/api/katib_client.py

Lines changed: 127 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import logging
1717
import multiprocessing
1818
import time
19-
from typing import Any, Callable, Dict, List, Optional, Union
19+
from typing import Any, Callable, Dict, List, Optional, TypedDict, Union
2020

2121
import grpc
2222
import kubeflow.katib.katib_api_pb2 as katib_api_pb2
@@ -30,6 +30,14 @@
3030

3131
logger = logging.getLogger(__name__)
3232

33+
TuneStoragePerTrialType = TypedDict(
34+
"TuneStoragePerTrial",
35+
{
36+
"volume": Union[client.V1Volume, Dict[str, Any]],
37+
"mount_path": Union[str, client.V1VolumeMount],
38+
},
39+
)
40+
3341

3442
class KatibClient(object):
3543
def __init__(
@@ -186,6 +194,7 @@ def tune(
186194
env_per_trial: Optional[
187195
Union[Dict[str, str], List[Union[client.V1EnvVar, client.V1EnvFromSource]]]
188196
] = None,
197+
storage_per_trial: Optional[List[TuneStoragePerTrialType]] = None,
189198
algorithm_name: str = "random",
190199
algorithm_settings: Union[
191200
dict, List[models.V1beta1AlgorithmSetting], None
@@ -276,6 +285,21 @@ class name in this argument.
276285
https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1EnvVar.md)
277286
or a kubernetes.client.models.V1EnvFromSource (documented here:
278287
https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1EnvFromSource.md)
288+
storage_per_trial: List of storage configurations for each trial container.
289+
Each element in the list should be a dictionary with two keys:
290+
- volume: Either a kubernetes.client.V1Volume object or a dictionary
291+
containing volume configuration with required fields:
292+
- name: Name of the volume
293+
- type: One of "pvc", "secret", "config_map", or "empty_dir"
294+
Additional fields based on volume type:
295+
- For pvc: claim_name, read_only (optional)
296+
- For secret: secret_name, items (optional), default_mode (optional),
297+
optional (optional)
298+
- For config_map: config_map_name, items (optional), default_mode
299+
(optional), optional (optional)
300+
- For empty_dir: medium (optional), size_limit (optional)
301+
- mount_path: Either a kubernetes.client.V1VolumeMount object or a string
302+
specifying the path where the volume should be mounted in the container
279303
algorithm_name: Search algorithm for the HyperParameter tuning.
280304
algorithm_settings: Settings for the search algorithm given.
281305
For available fields, check this doc:
@@ -468,6 +492,103 @@ class name in this argument.
468492
f"Incorrect value for env_per_trial: {env_per_trial}"
469493
)
470494

495+
volumes: List[client.V1Volume] = []
496+
volume_mounts: List[client.V1VolumeMount] = []
497+
if storage_per_trial:
498+
if isinstance(storage_per_trial, dict):
499+
storage_per_trial = [storage_per_trial]
500+
for storage in storage_per_trial:
501+
print(f"storage: {storage}")
502+
volume = None
503+
if isinstance(storage["volume"], client.V1Volume):
504+
volume = storage["volume"]
505+
elif isinstance(storage["volume"], dict):
506+
volume_name = storage["volume"].get("name")
507+
volume_type = storage["volume"].get("type")
508+
509+
if not volume_name:
510+
raise ValueError(
511+
"storage_per_trial['volume'] does not have a 'name' key"
512+
)
513+
if not volume_type:
514+
raise ValueError(
515+
"storage_per_trial['volume'] does not have a 'type' key"
516+
)
517+
518+
if volume_type == "pvc":
519+
volume_claim_name = storage["volume"].get("claim_name")
520+
if not volume_claim_name:
521+
raise ValueError(
522+
"storage_per_trial['volume'] should have a "
523+
"'claim_name' key for type pvc"
524+
)
525+
volume = client.V1Volume(
526+
name=volume_name,
527+
persistent_volume_claim=client.V1PersistentVolumeClaimVolumeSource(
528+
claim_name=volume_claim_name,
529+
read_only=storage["volume"].get("read_only", False),
530+
),
531+
)
532+
elif volume_type == "secret":
533+
volume = client.V1Volume(
534+
name=volume_name,
535+
secret=client.V1SecretVolumeSource(
536+
secret_name=storage["volume"].get("secret_name"),
537+
items=storage["volume"].get("items", None),
538+
default_mode=storage["volume"].get(
539+
"default_mode", None
540+
),
541+
optional=storage["volume"].get("optional", False),
542+
),
543+
)
544+
elif volume_type == "config_map":
545+
volume = client.V1Volume(
546+
name=volume_name,
547+
config_map=client.V1ConfigMapVolumeSource(
548+
name=storage["volume"].get("config_map_name"),
549+
items=storage["volume"].get("items", []),
550+
default_mode=storage["volume"].get(
551+
"default_mode", None
552+
),
553+
optional=storage["volume"].get("optional", False),
554+
),
555+
)
556+
elif volume_type == "empty_dir":
557+
volume = client.V1Volume(
558+
name=volume_name,
559+
empty_dir=client.V1EmptyDirVolumeSource(
560+
medium=storage["volume"].get("medium", None),
561+
size_limit=storage["volume"].get(
562+
"size_limit", None
563+
),
564+
),
565+
)
566+
else:
567+
raise ValueError(
568+
"storage_per_trial['volume'] must be a client.V1Volume or a dict"
569+
)
570+
571+
else:
572+
raise ValueError(
573+
"storage_per_trial['volume'] must be a client.V1Volume or a dict"
574+
)
575+
576+
volumes.append(volume)
577+
578+
if isinstance(storage["mount_path"], client.V1VolumeMount):
579+
volume_mounts.append(storage["mount_path"])
580+
elif isinstance(storage["mount_path"], str):
581+
volume_mounts.append(
582+
client.V1VolumeMount(
583+
name=volume_name, mount_path=storage["mount_path"]
584+
)
585+
)
586+
else:
587+
raise ValueError(
588+
"storage_per_trial['mount_path'] must be a "
589+
"client.V1VolumeMount or a str"
590+
)
591+
471592
# Create Trial specification.
472593
trial_spec = client.V1Job(
473594
api_version="batch/v1",
@@ -488,8 +609,12 @@ class name in this argument.
488609
env=env if env else None,
489610
env_from=env_from if env_from else None,
490611
resources=resources_per_trial,
612+
volume_mounts=(
613+
volume_mounts if volume_mounts else None
614+
),
491615
)
492616
],
617+
volumes=volumes if volumes else None,
493618
),
494619
)
495620
),
@@ -576,7 +701,7 @@ class name in this argument.
576701
f"It must also start and end with an alphanumeric character."
577702
)
578703
elif hasattr(e, "status") and e.status == 409:
579-
print(f"PVC '{name}' already exists in namespace " f"{namespace}.")
704+
print(f"PVC '{name}' already exists in namespace {namespace}.")
580705
else:
581706
raise RuntimeError(f"failed to create PVC. Error: {e}")
582707

0 commit comments

Comments
 (0)