Skip to content

Commit 715ead4

Browse files
authored
Refactor tests (#57)
1 parent d908f3b commit 715ead4

15 files changed

+703
-433
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ ignore = [
6767
"PLC2801",
6868
"PLR0913",
6969
]
70-
extend-per-file-ignores = { "tests/*" = ["S101", "SLF001", "ARG"] }
70+
extend-per-file-ignores = { "tests/*" = ["S101", "SLF001", "ARG", "PLR6301"] }
7171

7272
[tool.pytest.ini_options]
7373
addopts = "--cov -s -vv"

stompman/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
from stompman.client import Client
22
from stompman.config import ConnectionParameters, Heartbeat
33
from stompman.errors import (
4+
AllServersUnavailable,
5+
AnyConnectionIssue,
46
ConnectionConfirmationTimeout,
7+
ConnectionLost,
58
ConnectionLostError,
69
Error,
710
FailedAllConnectAttemptsError,
@@ -38,7 +41,9 @@
3841
"AbortFrame",
3942
"AckFrame",
4043
"AckMode",
44+
"AllServersUnavailable",
4145
"AnyClientFrame",
46+
"AnyConnectionIssue",
4247
"AnyRealServerFrame",
4348
"AnyServerFrame",
4449
"BeginFrame",
@@ -47,6 +52,7 @@
4752
"ConnectFrame",
4853
"ConnectedFrame",
4954
"ConnectionConfirmationTimeout",
55+
"ConnectionLost",
5056
"ConnectionLostError",
5157
"ConnectionParameters",
5258
"DisconnectFrame",

stompman/connection_lifespan.py

Lines changed: 73 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import asyncio
2-
from collections.abc import Callable
2+
from collections.abc import AsyncIterable, Callable
33
from contextlib import suppress
44
from dataclasses import dataclass
55
from typing import Protocol
@@ -9,6 +9,7 @@
99
from stompman.connection import AbstractConnection
1010
from stompman.errors import ConnectionConfirmationTimeout, StompProtocolConnectionIssue, UnsupportedProtocolVersion
1111
from stompman.frames import (
12+
AnyServerFrame,
1213
ConnectedFrame,
1314
ConnectFrame,
1415
DisconnectFrame,
@@ -22,6 +23,54 @@
2223
from stompman.transaction import ActiveTransactions, commit_pending_transactions
2324

2425

26+
async def take_connected_frame(
27+
*, frames_iter: AsyncIterable[AnyServerFrame], connection_confirmation_timeout: int
28+
) -> ConnectedFrame | ConnectionConfirmationTimeout:
29+
collected_frames = []
30+
31+
async def take_connected_frame_and_collect_other_frames() -> ConnectedFrame:
32+
async for frame in frames_iter:
33+
if isinstance(frame, ConnectedFrame):
34+
return frame
35+
collected_frames.append(frame)
36+
msg = "unreachable"
37+
raise AssertionError(msg)
38+
39+
try:
40+
return await asyncio.wait_for(
41+
take_connected_frame_and_collect_other_frames(), timeout=connection_confirmation_timeout
42+
)
43+
except TimeoutError:
44+
return ConnectionConfirmationTimeout(timeout=connection_confirmation_timeout, frames=collected_frames)
45+
46+
47+
def check_stomp_protocol_version(
48+
*, connected_frame: ConnectedFrame, supported_version: str
49+
) -> UnsupportedProtocolVersion | None:
50+
if connected_frame.headers["version"] == supported_version:
51+
return None
52+
return UnsupportedProtocolVersion(
53+
given_version=connected_frame.headers["version"], supported_version=supported_version
54+
)
55+
56+
57+
def calculate_heartbeat_interval(*, connected_frame: ConnectedFrame, client_heartbeat: Heartbeat) -> float:
58+
server_heartbeat = Heartbeat.from_header(connected_frame.headers["heart-beat"])
59+
return max(client_heartbeat.will_send_interval_ms, server_heartbeat.want_to_receive_interval_ms) / 1000
60+
61+
62+
async def wait_for_receipt_frame(
63+
*, frames_iter: AsyncIterable[AnyServerFrame], disconnect_confirmation_timeout: int
64+
) -> None:
65+
async def inner() -> None:
66+
async for frame in frames_iter:
67+
if isinstance(frame, ReceiptFrame):
68+
break
69+
70+
with suppress(TimeoutError):
71+
await asyncio.wait_for(inner(), timeout=disconnect_confirmation_timeout)
72+
73+
2574
class AbstractConnectionLifespan(Protocol):
2675
async def enter(self) -> StompProtocolConnectionIssue | None: ...
2776
async def exit(self) -> None: ...
@@ -51,61 +100,48 @@ async def _establish_connection(self) -> StompProtocolConnectionIssue | None:
51100
},
52101
)
53102
)
54-
collected_frames = []
55-
56-
async def take_connected_frame_and_collect_other_frames() -> ConnectedFrame:
57-
async for frame in self.connection.read_frames():
58-
if isinstance(frame, ConnectedFrame):
59-
return frame
60-
collected_frames.append(frame)
61-
msg = "unreachable" # pragma: no cover
62-
raise AssertionError(msg) # pragma: no cover
63-
64-
try:
65-
connected_frame = await asyncio.wait_for(
66-
take_connected_frame_and_collect_other_frames(), timeout=self.connection_confirmation_timeout
67-
)
68-
except TimeoutError:
69-
return ConnectionConfirmationTimeout(timeout=self.connection_confirmation_timeout, frames=collected_frames)
103+
connected_frame_or_error = await take_connected_frame(
104+
frames_iter=self.connection.read_frames(),
105+
connection_confirmation_timeout=self.connection_confirmation_timeout,
106+
)
107+
if isinstance(connected_frame_or_error, ConnectionConfirmationTimeout):
108+
return connected_frame_or_error
109+
connected_frame = connected_frame_or_error
70110

