Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions packages/stompman/stompman/connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,21 @@ class ConnectionManager:
_reconnect_lock: asyncio.Lock = field(init=False, default_factory=asyncio.Lock)
_task_group: asyncio.TaskGroup = field(init=False, default_factory=asyncio.TaskGroup)
_send_heartbeat_task: asyncio.Task[None] = field(init=False, repr=False)
_check_server_heartbeat_task: asyncio.Task[None] = field(init=False, repr=False)

async def __aenter__(self) -> Self:
await self._task_group.__aenter__()
self._send_heartbeat_task = self._task_group.create_task(asyncio.sleep(0))
self._check_server_heartbeat_task = self._task_group.create_task(asyncio.sleep(0))
self._active_connection_state = await self._get_active_connection_state(is_initial_call=True)
return self

async def __aexit__(
self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None
) -> None:
self._send_heartbeat_task.cancel()
await asyncio.wait([self._send_heartbeat_task])
self._check_server_heartbeat_task.cancel()
await asyncio.wait([self._send_heartbeat_task, self._check_server_heartbeat_task])
await self._task_group.__aexit__(exc_type, exc_value, traceback)

if not self._active_connection_state:
Expand All @@ -77,18 +80,31 @@ async def __aexit__(
return
await self._active_connection_state.connection.close()

def _restart_heartbeat_task(self, server_heartbeat: Heartbeat) -> None:
def _restart_heartbeat_tasks(self, server_heartbeat: Heartbeat) -> None:
self._send_heartbeat_task.cancel()
self._check_server_heartbeat_task.cancel()
self._send_heartbeat_task = self._task_group.create_task(
self._send_heartbeats_forever(server_heartbeat.want_to_receive_interval_ms)
)
self._check_server_heartbeat_task = self._task_group.create_task(
self._check_server_heartbeat_forever(server_heartbeat.will_send_interval_ms)
)

async def _send_heartbeats_forever(self, send_heartbeat_interval_ms: int) -> None:
send_heartbeat_interval_seconds = send_heartbeat_interval_ms / 1000
while True:
await self.write_heartbeat_reconnecting()
await asyncio.sleep(send_heartbeat_interval_seconds)

async def _check_server_heartbeat_forever(self, receive_heartbeat_interval_ms: int) -> None:
receive_heartbeat_interval_seconds = receive_heartbeat_interval_ms / 1000
while True:
await asyncio.sleep(receive_heartbeat_interval_seconds * self.check_server_alive_interval_factor)
if not self._active_connection_state:
continue
if not self._active_connection_state.is_alive(self.check_server_alive_interval_factor):
self._clear_active_connection_state()

async def _create_connection_to_one_server(
self, server: ConnectionParameters
) -> tuple[AbstractConnection, ConnectionParameters] | None:
Expand Down Expand Up @@ -119,7 +135,7 @@ async def _connect_to_any_server(self) -> ActiveConnectionState | AnyConnectionI
lifespan = self.lifespan_factory(
connection=connection,
connection_parameters=connection_parameters,
set_heartbeat_interval=self._restart_heartbeat_task,
set_heartbeat_interval=self._restart_heartbeat_tasks,
)

try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ async def mock_sleep(delay: float) -> None:
async with EnrichedClient(connection_class=connection_class):
await real_sleep(0)

assert sleep_calls == [0, 1, 1, 1]
assert sleep_calls == [0, 0, 1, 3, 1, 3, 1, 3]
assert write_heartbeat_mock.mock_calls == [mock.call(), mock.call(), mock.call(), mock.call()]


Expand Down
4 changes: 2 additions & 2 deletions packages/stompman/test_stompman/test_connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,12 @@ async def test_get_active_connection_state_lifespan_flaky_ok() -> None:
mock.call(
connection=BaseMockConnection(),
connection_parameters=manager.servers[0],
set_heartbeat_interval=manager._restart_heartbeat_task,
set_heartbeat_interval=manager._restart_heartbeat_tasks,
),
mock.call(
connection=BaseMockConnection(),
connection_parameters=manager.servers[0],
set_heartbeat_interval=manager._restart_heartbeat_task,
set_heartbeat_interval=manager._restart_heartbeat_tasks,
),
]

Expand Down