1616import  logging 
1717import  multiprocessing 
1818import  time 
19- from  typing  import  Any , Callable , Dict , List , Optional , Union 
19+ from  typing  import  Any , Callable , Dict , List , Optional , Union ,  TypedDict 
2020
2121import  grpc 
2222import  kubeflow .katib .katib_api_pb2  as  katib_api_pb2 
3030
3131logger  =  logging .getLogger (__name__ )
3232
33+ TuneStoragePerTrialType  =  TypedDict (
34+     "TuneStoragePerTrial" ,
35+     {
36+         "volume" : Union [client .V1Volume , Dict [str , Any ]],
37+         "mount_path" : Union [str , client .V1VolumeMount ],
38+     },
39+ )
40+ 
3341
3442class  KatibClient (object ):
3543    def  __init__ (
@@ -186,6 +194,7 @@ def tune(
186194        env_per_trial : Optional [
187195            Union [Dict [str , str ], List [Union [client .V1EnvVar , client .V1EnvFromSource ]]]
188196        ] =  None ,
197+         storage_per_trial : Optional [List [TuneStoragePerTrialType ]] =  None ,
189198        algorithm_name : str  =  "random" ,
190199        algorithm_settings : Union [
191200            dict , List [models .V1beta1AlgorithmSetting ], None 
@@ -276,6 +285,21 @@ class name in this argument.
276285                https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1EnvVar.md) 
277286                or a kubernetes.client.models.V1EnvFromSource (documented here: 
278287                https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1EnvFromSource.md) 
288+             storage_per_trial: List of storage configurations for each trial container. 
289+                 Each element in the list should be a dictionary with two keys: 
290+                 - volume: Either a kubernetes.client.V1Volume object or a dictionary 
291+                   containing volume configuration with required fields: 
292+                   - name: Name of the volume 
293+                   - type: One of "pvc", "secret", "config_map", or "empty_dir" 
294+                   Additional fields based on volume type: 
295+                   - For pvc: claim_name, read_only (optional) 
296+                   - For secret: secret_name, items (optional), default_mode (optional), 
297+                     optional (optional) 
298+                   - For config_map: config_map_name, items (optional), default_mode 
299+                     (optional), optional (optional) 
300+                   - For empty_dir: medium (optional), size_limit (optional) 
301+                 - mount_path: Either a kubernetes.client.V1VolumeMount object or a string 
302+                   specifying the path where the volume should be mounted in the container 
279303            algorithm_name: Search algorithm for the HyperParameter tuning. 
280304            algorithm_settings: Settings for the search algorithm given. 
281305                For available fields, check this doc: 
@@ -468,6 +492,101 @@ class name in this argument.
468492                            f"Incorrect value for env_per_trial: { env_per_trial }  " 
469493                        )
470494
495+             volumes : List [client .V1Volume ] =  []
496+             volume_mounts : List [client .V1VolumeMount ] =  []
497+             if  storage_per_trial :
498+                 if  isinstance (storage_per_trial , dict ):
499+                     storage_per_trial  =  [storage_per_trial ]
500+                 for  storage  in  storage_per_trial :
501+                     print (f"storage: { storage }  " )
502+                     volume  =  None 
503+                     if  isinstance (storage ["volume" ], client .V1Volume ):
504+                         volume  =  storage ["volume" ]
505+                     elif  isinstance (storage ["volume" ], dict ):
506+                         volume_name  =  storage ["volume" ].get ("name" )
507+                         volume_type  =  storage ["volume" ].get ("type" )
508+ 
509+                         if  not  volume_name :
510+                             raise  ValueError (
511+                                 "storage_per_trial['volume'] does not have a 'name' key" 
512+                             )
513+                         if  not  volume_type :
514+                             raise  ValueError (
515+                                 "storage_per_trial['volume'] does not have a 'type' key" 
516+                             )
517+ 
518+                         if  volume_type  ==  "pvc" :
519+                             volume_claim_name  =  storage ["volume" ].get ("claim_name" )
520+                             if  not  volume_claim_name :
521+                                 raise  ValueError (
522+                                     "storage_per_trial['volume'] should have a 'claim_name' key for type pvc" 
523+                                 )
524+                             volume  =  client .V1Volume (
525+                                 name = volume_name ,
526+                                 persistent_volume_claim = client .V1PersistentVolumeClaimVolumeSource (
527+                                     claim_name = volume_claim_name ,
528+                                     read_only = storage ["volume" ].get ("read_only" , False ),
529+                                 ),
530+                             )
531+                         elif  volume_type  ==  "secret" :
532+                             volume  =  client .V1Volume (
533+                                 name = volume_name ,
534+                                 secret = client .V1SecretVolumeSource (
535+                                     secret_name = storage ["volume" ].get ("secret_name" ),
536+                                     items = storage ["volume" ].get ("items" , None ),
537+                                     default_mode = storage ["volume" ].get (
538+                                         "default_mode" , None 
539+                                     ),
540+                                     optional = storage ["volume" ].get ("optional" , False ),
541+                                 ),
542+                             )
543+                         elif  volume_type  ==  "config_map" :
544+                             volume  =  client .V1Volume (
545+                                 name = volume_name ,
546+                                 config_map = client .V1ConfigMapVolumeSource (
547+                                     name = storage ["volume" ].get ("config_map_name" ),
548+                                     items = storage ["volume" ].get ("items" , []),
549+                                     default_mode = storage ["volume" ].get (
550+                                         "default_mode" , None 
551+                                     ),
552+                                     optional = storage ["volume" ].get ("optional" , False ),
553+                                 ),
554+                             )
555+                         elif  volume_type  ==  "empty_dir" :
556+                             volume  =  client .V1Volume (
557+                                 name = volume_name ,
558+                                 empty_dir = client .V1EmptyDirVolumeSource (
559+                                     medium = storage ["volume" ].get ("medium" , None ),
560+                                     size_limit = storage ["volume" ].get (
561+                                         "size_limit" , None 
562+                                     ),
563+                                 ),
564+                             )
565+                         else :
566+                             raise  ValueError (
567+                                 "storage_per_trial['volume'] must be a client.V1Volume or a dict" 
568+                             )
569+ 
570+                     else :
571+                         raise  ValueError (
572+                             "storage_per_trial['volume'] must be a client.V1Volume or a dict" 
573+                         )
574+ 
575+                     volumes .append (volume )
576+ 
577+                     if  isinstance (storage ["mount_path" ], client .V1VolumeMount ):
578+                         volume_mounts .append (storage ["mount_path" ])
579+                     elif  isinstance (storage ["mount_path" ], str ):
580+                         volume_mounts .append (
581+                             client .V1VolumeMount (
582+                                 name = volume_name , mount_path = storage ["mount_path" ]
583+                             )
584+                         )
585+                     else :
586+                         raise  ValueError (
587+                             "storage_per_trial['mount_path'] must be a client.V1VolumeMount or a str" 
588+                         )
589+ 
471590            # Create Trial specification. 
472591            trial_spec  =  client .V1Job (
473592                api_version = "batch/v1" ,
@@ -488,8 +607,12 @@ class name in this argument.
488607                                    env = env  if  env  else  None ,
489608                                    env_from = env_from  if  env_from  else  None ,
490609                                    resources = resources_per_trial ,
610+                                     volume_mounts = (
611+                                         volume_mounts  if  volume_mounts  else  None 
612+                                     ),
491613                                )
492614                            ],
615+                             volumes = volumes  if  volumes  else  None ,
493616                        ),
494617                    )
495618                ),
@@ -576,7 +699,7 @@ class name in this argument.
576699                        f"It must also start and end with an alphanumeric character." 
577700                    )
578701                elif  hasattr (e , "status" ) and  e .status  ==  409 :
579-                     print (f"PVC '{ name }  ' already exists in namespace "    f" { namespace }  ." )
702+                     print (f"PVC '{ name }  ' already exists in namespace { namespace }  ." )
580703                else :
581704                    raise  RuntimeError (f"failed to create PVC. Error: { e }  " )
582705
0 commit comments