16
16
import logging
17
17
import multiprocessing
18
18
import time
19
- from typing import Any , Callable , Dict , List , Optional , Union
19
+ from typing import Any , Callable , Dict , List , Optional , Union , TypedDict
20
20
21
21
import grpc
22
22
import kubeflow .katib .katib_api_pb2 as katib_api_pb2
30
30
31
31
logger = logging .getLogger (__name__ )
32
32
33
+ TuneStoragePerTrialType = TypedDict (
34
+ "TuneStoragePerTrial" ,
35
+ {"volume" : Union [client .V1Volume , Dict [str , Any ]], "mount_path" : Union [str , client .V1VolumeMount ]},
36
+ )
33
37
34
38
class KatibClient (object ):
35
39
def __init__ (
@@ -186,6 +190,7 @@ def tune(
186
190
env_per_trial : Optional [
187
191
Union [Dict [str , str ], List [Union [client .V1EnvVar , client .V1EnvFromSource ]]]
188
192
] = None ,
193
+ storage_per_trial : Optional [List [TuneStoragePerTrialType ]] = None ,
189
194
algorithm_name : str = "random" ,
190
195
algorithm_settings : Union [
191
196
dict , List [models .V1beta1AlgorithmSetting ], None
@@ -276,6 +281,21 @@ class name in this argument.
276
281
https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1EnvVar.md)
277
282
or a kubernetes.client.models.V1EnvFromSource (documented here:
278
283
https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1EnvFromSource.md)
284
+ storage_per_trial: List of storage configurations for each trial container.
285
+ Each element in the list should be a dictionary with two keys:
286
+ - volume: Either a kubernetes.client.V1Volume object or a dictionary
287
+ containing volume configuration with required fields:
288
+ - name: Name of the volume
289
+ - type: One of "pvc", "secret", "config_map", or "empty_dir"
290
+ Additional fields based on volume type:
291
+ - For pvc: claim_name, read_only (optional)
292
+ - For secret: secret_name, items (optional), default_mode (optional),
293
+ optional (optional)
294
+ - For config_map: config_map_name, items (optional), default_mode
295
+ (optional), optional (optional)
296
+ - For empty_dir: medium (optional), size_limit (optional)
297
+ - mount_path: Either a kubernetes.client.V1VolumeMount object or a string
298
+ specifying the path where the volume should be mounted in the container
279
299
algorithm_name: Search algorithm for the HyperParameter tuning.
280
300
algorithm_settings: Settings for the search algorithm given.
281
301
For available fields, check this doc:
@@ -468,6 +488,79 @@ class name in this argument.
468
488
f"Incorrect value for env_per_trial: { env_per_trial } "
469
489
)
470
490
491
+ volumes : List [client .V1Volume ] = []
492
+ volume_mounts : List [client .V1VolumeMount ] = []
493
+ if storage_per_trial :
494
+ if isinstance (storage_per_trial , dict ):
495
+ storage_per_trial = [storage_per_trial ]
496
+ for storage in storage_per_trial :
497
+ print (f"storage: { storage } " )
498
+ volume = None
499
+ if isinstance (storage ["volume" ], client .V1Volume ):
500
+ volume = storage ["volume" ]
501
+ elif isinstance (storage ["volume" ], dict ):
502
+ volume_name = storage ["volume" ].get ("name" )
503
+ volume_type = storage ["volume" ].get ("type" )
504
+
505
+ if not volume_name :
506
+ raise ValueError ("storage_per_trial['volume'] does not have a 'name' key" )
507
+ if not volume_type :
508
+ raise ValueError ("storage_per_trial['volume'] does not have a 'type' key" )
509
+
510
+ if volume_type == "pvc" :
511
+ volume_claim_name = storage ["volume" ].get ("claim_name" )
512
+ if not volume_claim_name :
513
+ raise ValueError ("storage_per_trial['volume'] should have a 'claim_name' key for type pvc" )
514
+ volume = client .V1Volume (
515
+ name = volume_name ,
516
+ persistent_volume_claim = client .V1PersistentVolumeClaimVolumeSource (
517
+ claim_name = volume_claim_name ,
518
+ read_only = storage ["volume" ].get ("read_only" , False ),
519
+ )
520
+ )
521
+ elif volume_type == "secret" :
522
+ volume = client .V1Volume (
523
+ name = volume_name ,
524
+ secret = client .V1SecretVolumeSource (
525
+ secret_name = storage ["volume" ].get ("secret_name" ),
526
+ items = storage ["volume" ].get ("items" , None ),
527
+ default_mode = storage ["volume" ].get ("default_mode" , None ),
528
+ optional = storage ["volume" ].get ("optional" , False ),
529
+ )
530
+ )
531
+ elif volume_type == "config_map" :
532
+ volume = client .V1Volume (
533
+ name = volume_name ,
534
+ config_map = client .V1ConfigMapVolumeSource (
535
+ name = storage ["volume" ].get ("config_map_name" ),
536
+ items = storage ["volume" ].get ("items" , []),
537
+ default_mode = storage ["volume" ].get ("default_mode" , None ),
538
+ optional = storage ["volume" ].get ("optional" , False ),
539
+ )
540
+ )
541
+ elif volume_type == "empty_dir" :
542
+ volume = client .V1Volume (
543
+ name = volume_name ,
544
+ empty_dir = client .V1EmptyDirVolumeSource (
545
+ medium = storage ["volume" ].get ("medium" , None ),
546
+ size_limit = storage ["volume" ].get ("size_limit" , None ),
547
+ )
548
+ )
549
+ else :
550
+ raise ValueError ("storage_per_trial['volume'] must be a client.V1Volume or a dict" )
551
+
552
+ else :
553
+ raise ValueError ("storage_per_trial['volume'] must be a client.V1Volume or a dict" )
554
+
555
+ volumes .append (volume )
556
+
557
+ if isinstance (storage ["mount_path" ], client .V1VolumeMount ):
558
+ volume_mounts .append (storage ["mount_path" ])
559
+ elif isinstance (storage ["mount_path" ], str ):
560
+ volume_mounts .append (client .V1VolumeMount (name = volume_name , mount_path = storage ["mount_path" ]))
561
+ else :
562
+ raise ValueError ("storage_per_trial['mount_path'] must be a client.V1VolumeMount or a str" )
563
+
471
564
# Create Trial specification.
472
565
trial_spec = client .V1Job (
473
566
api_version = "batch/v1" ,
@@ -488,8 +581,10 @@ class name in this argument.
488
581
env = env if env else None ,
489
582
env_from = env_from if env_from else None ,
490
583
resources = resources_per_trial ,
584
+ volume_mounts = volume_mounts if volume_mounts else None ,
491
585
)
492
586
],
587
+ volumes = volumes if volumes else None ,
493
588
),
494
589
)
495
590
),
@@ -576,7 +671,7 @@ class name in this argument.
576
671
f"It must also start and end with an alphanumeric character."
577
672
)
578
673
elif hasattr (e , "status" ) and e .status == 409 :
579
- print (f"PVC '{ name } ' already exists in namespace " f" { namespace } ." )
674
+ print (f"PVC '{ name } ' already exists in namespace { namespace } ." )
580
675
else :
581
676
raise RuntimeError (f"failed to create PVC. Error: { e } " )
582
677
0 commit comments