71-
if connected_frame.headers["version"] != self.protocol_version:
72-
return UnsupportedProtocolVersion(
73-
given_version=connected_frame.headers["version"], supported_version=self.protocol_version
74-
)
111+
if unsupported_protocol_version_error := check_stomp_protocol_version(
112+
connected_frame=connected_frame, supported_version=self.protocol_version
113+
):
114+
return unsupported_protocol_version_error
75115

76-
server_heartbeat = Heartbeat.from_header(connected_frame.headers["heart-beat"])
77116
self.set_heartbeat_interval(
78-
max(self.client_heartbeat.will_send_interval_ms, server_heartbeat.want_to_receive_interval_ms) / 1000
117+
calculate_heartbeat_interval(connected_frame=connected_frame, client_heartbeat=self.client_heartbeat)
79118
)
80119
return None
81120

82121
async def enter(self) -> StompProtocolConnectionIssue | None:
83-
if connection_issue := await self._establish_connection():
84-
return connection_issue
122+
if protocol_connection_issue := await self._establish_connection():
123+
return protocol_connection_issue
124+
85125
await resubscribe_to_active_subscriptions(
86126
connection=self.connection, active_subscriptions=self.active_subscriptions
87127
)
88128
await commit_pending_transactions(connection=self.connection, active_transactions=self.active_transactions)
89129
return None
90130

91-
async def _take_receipt_frame(self) -> None:
92-
async for frame in self.connection.read_frames():
93-
if isinstance(frame, ReceiptFrame):
94-
break
95-
96131
async def exit(self) -> None:
97132
await unsubscribe_from_all_active_subscriptions(active_subscriptions=self.active_subscriptions)
98133
await self.connection.write_frame(DisconnectFrame(headers={"receipt": _make_receipt_id()}))
99-
100-
with suppress(TimeoutError):
101-
await asyncio.wait_for(self._take_receipt_frame(), timeout=self.disconnect_confirmation_timeout)
102-
103-
104-
def _make_receipt_id() -> str:
105-
return str(uuid4())
134+
await wait_for_receipt_frame(
135+
frames_iter=self.connection.read_frames(),
136+
disconnect_confirmation_timeout=self.disconnect_confirmation_timeout,
137+
)
106138

107139

108140
class ConnectionLifespanFactory(Protocol):
109141
def __call__(
110142
self, *, connection: AbstractConnection, connection_parameters: ConnectionParameters
111143
) -> AbstractConnectionLifespan: ...
144+
145+
146+
def _make_receipt_id() -> str:
147+
return str(uuid4())

