Skip to content

fix: #90 #91

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion taskiq_faststream/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def dumps( # type: ignore[override]
:param message: message to send.
:return: Dumped message.
"""
labels = message.labels
labels = message.labels.copy()
labels.pop("schedule", None)
labels.pop("schedule_id", None)

Expand Down
8 changes: 8 additions & 0 deletions taskiq_faststream/kicker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
from typing import Any

from taskiq.kicker import AsyncKicker, _FuncParams, _ReturnType
from taskiq.message import TaskiqMessage


class LabelRespectKicker(AsyncKicker[_FuncParams, _ReturnType]):
"""Patched kicker doesn't cast labels to str."""

def _prepare_message(self, *args: Any, **kwargs: Any) -> TaskiqMessage:
msg = super()._prepare_message(*args, **kwargs)
msg.labels = self.labels
return msg
33 changes: 33 additions & 0 deletions tests/messages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from collections.abc import AsyncIterator, Iterator

message = "Hi!"


def sync_callable_msg() -> str:
return message


async def async_callable_msg() -> str:
return message


async def async_generator_msg() -> AsyncIterator[str]:
yield message


def sync_generator_msg() -> Iterator[str]:
yield message


class _C:
def __call__(self) -> str:
return message


class _AC:
async def __call__(self) -> str:
return message


sync_callable_class_message = _C()
async_callable_class_message = _AC()
94 changes: 29 additions & 65 deletions tests/test_resolve_message.py
Original file line number Diff line number Diff line change
@@ -1,70 +1,34 @@
from collections.abc import AsyncIterator, Iterator
import typing

import pytest
from faststream.types import SendableMessage

from taskiq_faststream.utils import resolve_msg


@pytest.mark.anyio
async def test_regular() -> None:
async for m in resolve_msg("msg"):
assert m == "msg"


@pytest.mark.anyio
async def test_sync_callable() -> None:
async for m in resolve_msg(lambda: "msg"):
assert m == "msg"


from tests import messages


@pytest.mark.parametrize(
"msg",
[
messages.message, # regular msg
messages.sync_callable_msg, # sync callable
messages.async_callable_msg, # async callable
messages.sync_generator_msg, # sync generator
messages.async_generator_msg, # async generator
messages.sync_callable_class_message, # sync callable class
messages.async_callable_class_message, # async callable class
],
)
@pytest.mark.anyio
async def test_async_callable() -> None:
async def gen_msg() -> str:
return "msg"

async for m in resolve_msg(gen_msg):
assert m == "msg"


@pytest.mark.anyio
async def test_sync_callable_class() -> None:
class C:
def __init__(self) -> None:
pass

def __call__(self) -> str:
return "msg"

async for m in resolve_msg(C()):
assert m == "msg"


@pytest.mark.anyio
async def test_async_callable_class() -> None:
class C:
def __init__(self) -> None:
pass

async def __call__(self) -> str:
return "msg"

async for m in resolve_msg(C()):
assert m == "msg"


@pytest.mark.anyio
async def test_async_generator() -> None:
async def get_msg() -> AsyncIterator[str]:
yield "msg"

async for m in resolve_msg(get_msg):
assert m == "msg"


@pytest.mark.anyio
async def test_sync_generator() -> None:
def get_msg() -> Iterator[str]:
yield "msg"

async for m in resolve_msg(get_msg):
assert m == "msg"
async def test_resolve_msg(
msg: typing.Union[
None,
SendableMessage,
typing.Callable[[], SendableMessage],
typing.Callable[[], typing.Awaitable[SendableMessage]],
typing.Callable[[], typing.Generator[SendableMessage, None, None]],
typing.Callable[[], typing.AsyncGenerator[SendableMessage, None]],
],
) -> None:
async for m in resolve_msg(msg):
assert m == messages.message
37 changes: 31 additions & 6 deletions tests/testcase.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
import asyncio
import typing
from datetime import datetime, timedelta, timezone
from typing import Any
from unittest.mock import MagicMock

import pytest
from faststream.types import SendableMessage
from faststream.utils.functions import timeout_scope
from freezegun import freeze_time
from taskiq import AsyncBroker, TaskiqScheduler
from taskiq import AsyncBroker
from taskiq.cli.scheduler.args import SchedulerArgs
from taskiq.cli.scheduler.run import run_scheduler
from taskiq.schedule_sources import LabelScheduleSource

from taskiq_faststream import BrokerWrapper, StreamScheduler
from tests import messages


@pytest.mark.anyio
Expand Down Expand Up @@ -54,7 +57,7 @@ async def handler(msg: str) -> None:
task = asyncio.create_task(
run_scheduler(
SchedulerArgs(
scheduler=TaskiqScheduler(
scheduler=StreamScheduler(
broker=taskiq_broker,
sources=[LabelScheduleSource(taskiq_broker)],
),
Expand All @@ -69,24 +72,44 @@ async def handler(msg: str) -> None:
mock.assert_called_once_with("Hi!")
task.cancel()

@pytest.mark.parametrize(
"msg",
[
messages.message, # regular msg
messages.sync_callable_msg, # sync callable
messages.async_callable_msg, # async callable
messages.sync_generator_msg, # sync generator
messages.async_generator_msg, # async generator
messages.sync_callable_class_message, # sync callable class
messages.async_callable_class_message, # async callable class
],
)
async def test_task_multiple_schedules_by_cron(
self,
subject: str,
broker: Any,
event: asyncio.Event,
msg: typing.Union[
None,
SendableMessage,
typing.Callable[[], SendableMessage],
typing.Callable[[], typing.Awaitable[SendableMessage]],
typing.Callable[[], typing.Generator[SendableMessage, None, None]],
typing.Callable[[], typing.AsyncGenerator[SendableMessage, None]],
],
) -> None:
"""Test cron runs twice via StreamScheduler."""
received_message = []

@broker.subscriber(subject)
async def handler(msg: str) -> None:
received_message.append(msg)
async def handler(message: str) -> None:
received_message.append(message)
event.set()

taskiq_broker = self.build_taskiq_broker(broker)

taskiq_broker.task(
"Hi!",
msg,
**{self.subj_name: subject},
schedule=[
{
Expand Down Expand Up @@ -116,4 +139,6 @@ async def handler(msg: str) -> None:

task.cancel()

assert received_message == ["Hi!", "Hi!"], received_message
assert received_message == [messages.message, messages.message], (
received_message
)