Skip to content

Commit bccd3a5

Browse files
authored
Revert tests refactoring (#59)
1 parent cd1da4a commit bccd3a5

16 files changed

+479
-795
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ ignore = [
6767
"PLC2801",
6868
"PLR0913",
6969
]
70-
extend-per-file-ignores = { "tests/*" = ["S101", "SLF001", "ARG", "PLR6301"] }
70+
extend-per-file-ignores = { "tests/*" = ["S101", "SLF001", "ARG"] }
7171

7272
[tool.pytest.ini_options]
7373
addopts = "--cov -s -vv"

stompman/__init__.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
from stompman.client import Client
22
from stompman.config import ConnectionParameters, Heartbeat
33
from stompman.errors import (
4-
AllServersUnavailable,
5-
AnyConnectionIssue,
64
ConnectionConfirmationTimeout,
7-
ConnectionLost,
85
ConnectionLostError,
96
Error,
107
FailedAllConnectAttemptsError,
@@ -41,9 +38,7 @@
4138
"AbortFrame",
4239
"AckFrame",
4340
"AckMode",
44-
"AllServersUnavailable",
4541
"AnyClientFrame",
46-
"AnyConnectionIssue",
4742
"AnyRealServerFrame",
4843
"AnyServerFrame",
4944
"BeginFrame",
@@ -52,7 +47,6 @@
5247
"ConnectFrame",
5348
"ConnectedFrame",
5449
"ConnectionConfirmationTimeout",
55-
"ConnectionLost",
5650
"ConnectionLostError",
5751
"ConnectionParameters",
5852
"DisconnectFrame",

stompman/client.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def __post_init__(self) -> None:
7676
async def __aenter__(self) -> Self:
7777
self._task_group = await self._exit_stack.enter_async_context(asyncio.TaskGroup())
7878
self._heartbeat_task = self._task_group.create_task(asyncio.sleep(0))
79-
await self._connection_manager.enter()
79+
await self._exit_stack.enter_async_context(self._connection_manager)
8080
self._listen_task = self._task_group.create_task(self._listen_to_frames())
8181
return self
8282

@@ -91,7 +91,6 @@ async def __aexit__(
9191
self._heartbeat_task.cancel()
9292
await asyncio.wait([self._listen_task, self._heartbeat_task])
9393
await self._exit_stack.aclose()
94-
await self._connection_manager.exit()
9594

9695
def _restart_heartbeat_task(self, interval: float) -> None:
9796
self._heartbeat_task.cancel()

stompman/connection_lifespan.py

Lines changed: 40 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import asyncio
2-
from collections.abc import AsyncIterable, Awaitable, Callable
3-
from dataclasses import dataclass, field
4-
from typing import Any, Protocol, TypeVar
2+
from collections.abc import Callable
3+
from contextlib import suppress
4+
from dataclasses import dataclass
5+
from typing import Protocol
56
from uuid import uuid4
67

78
from stompman.config import ConnectionParameters, Heartbeat
89
from stompman.connection import AbstractConnection
910
from stompman.errors import ConnectionConfirmationTimeout, StompProtocolConnectionIssue, UnsupportedProtocolVersion
1011
from stompman.frames import (
11-
AnyServerFrame,
1212
ConnectedFrame,
1313
ConnectFrame,
1414
DisconnectFrame,
@@ -21,56 +21,6 @@
2121
)
2222
from stompman.transaction import ActiveTransactions, commit_pending_transactions
2323

24-
FrameType = TypeVar("FrameType", bound=AnyServerFrame)
25-
WaitForFutureReturnType = TypeVar("WaitForFutureReturnType")
26-
27-
28-
async def wait_for_or_none(
29-
awaitable: Awaitable[WaitForFutureReturnType], timeout: float
30-
) -> WaitForFutureReturnType | None:
31-
try:
32-
return await asyncio.wait_for(awaitable, timeout=timeout)
33-
except TimeoutError:
34-
return None
35-
36-
37-
WaitForOrNone = Callable[[Awaitable[WaitForFutureReturnType], float], Awaitable[WaitForFutureReturnType | None]]
38-
39-
40-
async def take_frame_of_type(
41-
*,
42-
frame_type: type[FrameType],
43-
frames_iter: AsyncIterable[AnyServerFrame],
44-
timeout: int,
45-
wait_for_or_none: WaitForOrNone[FrameType],
46-
) -> FrameType | list[Any]:
47-
collected_frames = []
48-
49-
async def inner() -> FrameType:
50-
async for frame in frames_iter:
51-
if isinstance(frame, frame_type):
52-
return frame
53-
collected_frames.append(frame)
54-
msg = "unreachable"
55-
raise AssertionError(msg)
56-
57-
return await wait_for_or_none(inner(), timeout) or collected_frames
58-
59-
60-
def check_stomp_protocol_version(
61-
*, connected_frame: ConnectedFrame, supported_version: str
62-
) -> UnsupportedProtocolVersion | None:
63-
if connected_frame.headers["version"] == supported_version:
64-
return None
65-
return UnsupportedProtocolVersion(
66-
given_version=connected_frame.headers["version"], supported_version=supported_version
67-
)
68-
69-
70-
def calculate_heartbeat_interval(*, connected_frame: ConnectedFrame, client_heartbeat: Heartbeat) -> float:
71-
server_heartbeat = Heartbeat.from_header(connected_frame.headers["heart-beat"])
72-
return max(client_heartbeat.will_send_interval_ms, server_heartbeat.want_to_receive_interval_ms) / 1000
73-
7424

7525
class AbstractConnectionLifespan(Protocol):
7626
async def enter(self) -> StompProtocolConnectionIssue | None: ...
@@ -88,7 +38,6 @@ class ConnectionLifespan(AbstractConnectionLifespan):
8838
active_subscriptions: ActiveSubscriptions
8939
active_transactions: ActiveTransactions
9040
set_heartbeat_interval: Callable[[float], None]
91-
_generate_receipt_id: Callable[[], str] = field(default=lambda: _make_receipt_id()) # noqa: PLW0108
9241

9342
async def _establish_connection(self) -> StompProtocolConnectionIssue | None:
9443
await self.connection.write_frame(
@@ -102,54 +51,61 @@ async def _establish_connection(self) -> StompProtocolConnectionIssue | None:
10251
},
10352
)
10453
)
105-
connected_frame_or_collected_frames = await take_frame_of_type(
106-
frame_type=ConnectedFrame,
107-
frames_iter=self.connection.read_frames(),
108-
timeout=self.connection_confirmation_timeout,
109-
wait_for_or_none=wait_for_or_none,
110-
)
111-
if not isinstance(connected_frame_or_collected_frames, ConnectedFrame):
112-
return ConnectionConfirmationTimeout(
113-
timeout=self.connection_confirmation_timeout, frames=connected_frame_or_collected_frames
54+
collected_frames = []
55+
56+
async def take_connected_frame_and_collect_other_frames() -> ConnectedFrame:
57+
async for frame in self.connection.read_frames():
58+
if isinstance(frame, ConnectedFrame):
59+
return frame
60+
collected_frames.append(frame)
61+
msg = "unreachable" # pragma: no cover
62+
raise AssertionError(msg) # pragma: no cover
63+
64+
try:
65+
connected_frame = await asyncio.wait_for(
66+
take_connected_frame_and_collect_other_frames(), timeout=self.connection_confirmation_timeout
11467
)
115-
connected_frame = connected_frame_or_collected_frames
68+
except TimeoutError:
69+
return ConnectionConfirmationTimeout(timeout=self.connection_confirmation_timeout, frames=collected_frames)
11670

117-
if unsupported_protocol_version_error := check_stomp_protocol_version(
118-
connected_frame=connected_frame, supported_version=self.protocol_version
119-
):
120-
return unsupported_protocol_version_error
71+
if connected_frame.headers["version"] != self.protocol_version:
72+
return UnsupportedProtocolVersion(
73+
given_version=connected_frame.headers["version"], supported_version=self.protocol_version
74+
)
12175

76+
server_heartbeat = Heartbeat.from_header(connected_frame.headers["heart-beat"])
12277
self.set_heartbeat_interval(
123-
calculate_heartbeat_interval(connected_frame=connected_frame, client_heartbeat=self.client_heartbeat)
78+
max(self.client_heartbeat.will_send_interval_ms, server_heartbeat.want_to_receive_interval_ms) / 1000
12479
)
12580
return None
12681

12782
async def enter(self) -> StompProtocolConnectionIssue | None:
128-
if protocol_connection_issue := await self._establish_connection():
129-
return protocol_connection_issue
130-
83+
if connection_issue := await self._establish_connection():
84+
return connection_issue
13185
await resubscribe_to_active_subscriptions(
13286
connection=self.connection, active_subscriptions=self.active_subscriptions
13387
)
13488
await commit_pending_transactions(connection=self.connection, active_transactions=self.active_transactions)
13589
return None
13690

91+
async def _take_receipt_frame(self) -> None:
92+
async for frame in self.connection.read_frames():
93+
if isinstance(frame, ReceiptFrame):
94+
break
95+
13796
async def exit(self) -> None:
13897
await unsubscribe_from_all_active_subscriptions(active_subscriptions=self.active_subscriptions)
139-
await self.connection.write_frame(DisconnectFrame(headers={"receipt": self._generate_receipt_id()}))
140-
await take_frame_of_type(
141-
frame_type=ReceiptFrame,
142-
frames_iter=self.connection.read_frames(),
143-
timeout=self.disconnect_confirmation_timeout,
144-
wait_for_or_none=wait_for_or_none,
145-
)
98+
await self.connection.write_frame(DisconnectFrame(headers={"receipt": _make_receipt_id()}))
99+
100+
with suppress(TimeoutError):
101+
await asyncio.wait_for(self._take_receipt_frame(), timeout=self.disconnect_confirmation_timeout)
102+
103+
104+
def _make_receipt_id() -> str:
105+
return str(uuid4())
146106

147107

148108
class ConnectionLifespanFactory(Protocol):
149109
def __call__(
150110
self, *, connection: AbstractConnection, connection_parameters: ConnectionParameters
151111
) -> AbstractConnectionLifespan: ...
152-
153-
154-
def _make_receipt_id() -> str:
155-
return str(uuid4())

stompman/connection_manager.py

Lines changed: 40 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import asyncio
2-
from collections.abc import AsyncGenerator, Awaitable, Callable
2+
from collections.abc import AsyncGenerator
33
from dataclasses import dataclass, field
4-
from typing import TYPE_CHECKING
4+
from types import TracebackType
5+
from typing import TYPE_CHECKING, Self
56

67
from stompman.config import ConnectionParameters
78
from stompman.connection import AbstractConnection
@@ -25,52 +26,6 @@ class ActiveConnectionState:
2526
lifespan: "AbstractConnectionLifespan"
2627

2728

28-
Sleep = Callable[[float], Awaitable[None]]
29-
30-
31-
async def attempt_to_connect(
32-
*,
33-
connect: Callable[[], Awaitable[ActiveConnectionState | AnyConnectionIssue]],
34-
connect_retry_interval: int,
35-
connect_retry_attempts: int,
36-
sleep: Sleep,
37-
) -> ActiveConnectionState:
38-
connection_issues = []
39-
40-
for attempt in range(connect_retry_attempts):
41-
connection_result = await connect()
42-
if isinstance(connection_result, ActiveConnectionState):
43-
return connection_result
44-
45-
connection_issues.append(connection_result)
46-
await sleep(connect_retry_interval * (attempt + 1))
47-
48-
raise FailedAllConnectAttemptsError(retry_attempts=connect_retry_attempts, issues=connection_issues)
49-
50-
51-
async def connect_to_first_server(
52-
connect_awaitables: list[Awaitable[ActiveConnectionState | None]],
53-
) -> ActiveConnectionState | None:
54-
for maybe_connection_future in asyncio.as_completed(connect_awaitables):
55-
if connection_state := await maybe_connection_future:
56-
return connection_state
57-
return None
58-
59-
60-
async def make_healthy_connection(
61-
*, active_connection_state: ActiveConnectionState | None, servers: list[ConnectionParameters], connect_timeout: int
62-
) -> ActiveConnectionState | AnyConnectionIssue:
63-
if not active_connection_state:
64-
return AllServersUnavailable(servers=servers, timeout=connect_timeout)
65-
66-
try:
67-
connection_issue = await active_connection_state.lifespan.enter()
68-
except ConnectionLostError:
69-
return ConnectionLost()
70-
71-
return active_connection_state if connection_issue is None else connection_issue
72-
73-
7429
@dataclass(kw_only=True, slots=True)
7530
class ConnectionManager:
7631
servers: list[ConnectionParameters]
@@ -86,10 +41,13 @@ class ConnectionManager:
8641
_active_connection_state: ActiveConnectionState | None = field(default=None, init=False)
8742
_reconnect_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
8843

89-
async def enter(self) -> None:
44+
async def __aenter__(self) -> Self:
9045
self._active_connection_state = await self._get_active_connection_state()
46+
return self
9147

92-
async def exit(self) -> None:
48+
async def __aexit__(
49+
self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None
50+
) -> None:
9351
if not self._active_connection_state:
9452
return
9553
try:
@@ -112,28 +70,47 @@ async def _create_connection_to_one_server(self, server: ConnectionParameters) -
11270
)
11371
return None
11472

115-
async def _connect_to_any_server(self) -> ActiveConnectionState | AnyConnectionIssue:
116-
active_connection_state = await connect_to_first_server(
73+
async def _create_connection_to_any_server(self) -> ActiveConnectionState | None:
74+
for maybe_connection_future in asyncio.as_completed(
11775
[self._create_connection_to_one_server(server) for server in self.servers]
118-
)
119-
return await make_healthy_connection(
120-
active_connection_state=active_connection_state, servers=self.servers, connect_timeout=self.connect_timeout
121-
)
76+
):
77+
if connection_state := await maybe_connection_future:
78+
return connection_state
79+
return None
80+
81+
async def _connect_to_any_server(self) -> ActiveConnectionState | AnyConnectionIssue:
82+
if not (active_connection_state := await self._create_connection_to_any_server()):
83+
return AllServersUnavailable(servers=self.servers, timeout=self.connect_timeout)
84+
85+
try:
86+
if connection_issue := await active_connection_state.lifespan.enter():
87+
return connection_issue
88+
except ConnectionLostError:
89+
return ConnectionLost()
90+
91+
return active_connection_state
12292

12393
async def _get_active_connection_state(self) -> ActiveConnectionState:
12494
if self._active_connection_state:
12595
return self._active_connection_state
12696

97+
connection_issues: list[AnyConnectionIssue] = []
98+
12799
async with self._reconnect_lock:
128100
if self._active_connection_state:
129101
return self._active_connection_state
130-
self._active_connection_state = await attempt_to_connect(
131-
connect=self._connect_to_any_server,
132-
connect_retry_interval=self.connect_retry_interval,
133-
connect_retry_attempts=self.connect_retry_attempts,
134-
sleep=asyncio.sleep,
135-
)
136-
return self._active_connection_state
102+
103+
for attempt in range(self.connect_retry_attempts):
104+
connection_result = await self._connect_to_any_server()
105+
106+
if isinstance(connection_result, ActiveConnectionState):
107+
self._active_connection_state = connection_result
108+
return connection_result
109+
110+
connection_issues.append(connection_result)
111+
await asyncio.sleep(self.connect_retry_interval * (attempt + 1))
112+
113+
raise FailedAllConnectAttemptsError(retry_attempts=self.connect_retry_attempts, issues=connection_issues)
137114

138115
def _clear_active_connection_state(self) -> None:
139116
self._active_connection_state = None

0 commit comments

Comments
 (0)