stompman/connection_manager.py

Lines changed: 55 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import asyncio
2-
from collections.abc import AsyncGenerator
2+
from collections.abc import AsyncGenerator, Awaitable, Callable
33
from dataclasses import dataclass, field
44
from types import TracebackType
55
from typing import TYPE_CHECKING, Self
@@ -26,6 +26,48 @@ class ActiveConnectionState:
2626
lifespan: "AbstractConnectionLifespan"
2727

2828

29+
async def attempt_to_connect(
30+
*,
31+
connect: Callable[[], Awaitable[ActiveConnectionState | AnyConnectionIssue]],
32+
connect_retry_interval: int,
33+
connect_retry_attempts: int,
34+
) -> ActiveConnectionState:
35+
connection_issues = []
36+
37+
for attempt in range(connect_retry_attempts):
38+
connection_result = await connect()
39+
if isinstance(connection_result, ActiveConnectionState):
40+
return connection_result
41+
42+
connection_issues.append(connection_result)
43+
await asyncio.sleep(connect_retry_interval * (attempt + 1))
44+
45+
raise FailedAllConnectAttemptsError(retry_attempts=connect_retry_attempts, issues=connection_issues)
46+
47+
48+
async def connect_to_first_server(
49+
connect_awaitables: list[Awaitable[ActiveConnectionState | None]],
50+
) -> ActiveConnectionState | None:
51+
for maybe_connection_future in asyncio.as_completed(connect_awaitables):
52+
if connection_state := await maybe_connection_future:
53+
return connection_state
54+
return None
55+
56+
57+
async def make_healthy_connection(
58+
*, active_connection_state: ActiveConnectionState | None, servers: list[ConnectionParameters], connect_timeout: int
59+
) -> ActiveConnectionState | AnyConnectionIssue:
60+
if not active_connection_state:
61+
return AllServersUnavailable(servers=servers, timeout=connect_timeout)
62+
63+
try:
64+
connection_issue = await active_connection_state.lifespan.enter()
65+
except ConnectionLostError:
66+
return ConnectionLost()
67+
68+
return active_connection_state if connection_issue is None else connection_issue
69+
70+
2971
@dataclass(kw_only=True, slots=True)
3072
class ConnectionManager:
3173
servers: list[ConnectionParameters]
@@ -70,47 +112,27 @@ async def _create_connection_to_one_server(self, server: ConnectionParameters) -
70112
)
71113
return None
72114

73-
async def _create_connection_to_any_server(self) -> ActiveConnectionState | None:
74-
for maybe_connection_future in asyncio.as_completed(
75-
[self._create_connection_to_one_server(server) for server in self.servers]
76-
):
77-
if connection_state := await maybe_connection_future:
78-
return connection_state
79-
return None
80-
81115
async def _connect_to_any_server(self) -> ActiveConnectionState | AnyConnectionIssue:
82-
if not (active_connection_state := await self._create_connection_to_any_server()):
83-
return AllServersUnavailable(servers=self.servers, timeout=self.connect_timeout)
84-
85-
try:
86-
if connection_issue := await active_connection_state.lifespan.enter():
87-
return connection_issue
88-
except ConnectionLostError:
89-
return ConnectionLost()
90-
91-
return active_connection_state
116+
active_connection_state = await connect_to_first_server(
117+
[self._create_connection_to_one_server(server) for server in self.servers]
118+
)
119+
return await make_healthy_connection(
120+
active_connection_state=active_connection_state, servers=self.servers, connect_timeout=self.connect_timeout
121+
)
92122

93123
async def _get_active_connection_state(self) -> ActiveConnectionState:
94124
if self._active_connection_state:
95125
return self._active_connection_state
96126

97-
connection_issues: list[AnyConnectionIssue] = []
98-
99127
async with self._reconnect_lock:
100128
if self._active_connection_state:
101129
return self._active_connection_state
102-
103-
for attempt in range(self.connect_retry_attempts):
104-
connection_result = await self._connect_to_any_server()
105-
106-
if isinstance(connection_result, ActiveConnectionState):
107-
self._active_connection_state = connection_result
108-
return connection_result
109-
110-
connection_issues.append(connection_result)
111-
await asyncio.sleep(self.connect_retry_interval * (attempt + 1))
112-
113-
raise FailedAllConnectAttemptsError(retry_attempts=self.connect_retry_attempts, issues=connection_issues)
130+
self._active_connection_state = await attempt_to_connect(
131+
connect=self._connect_to_any_server,
132+
connect_retry_interval=self.connect_retry_interval,
133+
connect_retry_attempts=self.connect_retry_attempts,
134+
)
135+
return self._active_connection_state
114136

