Skip to content

Commit c7191b2

Browse files
committed
Refactor heartbeat management by moving tasks to ConnectionManager and updating ConnectionLifespanFactory
1 parent 3c63e89 commit c7191b2

File tree

5 files changed

+61
-40
lines changed

5 files changed

+61
-40
lines changed

packages/stompman/stompman/client.py

Lines changed: 1 addition & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,6 @@ class Client:
5050
_active_subscriptions: ActiveSubscriptions = field(default_factory=dict, init=False)
5151
_active_transactions: set[Transaction] = field(default_factory=set, init=False)
5252
_exit_stack: AsyncExitStack = field(default_factory=AsyncExitStack, init=False)
53-
_send_heartbeat_task: asyncio.Task[None] = field(init=False)
54-
_check_server_heartbeat_task: asyncio.Task[None] = field(init=False)
5553
_listen_task: asyncio.Task[None] = field(init=False)
5654
_task_group: asyncio.TaskGroup = field(init=False)
5755

@@ -66,7 +64,6 @@ def __post_init__(self) -> None:
6664
disconnect_confirmation_timeout=self.disconnect_confirmation_timeout,
6765
active_subscriptions=self._active_subscriptions,
6866
active_transactions=self._active_transactions,
69-
set_heartbeat_interval=self._restart_heartbeat_tasks,
7067
),
7168
connection_class=self.connection_class,
7269
connect_retry_attempts=self.connect_retry_attempts,
@@ -81,9 +78,6 @@ def __post_init__(self) -> None:
8178

