Skip to content

Commit c8fa21c

Browse files
authored
Allow async on_heartbeat callback (#83)
1 parent 8d4f4b4 commit c8fa21c

File tree

3 files changed

+64
-4
lines changed

3 files changed

+64
-4
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ async with stompman.Client(
2727

2828
# Handlers:
2929
on_error_frame=lambda error_frame: print(error_frame.body),
30-
on_heartbeat=lambda: print("Server sent a heartbeat"),
30+
on_heartbeat=lambda: print("Server sent a heartbeat"), # also can be async
3131

3232
# SSL — can be either `None` (default), `True`, or `ssl.SSLContext'
3333
ssl=None,

stompman/client.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import asyncio
2+
import inspect
23
from collections.abc import AsyncGenerator, Awaitable, Callable, Coroutine
34
from contextlib import AsyncExitStack, asynccontextmanager
45
from dataclasses import dataclass, field
56
from functools import partial
6-
import inspect
77
from ssl import SSLContext
88
from types import TracebackType
99
from typing import ClassVar, Literal, Self
@@ -31,7 +31,7 @@ class Client:
3131

3232
servers: list[ConnectionParameters] = field(kw_only=False)
3333
on_error_frame: Callable[[ErrorFrame], None] | None = None
34-
on_heartbeat: Callable[[], None] | None = None
34+
on_heartbeat: Callable[[], None] | Callable[[], Awaitable[None]] | None = None
3535

3636
heartbeat: Heartbeat = field(default=Heartbeat(1000, 1000))
3737
ssl: Literal[True] | SSLContext | None = None
@@ -53,6 +53,7 @@ class Client:
5353
_heartbeat_task: asyncio.Task[None] = field(init=False)
5454
_listen_task: asyncio.Task[None] = field(init=False)
5555
_task_group: asyncio.TaskGroup = field(init=False)
56+
_on_heartbeat_is_async: bool = field(init=False)
5657

5758
def __post_init__(self) -> None:
5859
self._connection_manager = ConnectionManager(
@@ -76,6 +77,7 @@ def __post_init__(self) -> None:
7677
write_retry_attempts=self.write_retry_attempts,
7778
ssl=self.ssl,
7879
)
80+
self._on_heartbeat_is_async = inspect.iscoroutinefunction(self.on_heartbeat) if self.on_heartbeat else False
7981

8082
async def __aenter__(self) -> Self:
8183
self._task_group = await self._exit_stack.enter_async_context(asyncio.TaskGroup())
@@ -116,7 +118,11 @@ async def _listen_to_frames(self) -> None:
116118
if self.on_error_frame:
117119
self.on_error_frame(frame)
118120
case HeartbeatFrame():
119-
if self.on_heartbeat:
121+
if self.on_heartbeat is None:
122+
pass
123+
elif self._on_heartbeat_is_async:
124+
task_group.create_task(self.on_heartbeat()) # type: ignore[arg-type]
125+
else:
120126
self.on_heartbeat()
121127
case ConnectedFrame() | ReceiptFrame():
122128
pass

tests/test_connection_lifespan.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
from collections.abc import AsyncGenerator, Coroutine
3+
from functools import partial
34
from typing import Any
45
from unittest import mock
56

@@ -17,6 +18,7 @@
1718
DisconnectFrame,
1819
ErrorFrame,
1920
FailedAllConnectAttemptsError,
21+
HeartbeatFrame,
2022
ReceiptFrame,
2123
UnsupportedProtocolVersion,
2224
)
@@ -154,6 +156,58 @@ async def mock_sleep(delay: float) -> None:
154156
assert write_heartbeat_mock.mock_calls == [mock.call(), mock.call(), mock.call()]
155157

156158

159+
async def test_client_on_heartbeat_none(monkeypatch: pytest.MonkeyPatch) -> None:
160+
real_sleep = asyncio.sleep
161+
monkeypatch.setattr("asyncio.sleep", partial(asyncio.sleep, 0))
162+
connection_class, _ = create_spying_connection(
163+
*get_read_frames_with_lifespan(
164+
[build_dataclass(HeartbeatFrame), build_dataclass(HeartbeatFrame), build_dataclass(HeartbeatFrame)]
165+
)
166+
)
167+
168+
async with EnrichedClient(connection_class=connection_class, on_heartbeat=None):
169+
await real_sleep(0)
170+
await real_sleep(0)
171+
await real_sleep(0)
172+
173+
174+
async def test_client_on_heartbeat_sync(monkeypatch: pytest.MonkeyPatch) -> None:
175+
real_sleep = asyncio.sleep
176+
monkeypatch.setattr("asyncio.sleep", partial(asyncio.sleep, 0))
177+
connection_class, _ = create_spying_connection(
178+
*get_read_frames_with_lifespan(
179+
[build_dataclass(HeartbeatFrame), build_dataclass(HeartbeatFrame), build_dataclass(HeartbeatFrame)]
180+
)
181+
)
182+
on_heartbeat_mock = mock.Mock()
183+
184+
async with EnrichedClient(connection_class=connection_class, on_heartbeat=on_heartbeat_mock):
185+
await real_sleep(0)
186+
await real_sleep(0)
187+
await real_sleep(0)
188+
189+
assert on_heartbeat_mock.mock_calls == [mock.call(), mock.call(), mock.call()]
190+
191+
192+
async def test_client_on_heartbeat_async(monkeypatch: pytest.MonkeyPatch) -> None:
193+
real_sleep = asyncio.sleep
194+
monkeypatch.setattr("asyncio.sleep", partial(asyncio.sleep, 0))
195+
connection_class, _ = create_spying_connection(
196+
*get_read_frames_with_lifespan(
197+
[build_dataclass(HeartbeatFrame), build_dataclass(HeartbeatFrame), build_dataclass(HeartbeatFrame)]
198+
)
199+
)
200+
on_heartbeat_mock = mock.AsyncMock()
201+
202+
async with EnrichedClient(connection_class=connection_class, on_heartbeat=on_heartbeat_mock):
203+
await real_sleep(0)
204+
await real_sleep(0)
205+
await real_sleep(0)
206+
207+
assert on_heartbeat_mock.await_count == 3 # noqa: PLR2004
208+
assert on_heartbeat_mock.mock_calls == [mock.call.__bool__(), mock.call(), mock.call(), mock.call()]
209+
210+
157211
def test_make_receipt_id(monkeypatch: pytest.MonkeyPatch) -> None:
158212
monkeypatch.undo()
159213
stompman.connection_lifespan._make_receipt_id()

0 commit comments

Comments
 (0)