Skip to content

feature: retry decorator improvements #2390

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions metaflow/plugins/argo/argo_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from metaflow.plugins.kubernetes.kube_utils import qos_requests_and_limits

from metaflow.plugins.kubernetes.kubernetes_jobsets import KubernetesArgoJobSet
from metaflow.plugins.retry_decorator import PLATFORM_EVICTED_EXITCODE, RetryEvents
from metaflow.unbounded_foreach import UBF_CONTROL, UBF_TASK
from metaflow.user_configs.config_options import ConfigInput
from metaflow.util import (
Expand Down Expand Up @@ -1520,11 +1521,13 @@ def _container_templates(self):
max_user_code_retries = 0
max_error_retries = 0
minutes_between_retries = "2"
retry_conditions = []
for decorator in node.decorators:
if decorator.name == "retry":
minutes_between_retries = decorator.attributes.get(
"minutes_between_retries", minutes_between_retries
)
retry_conditions = decorator.attributes["only_on"]
user_code_retries, error_retries = decorator.step_task_retry_count()
max_user_code_retries = max(max_user_code_retries, user_code_retries)
max_error_retries = max(max_error_retries, error_retries)
Expand All @@ -1546,6 +1549,21 @@ def _container_templates(self):

minutes_between_retries = int(minutes_between_retries)

# Translate RetryEvents to expressions for Argo
event_to_expr = {
RetryEvents.STEP: "asInt(lastRetry.exitCode) == 1",
RetryEvents.PREEMPT: "asInt(lastRetry.exitCode) == %s"
% PLATFORM_EVICTED_EXITCODE,
}
retry_expr = None
if retry_conditions:
retry_expressions = [
expr
for event, expr in event_to_expr.items()
if event.value in retry_conditions
]
retry_expr = "||".join(retry_expressions)

# Configure log capture.
mflog_expr = export_mflog_env_vars(
datastore_type=self.flow_datastore.TYPE,
Expand Down Expand Up @@ -2137,6 +2155,7 @@ def _container_templates(self):
.retry_strategy(
times=total_retries,
minutes_between_retries=minutes_between_retries,
expression=retry_expr,
)
)
else:
Expand All @@ -2156,6 +2175,7 @@ def _container_templates(self):
.retry_strategy(
times=total_retries,
minutes_between_retries=minutes_between_retries,
expression=retry_expr,
)
.metadata(
ObjectMeta()
Expand Down Expand Up @@ -3661,13 +3681,17 @@ def service_account_name(self, service_account_name):
self.payload["serviceAccountName"] = service_account_name
return self

def retry_strategy(self, times, minutes_between_retries):
def retry_strategy(self, times, minutes_between_retries, expression=None):
if times > 0:
self.payload["retryStrategy"] = {
"retryPolicy": "Always",
"limit": times,
"backoff": {"duration": "%sm" % minutes_between_retries},
}
if expression is None:
self.payload["retryStrategy"]["retryPolicy"] = "Always"
else:
self.payload["retryStrategy"]["expression"] = expression

return self

def empty_dir_volume(self, name, medium=None, size_limit=None):
Expand Down
9 changes: 9 additions & 0 deletions metaflow/plugins/aws/batch/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ class BatchKilledException(MetaflowException):
headline = "AWS Batch task killed"


class BatchSpotInstanceTerminated(MetaflowException):
headline = "Spot Instance has been terminated"


class Batch(object):
def __init__(self, metadata, environment):
self.metadata = metadata
Expand Down Expand Up @@ -482,6 +486,11 @@ def wait_for_launch(job, child_jobs):
# to Amazon S3.

if self.job.is_crashed:

# Custom exception for spot instance terminations
if self.job.status_code == 234:
raise BatchSpotInstanceTerminated()

msg = next(
msg
for msg in [
Expand Down
25 changes: 23 additions & 2 deletions metaflow/plugins/aws/batch/batch_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,22 @@

from metaflow import util
from metaflow import R
from metaflow.exception import CommandException, METAFLOW_EXIT_DISALLOW_RETRY
from metaflow.exception import (
METAFLOW_EXIT_ALLOW_RETRY,
CommandException,
METAFLOW_EXIT_DISALLOW_RETRY,
)
from metaflow.metadata_provider.util import sync_local_metadata_from_datastore
from metaflow.metaflow_config import DATASTORE_LOCAL_DIR
from metaflow.mflog import TASK_LOG_SOURCE
from metaflow.unbounded_foreach import UBF_CONTROL, UBF_TASK
from .batch import Batch, BatchKilledException
from .batch import (
Batch,
BatchException,
BatchKilledException,
BatchSpotInstanceTerminated,
)
from metaflow.plugins.retry_decorator import RetryEvents


@click.group()
Expand Down Expand Up @@ -283,6 +293,7 @@ def echo(msg, stream="stderr", batch_id=None, **kwargs):
if split_vars:
env.update(split_vars)

retry_conditions = retry_deco[0].attributes["only_on"] if retry_deco else []
if retry_count:
ctx.obj.echo_always(
"Sleeping %d minutes before the next AWS Batch retry"
Expand Down Expand Up @@ -356,5 +367,15 @@ def _sync_metadata():
# don't retry killed tasks
traceback.print_exc()
sys.exit(METAFLOW_EXIT_DISALLOW_RETRY)
except BatchSpotInstanceTerminated:
traceback.print_exc()
if not retry_conditions or RetryEvents.PREEMPT.value in retry_conditions:
sys.exit(METAFLOW_EXIT_ALLOW_RETRY)
else:
sys.exit(METAFLOW_EXIT_DISALLOW_RETRY)
except BatchException:
if not retry_conditions or RetryEvents.STEP.value in retry_conditions:
raise
sys.exit(METAFLOW_EXIT_DISALLOW_RETRY)
finally:
_sync_metadata()
7 changes: 6 additions & 1 deletion metaflow/plugins/aws/batch/batch_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import random
import time
import hashlib
from typing import Dict, List, Optional

try:
unicode
Expand Down Expand Up @@ -630,8 +631,12 @@ def parameter(self, key, value):
self.payload["parameters"][key] = str(value)
return self

def attempts(self, attempts):
def attempts(self, attempts, evaluate_on_exit: Optional[List[Dict]] = None):
self.payload["retryStrategy"]["attempts"] = attempts
if evaluate_on_exit is not None:
# required for specifying custom retry strategies
# ref: https://docs.aws.amazon.com/batch/latest/APIReference/API_EvaluateOnExit.html
self.payload["retryStrategy"]["evaluateOnExit"] = evaluate_on_exit
return self


Expand Down
2 changes: 2 additions & 0 deletions metaflow/plugins/aws/batch/batch_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,8 @@ def task_pre_step(
self._save_logs_sidecar.start()

# Start spot termination monitor sidecar.
# TODO: A nicer way to pass the main process id to a Sidecar, in order to allow sidecars to send signals back to the main process.
os.environ["MF_MAIN_PID"] = str(os.getpid())
current._update_env(
{"spot_termination_notice": "/tmp/spot_termination_notice"}
)
Expand Down
32 changes: 31 additions & 1 deletion metaflow/plugins/aws/step_functions/step_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
SFN_S3_DISTRIBUTED_MAP_OUTPUT_PATH,
)
from metaflow.parameters import deploy_time_eval
from metaflow.plugins.retry_decorator import PLATFORM_EVICTED_EXITCODE, RetryEvents
from metaflow.user_configs.config_options import ConfigInput
from metaflow.util import dict_to_cli_options, to_pascalcase

Expand Down Expand Up @@ -824,9 +825,38 @@ def _batch(self, node):
batch_deco = [deco for deco in node.decorators if deco.name == "batch"][0]
resources = {}
resources.update(batch_deco.attributes)

# Resolve retry strategy.
user_code_retries, total_retries = self._get_retries(node)

# retry conditions mapping
retry_deco = next(
(deco for deco in node.decorators if deco.name == "retry"), None
)
retry_conditions = (
retry_deco.attributes["only_on"] if retry_deco is not None else []
)

# Translate RetryEvents to expressions for SFN
event_to_expr = {
RetryEvents.STEP: {"action": "RETRY", "onExitCode": "1"},
RetryEvents.PREEMPT: {
"action": "RETRY",
"onExitCode": str(PLATFORM_EVICTED_EXITCODE),
},
}
retry_expr = None
# NOTE: AWS only allows 5 distinct EvaluateOnExit conditions, so any more than this will require combining them.
if retry_conditions:
retry_expr = [
expr
for event, expr in event_to_expr.items()
if event.value in retry_conditions
]
# we need to append a catch-all exit condition, as for no matches the default behavior with Batch is to retry the job.
# retry conditions are only evaluated for non-zero exit codes, so the wildcard is fine here.
retry_expr.append({"action": "EXIT", "onExitCode": "*"})

task_spec = {
"flow_name": attrs["metaflow.flow_name"],
"step_name": attrs["metaflow.step_name"],
Expand Down Expand Up @@ -875,7 +905,7 @@ def _batch(self, node):
log_driver=resources["log_driver"],
log_options=resources["log_options"],
)
.attempts(total_retries + 1)
.attempts(attempts=total_retries + 1, evaluate_on_exit=retry_expr)
)

def _get_retries(self, node):
Expand Down
7 changes: 7 additions & 0 deletions metaflow/plugins/kubernetes/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ class KubernetesKilledException(MetaflowException):
headline = "Kubernetes Batch job killed"


class KubernetesSpotInstanceTerminated(MetaflowException):
headline = "Kubernetes node spot instance has been terminated"


class Kubernetes(object):
def __init__(
self,
Expand Down Expand Up @@ -764,6 +768,9 @@ def _has_updates():
)
if int(exit_code) == 134:
raise KubernetesException("%s (exit code %s)" % (msg, exit_code))
if int(exit_code) == 234:
# NOTE. K8S exit codes are mod 256
raise KubernetesSpotInstanceTerminated()
else:
msg = "%s (exit code %s)" % (msg, exit_code)
raise KubernetesException(
Expand Down
20 changes: 19 additions & 1 deletion metaflow/plugins/kubernetes/kubernetes_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
import metaflow.tracing as tracing
from metaflow import JSONTypeClass, util
from metaflow._vendor import click
from metaflow.exception import METAFLOW_EXIT_DISALLOW_RETRY, MetaflowException
from metaflow.exception import (
METAFLOW_EXIT_DISALLOW_RETRY,
METAFLOW_EXIT_ALLOW_RETRY,
MetaflowException,
)
from metaflow.metadata_provider.util import sync_local_metadata_from_datastore
from metaflow.metaflow_config import DATASTORE_LOCAL_DIR
from metaflow.mflog import TASK_LOG_SOURCE
Expand All @@ -21,7 +25,9 @@
Kubernetes,
KubernetesException,
KubernetesKilledException,
KubernetesSpotInstanceTerminated,
)
from metaflow.plugins.retry_decorator import RetryEvents


@click.group()
Expand Down Expand Up @@ -221,6 +227,7 @@ def echo(msg, stream="stderr", job_id=None, **kwargs):
minutes_between_retries = int(
retry_deco[0].attributes.get("minutes_between_retries", 2)
)
retry_conditions = retry_deco[0].attributes["only_on"] if retry_deco else []
if retry_count:
ctx.obj.echo_always(
"Sleeping %d minutes before the next retry" % minutes_between_retries
Expand Down Expand Up @@ -330,6 +337,17 @@ def _sync_metadata():
# don't retry killed tasks
traceback.print_exc()
sys.exit(METAFLOW_EXIT_DISALLOW_RETRY)
except KubernetesSpotInstanceTerminated:
traceback.print_exc()
if not retry_conditions or RetryEvents.PREEMPT.value in retry_conditions:
sys.exit(METAFLOW_EXIT_ALLOW_RETRY)
else:
sys.exit(METAFLOW_EXIT_DISALLOW_RETRY)
except KubernetesException:
if not retry_conditions or RetryEvents.STEP.value in retry_conditions:
raise
sys.exit(METAFLOW_EXIT_DISALLOW_RETRY)

finally:
_sync_metadata()

Expand Down
2 changes: 2 additions & 0 deletions metaflow/plugins/kubernetes/kubernetes_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,8 @@ def task_pre_step(
self._save_logs_sidecar.start()

# Start spot termination monitor sidecar.
# TODO: A nicer way to pass the main process id to a Sidecar, in order to allow sidecars to send signals back to the main process.
os.environ["MF_MAIN_PID"] = str(os.getpid())
current._update_env(
{"spot_termination_notice": "/tmp/spot_termination_notice"}
)
Expand Down
5 changes: 4 additions & 1 deletion metaflow/plugins/kubernetes/spot_monitor_sidecar.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ def __init__(self):
self._token = None
self._token_expiry = 0

# Due to nesting, os.getppid is not reliable for fetching the main task pid
self.main_pid = int(os.getenv("MF_MAIN_PID", os.getppid()))

if self._is_aws_spot_instance():
self._process = Process(target=self._monitor_loop)
self._process.start()
Expand Down Expand Up @@ -71,7 +74,7 @@ def _monitor_loop(self):
if response.status_code == 200:
termination_time = response.text
self._emit_termination_metadata(termination_time)
os.kill(os.getppid(), signal.SIGTERM)
os.kill(self.main_pid, signal.SIGUSR1)
break
except (requests.exceptions.RequestException, requests.exceptions.Timeout):
pass
Expand Down
Loading
Loading