Skip to content

Commit b164548

Browse files
authored
Fix weird retries (#54)
1 parent c5023db commit b164548

File tree

10 files changed

+102
-94
lines changed

10 files changed

+102
-94
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ async with stompman.Client(
3737
connection_confirmation_timeout=2,
3838
disconnect_confirmation_timeout=2,
3939
read_timeout=2,
40+
write_retry_attempts=3,
4041
) as client:
4142
...
4243
```
@@ -107,7 +108,7 @@ stompman takes care of cleaning up resources automatically. When you leave the c
107108

108109
- If multiple servers were provided, stompman will attempt to connect to each one simultaneously and will use the first that succeeds. If all servers fail to connect, an `stompman.FailedAllConnectAttemptsError` will be raised. In normal situation it doesn't need to be handled: tune retry and timeout parameters in `stompman.Client()` to your needs.
109110

110-
- When connection is lost, stompman will handle it automatically. `stompman.FailedAllConnectAttemptsError` will be raised if all connection attempts fail. `stompman.RepeatedConnectionLostError` or `stompman.ConnectionLostDuringOperationError` will be raised if connection succeeds but operation (like sending a frame) leads to connection getting lost.
111+
- When connection is lost, stompman will attempt to handle it automatically. `stompman.FailedAllConnectAttemptsError` will be raised if all connection attempts fail. `stompman.FailedAllWriteAttemptsError` will be raised if connection succeeds but sending a frame or heartbeat lead to losing connection.
111112

112113
### ...and caveats
113114

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ repository = "https://github.com/vrslev/stompman"
2222
[tool.uv]
2323
dev-dependencies = [
2424
"anyio~=4.4.0",
25-
"mypy~=1.10.0",
25+
"mypy~=1.11.1",
2626
"pytest-cov~=5.0.0",
2727
"pytest~=8.2.2",
2828
"ruff~=0.4.9",

stompman/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,11 @@
88
ConnectionManager,
99
)
1010
from stompman.errors import (
11-
ConnectionAttemptsFailedError,
1211
ConnectionConfirmationTimeout,
13-
ConnectionLostDuringOperationError,
1412
ConnectionLostError,
1513
Error,
1614
FailedAllConnectAttemptsError,
15+
FailedAllWriteAttemptsError,
1716
StompProtocolConnectionIssue,
1817
UnsupportedProtocolVersion,
1918
)
@@ -56,18 +55,17 @@
5655
"ConnectFrame",
5756
"ConnectedFrame",
5857
"Connection",
59-
"ConnectionAttemptsFailedError",
6058
"ConnectionConfirmationTimeout",
6159
"ConnectionLifespan",
6260
"ConnectionLifespanFactory",
63-
"ConnectionLostDuringOperationError",
6461
"ConnectionLostError",
6562
"ConnectionManager",
6663
"ConnectionParameters",
6764
"DisconnectFrame",
6865
"Error",
6966
"ErrorFrame",
7067
"FailedAllConnectAttemptsError",
68+
"FailedAllWriteAttemptsError",
7169
"FrameParser",
7270
"Heartbeat",
7371
"HeartbeatFrame",

stompman/client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ class Client:
225225
connect_timeout: int = 2
226226
read_timeout: int = 2
227227
read_max_chunk_size: int = 1024 * 1024
228+
write_retry_attempts: int = 3
228229
connection_confirmation_timeout: int = 2
229230
disconnect_confirmation_timeout: int = 2
230231

@@ -257,6 +258,7 @@ def __post_init__(self) -> None:
257258
connect_timeout=self.connect_timeout,
258259
read_timeout=self.read_timeout,
259260
read_max_chunk_size=self.read_max_chunk_size,
261+
write_retry_attempts=self.write_retry_attempts,
260262
)
261263

262264
async def __aenter__(self) -> Self:

stompman/connection_manager.py

Lines changed: 51 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77
from stompman.config import ConnectionParameters
88
from stompman.connection import AbstractConnection
99
from stompman.errors import (
10-
ConnectionAttemptsFailedError,
10+
AllServersUnavailable,
11+
AnyConnectionIssue,
1112
ConnectionLost,
12-
ConnectionLostDuringOperationError,
1313
ConnectionLostError,
1414
FailedAllConnectAttemptsError,
15+
FailedAllWriteAttemptsError,
1516
StompProtocolConnectionIssue,
1617
)
1718
from stompman.frames import AnyClientFrame, AnyServerFrame
@@ -44,6 +45,7 @@ class ConnectionManager:
4445
connect_timeout: int
4546
read_timeout: int
4647
read_max_chunk_size: int
48+
write_retry_attempts: int
4749

4850
_active_connection_state: ActiveConnectionState | None = field(default=None, init=False)
4951
_reconnect_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
@@ -63,88 +65,84 @@ async def __aexit__(
6365
return
6466
await self._active_connection_state.connection.close()
6567

66-
async def _connect_to_one_server(
67-
self, server: ConnectionParameters
68-
) -> tuple[AbstractConnection, ConnectionParameters] | None:
69-
for attempt in range(self.connect_retry_attempts):
70-
if connection := await self.connection_class.connect(
71-
host=server.host,
72-
port=server.port,
73-
timeout=self.connect_timeout,
74-
read_max_chunk_size=self.read_max_chunk_size,
75-
read_timeout=self.read_timeout,
76-
):
77-
return connection, server
78-
await asyncio.sleep(self.connect_retry_interval * (attempt + 1))
68+
async def _create_connection_to_one_server(self, server: ConnectionParameters) -> ActiveConnectionState | None:
69+
if connection := await self.connection_class.connect(
70+
host=server.host,
71+
port=server.port,
72+
timeout=self.connect_timeout,
73+
read_max_chunk_size=self.read_max_chunk_size,
74+
read_timeout=self.read_timeout,
75+
):
76+
return ActiveConnectionState(
77+
connection=connection,
78+
lifespan=self.lifespan_factory(connection=connection, connection_parameters=server),
79+
)
7980
return None
8081

81-
async def _connect_to_any_server(self) -> tuple[AbstractConnection, ConnectionParameters]:
82+
async def _create_connection_to_any_server(self) -> ActiveConnectionState | None:
8283
for maybe_connection_future in asyncio.as_completed(
83-
[self._connect_to_one_server(server) for server in self.servers]
84+
[self._create_connection_to_one_server(server) for server in self.servers]
8485
):
85-
if maybe_result := await maybe_connection_future:
86-
return maybe_result
87-
raise FailedAllConnectAttemptsError(
88-
servers=self.servers,
89-
retry_attempts=self.connect_retry_attempts,
90-
retry_interval=self.connect_retry_interval,
91-
timeout=self.connect_timeout,
92-
)
86+
if connection_state := await maybe_connection_future:
87+
return connection_state
88+
return None
89+
90+
async def _connect_to_any_server(self) -> ActiveConnectionState | AnyConnectionIssue:
91+
if not (active_connection_state := await self._create_connection_to_any_server()):
92+
return AllServersUnavailable(servers=self.servers, timeout=self.connect_timeout)
93+
94+
try:
95+
if connection_issue := await active_connection_state.lifespan.enter():
96+
return connection_issue
97+
except ConnectionLostError:
98+
return ConnectionLost()
99+
100+
return active_connection_state
93101

94102
async def _get_active_connection_state(self) -> ActiveConnectionState:
95-
connection_issues: list[StompProtocolConnectionIssue | ConnectionLost] = []
103+
if self._active_connection_state:
104+
return self._active_connection_state
96105

97-
for _ in range(self.connect_retry_attempts):
106+
connection_issues: list[AnyConnectionIssue] = []
107+
108+
async with self._reconnect_lock:
98109
if self._active_connection_state:
99110
return self._active_connection_state
100111

101-
async with self._reconnect_lock:
102-
if self._active_connection_state:
103-
return self._active_connection_state
104-
105-
connection, connection_parameters = await self._connect_to_any_server()
106-
self._active_connection_state = ActiveConnectionState(
107-
connection=connection,
108-
lifespan=self.lifespan_factory(connection=connection, connection_parameters=connection_parameters),
109-
)
110-
111-
try:
112-
lifespan_connection_issue: (
113-
StompProtocolConnectionIssue | ConnectionLost | None
114-
) = await self._active_connection_state.lifespan.enter()
115-
except ConnectionLostError:
116-
lifespan_connection_issue = ConnectionLost()
117-
else:
118-
if lifespan_connection_issue is None:
119-
return self._active_connection_state
112+
for attempt in range(self.connect_retry_attempts):
113+
connection_result = await self._connect_to_any_server()
120114

121-
self._clear_active_connection_state()
122-
connection_issues.append(lifespan_connection_issue)
115+
if isinstance(connection_result, ActiveConnectionState):
116+
self._active_connection_state = connection_result
117+
return connection_result
118+
119+
connection_issues.append(connection_result)
120+
await asyncio.sleep(self.connect_retry_interval * (attempt + 1))
123121

124-
raise ConnectionAttemptsFailedError(retry_attempts=self.connect_retry_attempts, issues=connection_issues)
122+
raise FailedAllConnectAttemptsError(retry_attempts=self.connect_retry_attempts, issues=connection_issues)
125123

126124
def _clear_active_connection_state(self) -> None:
127125
self._active_connection_state = None
128126

129127
async def write_heartbeat_reconnecting(self) -> None:
130-
for _ in range(self.connect_retry_attempts):
128+
for _ in range(self.write_retry_attempts):
131129
connection_state = await self._get_active_connection_state()
132130
try:
133131
return connection_state.connection.write_heartbeat()
134132
except ConnectionLostError:
135133
self._clear_active_connection_state()
136134

137-
raise ConnectionLostDuringOperationError(retry_attempts=self.connect_retry_attempts)
135+
raise FailedAllWriteAttemptsError(retry_attempts=self.write_retry_attempts)
138136

139137
async def write_frame_reconnecting(self, frame: AnyClientFrame) -> None:
140-
for _ in range(self.connect_retry_attempts):
138+
for _ in range(self.write_retry_attempts):
141139
connection_state = await self._get_active_connection_state()
142140
try:
143141
return await connection_state.connection.write_frame(frame)
144142
except ConnectionLostError:
145143
self._clear_active_connection_state()
146144

147-
raise ConnectionLostDuringOperationError(retry_attempts=self.connect_retry_attempts)
145+
raise FailedAllWriteAttemptsError(retry_attempts=self.write_retry_attempts)
148146

149147
async def read_frames_reconnecting(self) -> AsyncGenerator[AnyServerFrame, None]:
150148
while True:

stompman/errors.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,6 @@ class ConnectionLostError(Error):
1818
"""Raised in stompman.AbstractConnection—and handled in stompman.ConnectionManager, therefore is private."""
1919

2020

21-
@dataclass(kw_only=True)
22-
class FailedAllConnectAttemptsError(Error):
23-
servers: list["ConnectionParameters"]
24-
retry_attempts: int
25-
retry_interval: int
26-
timeout: int
27-
28-
2921
@dataclass(frozen=True, kw_only=True, slots=True)
3022
class ConnectionConfirmationTimeout:
3123
timeout: int
@@ -42,15 +34,22 @@ class UnsupportedProtocolVersion:
4234
class ConnectionLost: ...
4335

4436

37+
@dataclass(frozen=True, kw_only=True, slots=True)
38+
class AllServersUnavailable:
39+
servers: list["ConnectionParameters"]
40+
timeout: int
41+
42+
4543
StompProtocolConnectionIssue = ConnectionConfirmationTimeout | UnsupportedProtocolVersion
44+
AnyConnectionIssue = StompProtocolConnectionIssue | ConnectionLost | AllServersUnavailable
4645

4746

4847
@dataclass(kw_only=True)
49-
class ConnectionAttemptsFailedError(Error):
48+
class FailedAllConnectAttemptsError(Error):
5049
retry_attempts: int
51-
issues: list[StompProtocolConnectionIssue | ConnectionLost]
50+
issues: list[AnyConnectionIssue]
5251

5352

5453
@dataclass(kw_only=True)
55-
class ConnectionLostDuringOperationError(Error):
54+
class FailedAllWriteAttemptsError(Error):
5655
retry_attempts: int

tests/conftest.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from collections.abc import AsyncGenerator
33
from dataclasses import dataclass, field
44
from typing import Any, Self, TypeVar
5-
from unittest import mock
65

76
import pytest
87
from polyfactory.factories.dataclass_factory import DataclassFactory
@@ -22,7 +21,8 @@ def anyio_backend(request: pytest.FixtureRequest) -> object:
2221

2322
@pytest.fixture()
2423
def mock_sleep(monkeypatch: pytest.MonkeyPatch) -> None: # noqa: PT004
25-
monkeypatch.setattr("asyncio.sleep", mock.AsyncMock())
24+
original_sleep = asyncio.sleep
25+
monkeypatch.setattr("asyncio.sleep", lambda _: original_sleep(0))
2626

2727

2828
async def noop_message_handler(frame: stompman.MessageFrame) -> None: ...
@@ -74,6 +74,7 @@ class EnrichedConnectionManager(stompman.ConnectionManager):
7474
connect_timeout: int = 3
7575
read_timeout: int = 4
7676
read_max_chunk_size: int = 5
77+
write_retry_attempts: int = 3
7778

7879

7980
DataclassType = TypeVar("DataclassType")

tests/test_client.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
CommitFrame,
2222
ConnectedFrame,
2323
ConnectFrame,
24-
ConnectionAttemptsFailedError,
2524
ConnectionConfirmationTimeout,
2625
ConnectionParameters,
2726
DisconnectFrame,
@@ -133,6 +132,7 @@ async def test_client_connection_lifespan_ok(monkeypatch: pytest.MonkeyPatch) ->
133132
assert collected_frames == [connect_frame, connected_frame, disconnect_frame, receipt_frame]
134133

135134

135+
@pytest.mark.usefixtures("mock_sleep")
136136
async def test_client_connection_lifespan_connection_not_confirmed(monkeypatch: pytest.MonkeyPatch) -> None:
137137
async def mock_wait_for(future: Coroutine[Any, Any, Any], timeout: float) -> object:
138138
assert timeout == connection_confirmation_timeout
@@ -151,29 +151,30 @@ async def read_frames() -> AsyncGenerator[AnyServerFrame, None]:
151151
yield error_frame
152152
await asyncio.sleep(0)
153153

154-
with pytest.raises(ConnectionAttemptsFailedError) as exc_info:
154+
with pytest.raises(FailedAllConnectAttemptsError) as exc_info:
155155
await EnrichedClient(
156156
connection_class=MockConnection, connection_confirmation_timeout=connection_confirmation_timeout
157157
).__aenter__()
158158

159-
assert exc_info.value == ConnectionAttemptsFailedError(
159+
assert exc_info.value == FailedAllConnectAttemptsError(
160160
retry_attempts=3,
161161
issues=[ConnectionConfirmationTimeout(timeout=connection_confirmation_timeout, frames=[error_frame])] * 3,
162162
)
163163

164164

165+
@pytest.mark.usefixtures("mock_sleep")
165166
async def test_client_connection_lifespan_unsupported_protocol_version() -> None:
166167
given_version = FAKER.pystr()
167168

168-
with pytest.raises(ConnectionAttemptsFailedError) as exc_info:
169+
with pytest.raises(FailedAllConnectAttemptsError) as exc_info:
169170
await EnrichedClient(
170171
connection_class=create_spying_connection(
171172
[build_dataclass(ConnectedFrame, headers={"version": given_version})]
172173
)[0],
173174
connect_retry_attempts=1,
174175
).__aenter__()
175176

176-
assert exc_info.value == ConnectionAttemptsFailedError(
177+
assert exc_info.value == FailedAllConnectAttemptsError(
177178
retry_attempts=1,
178179
issues=[UnsupportedProtocolVersion(given_version=given_version, supported_version=Client.PROTOCOL_VERSION)],
179180
)

0 commit comments

Comments
 (0)