115137
def _clear_active_connection_state(self) -> None:
116138
self._active_connection_state = None

tests/conftest.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,9 @@
1212
from stompman.connection_manager import ConnectionManager
1313

1414

15-
@pytest.fixture(
16-
params=[
17-
pytest.param(("asyncio", {"use_uvloop": True}), id="asyncio+uvloop"),
18-
pytest.param(("asyncio", {"use_uvloop": False}), id="asyncio"),
19-
],
20-
)
21-
def anyio_backend(request: pytest.FixtureRequest) -> object:
22-
return request.param
15+
@pytest.fixture
16+
def anyio_backend() -> object:
17+
return "asyncio"
2318

2419

2520
@pytest.fixture
@@ -127,11 +122,6 @@ async def read_frames() -> AsyncGenerator[stompman.AnyServerFrame, None]:
127122
CONNECTED_FRAME = stompman.ConnectedFrame(headers={"version": stompman.Client.PROTOCOL_VERSION, "heart-beat": "1,1"})
128123

129124

130-
@pytest.fixture(autouse=True)
131-
def _mock_receipt_id(monkeypatch: pytest.MonkeyPatch) -> None:
132-
monkeypatch.setattr(stompman.connection_lifespan, "_make_receipt_id", lambda: "receipt-id-1")
133-
134-
135125
def get_read_frames_with_lifespan(*read_frames: list[stompman.AnyServerFrame]) -> list[list[stompman.AnyServerFrame]]:
136126
return [
137127
[CONNECTED_FRAME],

tests/integration.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,23 @@
2121
parse_header,
2222
)
2323

24-
pytestmark = pytest.mark.anyio
25-
2624
CONNECTION_PARAMETERS: Final = stompman.ConnectionParameters(
2725
host=os.environ["ARTEMIS_HOST"], port=61616, login="admin", passcode=":=123"
2826
)
2927
DESTINATION: Final = "DLQ"
3028

29+
pytestmark = pytest.mark.anyio
30+
31+
32+
@pytest.fixture(
33+
params=[
34+
pytest.param(("asyncio", {"use_uvloop": True}), id="asyncio+uvloop"),
35+
pytest.param(("asyncio", {"use_uvloop": False}), id="asyncio"),
36+
],
37+
)
38+
def anyio_backend(request: pytest.FixtureRequest) -> object:
39+
return request.param
40+
3141

3242
@asynccontextmanager
3343
async def create_client() -> AsyncGenerator[stompman.Client, None]:

tests/test_client.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import asyncio
2+
from unittest import mock
3+
4+
import pytest
5+
6+
from tests.conftest import EnrichedClient, create_spying_connection, get_read_frames_with_lifespan
7+
8+
pytestmark = pytest.mark.anyio
9+
10+
11+
async def test_client_heartbeats_ok(monkeypatch: pytest.MonkeyPatch) -> None:
12+
async def mock_sleep(delay: float) -> None:
13+
await real_sleep(0)
14+
sleep_calls.append(delay)
15+
16+
sleep_calls: list[float] = []
17+
real_sleep = asyncio.sleep
18+
monkeypatch.setattr("asyncio.sleep", mock_sleep)
19+
20+
connection_class, _ = create_spying_connection(*get_read_frames_with_lifespan([]))
21+
connection_class.write_heartbeat = (write_heartbeat_mock := mock.Mock()) # type: ignore[method-assign]
22+
23+
async with EnrichedClient(connection_class=connection_class):
24+
await real_sleep(0)
25+
await real_sleep(0)
26+
await real_sleep(0)
27+
28+
assert sleep_calls == [0, 1, 1]
29+
assert write_heartbeat_mock.mock_calls == [mock.call(), mock.call(), mock.call()]

0 commit comments

Comments
 (0)