From e335493027f84ad5928d8d1a109491ab215fe0a5 Mon Sep 17 00:00:00 2001 From: GefMar Date: Sat, 19 Apr 2025 18:02:27 +0200 Subject: [PATCH 1/4] add: SmartRetryMiddleware rename: middleware file --- taskiq/__init__.py | 9 +- taskiq/middlewares/__init__.py | 11 ++ ...ddleware.py => simple_retry_middleware.py} | 0 taskiq/middlewares/smart_retry_middleware.py | 177 ++++++++++++++++++ tests/middlewares/test_simple_retry.py | 2 +- 5 files changed, 196 insertions(+), 3 deletions(-) rename taskiq/middlewares/{retry_middleware.py => simple_retry_middleware.py} (100%) create mode 100644 taskiq/middlewares/smart_retry_middleware.py diff --git a/taskiq/__init__.py b/taskiq/__init__.py index 87a5c4f1..b2695e4b 100644 --- a/taskiq/__init__.py +++ b/taskiq/__init__.py @@ -25,8 +25,11 @@ ) from taskiq.funcs import gather from taskiq.message import BrokerMessage, TaskiqMessage -from taskiq.middlewares.prometheus_middleware import PrometheusMiddleware -from taskiq.middlewares.retry_middleware import SimpleRetryMiddleware +from taskiq.middlewares import ( + PrometheusMiddleware, + SimpleRetryMiddleware, + SmartRetryMiddleware, +) from taskiq.result import TaskiqResult from taskiq.scheduler.scheduled_task import ScheduledTask from taskiq.scheduler.scheduler import TaskiqScheduler @@ -34,6 +37,7 @@ from taskiq.task import AsyncTaskiqTask __version__ = version("taskiq") + __all__ = [ "AckableMessage", "AsyncBroker", @@ -52,6 +56,7 @@ "SecurityError", "SendTaskError", "SimpleRetryMiddleware", + "SmartRetryMiddleware", "TaskiqDepends", "TaskiqError", "TaskiqEvents", diff --git a/taskiq/middlewares/__init__.py b/taskiq/middlewares/__init__.py index c3be92a6..f236a3eb 100644 --- a/taskiq/middlewares/__init__.py +++ b/taskiq/middlewares/__init__.py @@ -1 +1,12 @@ """Taskiq middlewares.""" + + +from .prometheus_middleware import PrometheusMiddleware +from .simple_retry_middleware import SimpleRetryMiddleware +from .smart_retry_middleware import SmartRetryMiddleware + +__all__ = ( + "PrometheusMiddleware", + "SimpleRetryMiddleware", + "SmartRetryMiddleware", +) diff --git a/taskiq/middlewares/retry_middleware.py b/taskiq/middlewares/simple_retry_middleware.py similarity index 100% rename from taskiq/middlewares/retry_middleware.py rename to taskiq/middlewares/simple_retry_middleware.py diff --git a/taskiq/middlewares/smart_retry_middleware.py b/taskiq/middlewares/smart_retry_middleware.py new file mode 100644 index 00000000..1dc4e4e8 --- /dev/null +++ b/taskiq/middlewares/smart_retry_middleware.py @@ -0,0 +1,177 @@ +from __future__ import annotations + +import datetime +import random +from logging import getLogger +from typing import Any + +from taskiq import ScheduleSource +from taskiq.abc.middleware import TaskiqMiddleware +from taskiq.exceptions import NoResultError +from taskiq.kicker import AsyncKicker +from taskiq.message import TaskiqMessage +from taskiq.result import TaskiqResult + +__all__ = ("SmartRetryMiddleware",) + +_logger = getLogger("taskiq.smart_retry_middleware") + + +class SmartRetryMiddleware(TaskiqMiddleware): + """Middleware to retry tasks delays. + + This middleware retries failed tasks with support for: + - max retries + - delay + - jitter + - exponential backoff + """ + + def __init__( + self, + default_retry_count: int = 3, + default_retry_label: bool = False, + no_result_on_retry: bool = True, + default_delay: float = 5, + use_jitter: bool = False, + use_delay_exponent: bool = False, + max_delay_exponent: float = 60, + schedule_source: ScheduleSource | None = None, + ) -> None: + """ + Initialize retry middleware. + + :param default_retry_count: Default max retries if not specified. + :param default_retry_label: Whether to retry tasks by default. + :param no_result_on_retry: Replace result with NoResultError on retry. + :param default_delay: Delay in seconds before retrying. + :param use_jitter: Add random jitter to retry delay. + :param use_delay_exponent: Apply exponential backoff to delay. + :param max_delay_exponent: Maximum allowed delay when using backoff. + :param schedule_source: Schedule source to use for scheduling. + If None, the default broker will be used. + """ + super().__init__() + self.default_retry_count = default_retry_count + self.default_retry_label = default_retry_label + self.no_result_on_retry = no_result_on_retry + self.default_delay = default_delay + self.use_jitter = use_jitter + self.use_delay_exponent = use_delay_exponent + self.max_delay_exponent = max_delay_exponent + self.schedule_source = schedule_source + + def is_retry_on_error(self, message: TaskiqMessage) -> bool: + """ + Check if retry is enabled for this task. + + Looks for `retry_on_error` label, falls back to default. + + :param message: Original task message. + :return: True if should retry on error. + """ + retry_on_error = message.labels.get("retry_on_error") + if isinstance(retry_on_error, str): + retry_on_error = retry_on_error.lower() == "true" + if retry_on_error is None: + retry_on_error = self.default_retry_label + return retry_on_error + + def make_delay(self, message: TaskiqMessage, retries: int) -> float: + """ + Calculate retry delay. + + Includes jitter and exponential backoff if enabled. + + :param message: Task message. + :param retries: Current retry count. + :return: Delay in seconds. + """ + delay = float(message.labels.get("delay", self.default_delay)) + if self.use_delay_exponent: + delay = min(delay * retries, self.max_delay_exponent) + + if self.use_jitter: + delay += random.random() # noqa: S311 + + return delay + + async def on_send( + self, + kicker: AsyncKicker[Any, Any], + message: TaskiqMessage, + delay: float, + ) -> None: + """Execute the task with a delay.""" + if isinstance(self.schedule_source, ScheduleSource): + target_time = datetime.datetime.now(datetime.UTC) + datetime.timedelta( + seconds=delay, + ) + await kicker.schedule_by_time( + self.schedule_source, + target_time, + *message.args, + **message.kwargs, + ) + else: + await kicker.with_labels(delay=delay).kiq(*message.args, **message.kwargs) + + async def on_error( + self, + message: TaskiqMessage, + result: TaskiqResult[Any], + exception: BaseException, + ) -> None: + """ + Retry on error. + + If an error is raised during task execution, + this middleware schedules the task to be retried + after a calculated delay. + + :param message: Message that caused the error. + :param result: Execution result. + :param exception: Caught exception. + """ + if isinstance(exception, NoResultError): + return + + retry_on_error = self.is_retry_on_error(message) + + if not retry_on_error: + return + + retries = int(message.labels.get("_retries", 0)) + 1 + max_retries = int(message.labels.get("max_retries", self.default_retry_count)) + + if retries < max_retries: + delay = self.make_delay(message, retries) + + _logger.info( + "Task %s failed. Retrying %d/%d in %.2f seconds.", + message.task_name, + retries, + max_retries, + delay, + ) + + kicker: AsyncKicker[Any, Any] = ( + AsyncKicker( + task_name=message.task_name, + broker=self.broker, + labels=message.labels, + ) + .with_task_id(message.task_id) + .with_labels(_retries=retries) + ) + + await self.on_send(kicker, message, delay) + + if self.no_result_on_retry: + result.error = NoResultError() + + else: + _logger.warning( + "Task '%s' invocation failed. Maximum retries count is reached.", + message.task_name, + ) diff --git a/tests/middlewares/test_simple_retry.py b/tests/middlewares/test_simple_retry.py index 2192f221..7a0c12d3 100644 --- a/tests/middlewares/test_simple_retry.py +++ b/tests/middlewares/test_simple_retry.py @@ -5,7 +5,7 @@ from taskiq.formatters.json_formatter import JSONFormatter from taskiq.message import TaskiqMessage -from taskiq.middlewares.retry_middleware import SimpleRetryMiddleware +from taskiq.middlewares.simple_retry_middleware import SimpleRetryMiddleware from taskiq.result import TaskiqResult From e6fe9e059f77b35f9ba283ead27d1d1171656d0a Mon Sep 17 00:00:00 2001 From: GefMar Date: Sat, 19 Apr 2025 22:33:47 +0200 Subject: [PATCH 2/4] fix: check types fix: condition --- taskiq/middlewares/smart_retry_middleware.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/taskiq/middlewares/smart_retry_middleware.py b/taskiq/middlewares/smart_retry_middleware.py index 1dc4e4e8..a81d2dc6 100644 --- a/taskiq/middlewares/smart_retry_middleware.py +++ b/taskiq/middlewares/smart_retry_middleware.py @@ -3,6 +3,7 @@ import datetime import random from logging import getLogger +from types import NoneType from typing import Any from taskiq import ScheduleSource @@ -61,6 +62,11 @@ def __init__( self.max_delay_exponent = max_delay_exponent self.schedule_source = schedule_source + if not isinstance(schedule_source, (ScheduleSource, NoneType)): + raise TypeError( + "schedule_source must be an instance of ScheduleSource or None", + ) + def is_retry_on_error(self, message: TaskiqMessage) -> bool: """ Check if retry is enabled for this task. @@ -103,7 +109,9 @@ async def on_send( delay: float, ) -> None: """Execute the task with a delay.""" - if isinstance(self.schedule_source, ScheduleSource): + if self.schedule_source is None: + await kicker.with_labels(delay=delay).kiq(*message.args, **message.kwargs) + else: target_time = datetime.datetime.now(datetime.UTC) + datetime.timedelta( seconds=delay, ) @@ -113,8 +121,6 @@ async def on_send( *message.args, **message.kwargs, ) - else: - await kicker.with_labels(delay=delay).kiq(*message.args, **message.kwargs) async def on_error( self, From 6b08205f3e970e6f6dae8dde9716e296834649ea Mon Sep 17 00:00:00 2001 From: GefMar Date: Sat, 19 Apr 2025 22:45:33 +0200 Subject: [PATCH 3/4] fix: check types --- taskiq/middlewares/smart_retry_middleware.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/taskiq/middlewares/smart_retry_middleware.py b/taskiq/middlewares/smart_retry_middleware.py index a81d2dc6..a24eea39 100644 --- a/taskiq/middlewares/smart_retry_middleware.py +++ b/taskiq/middlewares/smart_retry_middleware.py @@ -3,7 +3,6 @@ import datetime import random from logging import getLogger -from types import NoneType from typing import Any from taskiq import ScheduleSource @@ -62,7 +61,7 @@ def __init__( self.max_delay_exponent = max_delay_exponent self.schedule_source = schedule_source - if not isinstance(schedule_source, (ScheduleSource, NoneType)): + if not isinstance(schedule_source, (ScheduleSource, type(None))): raise TypeError( "schedule_source must be an instance of ScheduleSource or None", ) From 01b3f8f607f6004e3645224b92ee0398df9a91d9 Mon Sep 17 00:00:00 2001 From: GefMar Date: Sat, 19 Apr 2025 22:47:11 +0200 Subject: [PATCH 4/4] fix: type annotation --- taskiq/middlewares/smart_retry_middleware.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/taskiq/middlewares/smart_retry_middleware.py b/taskiq/middlewares/smart_retry_middleware.py index a24eea39..d58c7611 100644 --- a/taskiq/middlewares/smart_retry_middleware.py +++ b/taskiq/middlewares/smart_retry_middleware.py @@ -1,9 +1,7 @@ -from __future__ import annotations - import datetime import random from logging import getLogger -from typing import Any +from typing import Any, Optional from taskiq import ScheduleSource from taskiq.abc.middleware import TaskiqMiddleware @@ -36,7 +34,7 @@ def __init__( use_jitter: bool = False, use_delay_exponent: bool = False, max_delay_exponent: float = 60, - schedule_source: ScheduleSource | None = None, + schedule_source: Optional[ScheduleSource] = None, ) -> None: """ Initialize retry middleware.