From 1d906c42f0d6e48ac0b3cc4b7d74741cc64658d2 Mon Sep 17 00:00:00 2001 From: Sakari Ikonen Date: Tue, 15 Apr 2025 12:46:44 +0300 Subject: [PATCH 01/15] add only_on attribute to retry deco --- metaflow/plugins/retry_decorator.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/metaflow/plugins/retry_decorator.py b/metaflow/plugins/retry_decorator.py index 00d40b3ac0c..4a56567ed57 100644 --- a/metaflow/plugins/retry_decorator.py +++ b/metaflow/plugins/retry_decorator.py @@ -2,6 +2,8 @@ from metaflow.exception import MetaflowException from metaflow.metaflow_config import MAX_ATTEMPTS +SUPPORTED_RETRY_EVENTS = ["all", "spot-termination"] + class RetryDecorator(StepDecorator): """ @@ -22,10 +24,13 @@ class RetryDecorator(StepDecorator): Number of times to retry this task. minutes_between_retries : int, default 2 Number of minutes between retries. + only_on : List[str], default None + List of failure events to retry on. Accepted values are + 'all', 'spot-termination' """ name = "retry" - defaults = {"times": "3", "minutes_between_retries": "2"} + defaults = {"times": "3", "minutes_between_retries": "2", "only_on": None} def step_init(self, flow, graph, step, decos, environment, flow_datastore, logger): # The total number of attempts must not exceed MAX_ATTEMPTS. @@ -36,5 +41,20 @@ def step_init(self, flow, graph, step, decos, environment, flow_datastore, logge "@retry(times=%d)." % (MAX_ATTEMPTS - 2) ) + if self.attributes["only_on"] is not None: + if not isinstance(self.attributes["only_on"], list): + raise MetaflowException("'only_on=' must be a list of values") + + unsupported_events = [ + event + for event in self.attributes["only_on"] + if event not in SUPPORTED_RETRY_EVENTS + ] + if unsupported_events: + raise MetaflowException( + "The event(s) %s are not supported for only_on=" + % ", ".join("*%s*" % event for event in unsupported_events) + ) + def step_task_retry_count(self): return int(self.attributes["times"]), 0 From 87f095acb46082a83afbf882c6841c051b4f739f Mon Sep 17 00:00:00 2001 From: Sakari Ikonen Date: Wed, 23 Apr 2025 00:33:19 +0300 Subject: [PATCH 02/15] wip: try out metadata based retry conditions for spot termination --- metaflow/plugins/kubernetes/kubernetes_cli.py | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/metaflow/plugins/kubernetes/kubernetes_cli.py b/metaflow/plugins/kubernetes/kubernetes_cli.py index 4645e25c34f..f438326f2a7 100644 --- a/metaflow/plugins/kubernetes/kubernetes_cli.py +++ b/metaflow/plugins/kubernetes/kubernetes_cli.py @@ -222,6 +222,28 @@ def echo(msg, stream="stderr", job_id=None, **kwargs): retry_deco[0].attributes.get("minutes_between_retries", 2) ) if retry_count: + retry_conditions = retry_deco[0].attributes["only_on"] + if retry_conditions: + print("retrying only on: %s" % retry_conditions) + # check if last failure reason matches the retry condition + # init the datastore for the previous known attempt so we can read metadata + previous_attempt_ds = ctx.obj.flow_datastore.get_task_datastore( + mode="r", + run_id=kwargs["run_id"], + step_name=step_name, + task_id=kwargs["task_id"], + attempt=int(retry_count) - 1, + ) + spot_termination = previous_attempt_ds.has_metadata( + "spot-termination-received-at" + ) + print("has spot termination: %s" % spot_termination) + if "spot-termination" in retry_conditions and not spot_termination: + ctx.obj.echo_always( + "Task failed due to an error that will not be retried." + ) + sys.exit(METAFLOW_EXIT_DISALLOW_RETRY) + ctx.obj.echo_always( "Sleeping %d minutes before the next retry" % minutes_between_retries ) From 24d6350d2b9daaa9f13eae38a76b8cb6b92f450a Mon Sep 17 00:00:00 2001 From: Sakari Ikonen Date: Wed, 23 Apr 2025 17:30:41 +0300 Subject: [PATCH 03/15] wip: handle spot termination via exit code in kubernetes --- metaflow/plugins/kubernetes/kubernetes.py | 7 ++++ metaflow/plugins/kubernetes/kubernetes_cli.py | 36 +++++++------------ 2 files changed, 20 insertions(+), 23 deletions(-) diff --git a/metaflow/plugins/kubernetes/kubernetes.py b/metaflow/plugins/kubernetes/kubernetes.py index 6625047395a..acfeeb83e98 100644 --- a/metaflow/plugins/kubernetes/kubernetes.py +++ b/metaflow/plugins/kubernetes/kubernetes.py @@ -72,6 +72,10 @@ class KubernetesKilledException(MetaflowException): headline = "Kubernetes Batch job killed" +class KubernetesSpotInstanceTerminated(MetaflowException): + headline = "Kubernetes node spot instance has been terminated" + + class Kubernetes(object): def __init__( self, @@ -764,6 +768,9 @@ def _has_updates(): ) if int(exit_code) == 134: raise KubernetesException("%s (exit code %s)" % (msg, exit_code)) + if int(exit_code) == 154: + # NOTE. K8S exit codes are mod 256 + raise KubernetesSpotInstanceTerminated() else: msg = "%s (exit code %s)" % (msg, exit_code) raise KubernetesException( diff --git a/metaflow/plugins/kubernetes/kubernetes_cli.py b/metaflow/plugins/kubernetes/kubernetes_cli.py index f438326f2a7..81bfb9958b2 100644 --- a/metaflow/plugins/kubernetes/kubernetes_cli.py +++ b/metaflow/plugins/kubernetes/kubernetes_cli.py @@ -11,7 +11,11 @@ import metaflow.tracing as tracing from metaflow import JSONTypeClass, util from metaflow._vendor import click -from metaflow.exception import METAFLOW_EXIT_DISALLOW_RETRY, MetaflowException +from metaflow.exception import ( + METAFLOW_EXIT_DISALLOW_RETRY, + METAFLOW_EXIT_ALLOW_RETRY, + MetaflowException, +) from metaflow.metadata_provider.util import sync_local_metadata_from_datastore from metaflow.metaflow_config import DATASTORE_LOCAL_DIR from metaflow.mflog import TASK_LOG_SOURCE @@ -21,6 +25,7 @@ Kubernetes, KubernetesException, KubernetesKilledException, + KubernetesSpotInstanceTerminated, ) @@ -221,29 +226,8 @@ def echo(msg, stream="stderr", job_id=None, **kwargs): minutes_between_retries = int( retry_deco[0].attributes.get("minutes_between_retries", 2) ) + retry_conditions = retry_deco[0].attributes["only_on"] if retry_deco else None if retry_count: - retry_conditions = retry_deco[0].attributes["only_on"] - if retry_conditions: - print("retrying only on: %s" % retry_conditions) - # check if last failure reason matches the retry condition - # init the datastore for the previous known attempt so we can read metadata - previous_attempt_ds = ctx.obj.flow_datastore.get_task_datastore( - mode="r", - run_id=kwargs["run_id"], - step_name=step_name, - task_id=kwargs["task_id"], - attempt=int(retry_count) - 1, - ) - spot_termination = previous_attempt_ds.has_metadata( - "spot-termination-received-at" - ) - print("has spot termination: %s" % spot_termination) - if "spot-termination" in retry_conditions and not spot_termination: - ctx.obj.echo_always( - "Task failed due to an error that will not be retried." - ) - sys.exit(METAFLOW_EXIT_DISALLOW_RETRY) - ctx.obj.echo_always( "Sleeping %d minutes before the next retry" % minutes_between_retries ) @@ -352,6 +336,12 @@ def _sync_metadata(): # don't retry killed tasks traceback.print_exc() sys.exit(METAFLOW_EXIT_DISALLOW_RETRY) + except KubernetesSpotInstanceTerminated: + traceback.print_exc() + if retry_conditions is not None and "spot-termination" in retry_conditions: + sys.exit(METAFLOW_EXIT_ALLOW_RETRY) + else: + sys.exit(METAFLOW_EXIT_DISALLOW_RETRY) finally: _sync_metadata() From eb4840ddcc14a08f90c22c914f011eca34310d87 Mon Sep 17 00:00:00 2001 From: Sakari Ikonen Date: Thu, 24 Apr 2025 01:16:47 +0300 Subject: [PATCH 04/15] rework spot_monitor_sidecar. add main process PID to env. add signal handlers to retry deco --- .../kubernetes/kubernetes_decorator.py | 2 + .../kubernetes/spot_monitor_sidecar.py | 42 +++++++++++++++---- metaflow/plugins/retry_decorator.py | 42 +++++++++++++++++++ 3 files changed, 77 insertions(+), 9 deletions(-) diff --git a/metaflow/plugins/kubernetes/kubernetes_decorator.py b/metaflow/plugins/kubernetes/kubernetes_decorator.py index 697d5c055b1..58c65f4fd22 100644 --- a/metaflow/plugins/kubernetes/kubernetes_decorator.py +++ b/metaflow/plugins/kubernetes/kubernetes_decorator.py @@ -549,6 +549,8 @@ def task_pre_step( self._save_logs_sidecar.start() # Start spot termination monitor sidecar. + # TODO: A nicer way to pass the main process id to a Sidecar, in order to allow sidecars to send signals back to the main process. + os.environ["MF_MAIN_PID"] = str(os.getpid()) current._update_env( {"spot_termination_notice": "/tmp/spot_termination_notice"} ) diff --git a/metaflow/plugins/kubernetes/spot_monitor_sidecar.py b/metaflow/plugins/kubernetes/spot_monitor_sidecar.py index 59f821f885e..1acb3dc7310 100644 --- a/metaflow/plugins/kubernetes/spot_monitor_sidecar.py +++ b/metaflow/plugins/kubernetes/spot_monitor_sidecar.py @@ -20,6 +20,7 @@ def __init__(self): self._process = None self._token = None self._token_expiry = 0 + self.termination_time = None if self._is_aws_spot_instance(): self._process = Process(target=self._monitor_loop) @@ -64,17 +65,40 @@ def _is_aws_spot_instance(self): except (requests.exceptions.RequestException, requests.exceptions.Timeout): return False + def _try_sending_termination_signal(self): + # wait for 20 seconds before the promised termination time of the spot instance before sending a SIGALRM + if (datetime.now() + self.POLL_INTERVAL - 20) < self.termination_time: + return False + else: + os.kill(int(os.getenv("MF_MAIN_PID")), signal.SIGALRM) + return True + + def _monitor_spot_termination(self): + if self.termination_time is not None: + return self.termination_time + + try: + response = self._make_ec2_request(url=self.METADATA_URL, timeout=1) + if response.status_code == 200: + self.termination_time = response.text + self._emit_termination_metadata(self.termination_time) + # TODO: Verify. This doesn't actually work, as getppid() does not return the main Metaflow task process id due to sidecar nesting. + os.kill(os.getppid(), signal.SIGTERM) + except (requests.exceptions.RequestException, requests.exceptions.Timeout): + pass + + return self.termination_time + def _monitor_loop(self): while self.is_alive: - try: - response = self._make_ec2_request(url=self.METADATA_URL, timeout=1) - if response.status_code == 200: - termination_time = response.text - self._emit_termination_metadata(termination_time) - os.kill(os.getppid(), signal.SIGTERM) - break - except (requests.exceptions.RequestException, requests.exceptions.Timeout): - pass + terminates_at = self._monitor_spot_termination() + + sent_signal = False + if terminates_at is not None: + sent_signal = self._try_sending_termination_signal() + + if sent_signal: + break time.sleep(self.POLL_INTERVAL) def _emit_termination_metadata(self, termination_time): diff --git a/metaflow/plugins/retry_decorator.py b/metaflow/plugins/retry_decorator.py index 4a56567ed57..8e2e2f3e7c3 100644 --- a/metaflow/plugins/retry_decorator.py +++ b/metaflow/plugins/retry_decorator.py @@ -1,6 +1,11 @@ +import os +import signal +import sys + from metaflow.decorators import StepDecorator from metaflow.exception import MetaflowException from metaflow.metaflow_config import MAX_ATTEMPTS +from metaflow import current SUPPORTED_RETRY_EVENTS = ["all", "spot-termination"] @@ -56,5 +61,42 @@ def step_init(self, flow, graph, step, decos, environment, flow_datastore, logge % ", ".join("*%s*" % event for event in unsupported_events) ) + def task_pre_step( + self, + step_name, + task_datastore, + metadata, + run_id, + task_id, + flow, + graph, + retry_count, + max_user_code_retries, + ubf_context, + inputs, + ): + # Bind signal handlers for user-code scope + self._old_alarm_signal_handler = signal.getsignal(signal.SIGALRM) + if "spot-termination" in self.attributes["only_on"]: + has_custom_signal_handler = ( + signal.getsignal(signal.SIGALRM) != signal.SIG_DFL + ) + + def _spot_handler(*args, **kwargs): + # call the custom signal handler first just in case it is of importance + if has_custom_signal_handler: + self._old_alarm_signal_handler(*args, **kwargs) + # custom exit code in case of Spot termination + if os.path.isfile(current.spot_termination_notice): + sys.exit(154) + + signal.signal(signal.SIGALRM, _spot_handler) + + def task_post_step( + self, step_name, flow, graph, retry_count, max_user_code_retries + ): + # Unbind the signal handlers as we are exiting user-code scope + signal.signal(signal.SIGALRM, self._old_alarm_signal_handler) + def step_task_retry_count(self): return int(self.attributes["times"]), 0 From a19e000cf2d28e9f8816245f292646483c6e897b Mon Sep 17 00:00:00 2001 From: Sakari Ikonen Date: Thu, 24 Apr 2025 10:33:49 +0300 Subject: [PATCH 05/15] default case fix --- metaflow/plugins/retry_decorator.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/metaflow/plugins/retry_decorator.py b/metaflow/plugins/retry_decorator.py index 8e2e2f3e7c3..82ae93129c9 100644 --- a/metaflow/plugins/retry_decorator.py +++ b/metaflow/plugins/retry_decorator.py @@ -77,7 +77,10 @@ def task_pre_step( ): # Bind signal handlers for user-code scope self._old_alarm_signal_handler = signal.getsignal(signal.SIGALRM) - if "spot-termination" in self.attributes["only_on"]: + if ( + self.attributes["only_on"] is not None + and "spot-termination" in self.attributes["only_on"] + ): has_custom_signal_handler = ( signal.getsignal(signal.SIGALRM) != signal.SIG_DFL ) From b5b03595b6114a6f33bbc232738abf6617d91b61 Mon Sep 17 00:00:00 2001 From: Sakari Ikonen Date: Thu, 24 Apr 2025 11:23:06 +0300 Subject: [PATCH 06/15] cleanup main PID handling --- metaflow/plugins/aws/batch/batch_decorator.py | 2 ++ metaflow/plugins/kubernetes/spot_monitor_sidecar.py | 8 +++++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/metaflow/plugins/aws/batch/batch_decorator.py b/metaflow/plugins/aws/batch/batch_decorator.py index 5c61d3f8f03..84f5cea9526 100644 --- a/metaflow/plugins/aws/batch/batch_decorator.py +++ b/metaflow/plugins/aws/batch/batch_decorator.py @@ -299,6 +299,8 @@ def task_pre_step( self._save_logs_sidecar.start() # Start spot termination monitor sidecar. + # TODO: A nicer way to pass the main process id to a Sidecar, in order to allow sidecars to send signals back to the main process. + os.environ["MF_MAIN_PID"] = str(os.getpid()) current._update_env( {"spot_termination_notice": "/tmp/spot_termination_notice"} ) diff --git a/metaflow/plugins/kubernetes/spot_monitor_sidecar.py b/metaflow/plugins/kubernetes/spot_monitor_sidecar.py index 1acb3dc7310..18b32bce130 100644 --- a/metaflow/plugins/kubernetes/spot_monitor_sidecar.py +++ b/metaflow/plugins/kubernetes/spot_monitor_sidecar.py @@ -20,7 +20,10 @@ def __init__(self): self._process = None self._token = None self._token_expiry = 0 + self.termination_time = None + # Due to nesting, os.getppid is not reliable for fetching the main task pid + self.main_pid = int(os.getenv("MF_MAIN_PID", os.getppid())) if self._is_aws_spot_instance(): self._process = Process(target=self._monitor_loop) @@ -70,7 +73,7 @@ def _try_sending_termination_signal(self): if (datetime.now() + self.POLL_INTERVAL - 20) < self.termination_time: return False else: - os.kill(int(os.getenv("MF_MAIN_PID")), signal.SIGALRM) + os.kill(self.main_pid, signal.SIGALRM) return True def _monitor_spot_termination(self): @@ -82,8 +85,7 @@ def _monitor_spot_termination(self): if response.status_code == 200: self.termination_time = response.text self._emit_termination_metadata(self.termination_time) - # TODO: Verify. This doesn't actually work, as getppid() does not return the main Metaflow task process id due to sidecar nesting. - os.kill(os.getppid(), signal.SIGTERM) + os.kill(self.main_pid, signal.SIGTERM) except (requests.exceptions.RequestException, requests.exceptions.Timeout): pass From 7d0470cec849af1e61bfdb6818ca574f74e7e934 Mon Sep 17 00:00:00 2001 From: Sakari Ikonen Date: Fri, 25 Apr 2025 18:20:44 +0300 Subject: [PATCH 07/15] fix termination time handling --- metaflow/plugins/kubernetes/spot_monitor_sidecar.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/metaflow/plugins/kubernetes/spot_monitor_sidecar.py b/metaflow/plugins/kubernetes/spot_monitor_sidecar.py index 18b32bce130..c527afd698c 100644 --- a/metaflow/plugins/kubernetes/spot_monitor_sidecar.py +++ b/metaflow/plugins/kubernetes/spot_monitor_sidecar.py @@ -5,7 +5,7 @@ import requests import subprocess from multiprocessing import Process -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone from metaflow.sidecar import MessageTypes @@ -69,8 +69,10 @@ def _is_aws_spot_instance(self): return False def _try_sending_termination_signal(self): - # wait for 20 seconds before the promised termination time of the spot instance before sending a SIGALRM - if (datetime.now() + self.POLL_INTERVAL - 20) < self.termination_time: + # wait for 100 seconds before the promised termination time of the spot instance before sending a SIGALRM + if ( + datetime.now(timezone.utc) + timedelta(0, self.POLL_INTERVAL + 100) + ) < datetime.strptime(self.termination_time, "%Y-%m-%dT%H:%M:%SZ"): return False else: os.kill(self.main_pid, signal.SIGALRM) From 77da5f0ce713a5b116f59263e8d08c86aed0c198 Mon Sep 17 00:00:00 2001 From: Sakari Ikonen Date: Mon, 28 Apr 2025 21:21:09 +0300 Subject: [PATCH 08/15] use SIGUSR1 as the spot termination signal. revert other spot termination monitor sidecar changes --- .../kubernetes/spot_monitor_sidecar.py | 45 +++++-------------- metaflow/plugins/retry_decorator.py | 41 ++++++++--------- 2 files changed, 29 insertions(+), 57 deletions(-) diff --git a/metaflow/plugins/kubernetes/spot_monitor_sidecar.py b/metaflow/plugins/kubernetes/spot_monitor_sidecar.py index c527afd698c..115b815a6e4 100644 --- a/metaflow/plugins/kubernetes/spot_monitor_sidecar.py +++ b/metaflow/plugins/kubernetes/spot_monitor_sidecar.py @@ -5,7 +5,7 @@ import requests import subprocess from multiprocessing import Process -from datetime import datetime, timedelta, timezone +from datetime import datetime, timezone from metaflow.sidecar import MessageTypes @@ -21,7 +21,6 @@ def __init__(self): self._token = None self._token_expiry = 0 - self.termination_time = None # Due to nesting, os.getppid is not reliable for fetching the main task pid self.main_pid = int(os.getenv("MF_MAIN_PID", os.getppid())) @@ -68,41 +67,17 @@ def _is_aws_spot_instance(self): except (requests.exceptions.RequestException, requests.exceptions.Timeout): return False - def _try_sending_termination_signal(self): - # wait for 100 seconds before the promised termination time of the spot instance before sending a SIGALRM - if ( - datetime.now(timezone.utc) + timedelta(0, self.POLL_INTERVAL + 100) - ) < datetime.strptime(self.termination_time, "%Y-%m-%dT%H:%M:%SZ"): - return False - else: - os.kill(self.main_pid, signal.SIGALRM) - return True - - def _monitor_spot_termination(self): - if self.termination_time is not None: - return self.termination_time - - try: - response = self._make_ec2_request(url=self.METADATA_URL, timeout=1) - if response.status_code == 200: - self.termination_time = response.text - self._emit_termination_metadata(self.termination_time) - os.kill(self.main_pid, signal.SIGTERM) - except (requests.exceptions.RequestException, requests.exceptions.Timeout): - pass - - return self.termination_time - def _monitor_loop(self): while self.is_alive: - terminates_at = self._monitor_spot_termination() - - sent_signal = False - if terminates_at is not None: - sent_signal = self._try_sending_termination_signal() - - if sent_signal: - break + try: + response = self._make_ec2_request(url=self.METADATA_URL, timeout=1) + if response.status_code == 200: + termination_time = response.text + self._emit_termination_metadata(termination_time) + os.kill(self.main_pid, signal.SIGUSR1) + break + except (requests.exceptions.RequestException, requests.exceptions.Timeout): + pass time.sleep(self.POLL_INTERVAL) def _emit_termination_metadata(self, termination_time): diff --git a/metaflow/plugins/retry_decorator.py b/metaflow/plugins/retry_decorator.py index 82ae93129c9..9d11776d776 100644 --- a/metaflow/plugins/retry_decorator.py +++ b/metaflow/plugins/retry_decorator.py @@ -1,6 +1,8 @@ import os import signal import sys +import threading +from time import sleep from metaflow.decorators import StepDecorator from metaflow.exception import MetaflowException @@ -75,31 +77,26 @@ def task_pre_step( ubf_context, inputs, ): - # Bind signal handlers for user-code scope - self._old_alarm_signal_handler = signal.getsignal(signal.SIGALRM) - if ( - self.attributes["only_on"] is not None - and "spot-termination" in self.attributes["only_on"] - ): - has_custom_signal_handler = ( - signal.getsignal(signal.SIGALRM) != signal.SIG_DFL - ) + pid = os.getpid() + + def _termination_timer(): + sleep(30) + os.kill(pid, signal.SIGALRM) - def _spot_handler(*args, **kwargs): - # call the custom signal handler first just in case it is of importance - if has_custom_signal_handler: - self._old_alarm_signal_handler(*args, **kwargs) - # custom exit code in case of Spot termination - if os.path.isfile(current.spot_termination_notice): - sys.exit(154) + def _spot_term_signal_handler(*args, **kwargs): + if os.path.isfile(current.spot_termination_notice): + print( + "Spot instance termination detected. Starting a timer to end the Metaflow task." + ) + timer_thread = threading.Thread(target=_termination_timer, daemon=True) + timer_thread.start() - signal.signal(signal.SIGALRM, _spot_handler) + def _curtain_call(*args, **kwargs): + # custom exit code in case of Spot termination + sys.exit(154) - def task_post_step( - self, step_name, flow, graph, retry_count, max_user_code_retries - ): - # Unbind the signal handlers as we are exiting user-code scope - signal.signal(signal.SIGALRM, self._old_alarm_signal_handler) + signal.signal(signal.SIGUSR1, _spot_term_signal_handler) + signal.signal(signal.SIGALRM, _curtain_call) def step_task_retry_count(self): return int(self.attributes["times"]), 0 From 964479a1a57f6e4f7c89b0009ed6bfb62f07deeb Mon Sep 17 00:00:00 2001 From: Sakari Ikonen Date: Mon, 28 Apr 2025 23:26:03 +0300 Subject: [PATCH 09/15] disjoint retry cases for kubernetes --- metaflow/plugins/kubernetes/kubernetes.py | 2 +- metaflow/plugins/kubernetes/kubernetes_cli.py | 5 +++++ metaflow/plugins/retry_decorator.py | 8 +++++--- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/metaflow/plugins/kubernetes/kubernetes.py b/metaflow/plugins/kubernetes/kubernetes.py index acfeeb83e98..18799bf730f 100644 --- a/metaflow/plugins/kubernetes/kubernetes.py +++ b/metaflow/plugins/kubernetes/kubernetes.py @@ -768,7 +768,7 @@ def _has_updates(): ) if int(exit_code) == 134: raise KubernetesException("%s (exit code %s)" % (msg, exit_code)) - if int(exit_code) == 154: + if int(exit_code) == 234: # NOTE. K8S exit codes are mod 256 raise KubernetesSpotInstanceTerminated() else: diff --git a/metaflow/plugins/kubernetes/kubernetes_cli.py b/metaflow/plugins/kubernetes/kubernetes_cli.py index 81bfb9958b2..bbea9f6495b 100644 --- a/metaflow/plugins/kubernetes/kubernetes_cli.py +++ b/metaflow/plugins/kubernetes/kubernetes_cli.py @@ -342,6 +342,11 @@ def _sync_metadata(): sys.exit(METAFLOW_EXIT_ALLOW_RETRY) else: sys.exit(METAFLOW_EXIT_DISALLOW_RETRY) + except KubernetesException: + if not retry_conditions or "step" in retry_conditions: + raise + sys.exit(METAFLOW_EXIT_DISALLOW_RETRY) + finally: _sync_metadata() diff --git a/metaflow/plugins/retry_decorator.py b/metaflow/plugins/retry_decorator.py index 9d11776d776..f1489ef466b 100644 --- a/metaflow/plugins/retry_decorator.py +++ b/metaflow/plugins/retry_decorator.py @@ -9,7 +9,9 @@ from metaflow.metaflow_config import MAX_ATTEMPTS from metaflow import current -SUPPORTED_RETRY_EVENTS = ["all", "spot-termination"] +SUPPORTED_RETRY_EVENTS = ["step", "spot-termination"] + +PLATFORM_EVICTED_EXITCODE = 234 class RetryDecorator(StepDecorator): @@ -33,7 +35,7 @@ class RetryDecorator(StepDecorator): Number of minutes between retries. only_on : List[str], default None List of failure events to retry on. Accepted values are - 'all', 'spot-termination' + 'step', 'spot-termination' """ name = "retry" @@ -93,7 +95,7 @@ def _spot_term_signal_handler(*args, **kwargs): def _curtain_call(*args, **kwargs): # custom exit code in case of Spot termination - sys.exit(154) + sys.exit(PLATFORM_EVICTED_EXITCODE) signal.signal(signal.SIGUSR1, _spot_term_signal_handler) signal.signal(signal.SIGALRM, _curtain_call) From c32b7d98922dda693f898293bf61f4fd07442875 Mon Sep 17 00:00:00 2001 From: Sakari Ikonen Date: Mon, 28 Apr 2025 23:56:09 +0300 Subject: [PATCH 10/15] add disjoint retry handling to batch --- metaflow/plugins/aws/batch/batch.py | 9 +++++++++ metaflow/plugins/aws/batch/batch_cli.py | 24 ++++++++++++++++++++++-- 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/metaflow/plugins/aws/batch/batch.py b/metaflow/plugins/aws/batch/batch.py index 16ce9a06cef..c03ce507745 100644 --- a/metaflow/plugins/aws/batch/batch.py +++ b/metaflow/plugins/aws/batch/batch.py @@ -52,6 +52,10 @@ class BatchKilledException(MetaflowException): headline = "AWS Batch task killed" +class BatchSpotInstanceTerminated(MetaflowException): + headline = "Spot Instance has been terminated" + + class Batch(object): def __init__(self, metadata, environment): self.metadata = metadata @@ -482,6 +486,11 @@ def wait_for_launch(job, child_jobs): # to Amazon S3. if self.job.is_crashed: + + # Custom exception for spot instance terminations + if self.job.status_code == 234: + raise BatchSpotInstanceTerminated() + msg = next( msg for msg in [ diff --git a/metaflow/plugins/aws/batch/batch_cli.py b/metaflow/plugins/aws/batch/batch_cli.py index a2d2199e2e6..9a36acb0d55 100644 --- a/metaflow/plugins/aws/batch/batch_cli.py +++ b/metaflow/plugins/aws/batch/batch_cli.py @@ -6,12 +6,21 @@ from metaflow import util from metaflow import R -from metaflow.exception import CommandException, METAFLOW_EXIT_DISALLOW_RETRY +from metaflow.exception import ( + METAFLOW_EXIT_ALLOW_RETRY, + CommandException, + METAFLOW_EXIT_DISALLOW_RETRY, +) from metaflow.metadata_provider.util import sync_local_metadata_from_datastore from metaflow.metaflow_config import DATASTORE_LOCAL_DIR from metaflow.mflog import TASK_LOG_SOURCE from metaflow.unbounded_foreach import UBF_CONTROL, UBF_TASK -from .batch import Batch, BatchKilledException +from .batch import ( + Batch, + BatchException, + BatchKilledException, + BatchSpotInstanceTerminated, +) @click.group() @@ -283,6 +292,7 @@ def echo(msg, stream="stderr", batch_id=None, **kwargs): if split_vars: env.update(split_vars) + retry_conditions = retry_deco[0].attributes["only_on"] if retry_deco else None if retry_count: ctx.obj.echo_always( "Sleeping %d minutes before the next AWS Batch retry" @@ -356,5 +366,15 @@ def _sync_metadata(): # don't retry killed tasks traceback.print_exc() sys.exit(METAFLOW_EXIT_DISALLOW_RETRY) + except BatchSpotInstanceTerminated: + traceback.print_exc() + if retry_conditions is not None and "spot-termination" in retry_conditions: + sys.exit(METAFLOW_EXIT_ALLOW_RETRY) + else: + sys.exit(METAFLOW_EXIT_DISALLOW_RETRY) + except BatchException: + if not retry_conditions or "step" in retry_conditions: + raise + sys.exit(METAFLOW_EXIT_DISALLOW_RETRY) finally: _sync_metadata() From fa5a890dfbfdf2d98bae9217083c6e318f888f0b Mon Sep 17 00:00:00 2001 From: Sakari Ikonen Date: Tue, 29 Apr 2025 00:11:07 +0300 Subject: [PATCH 11/15] rename spot-termination retry case --- metaflow/plugins/aws/batch/batch_cli.py | 2 +- metaflow/plugins/kubernetes/kubernetes_cli.py | 2 +- metaflow/plugins/retry_decorator.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/metaflow/plugins/aws/batch/batch_cli.py b/metaflow/plugins/aws/batch/batch_cli.py index 9a36acb0d55..d0d27c83bf2 100644 --- a/metaflow/plugins/aws/batch/batch_cli.py +++ b/metaflow/plugins/aws/batch/batch_cli.py @@ -368,7 +368,7 @@ def _sync_metadata(): sys.exit(METAFLOW_EXIT_DISALLOW_RETRY) except BatchSpotInstanceTerminated: traceback.print_exc() - if retry_conditions is not None and "spot-termination" in retry_conditions: + if retry_conditions is not None and "instance-preemption" in retry_conditions: sys.exit(METAFLOW_EXIT_ALLOW_RETRY) else: sys.exit(METAFLOW_EXIT_DISALLOW_RETRY) diff --git a/metaflow/plugins/kubernetes/kubernetes_cli.py b/metaflow/plugins/kubernetes/kubernetes_cli.py index bbea9f6495b..01f141da73a 100644 --- a/metaflow/plugins/kubernetes/kubernetes_cli.py +++ b/metaflow/plugins/kubernetes/kubernetes_cli.py @@ -338,7 +338,7 @@ def _sync_metadata(): sys.exit(METAFLOW_EXIT_DISALLOW_RETRY) except KubernetesSpotInstanceTerminated: traceback.print_exc() - if retry_conditions is not None and "spot-termination" in retry_conditions: + if retry_conditions is not None and "instance-preemption" in retry_conditions: sys.exit(METAFLOW_EXIT_ALLOW_RETRY) else: sys.exit(METAFLOW_EXIT_DISALLOW_RETRY) diff --git a/metaflow/plugins/retry_decorator.py b/metaflow/plugins/retry_decorator.py index f1489ef466b..d1a8f4a6882 100644 --- a/metaflow/plugins/retry_decorator.py +++ b/metaflow/plugins/retry_decorator.py @@ -9,7 +9,7 @@ from metaflow.metaflow_config import MAX_ATTEMPTS from metaflow import current -SUPPORTED_RETRY_EVENTS = ["step", "spot-termination"] +SUPPORTED_RETRY_EVENTS = ["step", "instance-preemption"] PLATFORM_EVICTED_EXITCODE = 234 @@ -35,7 +35,7 @@ class RetryDecorator(StepDecorator): Number of minutes between retries. only_on : List[str], default None List of failure events to retry on. Accepted values are - 'step', 'spot-termination' + 'step', 'instance-preemption' """ name = "retry" From f426302cd11a39ce38de623d910411deac72d3ec Mon Sep 17 00:00:00 2001 From: Sakari Ikonen Date: Tue, 29 Apr 2025 00:44:16 +0300 Subject: [PATCH 12/15] use an enum for RetryEvents naming. add coverage for 'all' case by omitting only_on attribute --- metaflow/plugins/aws/batch/batch_cli.py | 7 ++++--- metaflow/plugins/kubernetes/kubernetes_cli.py | 7 ++++--- metaflow/plugins/retry_decorator.py | 21 ++++++++++++++----- 3 files changed, 24 insertions(+), 11 deletions(-) diff --git a/metaflow/plugins/aws/batch/batch_cli.py b/metaflow/plugins/aws/batch/batch_cli.py index d0d27c83bf2..96aff8d774f 100644 --- a/metaflow/plugins/aws/batch/batch_cli.py +++ b/metaflow/plugins/aws/batch/batch_cli.py @@ -21,6 +21,7 @@ BatchKilledException, BatchSpotInstanceTerminated, ) +from metaflow.plugins.retry_decorator import RetryEvents @click.group() @@ -292,7 +293,7 @@ def echo(msg, stream="stderr", batch_id=None, **kwargs): if split_vars: env.update(split_vars) - retry_conditions = retry_deco[0].attributes["only_on"] if retry_deco else None + retry_conditions = retry_deco[0].attributes["only_on"] if retry_deco else [] if retry_count: ctx.obj.echo_always( "Sleeping %d minutes before the next AWS Batch retry" @@ -368,12 +369,12 @@ def _sync_metadata(): sys.exit(METAFLOW_EXIT_DISALLOW_RETRY) except BatchSpotInstanceTerminated: traceback.print_exc() - if retry_conditions is not None and "instance-preemption" in retry_conditions: + if not retry_conditions or RetryEvents.PREEMPT.value in retry_conditions: sys.exit(METAFLOW_EXIT_ALLOW_RETRY) else: sys.exit(METAFLOW_EXIT_DISALLOW_RETRY) except BatchException: - if not retry_conditions or "step" in retry_conditions: + if not retry_conditions or RetryEvents.STEP.value in retry_conditions: raise sys.exit(METAFLOW_EXIT_DISALLOW_RETRY) finally: diff --git a/metaflow/plugins/kubernetes/kubernetes_cli.py b/metaflow/plugins/kubernetes/kubernetes_cli.py index 01f141da73a..ae87224d0ff 100644 --- a/metaflow/plugins/kubernetes/kubernetes_cli.py +++ b/metaflow/plugins/kubernetes/kubernetes_cli.py @@ -27,6 +27,7 @@ KubernetesKilledException, KubernetesSpotInstanceTerminated, ) +from metaflow.plugins.retry_decorator import RetryEvents @click.group() @@ -226,7 +227,7 @@ def echo(msg, stream="stderr", job_id=None, **kwargs): minutes_between_retries = int( retry_deco[0].attributes.get("minutes_between_retries", 2) ) - retry_conditions = retry_deco[0].attributes["only_on"] if retry_deco else None + retry_conditions = retry_deco[0].attributes["only_on"] if retry_deco else [] if retry_count: ctx.obj.echo_always( "Sleeping %d minutes before the next retry" % minutes_between_retries @@ -338,12 +339,12 @@ def _sync_metadata(): sys.exit(METAFLOW_EXIT_DISALLOW_RETRY) except KubernetesSpotInstanceTerminated: traceback.print_exc() - if retry_conditions is not None and "instance-preemption" in retry_conditions: + if not retry_conditions or RetryEvents.PREEMPT.value in retry_conditions: sys.exit(METAFLOW_EXIT_ALLOW_RETRY) else: sys.exit(METAFLOW_EXIT_DISALLOW_RETRY) except KubernetesException: - if not retry_conditions or "step" in retry_conditions: + if not retry_conditions or RetryEvents.STEP.value in retry_conditions: raise sys.exit(METAFLOW_EXIT_DISALLOW_RETRY) diff --git a/metaflow/plugins/retry_decorator.py b/metaflow/plugins/retry_decorator.py index d1a8f4a6882..ac5a7fd3801 100644 --- a/metaflow/plugins/retry_decorator.py +++ b/metaflow/plugins/retry_decorator.py @@ -1,3 +1,4 @@ +from enum import Enum import os import signal import sys @@ -9,7 +10,11 @@ from metaflow.metaflow_config import MAX_ATTEMPTS from metaflow import current -SUPPORTED_RETRY_EVENTS = ["step", "instance-preemption"] + +class RetryEvents(Enum): + STEP = "step" + PREEMPT = "instance-preemption" + PLATFORM_EVICTED_EXITCODE = 234 @@ -34,7 +39,8 @@ class RetryDecorator(StepDecorator): minutes_between_retries : int, default 2 Number of minutes between retries. only_on : List[str], default None - List of failure events to retry on. Accepted values are + List of failure events to retry on. Not specifying values will retry on all known events. + Accepted values are 'step', 'instance-preemption' """ @@ -54,10 +60,15 @@ def step_init(self, flow, graph, step, decos, environment, flow_datastore, logge if not isinstance(self.attributes["only_on"], list): raise MetaflowException("'only_on=' must be a list of values") + def _known_event(event: str): + try: + RetryEvents(event) + return True + except ValueError: + return False + unsupported_events = [ - event - for event in self.attributes["only_on"] - if event not in SUPPORTED_RETRY_EVENTS + event for event in self.attributes["only_on"] if not _known_event(event) ] if unsupported_events: raise MetaflowException( From 04603da30d49f0508c31168ed9bb3abb1a134dde Mon Sep 17 00:00:00 2001 From: Sakari Ikonen Date: Tue, 29 Apr 2025 16:39:21 +0300 Subject: [PATCH 13/15] disjoint retry support for argo workflows --- metaflow/plugins/argo/argo_workflows.py | 28 +++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index 1d7f511282e..fbe08fa39ff 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -54,6 +54,7 @@ from metaflow.plugins.kubernetes.kube_utils import qos_requests_and_limits from metaflow.plugins.kubernetes.kubernetes_jobsets import KubernetesArgoJobSet +from metaflow.plugins.retry_decorator import PLATFORM_EVICTED_EXITCODE, RetryEvents from metaflow.unbounded_foreach import UBF_CONTROL, UBF_TASK from metaflow.user_configs.config_options import ConfigInput from metaflow.util import ( @@ -1520,11 +1521,13 @@ def _container_templates(self): max_user_code_retries = 0 max_error_retries = 0 minutes_between_retries = "2" + retry_conditions = [] for decorator in node.decorators: if decorator.name == "retry": minutes_between_retries = decorator.attributes.get( "minutes_between_retries", minutes_between_retries ) + retry_conditions = decorator.attributes["only_on"] user_code_retries, error_retries = decorator.step_task_retry_count() max_user_code_retries = max(max_user_code_retries, user_code_retries) max_error_retries = max(max_error_retries, error_retries) @@ -1546,6 +1549,21 @@ def _container_templates(self): minutes_between_retries = int(minutes_between_retries) + # Translate RetryEvents to expressions for Argo + event_to_expr = { + RetryEvents.STEP: "asInt(lastRetry.exitCode) == 1", + RetryEvents.PREEMPT: "asInt(lastRetry.exitCode) == %s" + % PLATFORM_EVICTED_EXITCODE, + } + retry_expr = None + if retry_conditions: + retry_expressions = [ + expr + for event, expr in event_to_expr.items() + if event.value in retry_conditions + ] + retry_expr = "||".join(retry_expressions) + # Configure log capture. mflog_expr = export_mflog_env_vars( datastore_type=self.flow_datastore.TYPE, @@ -2137,6 +2155,7 @@ def _container_templates(self): .retry_strategy( times=total_retries, minutes_between_retries=minutes_between_retries, + expression=retry_expr, ) ) else: @@ -2156,6 +2175,7 @@ def _container_templates(self): .retry_strategy( times=total_retries, minutes_between_retries=minutes_between_retries, + expression=retry_expr, ) .metadata( ObjectMeta() @@ -3661,13 +3681,17 @@ def service_account_name(self, service_account_name): self.payload["serviceAccountName"] = service_account_name return self - def retry_strategy(self, times, minutes_between_retries): + def retry_strategy(self, times, minutes_between_retries, expression=None): if times > 0: self.payload["retryStrategy"] = { - "retryPolicy": "Always", "limit": times, "backoff": {"duration": "%sm" % minutes_between_retries}, } + if expression is None: + self.payload["retryStrategy"]["retryPolicy"] = "Always" + else: + self.payload["retryStrategy"]["expression"] = expression + return self def empty_dir_volume(self, name, medium=None, size_limit=None): From 9c72cffe4e682c381d240296ece109c1fff4af2e Mon Sep 17 00:00:00 2001 From: Sakari Ikonen Date: Wed, 30 Apr 2025 16:20:15 +0300 Subject: [PATCH 14/15] add conditional retry support to step functions --- metaflow/plugins/aws/batch/batch_client.py | 7 +++- .../aws/step_functions/step_functions.py | 32 ++++++++++++++++++- 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/metaflow/plugins/aws/batch/batch_client.py b/metaflow/plugins/aws/batch/batch_client.py index bf0f6a824e7..44c719e5439 100644 --- a/metaflow/plugins/aws/batch/batch_client.py +++ b/metaflow/plugins/aws/batch/batch_client.py @@ -4,6 +4,7 @@ import random import time import hashlib +from typing import Dict, List, Optional try: unicode @@ -630,8 +631,12 @@ def parameter(self, key, value): self.payload["parameters"][key] = str(value) return self - def attempts(self, attempts): + def attempts(self, attempts, evaluate_on_exit: Optional[List[Dict]] = None): self.payload["retryStrategy"]["attempts"] = attempts + if evaluate_on_exit is not None: + # required for specifying custom retry strategies + # ref: https://docs.aws.amazon.com/batch/latest/APIReference/API_EvaluateOnExit.html + self.payload["retryStrategy"]["evaluateOnExit"] = evaluate_on_exit return self diff --git a/metaflow/plugins/aws/step_functions/step_functions.py b/metaflow/plugins/aws/step_functions/step_functions.py index ccf22b4fd35..2553ef3e6b5 100644 --- a/metaflow/plugins/aws/step_functions/step_functions.py +++ b/metaflow/plugins/aws/step_functions/step_functions.py @@ -18,6 +18,7 @@ SFN_S3_DISTRIBUTED_MAP_OUTPUT_PATH, ) from metaflow.parameters import deploy_time_eval +from metaflow.plugins.retry_decorator import PLATFORM_EVICTED_EXITCODE, RetryEvents from metaflow.user_configs.config_options import ConfigInput from metaflow.util import dict_to_cli_options, to_pascalcase @@ -824,9 +825,38 @@ def _batch(self, node): batch_deco = [deco for deco in node.decorators if deco.name == "batch"][0] resources = {} resources.update(batch_deco.attributes) + # Resolve retry strategy. user_code_retries, total_retries = self._get_retries(node) + # retry conditions mapping + retry_deco = next( + (deco for deco in node.decorators if deco.name == "retry"), None + ) + retry_conditions = ( + retry_deco.attributes["only_on"] if retry_deco is not None else [] + ) + + # Translate RetryEvents to expressions for SFN + event_to_expr = { + RetryEvents.STEP: {"action": "RETRY", "onExitCode": "1"}, + RetryEvents.PREEMPT: { + "action": "RETRY", + "onExitCode": str(PLATFORM_EVICTED_EXITCODE), + }, + } + retry_expr = None + # NOTE: AWS only allows 5 distinct EvaluateOnExit conditions, so any more than this will require combining them. + if retry_conditions: + retry_expr = [ + expr + for event, expr in event_to_expr.items() + if event.value in retry_conditions + ] + # we need to append a catch-all exit condition, as for no matches the default behavior with Batch is to retry the job. + # retry conditions are only evaluated for non-zero exit codes, so the wildcard is fine here. + retry_expr.append({"action": "EXIT", "onExitCode": "*"}) + task_spec = { "flow_name": attrs["metaflow.flow_name"], "step_name": attrs["metaflow.step_name"], @@ -875,7 +905,7 @@ def _batch(self, node): log_driver=resources["log_driver"], log_options=resources["log_options"], ) - .attempts(total_retries + 1) + .attempts(attempts=total_retries + 1, evaluate_on_exit=retry_expr) ) def _get_retries(self, node): From 206b9601565001efc6d613b3ee2fa3140e595b2c Mon Sep 17 00:00:00 2001 From: Sakari Ikonen Date: Fri, 2 May 2025 14:37:02 +0300 Subject: [PATCH 15/15] move custom exit code and signal handlers away from retry decorator --- metaflow/plugins/argo/argo_workflows.py | 5 ++- metaflow/plugins/aws/batch/batch.py | 4 +- metaflow/plugins/aws/batch/batch_decorator.py | 27 +++++++++++- .../aws/step_functions/step_functions.py | 6 +-- metaflow/plugins/kubernetes/kubernetes.py | 4 +- .../kubernetes/kubernetes_decorator.py | 27 +++++++++++- metaflow/plugins/retry_decorator.py | 44 ------------------- 7 files changed, 64 insertions(+), 53 deletions(-) diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index fbe08fa39ff..ebbe74b506d 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -54,7 +54,8 @@ from metaflow.plugins.kubernetes.kube_utils import qos_requests_and_limits from metaflow.plugins.kubernetes.kubernetes_jobsets import KubernetesArgoJobSet -from metaflow.plugins.retry_decorator import PLATFORM_EVICTED_EXITCODE, RetryEvents +from metaflow.plugins.kubernetes.kubernetes import SPOT_INTERRUPT_EXITCODE +from metaflow.plugins.retry_decorator import RetryEvents from metaflow.unbounded_foreach import UBF_CONTROL, UBF_TASK from metaflow.user_configs.config_options import ConfigInput from metaflow.util import ( @@ -1553,7 +1554,7 @@ def _container_templates(self): event_to_expr = { RetryEvents.STEP: "asInt(lastRetry.exitCode) == 1", RetryEvents.PREEMPT: "asInt(lastRetry.exitCode) == %s" - % PLATFORM_EVICTED_EXITCODE, + % SPOT_INTERRUPT_EXITCODE, } retry_expr = None if retry_conditions: diff --git a/metaflow/plugins/aws/batch/batch.py b/metaflow/plugins/aws/batch/batch.py index c03ce507745..54422024fab 100644 --- a/metaflow/plugins/aws/batch/batch.py +++ b/metaflow/plugins/aws/batch/batch.py @@ -43,6 +43,8 @@ STDOUT_PATH = os.path.join(LOGS_DIR, STDOUT_FILE) STDERR_PATH = os.path.join(LOGS_DIR, STDERR_FILE) +SPOT_INTERRUPT_EXITCODE = 234 + class BatchException(MetaflowException): headline = "AWS Batch error" @@ -488,7 +490,7 @@ def wait_for_launch(job, child_jobs): if self.job.is_crashed: # Custom exception for spot instance terminations - if self.job.status_code == 234: + if self.job.status_code == SPOT_INTERRUPT_EXITCODE: raise BatchSpotInstanceTerminated() msg = next( diff --git a/metaflow/plugins/aws/batch/batch_decorator.py b/metaflow/plugins/aws/batch/batch_decorator.py index 84f5cea9526..1f81dbc2ee6 100644 --- a/metaflow/plugins/aws/batch/batch_decorator.py +++ b/metaflow/plugins/aws/batch/batch_decorator.py @@ -1,6 +1,8 @@ import os import platform +import signal import sys +import threading import time from metaflow import R, current @@ -24,7 +26,7 @@ get_docker_registry, get_ec2_instance_metadata, ) -from .batch import BatchException +from .batch import SPOT_INTERRUPT_EXITCODE, BatchException class BatchDecorator(StepDecorator): @@ -298,6 +300,29 @@ def task_pre_step( self._save_logs_sidecar = Sidecar("save_logs_periodically") self._save_logs_sidecar.start() + # Set up signal handling for spot termination + main_pid = os.getpid() + + def _termination_timer(): + time.sleep(30) + os.kill(main_pid, signal.SIGALRM) + + def _spot_term_signal_handler(*args, **kwargs): + if os.path.isfile(current.spot_termination_notice): + print( + "Spot instance termination detected. Starting a timer to end the Metaflow task." + ) + timer_thread = threading.Thread( + target=_termination_timer, daemon=True + ) + timer_thread.start() + + def _curtain_call(*args, **kwargs): + # custom exit code in case of Spot termination + sys.exit(SPOT_INTERRUPT_EXITCODE) + + signal.signal(signal.SIGUSR1, _spot_term_signal_handler) + signal.signal(signal.SIGALRM, _curtain_call) # Start spot termination monitor sidecar. # TODO: A nicer way to pass the main process id to a Sidecar, in order to allow sidecars to send signals back to the main process. os.environ["MF_MAIN_PID"] = str(os.getpid()) diff --git a/metaflow/plugins/aws/step_functions/step_functions.py b/metaflow/plugins/aws/step_functions/step_functions.py index 2553ef3e6b5..50c709dd643 100644 --- a/metaflow/plugins/aws/step_functions/step_functions.py +++ b/metaflow/plugins/aws/step_functions/step_functions.py @@ -18,11 +18,11 @@ SFN_S3_DISTRIBUTED_MAP_OUTPUT_PATH, ) from metaflow.parameters import deploy_time_eval -from metaflow.plugins.retry_decorator import PLATFORM_EVICTED_EXITCODE, RetryEvents +from metaflow.plugins.retry_decorator import RetryEvents from metaflow.user_configs.config_options import ConfigInput from metaflow.util import dict_to_cli_options, to_pascalcase -from ..batch.batch import Batch +from ..batch.batch import Batch, SPOT_INTERRUPT_EXITCODE from .event_bridge_client import EventBridgeClient from .step_functions_client import StepFunctionsClient @@ -842,7 +842,7 @@ def _batch(self, node): RetryEvents.STEP: {"action": "RETRY", "onExitCode": "1"}, RetryEvents.PREEMPT: { "action": "RETRY", - "onExitCode": str(PLATFORM_EVICTED_EXITCODE), + "onExitCode": str(SPOT_INTERRUPT_EXITCODE), }, } retry_expr = None diff --git a/metaflow/plugins/kubernetes/kubernetes.py b/metaflow/plugins/kubernetes/kubernetes.py index 18799bf730f..06371c6f6f8 100644 --- a/metaflow/plugins/kubernetes/kubernetes.py +++ b/metaflow/plugins/kubernetes/kubernetes.py @@ -63,6 +63,8 @@ "{METAFLOW_PARALLEL_STEP_CLI_OPTIONS_TEMPLATE}" ) +SPOT_INTERRUPT_EXITCODE = 234 + class KubernetesException(MetaflowException): headline = "Kubernetes error" @@ -768,7 +770,7 @@ def _has_updates(): ) if int(exit_code) == 134: raise KubernetesException("%s (exit code %s)" % (msg, exit_code)) - if int(exit_code) == 234: + if int(exit_code) == SPOT_INTERRUPT_EXITCODE: # NOTE. K8S exit codes are mod 256 raise KubernetesSpotInstanceTerminated() else: diff --git a/metaflow/plugins/kubernetes/kubernetes_decorator.py b/metaflow/plugins/kubernetes/kubernetes_decorator.py index 58c65f4fd22..fe209465c3f 100644 --- a/metaflow/plugins/kubernetes/kubernetes_decorator.py +++ b/metaflow/plugins/kubernetes/kubernetes_decorator.py @@ -1,7 +1,9 @@ import json import os import platform +import signal import sys +import threading import time from metaflow import current @@ -37,7 +39,7 @@ from metaflow.unbounded_foreach import UBF_CONTROL from ..aws.aws_utils import get_docker_registry, get_ec2_instance_metadata -from .kubernetes import KubernetesException +from .kubernetes import KubernetesException, SPOT_INTERRUPT_EXITCODE from .kube_utils import validate_kube_labels, parse_kube_keyvalue_list try: @@ -548,6 +550,29 @@ def task_pre_step( self._save_logs_sidecar = Sidecar("save_logs_periodically") self._save_logs_sidecar.start() + # Set up signal handling for spot termination + main_pid = os.getpid() + + def _termination_timer(): + time.sleep(30) + os.kill(main_pid, signal.SIGALRM) + + def _spot_term_signal_handler(*args, **kwargs): + if os.path.isfile(current.spot_termination_notice): + print( + "Spot instance termination detected. Starting a timer to end the Metaflow task." + ) + timer_thread = threading.Thread( + target=_termination_timer, daemon=True + ) + timer_thread.start() + + def _curtain_call(*args, **kwargs): + # custom exit code in case of Spot termination + sys.exit(SPOT_INTERRUPT_EXITCODE) + + signal.signal(signal.SIGUSR1, _spot_term_signal_handler) + signal.signal(signal.SIGALRM, _curtain_call) # Start spot termination monitor sidecar. # TODO: A nicer way to pass the main process id to a Sidecar, in order to allow sidecars to send signals back to the main process. os.environ["MF_MAIN_PID"] = str(os.getpid()) diff --git a/metaflow/plugins/retry_decorator.py b/metaflow/plugins/retry_decorator.py index ac5a7fd3801..c08d60b6379 100644 --- a/metaflow/plugins/retry_decorator.py +++ b/metaflow/plugins/retry_decorator.py @@ -1,14 +1,8 @@ from enum import Enum -import os -import signal -import sys -import threading -from time import sleep from metaflow.decorators import StepDecorator from metaflow.exception import MetaflowException from metaflow.metaflow_config import MAX_ATTEMPTS -from metaflow import current class RetryEvents(Enum): @@ -16,9 +10,6 @@ class RetryEvents(Enum): PREEMPT = "instance-preemption" -PLATFORM_EVICTED_EXITCODE = 234 - - class RetryDecorator(StepDecorator): """ Specifies the number of times the task corresponding @@ -76,40 +67,5 @@ def _known_event(event: str): % ", ".join("*%s*" % event for event in unsupported_events) ) - def task_pre_step( - self, - step_name, - task_datastore, - metadata, - run_id, - task_id, - flow, - graph, - retry_count, - max_user_code_retries, - ubf_context, - inputs, - ): - pid = os.getpid() - - def _termination_timer(): - sleep(30) - os.kill(pid, signal.SIGALRM) - - def _spot_term_signal_handler(*args, **kwargs): - if os.path.isfile(current.spot_termination_notice): - print( - "Spot instance termination detected. Starting a timer to end the Metaflow task." - ) - timer_thread = threading.Thread(target=_termination_timer, daemon=True) - timer_thread.start() - - def _curtain_call(*args, **kwargs): - # custom exit code in case of Spot termination - sys.exit(PLATFORM_EVICTED_EXITCODE) - - signal.signal(signal.SIGUSR1, _spot_term_signal_handler) - signal.signal(signal.SIGALRM, _curtain_call) - def step_task_retry_count(self): return int(self.attributes["times"]), 0