Skip to content

Commit 74d413f

Browse files
authored
Remove Client(..., on_heartbeat=...), add Client.is_alive(), improve FastStream health check (#109)
1 parent 1acadd9 commit 74d413f

File tree

12 files changed

+190
-141
lines changed

12 files changed

+190
-141
lines changed

Justfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ test *args:
2525
run-artemis:
2626
#!/bin/bash
2727
trap 'echo; docker compose down --remove-orphans' EXIT
28-
docker compose run --service-ports artemis
28+
docker compose run --service-ports activemq-artemis
2929

3030
run-consumer:
3131
uv run examples/consumer.py

README.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@ async with stompman.Client(
2424
stompman.ConnectionParameters(host="172.0.0.1", port=61616, login="user2", passcode="passcode2"),
2525
],
2626

27-
# Handlers:
28-
on_error_frame=lambda error_frame: print(error_frame.body),
29-
on_heartbeat=lambda: print("Server sent a heartbeat"), # also can be async
3027

3128
# SSL — can be either `None` (default), `True`, or `ssl.SSLContext'
3229
ssl=None,
3330

31+
# Error frame handler:
32+
on_error_frame=lambda error_frame: print(error_frame.body),
33+
3434
# Optional parameters with sensible defaults:
3535
heartbeat=stompman.Heartbeat(will_send_interval_ms=1000, want_to_receive_interval_ms=1000),
3636
connect_retry_attempts=3,
@@ -40,6 +40,7 @@ async with stompman.Client(
4040
disconnect_confirmation_timeout=2,
4141
read_timeout=2,
4242
write_retry_attempts=3,
43+
check_server_alive_interval_factor=3,
4344
) as client:
4445
...
4546
```
@@ -131,6 +132,7 @@ stompman takes care of cleaning up resources automatically. When you leave the c
131132
- If multiple servers were provided, stompman will attempt to connect to each one simultaneously and will use the first that succeeds. If all servers fail to connect, an `stompman.FailedAllConnectAttemptsError` will be raised. In normal situation it doesn't need to be handled: tune retry and timeout parameters in `stompman.Client()` to your needs.
132133

133134
- When connection is lost, stompman will attempt to handle it automatically. `stompman.FailedAllConnectAttemptsError` will be raised if all connection attempts fail. `stompman.FailedAllWriteAttemptsError` will be raised if connection succeeds but sending a frame or heartbeat lead to losing connection.
135+
- To implement health checks, use `stompman.Client.is_alive()` — it will return `True` if everything is OK and `False` if server is not responding.
134136

135137
### ...and caveats
136138

packages/faststream-stomp/faststream_stomp/broker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ async def ping(self, timeout: float | None = None) -> bool:
114114
if cancel_scope.cancel_called:
115115
return False
116116

117-
if self._connection._connection_manager._active_connection_state: # noqa: SLF001
117+
if self._connection.is_alive():
118118
return True
119119

120120
await anyio.sleep(sleep_time) # pragma: no cover

packages/stompman/stompman/client.py

Lines changed: 10 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import asyncio
2-
import inspect
32
from collections.abc import AsyncGenerator, Awaitable, Callable, Coroutine
43
from contextlib import AsyncExitStack, asynccontextmanager
54
from dataclasses import dataclass, field
@@ -31,7 +30,6 @@ class Client:
3130

3231
servers: list[ConnectionParameters] = field(kw_only=False)
3332
on_error_frame: Callable[[ErrorFrame], Any] | None = None
34-
on_heartbeat: Callable[[], Any] | Callable[[], Awaitable[Any]] | None = None
3533

3634
heartbeat: Heartbeat = field(default=Heartbeat(1000, 1000))
3735
ssl: Literal[True] | SSLContext | None = None
@@ -43,17 +41,17 @@ class Client:
4341
write_retry_attempts: int = 3
4442
connection_confirmation_timeout: int = 2
4543
disconnect_confirmation_timeout: int = 2
44+
check_server_alive_interval_factor: int = 3
45+
"""Client will check if server alive `server heartbeat interval` times `interval factor`"""
4646

4747
connection_class: type[AbstractConnection] = Connection
4848

4949
_connection_manager: ConnectionManager = field(init=False)
5050
_active_subscriptions: ActiveSubscriptions = field(default_factory=dict, init=False)
5151
_active_transactions: set[Transaction] = field(default_factory=set, init=False)
5252
_exit_stack: AsyncExitStack = field(default_factory=AsyncExitStack, init=False)
53-
_heartbeat_task: asyncio.Task[None] = field(init=False)
5453
_listen_task: asyncio.Task[None] = field(init=False)
5554
_task_group: asyncio.TaskGroup = field(init=False)
56-
_on_heartbeat_is_async: bool = field(init=False)
5755

5856
def __post_init__(self) -> None:
5957
self._connection_manager = ConnectionManager(
@@ -66,7 +64,6 @@ def __post_init__(self) -> None:
6664
disconnect_confirmation_timeout=self.disconnect_confirmation_timeout,
6765
active_subscriptions=self._active_subscriptions,
6866
active_transactions=self._active_transactions,
69-
set_heartbeat_interval=self._restart_heartbeat_task,
7067
),
7168
connection_class=self.connection_class,
7269
connect_retry_attempts=self.connect_retry_attempts,
@@ -75,13 +72,12 @@ def __post_init__(self) -> None:
7572
read_timeout=self.read_timeout,
7673
read_max_chunk_size=self.read_max_chunk_size,
7774
write_retry_attempts=self.write_retry_attempts,
75+
check_server_alive_interval_factor=self.check_server_alive_interval_factor,
7876
ssl=self.ssl,
7977
)
80-
self._on_heartbeat_is_async = inspect.iscoroutinefunction(self.on_heartbeat) if self.on_heartbeat else False
8178

8279
async def __aenter__(self) -> Self:
8380
self._task_group = await self._exit_stack.enter_async_context(asyncio.TaskGroup())
84-
self._heartbeat_task = self._task_group.create_task(asyncio.sleep(0))
8581
await self._exit_stack.enter_async_context(self._connection_manager)
8682
self._listen_task = self._task_group.create_task(self._listen_to_frames())
8783
return self
@@ -94,19 +90,9 @@ async def __aexit__(
9490
await asyncio.Future()
9591
finally:
9692
self._listen_task.cancel()
97-
self._heartbeat_task.cancel()
98-
await asyncio.wait([self._listen_task, self._heartbeat_task])
93+
await asyncio.wait([self._listen_task])
9994
await self._exit_stack.aclose()
10095

101-
def _restart_heartbeat_task(self, interval: float) -> None:
102-
self._heartbeat_task.cancel()
103-
self._heartbeat_task = self._task_group.create_task(self._send_heartbeats_forever(interval))
104-
105-
async def _send_heartbeats_forever(self, interval: float) -> None:
106-
while True:
107-
await self._connection_manager.write_heartbeat_reconnecting()
108-
await asyncio.sleep(interval)
109-
11096
async def _listen_to_frames(self) -> None:
11197
async with asyncio.TaskGroup() as task_group:
11298
async for frame in self._connection_manager.read_frames_reconnecting():
@@ -125,14 +111,7 @@ async def _listen_to_frames(self) -> None:
125111
case ErrorFrame():
126112
if self.on_error_frame:
127113
self.on_error_frame(frame)
128-
case HeartbeatFrame():
129-
if self.on_heartbeat is None:
130-
pass
131-
elif self._on_heartbeat_is_async:
132-
task_group.create_task(self.on_heartbeat()) # type: ignore[arg-type]
133-
else:
134-
self.on_heartbeat()
135-
case ConnectedFrame() | ReceiptFrame():
114+
case HeartbeatFrame() | ConnectedFrame() | ReceiptFrame():
136115
pass
137116

138117
async def send(
@@ -192,3 +171,8 @@ async def subscribe_with_manual_ack(
192171
)
193172
await subscription._subscribe() # noqa: SLF001
194173
return subscription
174+
175+
def is_alive(self) -> bool:
176+
return (
177+
self._connection_manager._active_connection_state or False # noqa: SLF001
178+
) and self._connection_manager._active_connection_state.is_alive() # noqa: SLF001

packages/stompman/stompman/connection.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import asyncio
22
import socket
3+
import time
34
from collections.abc import AsyncGenerator, Generator, Iterator
45
from contextlib import contextmanager, suppress
5-
from dataclasses import dataclass
6+
from dataclasses import dataclass, field
67
from ssl import SSLContext
78
from typing import Literal, Protocol, Self, cast
89

@@ -13,6 +14,8 @@
1314

1415
@dataclass(kw_only=True)
1516
class AbstractConnection(Protocol):
17+
last_read_time: float | None = field(init=False, default=None)
18+
1619
@classmethod
1720
async def connect(
1821
cls,
@@ -100,6 +103,7 @@ async def read_frames(self) -> AsyncGenerator[AnyServerFrame, None]:
100103
raw_frames = await asyncio.wait_for(
101104
self._read_non_empty_bytes(self.read_max_chunk_size), timeout=self.read_timeout
102105
)
106+
self.last_read_time = time.time()
103107

104108
for frame in cast("Iterator[AnyServerFrame]", parser.parse_frames_from_chunk(raw_frames)):
105109
yield frame

packages/stompman/stompman/connection_lifespan.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,13 @@
2222
from stompman.transaction import ActiveTransactions, commit_pending_transactions
2323

2424

25+
@dataclass(frozen=True, kw_only=True, slots=True)
26+
class EstablishedConnectionResult:
27+
server_heartbeat: Heartbeat
28+
29+
2530
class AbstractConnectionLifespan(Protocol):
26-
async def enter(self) -> StompProtocolConnectionIssue | None: ...
31+
async def enter(self) -> EstablishedConnectionResult | StompProtocolConnectionIssue: ...
2732
async def exit(self) -> None: ...
2833

2934

@@ -37,9 +42,9 @@ class ConnectionLifespan(AbstractConnectionLifespan):
3742
disconnect_confirmation_timeout: int
3843
active_subscriptions: ActiveSubscriptions
3944
active_transactions: ActiveTransactions
40-
set_heartbeat_interval: Callable[[float], Any]
45+
set_heartbeat_interval: Callable[[Heartbeat], Any]
4146

42-
async def _establish_connection(self) -> StompProtocolConnectionIssue | None:
47+
async def _establish_connection(self) -> EstablishedConnectionResult | StompProtocolConnectionIssue:
4348
await self.connection.write_frame(
4449
ConnectFrame(
4550
headers={
@@ -74,19 +79,17 @@ async def take_connected_frame_and_collect_other_frames() -> ConnectedFrame:
7479
)
7580

7681
server_heartbeat = Heartbeat.from_header(connected_frame.headers["heart-beat"])
77-
self.set_heartbeat_interval(
78-
max(self.client_heartbeat.will_send_interval_ms, server_heartbeat.want_to_receive_interval_ms) / 1000
79-
)
80-
return None
81-
82-
async def enter(self) -> StompProtocolConnectionIssue | None:
83-
if connection_issue := await self._establish_connection():
84-
return connection_issue
85-
await resubscribe_to_active_subscriptions(
86-
connection=self.connection, active_subscriptions=self.active_subscriptions
87-
)
88-
await commit_pending_transactions(connection=self.connection, active_transactions=self.active_transactions)
89-
return None
82+
self.set_heartbeat_interval(server_heartbeat)
83+
return EstablishedConnectionResult(server_heartbeat=server_heartbeat)
84+
85+
async def enter(self) -> EstablishedConnectionResult | StompProtocolConnectionIssue:
86+
connection_result = await self._establish_connection()
87+
if isinstance(connection_result, EstablishedConnectionResult):
88+
await resubscribe_to_active_subscriptions(
89+
connection=self.connection, active_subscriptions=self.active_subscriptions
90+
)
91+
await commit_pending_transactions(connection=self.connection, active_transactions=self.active_transactions)
92+
return connection_result
9093

9194
async def _take_receipt_frame(self) -> None:
9295
async for frame in self.connection.read_frames():
@@ -107,5 +110,9 @@ def _make_receipt_id() -> str:
107110

108111
class ConnectionLifespanFactory(Protocol):
109112
def __call__(
110-
self, *, connection: AbstractConnection, connection_parameters: ConnectionParameters
113+
self,
114+
*,
115+
connection: AbstractConnection,
116+
connection_parameters: ConnectionParameters,
117+
set_heartbeat_interval: Callable[[Heartbeat], Any],
111118
) -> AbstractConnectionLifespan: ...

packages/stompman/stompman/connection_manager.py

Lines changed: 70 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import asyncio
2+
import time
23
from collections.abc import AsyncGenerator
34
from dataclasses import dataclass, field
45
from ssl import SSLContext
56
from types import TracebackType
67
from typing import TYPE_CHECKING, Literal, Self
78

8-
from stompman.config import ConnectionParameters
9+
from stompman.config import ConnectionParameters, Heartbeat
910
from stompman.connection import AbstractConnection
1011
from stompman.errors import (
1112
AllServersUnavailable,
@@ -25,6 +26,12 @@
2526
class ActiveConnectionState:
2627
connection: AbstractConnection
2728
lifespan: "AbstractConnectionLifespan"
29+
server_heartbeat: Heartbeat
30+
31+
def is_alive(self) -> bool:
32+
if not (last_read_time := self.connection.last_read_time):
33+
return True
34+
return (self.server_heartbeat.will_send_interval_ms / 1000) > (time.time() - last_read_time)
2835

2936

3037
@dataclass(kw_only=True, slots=True)
@@ -39,17 +46,29 @@ class ConnectionManager:
3946
read_timeout: int
4047
read_max_chunk_size: int
4148
write_retry_attempts: int
49+
check_server_alive_interval_factor: int
4250

4351
_active_connection_state: ActiveConnectionState | None = field(default=None, init=False)
44-
_reconnect_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
52+
_reconnect_lock: asyncio.Lock = field(init=False, default_factory=asyncio.Lock)
53+
_task_group: asyncio.TaskGroup = field(init=False, default_factory=asyncio.TaskGroup)
54+
_send_heartbeat_task: asyncio.Task[None] = field(init=False, repr=False)
55+
_check_server_heartbeat_task: asyncio.Task[None] = field(init=False, repr=False)
4556

4657
async def __aenter__(self) -> Self:
58+
await self._task_group.__aenter__()
59+
self._send_heartbeat_task = self._task_group.create_task(asyncio.sleep(0))
60+
self._check_server_heartbeat_task = self._task_group.create_task(asyncio.sleep(0))
4761
self._active_connection_state = await self._get_active_connection_state()
4862
return self
4963

5064
async def __aexit__(
5165
self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None
5266
) -> None:
67+
self._send_heartbeat_task.cancel()
68+
self._check_server_heartbeat_task.cancel()
69+
await asyncio.wait([self._send_heartbeat_task, self._check_server_heartbeat_task])
70+
await self._task_group.__aexit__(exc_type, exc_value, traceback)
71+
5372
if not self._active_connection_state:
5473
return
5574
try:
@@ -58,7 +77,34 @@ async def __aexit__(
5877
return
5978
await self._active_connection_state.connection.close()
6079

61-
async def _create_connection_to_one_server(self, server: ConnectionParameters) -> ActiveConnectionState | None:
80+
def _restart_heartbeat_tasks(self, server_heartbeat: Heartbeat) -> None:
81+
self._send_heartbeat_task.cancel()
82+
self._check_server_heartbeat_task.cancel()
83+
self._send_heartbeat_task = self._task_group.create_task(
84+
self._send_heartbeats_forever(server_heartbeat.want_to_receive_interval_ms)
85+
)
86+
self._check_server_heartbeat_task = self._task_group.create_task(
87+
self._check_server_heartbeat_forever(server_heartbeat.will_send_interval_ms)
88+
)
89+
90+
async def _send_heartbeats_forever(self, send_heartbeat_interval_ms: int) -> None:
91+
send_heartbeat_interval_seconds = send_heartbeat_interval_ms / 1000
92+
while True:
93+
await self.write_heartbeat_reconnecting()
94+
await asyncio.sleep(send_heartbeat_interval_seconds)
95+
96+
async def _check_server_heartbeat_forever(self, receive_heartbeat_interval_ms: int) -> None:
97+
receive_heartbeat_interval_seconds = receive_heartbeat_interval_ms / 1000
98+
while True:
99+
await asyncio.sleep(receive_heartbeat_interval_seconds * self.check_server_alive_interval_factor)
100+
if not self._active_connection_state:
101+
continue
102+
if not self._active_connection_state.is_alive():
103+
self._active_connection_state = None
104+
105+
async def _create_connection_to_one_server(
106+
self, server: ConnectionParameters
107+
) -> tuple[AbstractConnection, ConnectionParameters] | None:
62108
if connection := await self.connection_class.connect(
63109
host=server.host,
64110
port=server.port,
@@ -67,31 +113,41 @@ async def _create_connection_to_one_server(self, server: ConnectionParameters) -
67113
read_timeout=self.read_timeout,
68114
ssl=self.ssl,
69115
):
70-
return ActiveConnectionState(
71-
connection=connection,
72-
lifespan=self.lifespan_factory(connection=connection, connection_parameters=server),
73-
)
116+
return (connection, server)
74117
return None
75118

76-
async def _create_connection_to_any_server(self) -> ActiveConnectionState | None:
119+
async def _create_connection_to_any_server(self) -> tuple[AbstractConnection, ConnectionParameters] | None:
77120
for maybe_connection_future in asyncio.as_completed(
78121
[self._create_connection_to_one_server(server) for server in self.servers]
79122
):
80-
if connection_state := await maybe_connection_future:
81-
return connection_state
123+
if connection_and_server := await maybe_connection_future:
124+
return connection_and_server
82125
return None
83126

84127
async def _connect_to_any_server(self) -> ActiveConnectionState | AnyConnectionIssue:
85-
if not (active_connection_state := await self._create_connection_to_any_server()):
128+
from stompman.connection_lifespan import EstablishedConnectionResult # noqa: PLC0415
129+
130+
if not (connection_and_server := await self._create_connection_to_any_server()):
86131
return AllServersUnavailable(servers=self.servers, timeout=self.connect_timeout)
132+
connection, connection_parameters = connection_and_server
133+
lifespan = self.lifespan_factory(
134+
connection=connection,
135+
connection_parameters=connection_parameters,
136+
set_heartbeat_interval=self._restart_heartbeat_tasks,
137+
)
87138

88139
try:
89-
if connection_issue := await active_connection_state.lifespan.enter():
90-
return connection_issue
140+
connection_result = await lifespan.enter()
91141
except ConnectionLostError:
92142
return ConnectionLost()
93143

94-
return active_connection_state
144+
return (
145+
ActiveConnectionState(
146+
connection=connection, lifespan=lifespan, server_heartbeat=connection_result.server_heartbeat
147+
)
148+
if isinstance(connection_result, EstablishedConnectionResult)
149+
else connection_result
150+
)
95151

96152
async def _get_active_connection_state(self) -> ActiveConnectionState:
97153
if self._active_connection_state:

0 commit comments

Comments
 (0)