diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index 1d7f511282e..ebbe74b506d 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -54,6 +54,8 @@ from metaflow.plugins.kubernetes.kube_utils import qos_requests_and_limits from metaflow.plugins.kubernetes.kubernetes_jobsets import KubernetesArgoJobSet +from metaflow.plugins.kubernetes.kubernetes import SPOT_INTERRUPT_EXITCODE +from metaflow.plugins.retry_decorator import RetryEvents from metaflow.unbounded_foreach import UBF_CONTROL, UBF_TASK from metaflow.user_configs.config_options import ConfigInput from metaflow.util import ( @@ -1520,11 +1522,13 @@ def _container_templates(self): max_user_code_retries = 0 max_error_retries = 0 minutes_between_retries = "2" + retry_conditions = [] for decorator in node.decorators: if decorator.name == "retry": minutes_between_retries = decorator.attributes.get( "minutes_between_retries", minutes_between_retries ) + retry_conditions = decorator.attributes["only_on"] user_code_retries, error_retries = decorator.step_task_retry_count() max_user_code_retries = max(max_user_code_retries, user_code_retries) max_error_retries = max(max_error_retries, error_retries) @@ -1546,6 +1550,21 @@ def _container_templates(self): minutes_between_retries = int(minutes_between_retries) + # Translate RetryEvents to expressions for Argo + event_to_expr = { + RetryEvents.STEP: "asInt(lastRetry.exitCode) == 1", + RetryEvents.PREEMPT: "asInt(lastRetry.exitCode) == %s" + % SPOT_INTERRUPT_EXITCODE, + } + retry_expr = None + if retry_conditions: + retry_expressions = [ + expr + for event, expr in event_to_expr.items() + if event.value in retry_conditions + ] + retry_expr = "||".join(retry_expressions) + # Configure log capture. mflog_expr = export_mflog_env_vars( datastore_type=self.flow_datastore.TYPE, @@ -2137,6 +2156,7 @@ def _container_templates(self): .retry_strategy( times=total_retries, minutes_between_retries=minutes_between_retries, + expression=retry_expr, ) ) else: @@ -2156,6 +2176,7 @@ def _container_templates(self): .retry_strategy( times=total_retries, minutes_between_retries=minutes_between_retries, + expression=retry_expr, ) .metadata( ObjectMeta() @@ -3661,13 +3682,17 @@ def service_account_name(self, service_account_name): self.payload["serviceAccountName"] = service_account_name return self - def retry_strategy(self, times, minutes_between_retries): + def retry_strategy(self, times, minutes_between_retries, expression=None): if times > 0: self.payload["retryStrategy"] = { - "retryPolicy": "Always", "limit": times, "backoff": {"duration": "%sm" % minutes_between_retries}, } + if expression is None: + self.payload["retryStrategy"]["retryPolicy"] = "Always" + else: + self.payload["retryStrategy"]["expression"] = expression + return self def empty_dir_volume(self, name, medium=None, size_limit=None): diff --git a/metaflow/plugins/aws/batch/batch.py b/metaflow/plugins/aws/batch/batch.py index 16ce9a06cef..54422024fab 100644 --- a/metaflow/plugins/aws/batch/batch.py +++ b/metaflow/plugins/aws/batch/batch.py @@ -43,6 +43,8 @@ STDOUT_PATH = os.path.join(LOGS_DIR, STDOUT_FILE) STDERR_PATH = os.path.join(LOGS_DIR, STDERR_FILE) +SPOT_INTERRUPT_EXITCODE = 234 + class BatchException(MetaflowException): headline = "AWS Batch error" @@ -52,6 +54,10 @@ class BatchKilledException(MetaflowException): headline = "AWS Batch task killed" +class BatchSpotInstanceTerminated(MetaflowException): + headline = "Spot Instance has been terminated" + + class Batch(object): def __init__(self, metadata, environment): self.metadata = metadata @@ -482,6 +488,11 @@ def wait_for_launch(job, child_jobs): # to Amazon S3. if self.job.is_crashed: + + # Custom exception for spot instance terminations + if self.job.status_code == SPOT_INTERRUPT_EXITCODE: + raise BatchSpotInstanceTerminated() + msg = next( msg for msg in [ diff --git a/metaflow/plugins/aws/batch/batch_cli.py b/metaflow/plugins/aws/batch/batch_cli.py index a2d2199e2e6..96aff8d774f 100644 --- a/metaflow/plugins/aws/batch/batch_cli.py +++ b/metaflow/plugins/aws/batch/batch_cli.py @@ -6,12 +6,22 @@ from metaflow import util from metaflow import R -from metaflow.exception import CommandException, METAFLOW_EXIT_DISALLOW_RETRY +from metaflow.exception import ( + METAFLOW_EXIT_ALLOW_RETRY, + CommandException, + METAFLOW_EXIT_DISALLOW_RETRY, +) from metaflow.metadata_provider.util import sync_local_metadata_from_datastore from metaflow.metaflow_config import DATASTORE_LOCAL_DIR from metaflow.mflog import TASK_LOG_SOURCE from metaflow.unbounded_foreach import UBF_CONTROL, UBF_TASK -from .batch import Batch, BatchKilledException +from .batch import ( + Batch, + BatchException, + BatchKilledException, + BatchSpotInstanceTerminated, +) +from metaflow.plugins.retry_decorator import RetryEvents @click.group() @@ -283,6 +293,7 @@ def echo(msg, stream="stderr", batch_id=None, **kwargs): if split_vars: env.update(split_vars) + retry_conditions = retry_deco[0].attributes["only_on"] if retry_deco else [] if retry_count: ctx.obj.echo_always( "Sleeping %d minutes before the next AWS Batch retry" @@ -356,5 +367,15 @@ def _sync_metadata(): # don't retry killed tasks traceback.print_exc() sys.exit(METAFLOW_EXIT_DISALLOW_RETRY) + except BatchSpotInstanceTerminated: + traceback.print_exc() + if not retry_conditions or RetryEvents.PREEMPT.value in retry_conditions: + sys.exit(METAFLOW_EXIT_ALLOW_RETRY) + else: + sys.exit(METAFLOW_EXIT_DISALLOW_RETRY) + except BatchException: + if not retry_conditions or RetryEvents.STEP.value in retry_conditions: + raise + sys.exit(METAFLOW_EXIT_DISALLOW_RETRY) finally: _sync_metadata() diff --git a/metaflow/plugins/aws/batch/batch_client.py b/metaflow/plugins/aws/batch/batch_client.py index bf0f6a824e7..44c719e5439 100644 --- a/metaflow/plugins/aws/batch/batch_client.py +++ b/metaflow/plugins/aws/batch/batch_client.py @@ -4,6 +4,7 @@ import random import time import hashlib +from typing import Dict, List, Optional try: unicode @@ -630,8 +631,12 @@ def parameter(self, key, value): self.payload["parameters"][key] = str(value) return self - def attempts(self, attempts): + def attempts(self, attempts, evaluate_on_exit: Optional[List[Dict]] = None): self.payload["retryStrategy"]["attempts"] = attempts + if evaluate_on_exit is not None: + # required for specifying custom retry strategies + # ref: https://docs.aws.amazon.com/batch/latest/APIReference/API_EvaluateOnExit.html + self.payload["retryStrategy"]["evaluateOnExit"] = evaluate_on_exit return self diff --git a/metaflow/plugins/aws/batch/batch_decorator.py b/metaflow/plugins/aws/batch/batch_decorator.py index 5c61d3f8f03..1f81dbc2ee6 100644 --- a/metaflow/plugins/aws/batch/batch_decorator.py +++ b/metaflow/plugins/aws/batch/batch_decorator.py @@ -1,6 +1,8 @@ import os import platform +import signal import sys +import threading import time from metaflow import R, current @@ -24,7 +26,7 @@ get_docker_registry, get_ec2_instance_metadata, ) -from .batch import BatchException +from .batch import SPOT_INTERRUPT_EXITCODE, BatchException class BatchDecorator(StepDecorator): @@ -298,7 +300,32 @@ def task_pre_step( self._save_logs_sidecar = Sidecar("save_logs_periodically") self._save_logs_sidecar.start() + # Set up signal handling for spot termination + main_pid = os.getpid() + + def _termination_timer(): + time.sleep(30) + os.kill(main_pid, signal.SIGALRM) + + def _spot_term_signal_handler(*args, **kwargs): + if os.path.isfile(current.spot_termination_notice): + print( + "Spot instance termination detected. Starting a timer to end the Metaflow task." + ) + timer_thread = threading.Thread( + target=_termination_timer, daemon=True + ) + timer_thread.start() + + def _curtain_call(*args, **kwargs): + # custom exit code in case of Spot termination + sys.exit(SPOT_INTERRUPT_EXITCODE) + + signal.signal(signal.SIGUSR1, _spot_term_signal_handler) + signal.signal(signal.SIGALRM, _curtain_call) # Start spot termination monitor sidecar. + # TODO: A nicer way to pass the main process id to a Sidecar, in order to allow sidecars to send signals back to the main process. + os.environ["MF_MAIN_PID"] = str(os.getpid()) current._update_env( {"spot_termination_notice": "/tmp/spot_termination_notice"} ) diff --git a/metaflow/plugins/aws/step_functions/step_functions.py b/metaflow/plugins/aws/step_functions/step_functions.py index ccf22b4fd35..50c709dd643 100644 --- a/metaflow/plugins/aws/step_functions/step_functions.py +++ b/metaflow/plugins/aws/step_functions/step_functions.py @@ -18,10 +18,11 @@ SFN_S3_DISTRIBUTED_MAP_OUTPUT_PATH, ) from metaflow.parameters import deploy_time_eval +from metaflow.plugins.retry_decorator import RetryEvents from metaflow.user_configs.config_options import ConfigInput from metaflow.util import dict_to_cli_options, to_pascalcase -from ..batch.batch import Batch +from ..batch.batch import Batch, SPOT_INTERRUPT_EXITCODE from .event_bridge_client import EventBridgeClient from .step_functions_client import StepFunctionsClient @@ -824,9 +825,38 @@ def _batch(self, node): batch_deco = [deco for deco in node.decorators if deco.name == "batch"][0] resources = {} resources.update(batch_deco.attributes) + # Resolve retry strategy. user_code_retries, total_retries = self._get_retries(node) + # retry conditions mapping + retry_deco = next( + (deco for deco in node.decorators if deco.name == "retry"), None + ) + retry_conditions = ( + retry_deco.attributes["only_on"] if retry_deco is not None else [] + ) + + # Translate RetryEvents to expressions for SFN + event_to_expr = { + RetryEvents.STEP: {"action": "RETRY", "onExitCode": "1"}, + RetryEvents.PREEMPT: { + "action": "RETRY", + "onExitCode": str(SPOT_INTERRUPT_EXITCODE), + }, + } + retry_expr = None + # NOTE: AWS only allows 5 distinct EvaluateOnExit conditions, so any more than this will require combining them. + if retry_conditions: + retry_expr = [ + expr + for event, expr in event_to_expr.items() + if event.value in retry_conditions + ] + # we need to append a catch-all exit condition, as for no matches the default behavior with Batch is to retry the job. + # retry conditions are only evaluated for non-zero exit codes, so the wildcard is fine here. + retry_expr.append({"action": "EXIT", "onExitCode": "*"}) + task_spec = { "flow_name": attrs["metaflow.flow_name"], "step_name": attrs["metaflow.step_name"], @@ -875,7 +905,7 @@ def _batch(self, node): log_driver=resources["log_driver"], log_options=resources["log_options"], ) - .attempts(total_retries + 1) + .attempts(attempts=total_retries + 1, evaluate_on_exit=retry_expr) ) def _get_retries(self, node): diff --git a/metaflow/plugins/kubernetes/kubernetes.py b/metaflow/plugins/kubernetes/kubernetes.py index 6625047395a..06371c6f6f8 100644 --- a/metaflow/plugins/kubernetes/kubernetes.py +++ b/metaflow/plugins/kubernetes/kubernetes.py @@ -63,6 +63,8 @@ "{METAFLOW_PARALLEL_STEP_CLI_OPTIONS_TEMPLATE}" ) +SPOT_INTERRUPT_EXITCODE = 234 + class KubernetesException(MetaflowException): headline = "Kubernetes error" @@ -72,6 +74,10 @@ class KubernetesKilledException(MetaflowException): headline = "Kubernetes Batch job killed" +class KubernetesSpotInstanceTerminated(MetaflowException): + headline = "Kubernetes node spot instance has been terminated" + + class Kubernetes(object): def __init__( self, @@ -764,6 +770,9 @@ def _has_updates(): ) if int(exit_code) == 134: raise KubernetesException("%s (exit code %s)" % (msg, exit_code)) + if int(exit_code) == SPOT_INTERRUPT_EXITCODE: + # NOTE. K8S exit codes are mod 256 + raise KubernetesSpotInstanceTerminated() else: msg = "%s (exit code %s)" % (msg, exit_code) raise KubernetesException( diff --git a/metaflow/plugins/kubernetes/kubernetes_cli.py b/metaflow/plugins/kubernetes/kubernetes_cli.py index 4645e25c34f..ae87224d0ff 100644 --- a/metaflow/plugins/kubernetes/kubernetes_cli.py +++ b/metaflow/plugins/kubernetes/kubernetes_cli.py @@ -11,7 +11,11 @@ import metaflow.tracing as tracing from metaflow import JSONTypeClass, util from metaflow._vendor import click -from metaflow.exception import METAFLOW_EXIT_DISALLOW_RETRY, MetaflowException +from metaflow.exception import ( + METAFLOW_EXIT_DISALLOW_RETRY, + METAFLOW_EXIT_ALLOW_RETRY, + MetaflowException, +) from metaflow.metadata_provider.util import sync_local_metadata_from_datastore from metaflow.metaflow_config import DATASTORE_LOCAL_DIR from metaflow.mflog import TASK_LOG_SOURCE @@ -21,7 +25,9 @@ Kubernetes, KubernetesException, KubernetesKilledException, + KubernetesSpotInstanceTerminated, ) +from metaflow.plugins.retry_decorator import RetryEvents @click.group() @@ -221,6 +227,7 @@ def echo(msg, stream="stderr", job_id=None, **kwargs): minutes_between_retries = int( retry_deco[0].attributes.get("minutes_between_retries", 2) ) + retry_conditions = retry_deco[0].attributes["only_on"] if retry_deco else [] if retry_count: ctx.obj.echo_always( "Sleeping %d minutes before the next retry" % minutes_between_retries @@ -330,6 +337,17 @@ def _sync_metadata(): # don't retry killed tasks traceback.print_exc() sys.exit(METAFLOW_EXIT_DISALLOW_RETRY) + except KubernetesSpotInstanceTerminated: + traceback.print_exc() + if not retry_conditions or RetryEvents.PREEMPT.value in retry_conditions: + sys.exit(METAFLOW_EXIT_ALLOW_RETRY) + else: + sys.exit(METAFLOW_EXIT_DISALLOW_RETRY) + except KubernetesException: + if not retry_conditions or RetryEvents.STEP.value in retry_conditions: + raise + sys.exit(METAFLOW_EXIT_DISALLOW_RETRY) + finally: _sync_metadata() diff --git a/metaflow/plugins/kubernetes/kubernetes_decorator.py b/metaflow/plugins/kubernetes/kubernetes_decorator.py index 697d5c055b1..fe209465c3f 100644 --- a/metaflow/plugins/kubernetes/kubernetes_decorator.py +++ b/metaflow/plugins/kubernetes/kubernetes_decorator.py @@ -1,7 +1,9 @@ import json import os import platform +import signal import sys +import threading import time from metaflow import current @@ -37,7 +39,7 @@ from metaflow.unbounded_foreach import UBF_CONTROL from ..aws.aws_utils import get_docker_registry, get_ec2_instance_metadata -from .kubernetes import KubernetesException +from .kubernetes import KubernetesException, SPOT_INTERRUPT_EXITCODE from .kube_utils import validate_kube_labels, parse_kube_keyvalue_list try: @@ -548,7 +550,32 @@ def task_pre_step( self._save_logs_sidecar = Sidecar("save_logs_periodically") self._save_logs_sidecar.start() + # Set up signal handling for spot termination + main_pid = os.getpid() + + def _termination_timer(): + time.sleep(30) + os.kill(main_pid, signal.SIGALRM) + + def _spot_term_signal_handler(*args, **kwargs): + if os.path.isfile(current.spot_termination_notice): + print( + "Spot instance termination detected. Starting a timer to end the Metaflow task." + ) + timer_thread = threading.Thread( + target=_termination_timer, daemon=True + ) + timer_thread.start() + + def _curtain_call(*args, **kwargs): + # custom exit code in case of Spot termination + sys.exit(SPOT_INTERRUPT_EXITCODE) + + signal.signal(signal.SIGUSR1, _spot_term_signal_handler) + signal.signal(signal.SIGALRM, _curtain_call) # Start spot termination monitor sidecar. + # TODO: A nicer way to pass the main process id to a Sidecar, in order to allow sidecars to send signals back to the main process. + os.environ["MF_MAIN_PID"] = str(os.getpid()) current._update_env( {"spot_termination_notice": "/tmp/spot_termination_notice"} ) diff --git a/metaflow/plugins/kubernetes/spot_monitor_sidecar.py b/metaflow/plugins/kubernetes/spot_monitor_sidecar.py index 59f821f885e..115b815a6e4 100644 --- a/metaflow/plugins/kubernetes/spot_monitor_sidecar.py +++ b/metaflow/plugins/kubernetes/spot_monitor_sidecar.py @@ -21,6 +21,9 @@ def __init__(self): self._token = None self._token_expiry = 0 + # Due to nesting, os.getppid is not reliable for fetching the main task pid + self.main_pid = int(os.getenv("MF_MAIN_PID", os.getppid())) + if self._is_aws_spot_instance(): self._process = Process(target=self._monitor_loop) self._process.start() @@ -71,7 +74,7 @@ def _monitor_loop(self): if response.status_code == 200: termination_time = response.text self._emit_termination_metadata(termination_time) - os.kill(os.getppid(), signal.SIGTERM) + os.kill(self.main_pid, signal.SIGUSR1) break except (requests.exceptions.RequestException, requests.exceptions.Timeout): pass diff --git a/metaflow/plugins/retry_decorator.py b/metaflow/plugins/retry_decorator.py index 00d40b3ac0c..c08d60b6379 100644 --- a/metaflow/plugins/retry_decorator.py +++ b/metaflow/plugins/retry_decorator.py @@ -1,8 +1,15 @@ +from enum import Enum + from metaflow.decorators import StepDecorator from metaflow.exception import MetaflowException from metaflow.metaflow_config import MAX_ATTEMPTS +class RetryEvents(Enum): + STEP = "step" + PREEMPT = "instance-preemption" + + class RetryDecorator(StepDecorator): """ Specifies the number of times the task corresponding @@ -22,10 +29,14 @@ class RetryDecorator(StepDecorator): Number of times to retry this task. minutes_between_retries : int, default 2 Number of minutes between retries. + only_on : List[str], default None + List of failure events to retry on. Not specifying values will retry on all known events. + Accepted values are + 'step', 'instance-preemption' """ name = "retry" - defaults = {"times": "3", "minutes_between_retries": "2"} + defaults = {"times": "3", "minutes_between_retries": "2", "only_on": None} def step_init(self, flow, graph, step, decos, environment, flow_datastore, logger): # The total number of attempts must not exceed MAX_ATTEMPTS. @@ -36,5 +47,25 @@ def step_init(self, flow, graph, step, decos, environment, flow_datastore, logge "@retry(times=%d)." % (MAX_ATTEMPTS - 2) ) + if self.attributes["only_on"] is not None: + if not isinstance(self.attributes["only_on"], list): + raise MetaflowException("'only_on=' must be a list of values") + + def _known_event(event: str): + try: + RetryEvents(event) + return True + except ValueError: + return False + + unsupported_events = [ + event for event in self.attributes["only_on"] if not _known_event(event) + ] + if unsupported_events: + raise MetaflowException( + "The event(s) %s are not supported for only_on=" + % ", ".join("*%s*" % event for event in unsupported_events) + ) + def step_task_retry_count(self): return int(self.attributes["times"]), 0