@@ -69,16 +69,23 @@ class K8sIntegration(Worker):
69
69
'echo "ldconfig" >> /etc/profile' ,
70
70
"/usr/sbin/sshd -p {port}" ]
71
71
72
- CONTAINER_BASH_SCRIPT = [
72
+ _CONTAINER_APT_SCRIPT_SECTION = [
73
73
"export DEBIAN_FRONTEND='noninteractive'" ,
74
74
"echo 'Binary::apt::APT::Keep-Downloaded-Packages \" true\" ;' > /etc/apt/apt.conf.d/docker-clean" ,
75
75
"chown -R root /root/.cache/pip" ,
76
76
"apt-get update" ,
77
77
"apt-get install -y git libsm6 libxext6 libxrender-dev libglib2.0-0" ,
78
+ ]
79
+
80
+ CONTAINER_BASH_SCRIPT = [
81
+ * (
82
+ '[ ! -z "$CLEARML_AGENT_SKIP_CONTAINER_APT" ] || {}' .format (line )
83
+ for line in _CONTAINER_APT_SCRIPT_SECTION
84
+ ),
78
85
"declare LOCAL_PYTHON" ,
79
86
"[ ! -z $LOCAL_PYTHON ] || for i in {{15..5}}; do which python3.$i && python3.$i -m pip --version && "
80
87
"export LOCAL_PYTHON=$(which python3.$i) && break ; done" ,
81
- " [ ! -z $LOCAL_PYTHON ] || apt-get install -y python3-pip" ,
88
+ '[ ! -z "$CLEARML_AGENT_SKIP_CONTAINER_APT" ] || [ ! -z " $LOCAL_PYTHON" ] || apt-get install -y python3-pip' ,
82
89
"[ ! -z $LOCAL_PYTHON ] || export LOCAL_PYTHON=python3" ,
83
90
"{extra_bash_init_cmd}" ,
84
91
"[ ! -z $CLEARML_AGENT_NO_UPDATE ] || $LOCAL_PYTHON -m pip install clearml-agent{agent_install_args}" ,
@@ -100,6 +107,7 @@ def __init__(
100
107
num_of_services = 20 ,
101
108
base_pod_num = 1 ,
102
109
user_props_cb = None ,
110
+ runtime_cb = None ,
103
111
overrides_yaml = None ,
104
112
template_yaml = None ,
105
113
clearml_conf_file = None ,
@@ -127,6 +135,7 @@ def __init__(
127
135
:param callable user_props_cb: An Optional callable allowing additional user properties to be specified
128
136
when scheduling a task to run in a pod. Callable can receive an optional pod number and should return
129
137
a dictionary of user properties (name and value). Signature is [[Optional[int]], Dict[str,str]]
138
+ :param callable runtime_cb: An Optional callable allowing additional task runtime to be specified (see user_props_cb)
130
139
:param str overrides_yaml: YAML file containing the overrides for the pod (optional)
131
140
:param str template_yaml: YAML file containing the template for the pod (optional).
132
141
If provided the pod is scheduled with kubectl apply and overrides are ignored, otherwise with kubectl run.
@@ -161,6 +170,7 @@ def __init__(
161
170
self .base_pod_num = base_pod_num
162
171
self ._edit_hyperparams_support = None
163
172
self ._user_props_cb = user_props_cb
173
+ self ._runtime_cb = runtime_cb
164
174
self .conf_file_content = None
165
175
self .overrides_json_string = None
166
176
self .template_dict = None
@@ -198,6 +208,10 @@ def __init__(
198
208
self ._session .feature_set != "basic" and self ._session .check_min_server_version ("3.22.3" )
199
209
)
200
210
211
+ @property
212
+ def agent_label (self ):
213
+ return self ._get_agent_label ()
214
+
201
215
def _create_daemon_instance (self , cls_ , ** kwargs ):
202
216
return cls_ (agent = self , ** kwargs )
203
217
@@ -430,6 +444,9 @@ def resource_applied(self, resource_name: str, namespace: str, task_id: str, ses
430
444
""" Called when a resource (pod/job) was applied """
431
445
pass
432
446
447
+ def ports_mode_supported_for_task (self , task_id : str , task_data ):
448
+ return self .ports_mode
449
+
433
450
def run_one_task (self , queue : Text , task_id : Text , worker_args = None , task_session = None , ** _ ):
434
451
print ('Pulling task {} launching on kubernetes cluster' .format (task_id ))
435
452
session = task_session or self ._session
@@ -501,8 +518,10 @@ def run_one_task(self, queue: Text, task_id: Text, worker_args=None, task_sessio
501
518
)
502
519
)
503
520
504
- if self .ports_mode :
521
+ ports_mode = False
522
+ if self .ports_mode_supported_for_task (task_id , task_data ):
505
523
print ("Kubernetes looking for available pod to use" )
524
+ ports_mode = True
506
525
507
526
# noinspection PyBroadException
508
527
try :
@@ -513,12 +532,12 @@ def run_one_task(self, queue: Text, task_id: Text, worker_args=None, task_sessio
513
532
# Search for a free pod number
514
533
pod_count = 0
515
534
pod_number = self .base_pod_num
516
- while self . ports_mode or self .max_pods_limit :
535
+ while ports_mode or self .max_pods_limit :
517
536
pod_number = self .base_pod_num + pod_count
518
537
519
538
try :
520
539
items_count = self ._get_pod_count (
521
- extra_labels = [self .limit_pod_label .format (pod_number = pod_number )] if self . ports_mode else None ,
540
+ extra_labels = [self .limit_pod_label .format (pod_number = pod_number )] if ports_mode else None ,
522
541
msg = "Looking for a free pod/port"
523
542
)
524
543
except GetPodCountError :
@@ -568,11 +587,11 @@ def run_one_task(self, queue: Text, task_id: Text, worker_args=None, task_sessio
568
587
break
569
588
pod_count += 1
570
589
571
- labels = self ._get_pod_labels (queue , queue_name )
572
- if self . ports_mode :
590
+ labels = self ._get_pod_labels (queue , queue_name , task_data )
591
+ if ports_mode :
573
592
labels .append (self .limit_pod_label .format (pod_number = pod_number ))
574
593
575
- if self . ports_mode :
594
+ if ports_mode :
576
595
print ("Kubernetes scheduling task id={} on pod={} (pod_count={})" .format (task_id , pod_number , pod_count ))
577
596
else :
578
597
print ("Kubernetes scheduling task id={}" .format (task_id ))
@@ -611,40 +630,95 @@ def run_one_task(self, queue: Text, task_id: Text, worker_args=None, task_sessio
611
630
send_log = "Running kubectl encountered an error: {}" .format (error )
612
631
self .log .error (send_log )
613
632
self .send_logs (task_id , send_log .splitlines ())
633
+
634
+ # Make sure to remove the task from our k8s pending queue
635
+ self ._session .api_client .queues .remove_task (
636
+ task = task_id ,
637
+ queue = self .k8s_pending_queue_id ,
638
+ )
639
+ # Set task as failed
640
+ session .api_client .tasks .failed (task_id , force = True )
614
641
return
615
642
616
643
if pod_name :
617
644
self .resource_applied (
618
645
resource_name = pod_name , namespace = namespace , task_id = task_id , session = session
619
646
)
620
647
648
+ self .set_task_info (
649
+ task_id = task_id , task_session = task_session , queue_name = queue_name , ports_mode = ports_mode ,
650
+ pod_number = pod_number , pod_count = pod_count , task_data = task_data
651
+ )
652
+
653
+ def set_task_info (
654
+ self , task_id : str , task_session , task_data , queue_name : str , ports_mode : bool , pod_number , pod_count
655
+ ):
621
656
user_props = {"k8s-queue" : str (queue_name )}
622
- if self . ports_mode :
623
- user_props . update (
624
- {
625
- "k8s-pod-number" : pod_number ,
626
- "k8s-pod-label " : labels [ 0 ] ,
627
- "k8s-internal- pod-count " : pod_count ,
628
- "k8s-agent " : self . _get_agent_label () ,
629
- }
630
- )
657
+ runtime = {}
658
+ if ports_mode :
659
+ agent_label = self . _get_agent_label ()
660
+ user_props . update ({
661
+ "k8s-pod-number " : pod_number ,
662
+ "k8s-pod-label " : agent_label , # backwards-compatibility / legacy
663
+ "k8s-internal-pod-count " : pod_count ,
664
+ "k8s-agent" : agent_label ,
665
+ } )
631
666
632
667
if self ._user_props_cb :
633
668
# noinspection PyBroadException
634
669
try :
635
- custom_props = self ._user_props_cb (pod_number ) if self . ports_mode else self ._user_props_cb ()
670
+ custom_props = self ._user_props_cb (pod_number ) if ports_mode else self ._user_props_cb ()
636
671
user_props .update (custom_props )
637
672
except Exception :
638
673
pass
639
674
675
+ if self ._runtime_cb :
676
+ # noinspection PyBroadException
677
+ try :
678
+ custom_runtime = self ._runtime_cb (pod_number ) if ports_mode else self ._runtime_cb ()
679
+ runtime .update (custom_runtime )
680
+ except Exception :
681
+ pass
682
+
640
683
if user_props :
641
684
self ._set_task_user_properties (
642
685
task_id = task_id ,
643
686
task_session = task_session ,
644
687
** user_props
645
688
)
646
689
647
- def _get_pod_labels (self , queue , queue_name ):
690
+ if runtime :
691
+ task_runtime = self ._get_task_runtime (task_id ) or {}
692
+ task_runtime .update (runtime )
693
+
694
+ try :
695
+ res = task_session .send_request (
696
+ service = 'tasks' , action = 'edit' , method = Request .def_method ,
697
+ json = {
698
+ "task" : task_id , "force" : True , "runtime" : task_runtime
699
+ },
700
+ )
701
+ if not res .ok :
702
+ raise Exception ("failed setting runtime property" )
703
+ except Exception as ex :
704
+ print ("WARNING: failed setting custom runtime properties for task '{}': {}" .format (task_id , ex ))
705
+
706
+ def _get_task_runtime (self , task_id ) -> Optional [dict ]:
707
+ try :
708
+ res = self ._session .send_request (
709
+ service = 'tasks' , action = 'get_by_id' , method = Request .def_method ,
710
+ json = {"task" : task_id , "only_fields" : ["runtime" ]},
711
+ )
712
+ if not res .ok :
713
+ raise ValueError (f"request returned { res .status_code } " )
714
+ data = res .json ().get ("data" )
715
+ if not data or "task" not in data :
716
+ raise ValueError ("empty data in result" )
717
+ return data ["task" ].get ("runtime" , {})
718
+ except Exception as ex :
719
+ print (f"ERROR: Failed getting runtime properties for task { task_id } : { ex } " )
720
+
721
+ def _get_pod_labels (self , queue , queue_name , task_data ):
648
722
return [
649
723
self ._get_agent_label (),
650
724
"{}={}" .format (self .QUEUE_LABEL , self ._safe_k8s_label_value (queue )),
@@ -1012,6 +1086,9 @@ def _cleanup_old_pods(self, namespaces, extra_msg=None):
1012
1086
1013
1087
return deleted_pods
1014
1088
1089
+ def check_if_suspended (self ) -> bool :
1090
+ pass
1091
+
1015
1092
def run_tasks_loop (self , queues : List [Text ], worker_params , ** kwargs ):
1016
1093
"""
1017
1094
:summary: Pull and run tasks from queues.
@@ -1061,6 +1138,11 @@ def run_tasks_loop(self, queues: List[Text], worker_params, **kwargs):
1061
1138
# delete old completed / failed pods
1062
1139
self ._cleanup_old_pods (namespaces , extra_msg = "Cleanup cycle {cmd}" )
1063
1140
1141
+ if self .check_if_suspended ():
1142
+ print ("Agent is suspended, sleeping for {:.1f} seconds" .format (self ._polling_interval ))
1143
+ sleep (self ._polling_interval )
1144
+ break
1145
+
1064
1146
# get next task in queue
1065
1147
try :
1066
1148
# print(f"debug> getting tasks for queue {queue}")
0 commit comments