Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions packages/stompman/stompman/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ class Client:
connection_confirmation_timeout: int = 2
disconnect_confirmation_timeout: int = 2
check_server_alive_interval_factor: int = 3
_ff_disable_server_heartbeat_check: bool = False
"""Client will check if server alive `server heartbeat interval` times `interval factor`"""

connection_class: type[AbstractConnection] = Connection
Expand Down Expand Up @@ -75,7 +74,6 @@ def __post_init__(self) -> None:
read_max_chunk_size=self.read_max_chunk_size,
write_retry_attempts=self.write_retry_attempts,
check_server_alive_interval_factor=self.check_server_alive_interval_factor,
_ff_disable_server_heartbeat_check=self._ff_disable_server_heartbeat_check,
ssl=self.ssl,
)

Expand Down
21 changes: 1 addition & 20 deletions packages/stompman/stompman/connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,27 +50,23 @@ class ConnectionManager:
read_max_chunk_size: int
write_retry_attempts: int
check_server_alive_interval_factor: int
_ff_disable_server_heartbeat_check: bool

_active_connection_state: ActiveConnectionState | None = field(default=None, init=False)
_reconnect_lock: asyncio.Lock = field(init=False, default_factory=asyncio.Lock)
_task_group: asyncio.TaskGroup = field(init=False, default_factory=asyncio.TaskGroup)
_send_heartbeat_task: asyncio.Task[None] = field(init=False, repr=False)
_check_server_heartbeat_task: asyncio.Task[None] = field(init=False, repr=False)

async def __aenter__(self) -> Self:
await self._task_group.__aenter__()
self._send_heartbeat_task = self._task_group.create_task(asyncio.sleep(0))
self._check_server_heartbeat_task = self._task_group.create_task(asyncio.sleep(0))
self._active_connection_state = await self._get_active_connection_state(is_initial_call=True)
return self

async def __aexit__(
self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None
) -> None:
self._send_heartbeat_task.cancel()
self._check_server_heartbeat_task.cancel()
await asyncio.wait([self._send_heartbeat_task, self._check_server_heartbeat_task])
await asyncio.wait([self._send_heartbeat_task])
await self._task_group.__aexit__(exc_type, exc_value, traceback)

if not self._active_connection_state:
Expand All @@ -83,31 +79,16 @@ async def __aexit__(

def _restart_heartbeat_tasks(self, server_heartbeat: Heartbeat) -> None:
self._send_heartbeat_task.cancel()
self._check_server_heartbeat_task.cancel()
self._send_heartbeat_task = self._task_group.create_task(
self._send_heartbeats_forever(server_heartbeat.want_to_receive_interval_ms)
)
self._check_server_heartbeat_task = (
self._task_group.create_task(self._check_server_heartbeat_forever(server_heartbeat.will_send_interval_ms))
if not self._ff_disable_server_heartbeat_check
else self._task_group.create_task(asyncio.sleep(0))
)

async def _send_heartbeats_forever(self, send_heartbeat_interval_ms: int) -> None:
send_heartbeat_interval_seconds = send_heartbeat_interval_ms / 1000
while True:
await self.write_heartbeat_reconnecting()
await asyncio.sleep(send_heartbeat_interval_seconds)

async def _check_server_heartbeat_forever(self, receive_heartbeat_interval_ms: int) -> None:
receive_heartbeat_interval_seconds = receive_heartbeat_interval_ms / 1000
while True:
await asyncio.sleep(receive_heartbeat_interval_seconds * self.check_server_alive_interval_factor)
if not self._active_connection_state:
continue
if not self._active_connection_state.is_alive(self.check_server_alive_interval_factor):
self._clear_active_connection_state(ConnectionLostError(reason="server heartbeat timeout"))

async def _create_connection_to_one_server(
self, server: ConnectionParameters
) -> tuple[AbstractConnection, ConnectionParameters] | None:
Expand Down
1 change: 0 additions & 1 deletion packages/stompman/test_stompman/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ class EnrichedConnectionManager(ConnectionManager):
write_retry_attempts: int = 3
ssl: Literal[True] | SSLContext | None = None
check_server_alive_interval_factor: int = 3
_ff_disable_server_heartbeat_check: bool = False


DataclassType = TypeVar("DataclassType")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ async def mock_sleep(delay: float) -> None:
async with EnrichedClient(connection_class=connection_class):
await real_sleep(0)

assert sleep_calls == [0, 0, 1, 3, 1, 3, 1, 3]
assert sleep_calls == [0, 1, 1, 1]
assert write_heartbeat_mock.mock_calls == [mock.call(), mock.call(), mock.call(), mock.call()]


Expand Down