Skip to content

Commit 27b63e0

Browse files
authored
SmartRetryMiddleware (#451)
add: SmartRetryMiddleware rename: middleware file
1 parent fbcb1b8 commit 27b63e0

File tree

5 files changed

+199
-3
lines changed

5 files changed

+199
-3
lines changed

taskiq/__init__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,19 @@
2525
)
2626
from taskiq.funcs import gather
2727
from taskiq.message import BrokerMessage, TaskiqMessage
28-
from taskiq.middlewares.prometheus_middleware import PrometheusMiddleware
29-
from taskiq.middlewares.retry_middleware import SimpleRetryMiddleware
28+
from taskiq.middlewares import (
29+
PrometheusMiddleware,
30+
SimpleRetryMiddleware,
31+
SmartRetryMiddleware,
32+
)
3033
from taskiq.result import TaskiqResult
3134
from taskiq.scheduler.scheduled_task import ScheduledTask
3235
from taskiq.scheduler.scheduler import TaskiqScheduler
3336
from taskiq.state import TaskiqState
3437
from taskiq.task import AsyncTaskiqTask
3538

3639
__version__ = version("taskiq")
40+
3741
__all__ = [
3842
"AckableMessage",
3943
"AsyncBroker",
@@ -52,6 +56,7 @@
5256
"SecurityError",
5357
"SendTaskError",
5458
"SimpleRetryMiddleware",
59+
"SmartRetryMiddleware",
5560
"TaskiqDepends",
5661
"TaskiqError",
5762
"TaskiqEvents",

taskiq/middlewares/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,12 @@
11
"""Taskiq middlewares."""
2+
3+
4+
from .prometheus_middleware import PrometheusMiddleware
5+
from .simple_retry_middleware import SimpleRetryMiddleware
6+
from .smart_retry_middleware import SmartRetryMiddleware
7+
8+
__all__ = (
9+
"PrometheusMiddleware",
10+
"SimpleRetryMiddleware",
11+
"SmartRetryMiddleware",
12+
)
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
import datetime
2+
import random
3+
from logging import getLogger
4+
from typing import Any, Optional
5+
6+
from taskiq import ScheduleSource
7+
from taskiq.abc.middleware import TaskiqMiddleware
8+
from taskiq.exceptions import NoResultError
9+
from taskiq.kicker import AsyncKicker
10+
from taskiq.message import TaskiqMessage
11+
from taskiq.result import TaskiqResult
12+
13+
__all__ = ("SmartRetryMiddleware",)
14+
15+
_logger = getLogger("taskiq.smart_retry_middleware")
16+
17+
18+
class SmartRetryMiddleware(TaskiqMiddleware):
19+
"""Middleware to retry tasks delays.
20+
21+
This middleware retries failed tasks with support for:
22+
- max retries
23+
- delay
24+
- jitter
25+
- exponential backoff
26+
"""
27+
28+
def __init__(
29+
self,
30+
default_retry_count: int = 3,
31+
default_retry_label: bool = False,
32+
no_result_on_retry: bool = True,
33+
default_delay: float = 5,
34+
use_jitter: bool = False,
35+
use_delay_exponent: bool = False,
36+
max_delay_exponent: float = 60,
37+
schedule_source: Optional[ScheduleSource] = None,
38+
) -> None:
39+
"""
40+
Initialize retry middleware.
41+
42+
:param default_retry_count: Default max retries if not specified.
43+
:param default_retry_label: Whether to retry tasks by default.
44+
:param no_result_on_retry: Replace result with NoResultError on retry.
45+
:param default_delay: Delay in seconds before retrying.
46+
:param use_jitter: Add random jitter to retry delay.
47+
:param use_delay_exponent: Apply exponential backoff to delay.
48+
:param max_delay_exponent: Maximum allowed delay when using backoff.
49+
:param schedule_source: Schedule source to use for scheduling.
50+
If None, the default broker will be used.
51+
"""
52+
super().__init__()
53+
self.default_retry_count = default_retry_count
54+
self.default_retry_label = default_retry_label
55+
self.no_result_on_retry = no_result_on_retry
56+
self.default_delay = default_delay
57+
self.use_jitter = use_jitter
58+
self.use_delay_exponent = use_delay_exponent
59+
self.max_delay_exponent = max_delay_exponent
60+
self.schedule_source = schedule_source
61+
62+
if not isinstance(schedule_source, (ScheduleSource, type(None))):
63+
raise TypeError(
64+
"schedule_source must be an instance of ScheduleSource or None",
65+
)
66+
67+
def is_retry_on_error(self, message: TaskiqMessage) -> bool:
68+
"""
69+
Check if retry is enabled for this task.
70+
71+
Looks for `retry_on_error` label, falls back to default.
72+
73+
:param message: Original task message.
74+
:return: True if should retry on error.
75+
"""
76+
retry_on_error = message.labels.get("retry_on_error")
77+
if isinstance(retry_on_error, str):
78+
retry_on_error = retry_on_error.lower() == "true"
79+
if retry_on_error is None:
80+
retry_on_error = self.default_retry_label
81+
return retry_on_error
82+
83+
def make_delay(self, message: TaskiqMessage, retries: int) -> float:
84+
"""
85+
Calculate retry delay.
86+
87+
Includes jitter and exponential backoff if enabled.
88+
89+
:param message: Task message.
90+
:param retries: Current retry count.
91+
:return: Delay in seconds.
92+
"""
93+
delay = float(message.labels.get("delay", self.default_delay))
94+
if self.use_delay_exponent:
95+
delay = min(delay * retries, self.max_delay_exponent)
96+
97+
if self.use_jitter:
98+
delay += random.random() # noqa: S311
99+
100+
return delay
101+
102+
async def on_send(
103+
self,
104+
kicker: AsyncKicker[Any, Any],
105+
message: TaskiqMessage,
106+
delay: float,
107+
) -> None:
108+
"""Execute the task with a delay."""
109+
if self.schedule_source is None:
110+
await kicker.with_labels(delay=delay).kiq(*message.args, **message.kwargs)
111+
else:
112+
target_time = datetime.datetime.now(datetime.UTC) + datetime.timedelta(
113+
seconds=delay,
114+
)
115+
await kicker.schedule_by_time(
116+
self.schedule_source,
117+
target_time,
118+
*message.args,
119+
**message.kwargs,
120+
)
121+
122+
async def on_error(
123+
self,
124+
message: TaskiqMessage,
125+
result: TaskiqResult[Any],
126+
exception: BaseException,
127+
) -> None:
128+
"""
129+
Retry on error.
130+
131+
If an error is raised during task execution,
132+
this middleware schedules the task to be retried
133+
after a calculated delay.
134+
135+
:param message: Message that caused the error.
136+
:param result: Execution result.
137+
:param exception: Caught exception.
138+
"""
139+
if isinstance(exception, NoResultError):
140+
return
141+
142+
retry_on_error = self.is_retry_on_error(message)
143+
144+
if not retry_on_error:
145+
return
146+
147+
retries = int(message.labels.get("_retries", 0)) + 1
148+
max_retries = int(message.labels.get("max_retries", self.default_retry_count))
149+
150+
if retries < max_retries:
151+
delay = self.make_delay(message, retries)
152+
153+
_logger.info(
154+
"Task %s failed. Retrying %d/%d in %.2f seconds.",
155+
message.task_name,
156+
retries,
157+
max_retries,
158+
delay,
159+
)
160+
161+
kicker: AsyncKicker[Any, Any] = (
162+
AsyncKicker(
163+
task_name=message.task_name,
164+
broker=self.broker,
165+
labels=message.labels,
166+
)
167+
.with_task_id(message.task_id)
168+
.with_labels(_retries=retries)
169+
)
170+
171+
await self.on_send(kicker, message, delay)
172+
173+
if self.no_result_on_retry:
174+
result.error = NoResultError()
175+
176+
else:
177+
_logger.warning(
178+
"Task '%s' invocation failed. Maximum retries count is reached.",
179+
message.task_name,
180+
)

tests/middlewares/test_simple_retry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from taskiq.formatters.json_formatter import JSONFormatter
77
from taskiq.message import TaskiqMessage
8-
from taskiq.middlewares.retry_middleware import SimpleRetryMiddleware
8+
from taskiq.middlewares.simple_retry_middleware import SimpleRetryMiddleware
99
from taskiq.result import TaskiqResult
1010

1111

0 commit comments

Comments
 (0)