Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions packages/stompman/stompman/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _reraise_connection_lost(*causes: type[Exception]) -> Generator[None, None,
try:
yield
except causes as exception:
raise ConnectionLostError from exception
raise ConnectionLostError(reason=exception) from exception


@dataclass(kw_only=True, slots=True)
Expand Down Expand Up @@ -86,7 +86,7 @@ async def write_frame(self, frame: AnyClientFrame) -> None:

async def _read_non_empty_bytes(self, max_chunk_size: int) -> bytes:
if (chunk := await self.reader.read(max_chunk_size)) == b"":
raise ConnectionLostError
raise ConnectionLostError(reason="eof")
return chunk

async def read_frames(self) -> AsyncGenerator[AnyServerFrame, None]:
Expand Down
25 changes: 13 additions & 12 deletions packages/stompman/stompman/connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ async def _check_server_heartbeat_forever(self, receive_heartbeat_interval_ms: i
if not self._active_connection_state:
continue
if not self._active_connection_state.is_alive(self.check_server_alive_interval_factor):
self._clear_active_connection_state()
self._clear_active_connection_state(ConnectionLostError(reason="server heartbeat timeout"))

async def _create_connection_to_one_server(
self, server: ConnectionParameters
Expand Down Expand Up @@ -168,7 +168,7 @@ async def _get_active_connection_state(self, *, is_initial_call: bool = False) -
self._active_connection_state = connection_result
if not is_initial_call:
LOGGER.warning(
"reconnected after failure connection failure. connection_parameters: %s",
"reconnected after connection failure. connection_parameters: %s",
connection_result.lifespan.connection_parameters,
)
return connection_result
Expand All @@ -178,11 +178,12 @@ async def _get_active_connection_state(self, *, is_initial_call: bool = False) -

raise FailedAllConnectAttemptsError(retry_attempts=self.connect_retry_attempts, issues=connection_issues)

def _clear_active_connection_state(self) -> None:
def _clear_active_connection_state(self, error_reason: ConnectionLostError) -> None:
if not self._active_connection_state:
return
LOGGER.warning(
"connection lost. connection_parameters: %s",
"connection lost. reason: %s, connection_parameters: %s",
error_reason.reason,
self._active_connection_state.lifespan.connection_parameters,
)
self._active_connection_state = None
Expand All @@ -192,8 +193,8 @@ async def write_heartbeat_reconnecting(self) -> None:
connection_state = await self._get_active_connection_state()
try:
return connection_state.connection.write_heartbeat()
except ConnectionLostError:
self._clear_active_connection_state()
except ConnectionLostError as error:
self._clear_active_connection_state(error)

raise FailedAllWriteAttemptsError(retry_attempts=self.write_retry_attempts)

Expand All @@ -202,8 +203,8 @@ async def write_frame_reconnecting(self, frame: AnyClientFrame) -> None:
connection_state = await self._get_active_connection_state()
try:
return await connection_state.connection.write_frame(frame)
except ConnectionLostError:
self._clear_active_connection_state()
except ConnectionLostError as error:
self._clear_active_connection_state(error)

raise FailedAllWriteAttemptsError(retry_attempts=self.write_retry_attempts)

Expand All @@ -213,15 +214,15 @@ async def read_frames_reconnecting(self) -> AsyncGenerator[AnyServerFrame, None]
try:
async for frame in connection_state.connection.read_frames():
yield frame
except ConnectionLostError:
self._clear_active_connection_state()
except ConnectionLostError as error:
self._clear_active_connection_state(error)

async def maybe_write_frame(self, frame: AnyClientFrame) -> bool:
if not self._active_connection_state:
return False
try:
await self._active_connection_state.connection.write_frame(frame)
except ConnectionLostError:
self._clear_active_connection_state()
except ConnectionLostError as error:
self._clear_active_connection_state(error)
return False
return True
2 changes: 2 additions & 0 deletions packages/stompman/stompman/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ def __str__(self) -> str:
class ConnectionLostError(Error):
"""Raised in stompman.AbstractConnection—and handled in stompman.ConnectionManager, therefore is private."""

reason: Exception | str


@dataclass(frozen=True, kw_only=True, slots=True)
class ConnectionConfirmationTimeout:
Expand Down
38 changes: 28 additions & 10 deletions packages/stompman/test_stompman/test_connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,9 @@ class MockConnection(BaseMockConnection):


async def test_get_active_connection_state_lifespan_flaky_ok() -> None:
enter = mock.AsyncMock(side_effect=[ConnectionLostError, build_dataclass(EstablishedConnectionResult)])
enter = mock.AsyncMock(
side_effect=[build_dataclass(ConnectionLostError), build_dataclass(EstablishedConnectionResult)]
)
lifespan_factory = mock.Mock(return_value=mock.Mock(enter=enter))
manager = EnrichedConnectionManager(lifespan_factory=lifespan_factory, connection_class=BaseMockConnection)

Expand All @@ -154,7 +156,7 @@ async def test_get_active_connection_state_lifespan_flaky_ok() -> None:


async def test_get_active_connection_state_lifespan_flaky_fails() -> None:
enter = mock.AsyncMock(side_effect=ConnectionLostError)
enter = mock.AsyncMock(side_effect=build_dataclass(ConnectionLostError))
lifespan_factory = mock.Mock(return_value=mock.Mock(enter=enter))
manager = EnrichedConnectionManager(lifespan_factory=lifespan_factory, connection_class=BaseMockConnection)

Expand Down Expand Up @@ -209,16 +211,16 @@ async def test_get_active_connection_state_ok_concurrent() -> None:

async def test_connection_manager_context_connection_lost() -> None:
async with EnrichedConnectionManager(connection_class=BaseMockConnection) as manager:
manager._clear_active_connection_state()
manager._clear_active_connection_state()
manager._clear_active_connection_state(build_dataclass(ConnectionLostError))
manager._clear_active_connection_state(build_dataclass(ConnectionLostError))


async def test_connection_manager_context_lifespan_aexit_raises_connection_lost() -> None:
async with EnrichedConnectionManager(
lifespan_factory=mock.Mock(
return_value=mock.Mock(
enter=mock.AsyncMock(return_value=build_dataclass(EstablishedConnectionResult)),
exit=mock.AsyncMock(side_effect=[ConnectionLostError]),
exit=mock.AsyncMock(side_effect=[build_dataclass(ConnectionLostError)]),
)
),
connection_class=BaseMockConnection,
Expand Down Expand Up @@ -248,7 +250,13 @@ class MockConnection(BaseMockConnection):


async def test_write_heartbeat_reconnecting_raises() -> None:
write_heartbeat_mock = mock.Mock(side_effect=[ConnectionLostError, ConnectionLostError, ConnectionLostError])
write_heartbeat_mock = mock.Mock(
side_effect=[
build_dataclass(ConnectionLostError),
build_dataclass(ConnectionLostError),
build_dataclass(ConnectionLostError),
]
)

class MockConnection(BaseMockConnection):
write_heartbeat = write_heartbeat_mock
Expand All @@ -260,7 +268,13 @@ class MockConnection(BaseMockConnection):


async def test_write_frame_reconnecting_raises() -> None:
write_frame_mock = mock.AsyncMock(side_effect=[ConnectionLostError, ConnectionLostError, ConnectionLostError])
write_frame_mock = mock.AsyncMock(
side_effect=[
build_dataclass(ConnectionLostError),
build_dataclass(ConnectionLostError),
build_dataclass(ConnectionLostError),
]
)

class MockConnection(BaseMockConnection):
write_frame = write_frame_mock
Expand All @@ -271,7 +285,11 @@ class MockConnection(BaseMockConnection):
await manager.write_frame_reconnecting(build_dataclass(ConnectFrame))


SIDE_EFFECTS = [(None,), (ConnectionLostError(), None), (ConnectionLostError(), ConnectionLostError(), None)]
SIDE_EFFECTS = [
(None,),
(build_dataclass(ConnectionLostError), None),
(build_dataclass(ConnectionLostError), build_dataclass(ConnectionLostError), None),
]


@pytest.mark.parametrize("side_effect", SIDE_EFFECTS)
Expand Down Expand Up @@ -318,7 +336,7 @@ async def read_frames() -> AsyncGenerator[AnyServerFrame, None]:
attempt += 1
current_effect = side_effect[attempt]
if isinstance(current_effect, ConnectionLostError):
raise ConnectionLostError
raise current_effect
for frame in frames:
yield frame

Expand All @@ -339,7 +357,7 @@ async def test_maybe_write_frame_connection_already_lost() -> None:

async def test_maybe_write_frame_connection_now_lost() -> None:
class MockConnection(BaseMockConnection):
write_frame = mock.AsyncMock(side_effect=[ConnectionLostError])
write_frame = mock.AsyncMock(side_effect=[build_dataclass(ConnectionLostError)])

async with EnrichedConnectionManager(connection_class=MockConnection) as manager:
assert not await manager.maybe_write_frame(build_dataclass(ConnectFrame))
Expand Down
5 changes: 3 additions & 2 deletions packages/stompman/test_stompman/test_subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
SubscribeFrame,
UnsubscribeFrame,
)
from stompman.errors import ConnectionLostError

from test_stompman.conftest import (
CONNECT_FRAME,
Expand Down Expand Up @@ -52,7 +53,7 @@ async def test_client_subscriptions_lifespan_resubscribe(ack: AckMode, faker: fa
headers=sub_extra_headers,
on_suppressed_exception=noop_error_handler,
)
client._connection_manager._clear_active_connection_state()
client._connection_manager._clear_active_connection_state(build_dataclass(ConnectionLostError))
await client.send(message_body, destination=message_destination)
await subscription.unsubscribe()
await asyncio.sleep(0)
Expand Down Expand Up @@ -352,7 +353,7 @@ async def test_client_listen_raises_on_aexit(monkeypatch: pytest.MonkeyPatch, fa

async def close_connection_soon(client: stompman.Client) -> None:
await asyncio.sleep(0)
client._connection_manager._clear_active_connection_state()
client._connection_manager._clear_active_connection_state(build_dataclass(ConnectionLostError))

with pytest.raises(ExceptionGroup) as exc_info: # noqa: PT012
async with asyncio.TaskGroup() as task_group, EnrichedClient(connection_class=connection_class) as client:
Expand Down
4 changes: 3 additions & 1 deletion packages/stompman/test_stompman/test_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
CommitFrame,
SendFrame,
)
from stompman.errors import ConnectionLostError

from test_stompman.conftest import (
CONNECT_FRAME,
CONNECTED_FRAME,
EnrichedClient,
SomeError,
build_dataclass,
create_spying_connection,
enrich_expected_frames,
get_read_frames_with_lifespan,
Expand Down Expand Up @@ -85,7 +87,7 @@ async def test_commit_pending_transactions(monkeypatch: pytest.MonkeyPatch, fake
async with EnrichedClient(connection_class=connection_class) as client:
async with client.begin() as first_transaction:
await first_transaction.send(body, destination=destination)
client._connection_manager._clear_active_connection_state()
client._connection_manager._clear_active_connection_state(build_dataclass(ConnectionLostError))
async with client.begin() as second_transaction:
await second_transaction.send(body, destination=destination)
await asyncio.sleep(0)
Expand Down