8279
async def __aenter__(self) -> Self:
8380
self._task_group = await self._exit_stack.enter_async_context(asyncio.TaskGroup())
84-
# TODO: move this tasks to lifespan
85-
self._send_heartbeat_task = self._task_group.create_task(asyncio.sleep(0))
86-
self._check_server_heartbeat_task = self._task_group.create_task(asyncio.sleep(0))
8781
await self._exit_stack.enter_async_context(self._connection_manager)
8882
self._listen_task = self._task_group.create_task(self._listen_to_frames())
8983
return self
@@ -96,36 +90,9 @@ async def __aexit__(
9690
await asyncio.Future()
9791
finally:
9892
self._listen_task.cancel()
99-
self._send_heartbeat_task.cancel()
100-
self._check_server_heartbeat_task.cancel()
101-
await asyncio.wait([self._listen_task, self._send_heartbeat_task, self._check_server_heartbeat_task])
93+
await asyncio.wait([self._listen_task])
10294
await self._exit_stack.aclose()
10395

104-
def _restart_heartbeat_tasks(self, server_heartbeat: Heartbeat) -> None:
105-
self._send_heartbeat_task.cancel()
106-
self._check_server_heartbeat_task.cancel()
107-
print(server_heartbeat)
108-
self._send_heartbeat_task = self._task_group.create_task(
109-
self._send_heartbeats_forever(server_heartbeat.want_to_receive_interval_ms)
110-
)
111-
self._check_server_heartbeat_task = self._task_group.create_task(
112-
self._monitor_server_aliveness(server_heartbeat.will_send_interval_ms)
113-
)
114-
115-
async def _send_heartbeats_forever(self, client_heartbeat_interval_ms: int) -> None:
116-
while True:
117-
await self._connection_manager.write_heartbeat_reconnecting()
118-
await asyncio.sleep(client_heartbeat_interval_ms / 1000)
119-
120-
async def _monitor_server_aliveness(self, server_heartbeat_interval_ms: int) -> None:
121-
server_heartbeat_interval_seconds = server_heartbeat_interval_ms / 1000
122-
while True:
123-
await asyncio.sleep(server_heartbeat_interval_seconds * self.check_server_alive_interval_factor)
124-
if not self._connection_manager._active_connection_state:
125-
continue
126-
if not self._connection_manager._active_connection_state.is_alive():
127-
self._connection_manager._active_connection_state = None
128-
12996
async def _listen_to_frames(self) -> None:
13097
async with asyncio.TaskGroup() as task_group:
13198
async for frame in self._connection_manager.read_frames_reconnecting():

packages/stompman/stompman/connection_lifespan.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,5 +110,9 @@ def _make_receipt_id() -> str:
110110

111111
class ConnectionLifespanFactory(Protocol):
112112
def __call__(
113-
self, *, connection: AbstractConnection, connection_parameters: ConnectionParameters
113+
self,
114+
*,
115+
connection: AbstractConnection,
116+
connection_parameters: ConnectionParameters,
117+
set_heartbeat_interval: Callable[[Heartbeat], Any],
114118
) -> AbstractConnectionLifespan: ...

packages/stompman/stompman/connection_manager.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,26 @@ class ConnectionManager:
4949
check_server_alive_interval_factor: int
5050

5151
_active_connection_state: ActiveConnectionState | None = field(default=None, init=False)
52-
_reconnect_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
52+
_reconnect_lock: asyncio.Lock = field(init=False, default_factory=asyncio.Lock)
53+
_task_group: asyncio.TaskGroup = field(init=False, default_factory=asyncio.TaskGroup)
54+
_send_heartbeat_task: asyncio.Task[None] = field(init=False, repr=False)
55+
_check_server_heartbeat_task: asyncio.Task[None] = field(init=False, repr=False)
5356

5457
async def __aenter__(self) -> Self:
58+
await self._task_group.__aenter__()
59+
self._send_heartbeat_task = self._task_group.create_task(asyncio.sleep(0))
60+
self._check_server_heartbeat_task = self._task_group.create_task(asyncio.sleep(0))
5561
self._active_connection_state = await self._get_active_connection_state()
5662
return self
5763

5864
async def __aexit__(
5965
self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None
6066
) -> None:
67+
self._send_heartbeat_task.cancel()
68+
self._check_server_heartbeat_task.cancel()
69+
await asyncio.wait([self._send_heartbeat_task, self._check_server_heartbeat_task])
70+
await self._task_group.__aexit__(exc_type, exc_value, traceback)
71+
6172
if not self._active_connection_state:
6273
return
6374
try:
@@ -66,6 +77,31 @@ async def __aexit__(
6677
return
6778
await self._active_connection_state.connection.close()
6879

80+
def _restart_heartbeat_tasks(self, server_heartbeat: Heartbeat) -> None:
81+
self._send_heartbeat_task.cancel()
82+
self._check_server_heartbeat_task.cancel()
83+
self._send_heartbeat_task = self._task_group.create_task(
84+
self._send_heartbeats_forever(server_heartbeat.want_to_receive_interval_ms)
85+
)
86+
self._check_server_heartbeat_task = self._task_group.create_task(
87+
self._check_server_heartbeat_forever(server_heartbeat.will_send_interval_ms)
88+
)
89+
90+
async def _send_heartbeats_forever(self, send_heartbeat_interval_ms: int) -> None:
91+
send_heartbeat_interval_seconds = send_heartbeat_interval_ms / 1000
92+
while True:
93+
await self.write_heartbeat_reconnecting()
94+
await asyncio.sleep(send_heartbeat_interval_seconds)
95+
96+
async def _check_server_heartbeat_forever(self, receive_heartbeat_interval_ms: int) -> None:
97+
receive_heartbeat_interval_seconds = receive_heartbeat_interval_ms / 1000
98+
while True:
99+
await asyncio.sleep(receive_heartbeat_interval_seconds * self.check_server_alive_interval_factor)
100+
if not self._active_connection_state:
101+
continue
102+
if not self._active_connection_state.is_alive():
103+
self._active_connection_state = None
104+
69105
async def _create_connection_to_one_server(
70106
self, server: ConnectionParameters
71107
) -> tuple[AbstractConnection, ConnectionParameters] | None:
@@ -94,7 +130,11 @@ async def _connect_to_any_server(self) -> ActiveConnectionState | AnyConnectionI
94130
if not (connection_and_server := await self._create_connection_to_any_server()):
95131
return AllServersUnavailable(servers=self.servers, timeout=self.connect_timeout)
96132
connection, connection_parameters = connection_and_server
97-
lifespan = self.lifespan_factory(connection=connection, connection_parameters=connection_parameters)
133+
lifespan = self.lifespan_factory(
134+
connection=connection,
135+
connection_parameters=connection_parameters,
136+
set_heartbeat_interval=self._restart_heartbeat_tasks,
137+
)
98138

99139
try:
100140
connection_result = await lifespan.enter()

packages/stompman/test_stompman/conftest.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import asyncio
2-
from collections.abc import AsyncGenerator
2+
from collections.abc import AsyncGenerator, Callable
33
from dataclasses import dataclass, field
44
from ssl import SSLContext
55
from typing import Any, Literal, Self, TypeVar
66

77
import pytest
88
import stompman
99
from polyfactory.factories.dataclass_factory import DataclassFactory
10+
from stompman.config import Heartbeat
1011
from stompman.connection import AbstractConnection
1112
from stompman.connection_lifespan import AbstractConnectionLifespan, EstablishedConnectionResult
1213
from stompman.connection_manager import ConnectionManager
@@ -58,6 +59,7 @@ class EnrichedClient(stompman.Client):
5859
class NoopLifespan(AbstractConnectionLifespan):
5960
connection: AbstractConnection
6061
connection_parameters: stompman.ConnectionParameters
62+
set_heartbeat_interval: Callable[[Heartbeat], Any]
6163

6264
async def enter(self) -> EstablishedConnectionResult | stompman.StompProtocolConnectionIssue:
6365
return EstablishedConnectionResult(server_heartbeat=stompman.Heartbeat(1000, 1000))

packages/stompman/test_stompman/test_connection_manager.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,16 @@ async def test_get_active_connection_state_lifespan_flaky_ok() -> None:
140140

141141
assert enter.mock_calls == [mock.call(), mock.call()]
142142
assert lifespan_factory.mock_calls == [
143-
mock.call(connection=BaseMockConnection(), connection_parameters=manager.servers[0]),
144-
mock.call(connection=BaseMockConnection(), connection_parameters=manager.servers[0]),
143+
mock.call(
144+
connection=BaseMockConnection(),
145+
connection_parameters=manager.servers[0],
146+
set_heartbeat_interval=manager._restart_heartbeat_tasks,
147+
),
148+
mock.call(
149+
connection=BaseMockConnection(),
150+
connection_parameters=manager.servers[0],
151+
set_heartbeat_interval=manager._restart_heartbeat_tasks,
152+
),
145153
]
146154

147155

0 commit comments

Comments
 (0)