16
16
import logging
17
17
import multiprocessing
18
18
import time
19
- from typing import Any , Callable , Dict , List , Optional , Union
19
+ from typing import Any , Callable , Dict , List , Optional , TypedDict , Union
20
20
21
21
import grpc
22
22
import kubeflow .katib .katib_api_pb2 as katib_api_pb2
30
30
31
31
logger = logging .getLogger (__name__ )
32
32
33
+ TuneStoragePerTrialType = TypedDict (
34
+ "TuneStoragePerTrial" ,
35
+ {
36
+ "volume" : Union [client .V1Volume , Dict [str , Any ]],
37
+ "mount_path" : Union [str , client .V1VolumeMount ],
38
+ },
39
+ )
40
+
33
41
34
42
class KatibClient (object ):
35
43
def __init__ (
@@ -186,6 +194,7 @@ def tune(
186
194
env_per_trial : Optional [
187
195
Union [Dict [str , str ], List [Union [client .V1EnvVar , client .V1EnvFromSource ]]]
188
196
] = None ,
197
+ storage_per_trial : Optional [List [TuneStoragePerTrialType ]] = None ,
189
198
algorithm_name : str = "random" ,
190
199
algorithm_settings : Union [
191
200
dict , List [models .V1beta1AlgorithmSetting ], None
@@ -276,6 +285,21 @@ class name in this argument.
276
285
https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1EnvVar.md)
277
286
or a kubernetes.client.models.V1EnvFromSource (documented here:
278
287
https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1EnvFromSource.md)
288
+ storage_per_trial: List of storage configurations for each trial container.
289
+ Each element in the list should be a dictionary with two keys:
290
+ - volume: Either a kubernetes.client.V1Volume object or a dictionary
291
+ containing volume configuration with required fields:
292
+ - name: Name of the volume
293
+ - type: One of "pvc", "secret", "config_map", or "empty_dir"
294
+ Additional fields based on volume type:
295
+ - For pvc: claim_name, read_only (optional)
296
+ - For secret: secret_name, items (optional), default_mode (optional),
297
+ optional (optional)
298
+ - For config_map: config_map_name, items (optional), default_mode
299
+ (optional), optional (optional)
300
+ - For empty_dir: medium (optional), size_limit (optional)
301
+ - mount_path: Either a kubernetes.client.V1VolumeMount object or a string
302
+ specifying the path where the volume should be mounted in the container
279
303
algorithm_name: Search algorithm for the HyperParameter tuning.
280
304
algorithm_settings: Settings for the search algorithm given.
281
305
For available fields, check this doc:
@@ -468,6 +492,103 @@ class name in this argument.
468
492
f"Incorrect value for env_per_trial: { env_per_trial } "
469
493
)
470
494
495
+ volumes : List [client .V1Volume ] = []
496
+ volume_mounts : List [client .V1VolumeMount ] = []
497
+ if storage_per_trial :
498
+ if isinstance (storage_per_trial , dict ):
499
+ storage_per_trial = [storage_per_trial ]
500
+ for storage in storage_per_trial :
501
+ print (f"storage: { storage } " )
502
+ volume = None
503
+ if isinstance (storage ["volume" ], client .V1Volume ):
504
+ volume = storage ["volume" ]
505
+ elif isinstance (storage ["volume" ], dict ):
506
+ volume_name = storage ["volume" ].get ("name" )
507
+ volume_type = storage ["volume" ].get ("type" )
508
+
509
+ if not volume_name :
510
+ raise ValueError (
511
+ "storage_per_trial['volume'] does not have a 'name' key"
512
+ )
513
+ if not volume_type :
514
+ raise ValueError (
515
+ "storage_per_trial['volume'] does not have a 'type' key"
516
+ )
517
+
518
+ if volume_type == "pvc" :
519
+ volume_claim_name = storage ["volume" ].get ("claim_name" )
520
+ if not volume_claim_name :
521
+ raise ValueError (
522
+ "storage_per_trial['volume'] should have a "
523
+ "'claim_name' key for type pvc"
524
+ )
525
+ volume = client .V1Volume (
526
+ name = volume_name ,
527
+ persistent_volume_claim = client .V1PersistentVolumeClaimVolumeSource (
528
+ claim_name = volume_claim_name ,
529
+ read_only = storage ["volume" ].get ("read_only" , False ),
530
+ ),
531
+ )
532
+ elif volume_type == "secret" :
533
+ volume = client .V1Volume (
534
+ name = volume_name ,
535
+ secret = client .V1SecretVolumeSource (
536
+ secret_name = storage ["volume" ].get ("secret_name" ),
537
+ items = storage ["volume" ].get ("items" , None ),
538
+ default_mode = storage ["volume" ].get (
539
+ "default_mode" , None
540
+ ),
541
+ optional = storage ["volume" ].get ("optional" , False ),
542
+ ),
543
+ )
544
+ elif volume_type == "config_map" :
545
+ volume = client .V1Volume (
546
+ name = volume_name ,
547
+ config_map = client .V1ConfigMapVolumeSource (
548
+ name = storage ["volume" ].get ("config_map_name" ),
549
+ items = storage ["volume" ].get ("items" , []),
550
+ default_mode = storage ["volume" ].get (
551
+ "default_mode" , None
552
+ ),
553
+ optional = storage ["volume" ].get ("optional" , False ),
554
+ ),
555
+ )
556
+ elif volume_type == "empty_dir" :
557
+ volume = client .V1Volume (
558
+ name = volume_name ,
559
+ empty_dir = client .V1EmptyDirVolumeSource (
560
+ medium = storage ["volume" ].get ("medium" , None ),
561
+ size_limit = storage ["volume" ].get (
562
+ "size_limit" , None
563
+ ),
564
+ ),
565
+ )
566
+ else :
567
+ raise ValueError (
568
+ "storage_per_trial['volume'] must be a client.V1Volume or a dict"
569
+ )
570
+
571
+ else :
572
+ raise ValueError (
573
+ "storage_per_trial['volume'] must be a client.V1Volume or a dict"
574
+ )
575
+
576
+ volumes .append (volume )
577
+
578
+ if isinstance (storage ["mount_path" ], client .V1VolumeMount ):
579
+ volume_mounts .append (storage ["mount_path" ])
580
+ elif isinstance (storage ["mount_path" ], str ):
581
+ volume_mounts .append (
582
+ client .V1VolumeMount (
583
+ name = volume_name , mount_path = storage ["mount_path" ]
584
+ )
585
+ )
586
+ else :
587
+ raise ValueError (
588
+ "storage_per_trial['mount_path'] must be a "
589
+ "client.V1VolumeMount or a str"
590
+ )
591
+
471
592
# Create Trial specification.
472
593
trial_spec = client .V1Job (
473
594
api_version = "batch/v1" ,
@@ -488,8 +609,12 @@ class name in this argument.
488
609
env = env if env else None ,
489
610
env_from = env_from if env_from else None ,
490
611
resources = resources_per_trial ,
612
+ volume_mounts = (
613
+ volume_mounts if volume_mounts else None
614
+ ),
491
615
)
492
616
],
617
+ volumes = volumes if volumes else None ,
493
618
),
494
619
)
495
620
),
@@ -576,7 +701,7 @@ class name in this argument.
576
701
f"It must also start and end with an alphanumeric character."
577
702
)
578
703
elif hasattr (e , "status" ) and e .status == 409 :
579
- print (f"PVC '{ name } ' already exists in namespace " f" { namespace } ." )
704
+ print (f"PVC '{ name } ' already exists in namespace { namespace } ." )
580
705
else :
581
706
raise RuntimeError (f"failed to create PVC. Error: { e } " )
582
707
0 commit comments