1616import logging
1717import multiprocessing
1818import time
19- from typing import Any , Callable , Dict , List , Optional , Union
19+ from typing import Any , Callable , Dict , List , Optional , Union , TypedDict
2020
2121import grpc
2222import kubeflow .katib .katib_api_pb2 as katib_api_pb2
3030
3131logger = logging .getLogger (__name__ )
3232
33+ TuneStoragePerTrialType = TypedDict (
34+ "TuneStoragePerTrial" ,
35+ {"volume" : client .V1Volume , "mount_path" : str },
36+ )
3337
3438class KatibClient (object ):
3539 def __init__ (
@@ -186,6 +190,7 @@ def tune(
186190 env_per_trial : Optional [
187191 Union [Dict [str , str ], List [Union [client .V1EnvVar , client .V1EnvFromSource ]]]
188192 ] = None ,
193+ storage_per_trial : Optional [Dict [str , TuneStoragePerTrialType ]] = None ,
189194 algorithm_name : str = "random" ,
190195 algorithm_settings : Union [
191196 dict , List [models .V1beta1AlgorithmSetting ], None
@@ -468,6 +473,19 @@ class name in this argument.
468473 f"Incorrect value for env_per_trial: { env_per_trial } "
469474 )
470475
476+ volumes : List [client .V1Volume ] = []
477+ volume_mounts : List [client .V1VolumeMount ] = []
478+ if storage_per_trial :
479+ for name , storage in storage_per_trial .items ():
480+ volumes .append (storage ["volume" ])
481+ volume_mounts .append (
482+ client .V1VolumeMount (name = name , mount_path = storage ["mount_path" ]),
483+ )
484+ print ('=' * 100 )
485+ print ("volumes" , volumes )
486+ print ("volume_mounts" , volume_mounts )
487+ print ('=' * 100 )
488+
471489 # Create Trial specification.
472490 trial_spec = client .V1Job (
473491 api_version = "batch/v1" ,
@@ -488,8 +506,10 @@ class name in this argument.
488506 env = env if env else None ,
489507 env_from = env_from if env_from else None ,
490508 resources = resources_per_trial ,
509+ volume_mounts = volume_mounts if volume_mounts else None ,
491510 )
492511 ],
512+ volumes = volumes if volumes else None ,
493513 ),
494514 )
495515 ),
@@ -576,7 +596,7 @@ class name in this argument.
576596 f"It must also start and end with an alphanumeric character."
577597 )
578598 elif hasattr (e , "status" ) and e .status == 409 :
579- print (f"PVC '{ name } ' already exists in namespace " f" { namespace } ." )
599+ print (f"PVC '{ name } ' already exists in namespace { namespace } ." )
580600 else :
581601 raise RuntimeError (f"failed to create PVC. Error: { e } " )
582602
0 commit comments