16
16
import logging
17
17
import multiprocessing
18
18
import time
19
- from typing import Any , Callable , Dict , List , Optional , Union
19
+ from typing import Any , Callable , Dict , List , Optional , Union , TypedDict
20
20
21
21
import grpc
22
22
import kubeflow .katib .katib_api_pb2 as katib_api_pb2
30
30
31
31
logger = logging .getLogger (__name__ )
32
32
33
+ TuneStoragePerTrialType = TypedDict (
34
+ "TuneStoragePerTrial" ,
35
+ {"volume" : client .V1Volume , "mount_path" : str },
36
+ )
33
37
34
38
class KatibClient (object ):
35
39
def __init__ (
@@ -186,6 +190,7 @@ def tune(
186
190
env_per_trial : Optional [
187
191
Union [Dict [str , str ], List [Union [client .V1EnvVar , client .V1EnvFromSource ]]]
188
192
] = None ,
193
+ storage_per_trial : Optional [Dict [str , TuneStoragePerTrialType ]] = None ,
189
194
algorithm_name : str = "random" ,
190
195
algorithm_settings : Union [
191
196
dict , List [models .V1beta1AlgorithmSetting ], None
@@ -468,6 +473,19 @@ class name in this argument.
468
473
f"Incorrect value for env_per_trial: { env_per_trial } "
469
474
)
470
475
476
+ volumes : List [client .V1Volume ] = []
477
+ volume_mounts : List [client .V1VolumeMount ] = []
478
+ if storage_per_trial :
479
+ for name , storage in storage_per_trial .items ():
480
+ volumes .append (storage ["volume" ])
481
+ volume_mounts .append (
482
+ client .V1VolumeMount (name = name , mount_path = storage ["mount_path" ]),
483
+ )
484
+ print ('=' * 100 )
485
+ print ("volumes" , volumes )
486
+ print ("volume_mounts" , volume_mounts )
487
+ print ('=' * 100 )
488
+
471
489
# Create Trial specification.
472
490
trial_spec = client .V1Job (
473
491
api_version = "batch/v1" ,
@@ -488,8 +506,10 @@ class name in this argument.
488
506
env = env if env else None ,
489
507
env_from = env_from if env_from else None ,
490
508
resources = resources_per_trial ,
509
+ volume_mounts = volume_mounts if volume_mounts else None ,
491
510
)
492
511
],
512
+ volumes = volumes if volumes else None ,
493
513
),
494
514
)
495
515
),
@@ -576,7 +596,7 @@ class name in this argument.
576
596
f"It must also start and end with an alphanumeric character."
577
597
)
578
598
elif hasattr (e , "status" ) and e .status == 409 :
579
- print (f"PVC '{ name } ' already exists in namespace " f" { namespace } ." )
599
+ print (f"PVC '{ name } ' already exists in namespace { namespace } ." )
580
600
else :
581
601
raise RuntimeError (f"failed to create PVC. Error: { e } " )
582
602
0 commit comments