Skip to content

Commit 8c61811

Browse files
committed
Add support for quorum queues and max_attempts_at_message
1 parent b6a8be3 commit 8c61811

File tree

4 files changed

+263
-24
lines changed

4 files changed

+263
-24
lines changed

README.md

+31-6
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,32 @@ async def main():
116116

117117
```
118118

119+
## Queue Types and Message Reliability
120+
121+
AioPikaBroker supports both classic and quorum queues. Quorum queues are a more modern queue type in RabbitMQ that provides better reliability and data safety guarantees.
122+
123+
```python
124+
from taskiq_aio_pika import AioPikaBroker, QueueType
125+
126+
broker = AioPikaBroker(
127+
queue_type=QueueType.QUORUM, # Use quorum queues for better reliability
128+
max_attempts_at_message=3 # Limit redelivery attempts
129+
)
130+
```
131+
132+
### Message Redelivery Control
133+
134+
When message processing fails due to consumer crashes (e.g. due to an OOM condition resulting in a SIGKILL), network issues, or other infrastructure problems, before the consumer has had the chance to acknowledge, positively or negatively, the message (and schedule a retry via taskiq's retry middleware), RabbitMQ will requeue the message to the front of the queue and it will be redelivered. With quorum queues, you can control how many times such a message will be redelivered:
135+
136+
- Set `max_attempts_at_message` to limit delivery attempts.
137+
- Set `max_attempts_at_message=None` for unlimited attempts.
138+
- This operates at the message delivery level, not application retry level. For application-level retries in case of exceptions that can be caught (e.g., temporary API failures), use taskiq's retry middleware instead.
139+
- After max attempts, the message is logged and discarded.
140+
- `max_attempts_at_message` requires using quorum queues (`queue_type=QueueType.QUORUM`).
141+
142+
This is particularly useful for preventing infinite loops of redeliveries of messages that consistently cause the consumer to crash ([poison messages](https://www.rabbitmq.com/docs/quorum-queues#poison-message-handling)) and can cause the queue to backup.
143+
144+
119145
## Configuration
120146

121147
AioPikaBroker parameters:
@@ -125,13 +151,12 @@ AioPikaBroker parameters:
125151
* `exchange_name` - name of exchange that used to send messages.
126152
* `exchange_type` - type of the exchange. Used only if `declare_exchange` is True.
127153
* `queue_name` - queue that used to get incoming messages.
154+
* `queue_type` - type of RabbitMQ queue to use: `classic` or `quorum`. defaults to `classic`.
128155
* `routing_key` - that used to bind that queue to the exchange.
129156
* `declare_exchange` - whether you want to declare new exchange if it doesn't exist.
130157
* `max_priority` - maximum priority for messages.
131-
* `delay_queue_name` - custom delay queue name.
132-
This queue is used to deliver messages with delays.
133-
* `dead_letter_queue_name` - custom dead letter queue name.
134-
This queue is used to receive negatively acknowleged messages from the main queue.
158+
* `delay_queue_name` - custom delay queue name. This queue is used to deliver messages with delays.
159+
* `dead_letter_queue_name` - custom dead letter queue name. This queue is used to receive negatively acknowleged messages from the main queue.
135160
* `qos` - number of messages that worker can prefetch.
136-
* `declare_queues` - whether you want to declare queues even on
137-
client side. May be useful for message persistance.
161+
* `declare_queues` - whether you want to declare queues even on client side. May be useful for message persistance.
162+
* `max_attempts_at_message` - maximum number of attempts at processing the same message. requires the queue type to be set to `QueueType.QUORUM`. defaults to `20` for quorum queues and to `None` for classic queues. is not the same as task retries. pass `None` for unlimited attempts.

taskiq_aio_pika/broker.py

+105-12
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,41 @@
11
import asyncio
2+
import copy
23
from datetime import timedelta
4+
from enum import Enum
35
from logging import getLogger
4-
from typing import Any, AsyncGenerator, Callable, Dict, Optional, TypeVar
6+
from typing import (
7+
Any,
8+
AsyncGenerator,
9+
Callable,
10+
Dict,
11+
Literal,
12+
Optional,
13+
TypeVar,
14+
Union,
15+
override,
16+
)
517

618
from aio_pika import DeliveryMode, ExchangeType, Message, connect_robust
719
from aio_pika.abc import AbstractChannel, AbstractQueue, AbstractRobustConnection
8-
from taskiq import AckableMessage, AsyncBroker, AsyncResultBackend, BrokerMessage
20+
from taskiq import (
21+
AsyncBroker,
22+
AsyncResultBackend,
23+
BrokerMessage,
24+
)
25+
from taskiq.message import AckableNackableWrappedMessageWithMetadata, MessageMetadata
926

1027
_T = TypeVar("_T")
1128

1229
logger = getLogger("taskiq.aio_pika_broker")
1330

1431

32+
class QueueType(Enum):
33+
"""Type of RabbitMQ queue."""
34+
35+
CLASSIC = "classic"
36+
QUORUM = "quorum"
37+
38+
1539
def parse_val(
1640
parse_func: Callable[[str], _T],
1741
target: Optional[str] = None,
@@ -35,7 +59,7 @@ def parse_val(
3559
class AioPikaBroker(AsyncBroker):
3660
"""Broker that works with RabbitMQ."""
3761

38-
def __init__(
62+
def __init__( # noqa: PLR0912
3963
self,
4064
url: Optional[str] = None,
4165
result_backend: Optional[AsyncResultBackend[_T]] = None,
@@ -44,6 +68,7 @@ def __init__(
4468
loop: Optional[asyncio.AbstractEventLoop] = None,
4569
exchange_name: str = "taskiq",
4670
queue_name: str = "taskiq",
71+
queue_type: QueueType = QueueType.CLASSIC,
4772
dead_letter_queue_name: Optional[str] = None,
4873
delay_queue_name: Optional[str] = None,
4974
declare_exchange: bool = True,
@@ -54,6 +79,7 @@ def __init__(
5479
delayed_message_exchange_plugin: bool = False,
5580
declare_exchange_kwargs: Optional[Dict[Any, Any]] = None,
5681
declare_queues_kwargs: Optional[Dict[Any, Any]] = None,
82+
max_attempts_at_message: Union[Optional[int], Literal["default"]] = "default",
5783
**connection_kwargs: Any,
5884
) -> None:
5985
"""
@@ -62,12 +88,13 @@ def __init__(
6288
:param url: url to rabbitmq. If None,
6389
the default "amqp://guest:guest@localhost:5672" is used.
6490
:param result_backend: custom result backend.
65-
6691
:param task_id_generator: custom task_id genertaor.
6792
:param qos: number of messages that worker can prefetch.
6893
:param loop: specific even loop.
6994
:param exchange_name: name of exchange that used to send messages.
7095
:param queue_name: queue that used to get incoming messages.
96+
:param queue_type: type of RabbitMQ queue to use: `classic` or `quorum`.
97+
defaults to `classic`.
7198
:param dead_letter_queue_name: custom name for dead-letter queue.
7299
by default it set to {queue_name}.dead_letter.
73100
:param delay_queue_name: custom name for queue that used to
@@ -86,6 +113,11 @@ def __init__(
86113
:param declare_queues_kwargs: additional from AbstractChannel.declare_queue
87114
:param connection_kwargs: additional keyword arguments,
88115
for connect_robust method of aio-pika.
116+
:param max_attempts_at_message: maximum number of attempts at processing
117+
the same message. requires the queue type to be set to `QueueType.QUORUM`.
118+
defaults to `20` for quorum queues and to `None` for classic queues.
119+
is not the same as task retries. pass `None` for unlimited attempts.
120+
:raises ValueError: if inappropriate arguments were passed.
89121
"""
90122
super().__init__(result_backend, task_id_generator)
91123

@@ -104,6 +136,52 @@ def __init__(
104136
self._max_priority = max_priority
105137
self._delayed_message_exchange_plugin = delayed_message_exchange_plugin
106138

139+
if self._declare_queues_kwargs.get("arguments", {}).get(
140+
"x-queue-type",
141+
) or self._declare_queues_kwargs.get("arguments", {}).get("x-delivery-limit"):
142+
raise ValueError(
143+
"Use the `queue_type` and `max_attempts_at_message` parameters of "
144+
"`AioPikaBroker.__init__` instead of `x-queue-type` and "
145+
"`x-delivery-limit`",
146+
)
147+
if queue_type == QueueType.QUORUM:
148+
self._declare_queues_kwargs.setdefault("arguments", {})[
149+
"x-queue-type"
150+
] = "quorum"
151+
self._declare_queues_kwargs["durable"] = True
152+
else:
153+
self._declare_queues_kwargs.setdefault("arguments", {})[
154+
"x-queue-type"
155+
] = "classic"
156+
157+
if queue_type != QueueType.QUORUM and max_attempts_at_message not in (
158+
"default",
159+
None,
160+
):
161+
raise ValueError(
162+
"`max_attempts_at_message` requires `queue_type` to be set to "
163+
"`QueueType.QUORUM`.",
164+
)
165+
166+
if max_attempts_at_message == "default":
167+
if queue_type == QueueType.QUORUM:
168+
self.max_attempts_at_message = 20
169+
else:
170+
self.max_attempts_at_message = None
171+
else:
172+
self.max_attempts_at_message = max_attempts_at_message
173+
174+
if queue_type == QueueType.QUORUM:
175+
if self.max_attempts_at_message is None:
176+
# no limit
177+
self._declare_queues_kwargs["arguments"]["x-delivery-limit"] = "-1"
178+
else:
179+
# the final attempt will be handled in `taskiq.Receiver`
180+
# to generate visible logs
181+
self._declare_queues_kwargs["arguments"]["x-delivery-limit"] = (
182+
self.max_attempts_at_message + 1
183+
)
184+
107185
self._dead_letter_queue_name = f"{queue_name}.dead_letter"
108186
if dead_letter_queue_name:
109187
self._dead_letter_queue_name = dead_letter_queue_name
@@ -183,9 +261,15 @@ async def declare_queues(
183261
:param channel: channel to used for declaration.
184262
:return: main queue instance.
185263
"""
264+
declare_queues_kwargs_ex_arguments = copy.copy(self._declare_queues_kwargs)
265+
declare_queue_arguments = declare_queues_kwargs_ex_arguments.pop(
266+
"arguments",
267+
{},
268+
)
186269
await channel.declare_queue(
187270
self._dead_letter_queue_name,
188-
**self._declare_queues_kwargs,
271+
**declare_queues_kwargs_ex_arguments,
272+
arguments=declare_queue_arguments,
189273
)
190274
args: "Dict[str, Any]" = {
191275
"x-dead-letter-exchange": "",
@@ -195,8 +279,8 @@ async def declare_queues(
195279
args["x-max-priority"] = self._max_priority
196280
queue = await channel.declare_queue(
197281
self._queue_name,
198-
arguments=args,
199-
**self._declare_queues_kwargs,
282+
arguments=args | declare_queue_arguments,
283+
**declare_queues_kwargs_ex_arguments,
200284
)
201285
if self._delayed_message_exchange_plugin:
202286
await queue.bind(
@@ -209,8 +293,9 @@ async def declare_queues(
209293
arguments={
210294
"x-dead-letter-exchange": "",
211295
"x-dead-letter-routing-key": self._queue_name,
212-
},
213-
**self._declare_queues_kwargs,
296+
}
297+
| declare_queue_arguments,
298+
**declare_queues_kwargs_ex_arguments,
214299
)
215300

216301
await queue.bind(
@@ -275,7 +360,10 @@ async def kick(self, message: BrokerMessage) -> None:
275360
routing_key=self._delay_queue_name,
276361
)
277362

278-
async def listen(self) -> AsyncGenerator[AckableMessage, None]:
363+
@override
364+
async def listen(
365+
self,
366+
) -> AsyncGenerator[AckableNackableWrappedMessageWithMetadata, None]:
279367
"""
280368
Listen to queue.
281369
@@ -291,7 +379,12 @@ async def listen(self) -> AsyncGenerator[AckableMessage, None]:
291379
queue = await self.declare_queues(self.read_channel)
292380
async with queue.iterator() as iterator:
293381
async for message in iterator:
294-
yield AckableMessage(
295-
data=message.body,
382+
delivery_count: Optional[int] = message.headers.get("x-delivery-count") # type: ignore[assignment]
383+
yield AckableNackableWrappedMessageWithMetadata(
384+
message=message.body,
385+
metadata=MessageMetadata(
386+
delivery_count=delivery_count,
387+
),
296388
ack=message.ack,
389+
nack=message.nack,
297390
)

tests/conftest.py

+11
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
from contextlib import suppress
23
from typing import AsyncGenerator
34
from uuid import uuid4
45

@@ -229,3 +230,13 @@ async def broker_with_delayed_message_plugin(
229230
if_empty=False,
230231
if_unused=False,
231232
)
233+
234+
235+
@pytest.fixture(autouse=True, scope="function")
236+
async def cleanup_rabbitmq(test_channel: Channel) -> AsyncGenerator[None, None]:
237+
yield
238+
239+
for queue_name in ["taskiq", "taskiq.dead_letter", "taskiq.delay"]:
240+
with suppress(Exception):
241+
queue = await test_channel.get_queue(queue_name, ensure=False)
242+
await queue.delete(if_unused=False, if_empty=False)

0 commit comments

Comments
 (0)