Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ async with stompman.Client(

# Handlers:
on_error_frame=lambda error_frame: print(error_frame.body),
on_heartbeat=lambda: print("Server sent a heartbeat"),
on_heartbeat=lambda: print("Server sent a heartbeat"), # also can be async

# SSL — can be either `None` (default), `True`, or `ssl.SSLContext'
ssl=None,
Expand Down
12 changes: 9 additions & 3 deletions stompman/client.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import asyncio
import inspect
from collections.abc import AsyncGenerator, Awaitable, Callable, Coroutine
from contextlib import AsyncExitStack, asynccontextmanager
from dataclasses import dataclass, field
from functools import partial
import inspect
from ssl import SSLContext
from types import TracebackType
from typing import ClassVar, Literal, Self
Expand Down Expand Up @@ -31,7 +31,7 @@ class Client:

servers: list[ConnectionParameters] = field(kw_only=False)
on_error_frame: Callable[[ErrorFrame], None] | None = None
on_heartbeat: Callable[[], None] | None = None
on_heartbeat: Callable[[], None] | Callable[[], Awaitable[None]] | None = None

heartbeat: Heartbeat = field(default=Heartbeat(1000, 1000))
ssl: Literal[True] | SSLContext | None = None
Expand All @@ -53,6 +53,7 @@ class Client:
_heartbeat_task: asyncio.Task[None] = field(init=False)
_listen_task: asyncio.Task[None] = field(init=False)
_task_group: asyncio.TaskGroup = field(init=False)
_on_heartbeat_is_async: bool = field(init=False)

def __post_init__(self) -> None:
self._connection_manager = ConnectionManager(
Expand All @@ -76,6 +77,7 @@ def __post_init__(self) -> None:
write_retry_attempts=self.write_retry_attempts,
ssl=self.ssl,
)
self._on_heartbeat_is_async = inspect.iscoroutinefunction(self.on_heartbeat) if self.on_heartbeat else False

async def __aenter__(self) -> Self:
self._task_group = await self._exit_stack.enter_async_context(asyncio.TaskGroup())
Expand Down Expand Up @@ -116,7 +118,11 @@ async def _listen_to_frames(self) -> None:
if self.on_error_frame:
self.on_error_frame(frame)
case HeartbeatFrame():
if self.on_heartbeat:
if self.on_heartbeat is None:
pass
elif self._on_heartbeat_is_async:
task_group.create_task(self.on_heartbeat()) # type: ignore[arg-type]
else:
self.on_heartbeat()
case ConnectedFrame() | ReceiptFrame():
pass
Expand Down
54 changes: 54 additions & 0 deletions tests/test_connection_lifespan.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
from collections.abc import AsyncGenerator, Coroutine
from functools import partial
from typing import Any
from unittest import mock

Expand All @@ -17,6 +18,7 @@
DisconnectFrame,
ErrorFrame,
FailedAllConnectAttemptsError,
HeartbeatFrame,
ReceiptFrame,
UnsupportedProtocolVersion,
)
Expand Down Expand Up @@ -154,6 +156,58 @@ async def mock_sleep(delay: float) -> None:
assert write_heartbeat_mock.mock_calls == [mock.call(), mock.call(), mock.call()]


async def test_client_on_heartbeat_none(monkeypatch: pytest.MonkeyPatch) -> None:
real_sleep = asyncio.sleep
monkeypatch.setattr("asyncio.sleep", partial(asyncio.sleep, 0))
connection_class, _ = create_spying_connection(
*get_read_frames_with_lifespan(
[build_dataclass(HeartbeatFrame), build_dataclass(HeartbeatFrame), build_dataclass(HeartbeatFrame)]
)
)

async with EnrichedClient(connection_class=connection_class, on_heartbeat=None):
await real_sleep(0)
await real_sleep(0)
await real_sleep(0)


async def test_client_on_heartbeat_sync(monkeypatch: pytest.MonkeyPatch) -> None:
real_sleep = asyncio.sleep
monkeypatch.setattr("asyncio.sleep", partial(asyncio.sleep, 0))
connection_class, _ = create_spying_connection(
*get_read_frames_with_lifespan(
[build_dataclass(HeartbeatFrame), build_dataclass(HeartbeatFrame), build_dataclass(HeartbeatFrame)]
)
)
on_heartbeat_mock = mock.Mock()

async with EnrichedClient(connection_class=connection_class, on_heartbeat=on_heartbeat_mock):
await real_sleep(0)
await real_sleep(0)
await real_sleep(0)

assert on_heartbeat_mock.mock_calls == [mock.call(), mock.call(), mock.call()]


async def test_client_on_heartbeat_async(monkeypatch: pytest.MonkeyPatch) -> None:
real_sleep = asyncio.sleep
monkeypatch.setattr("asyncio.sleep", partial(asyncio.sleep, 0))
connection_class, _ = create_spying_connection(
*get_read_frames_with_lifespan(
[build_dataclass(HeartbeatFrame), build_dataclass(HeartbeatFrame), build_dataclass(HeartbeatFrame)]
)
)
on_heartbeat_mock = mock.AsyncMock()

async with EnrichedClient(connection_class=connection_class, on_heartbeat=on_heartbeat_mock):
await real_sleep(0)
await real_sleep(0)
await real_sleep(0)

assert on_heartbeat_mock.await_count == 3 # noqa: PLR2004
assert on_heartbeat_mock.mock_calls == [mock.call.__bool__(), mock.call(), mock.call(), mock.call()]


def test_make_receipt_id(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.undo()
stompman.connection_lifespan._make_receipt_id()