Skip to content

Commit a2d2c5f

Browse files
committed
feat(sdk): support volume mount in tune API
1 parent 4d2a230 commit a2d2c5f

File tree

1 file changed

+22
-2
lines changed

1 file changed

+22
-2
lines changed

sdk/python/v1beta1/kubeflow/katib/api/katib_client.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import logging
1717
import multiprocessing
1818
import time
19-
from typing import Any, Callable, Dict, List, Optional, Union
19+
from typing import Any, Callable, Dict, List, Optional, Union, TypedDict
2020

2121
import grpc
2222
import kubeflow.katib.katib_api_pb2 as katib_api_pb2
@@ -30,6 +30,10 @@
3030

3131
logger = logging.getLogger(__name__)
3232

33+
TuneStoragePerTrialType = TypedDict(
34+
"TuneStoragePerTrial",
35+
{"volume": client.V1Volume, "mount_path": str},
36+
)
3337

3438
class KatibClient(object):
3539
def __init__(
@@ -186,6 +190,7 @@ def tune(
186190
env_per_trial: Optional[
187191
Union[Dict[str, str], List[Union[client.V1EnvVar, client.V1EnvFromSource]]]
188192
] = None,
193+
storage_per_trial: Optional[Dict[str, TuneStoragePerTrialType]] = None,
189194
algorithm_name: str = "random",
190195
algorithm_settings: Union[
191196
dict, List[models.V1beta1AlgorithmSetting], None
@@ -468,6 +473,19 @@ class name in this argument.
468473
f"Incorrect value for env_per_trial: {env_per_trial}"
469474
)
470475

476+
volumes: List[client.V1Volume] = []
477+
volume_mounts: List[client.V1VolumeMount] = []
478+
if storage_per_trial:
479+
for name, storage in storage_per_trial.items():
480+
volumes.append(storage["volume"])
481+
volume_mounts.append(
482+
client.V1VolumeMount(name=name, mount_path=storage["mount_path"]),
483+
)
484+
print('='*100)
485+
print("volumes", volumes)
486+
print("volume_mounts", volume_mounts)
487+
print('='*100)
488+
471489
# Create Trial specification.
472490
trial_spec = client.V1Job(
473491
api_version="batch/v1",
@@ -488,8 +506,10 @@ class name in this argument.
488506
env=env if env else None,
489507
env_from=env_from if env_from else None,
490508
resources=resources_per_trial,
509+
volume_mounts=volume_mounts if volume_mounts else None,
491510
)
492511
],
512+
volumes=volumes if volumes else None,
493513
),
494514
)
495515
),
@@ -576,7 +596,7 @@ class name in this argument.
576596
f"It must also start and end with an alphanumeric character."
577597
)
578598
elif hasattr(e, "status") and e.status == 409:
579-
print(f"PVC '{name}' already exists in namespace " f"{namespace}.")
599+
print(f"PVC '{name}' already exists in namespace {namespace}.")
580600
else:
581601
raise RuntimeError(f"failed to create PVC. Error: {e}")
582602

0 commit comments

Comments
 (0)