Skip to content

Commit cd1da4a

Browse files
authored
Refactor tests (second part) (#58)
1 parent 715ead4 commit cd1da4a

File tree

7 files changed

+292
-246
lines changed

7 files changed

+292
-246
lines changed

stompman/client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def __post_init__(self) -> None:
7676
async def __aenter__(self) -> Self:
7777
self._task_group = await self._exit_stack.enter_async_context(asyncio.TaskGroup())
7878
self._heartbeat_task = self._task_group.create_task(asyncio.sleep(0))
79-
await self._exit_stack.enter_async_context(self._connection_manager)
79+
await self._connection_manager.enter()
8080
self._listen_task = self._task_group.create_task(self._listen_to_frames())
8181
return self
8282

@@ -91,6 +91,7 @@ async def __aexit__(
9191
self._heartbeat_task.cancel()
9292
await asyncio.wait([self._listen_task, self._heartbeat_task])
9393
await self._exit_stack.aclose()
94+
await self._connection_manager.exit()
9495

9596
def _restart_heartbeat_task(self, interval: float) -> None:
9697
self._heartbeat_task.cancel()

stompman/connection_lifespan.py

Lines changed: 43 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import asyncio
2-
from collections.abc import AsyncIterable, Callable
3-
from contextlib import suppress
4-
from dataclasses import dataclass
5-
from typing import Protocol
2+
from collections.abc import AsyncIterable, Awaitable, Callable
3+
from dataclasses import dataclass, field
4+
from typing import Any, Protocol, TypeVar
65
from uuid import uuid4
76

87
from stompman.config import ConnectionParameters, Heartbeat
@@ -22,26 +21,40 @@
2221
)
2322
from stompman.transaction import ActiveTransactions, commit_pending_transactions
2423

24+
FrameType = TypeVar("FrameType", bound=AnyServerFrame)
25+
WaitForFutureReturnType = TypeVar("WaitForFutureReturnType")
2526

26-
async def take_connected_frame(
27-
*, frames_iter: AsyncIterable[AnyServerFrame], connection_confirmation_timeout: int
28-
) -> ConnectedFrame | ConnectionConfirmationTimeout:
27+
28+
async def wait_for_or_none(
29+
awaitable: Awaitable[WaitForFutureReturnType], timeout: float
30+
) -> WaitForFutureReturnType | None:
31+
try:
32+
return await asyncio.wait_for(awaitable, timeout=timeout)
33+
except TimeoutError:
34+
return None
35+
36+
37+
WaitForOrNone = Callable[[Awaitable[WaitForFutureReturnType], float], Awaitable[WaitForFutureReturnType | None]]
38+
39+
40+
async def take_frame_of_type(
41+
*,
42+
frame_type: type[FrameType],
43+
frames_iter: AsyncIterable[AnyServerFrame],
44+
timeout: int,
45+
wait_for_or_none: WaitForOrNone[FrameType],
46+
) -> FrameType | list[Any]:
2947
collected_frames = []
3048

31-
async def take_connected_frame_and_collect_other_frames() -> ConnectedFrame:
49+
async def inner() -> FrameType:
3250
async for frame in frames_iter:
33-
if isinstance(frame, ConnectedFrame):
51+
if isinstance(frame, frame_type):
3452
return frame
3553
collected_frames.append(frame)
3654
msg = "unreachable"
3755
raise AssertionError(msg)
3856

39-
try:
40-
return await asyncio.wait_for(
41-
take_connected_frame_and_collect_other_frames(), timeout=connection_confirmation_timeout
42-
)
43-
except TimeoutError:
44-
return ConnectionConfirmationTimeout(timeout=connection_confirmation_timeout, frames=collected_frames)
57+
return await wait_for_or_none(inner(), timeout) or collected_frames
4558

4659

4760
def check_stomp_protocol_version(
@@ -59,18 +72,6 @@ def calculate_heartbeat_interval(*, connected_frame: ConnectedFrame, client_hear
5972
return max(client_heartbeat.will_send_interval_ms, server_heartbeat.want_to_receive_interval_ms) / 1000
6073

6174

62-
async def wait_for_receipt_frame(
63-
*, frames_iter: AsyncIterable[AnyServerFrame], disconnect_confirmation_timeout: int
64-
) -> None:
65-
async def inner() -> None:
66-
async for frame in frames_iter:
67-
if isinstance(frame, ReceiptFrame):
68-
break
69-
70-
with suppress(TimeoutError):
71-
await asyncio.wait_for(inner(), timeout=disconnect_confirmation_timeout)
72-
73-
7475
class AbstractConnectionLifespan(Protocol):
7576
async def enter(self) -> StompProtocolConnectionIssue | None: ...
7677
async def exit(self) -> None: ...
@@ -87,6 +88,7 @@ class ConnectionLifespan(AbstractConnectionLifespan):
8788
active_subscriptions: ActiveSubscriptions
8889
active_transactions: ActiveTransactions
8990
set_heartbeat_interval: Callable[[float], None]
91+
_generate_receipt_id: Callable[[], str] = field(default=lambda: _make_receipt_id()) # noqa: PLW0108
9092

9193
async def _establish_connection(self) -> StompProtocolConnectionIssue | None:
9294
await self.connection.write_frame(
@@ -100,13 +102,17 @@ async def _establish_connection(self) -> StompProtocolConnectionIssue | None:
100102
},
101103
)
102104
)
103-
connected_frame_or_error = await take_connected_frame(
105+
connected_frame_or_collected_frames = await take_frame_of_type(
106+
frame_type=ConnectedFrame,
104107
frames_iter=self.connection.read_frames(),
105-
connection_confirmation_timeout=self.connection_confirmation_timeout,
108+
timeout=self.connection_confirmation_timeout,
109+
wait_for_or_none=wait_for_or_none,
106110
)
107-
if isinstance(connected_frame_or_error, ConnectionConfirmationTimeout):
108-
return connected_frame_or_error
109-
connected_frame = connected_frame_or_error
111+
if not isinstance(connected_frame_or_collected_frames, ConnectedFrame):
112+
return ConnectionConfirmationTimeout(
113+
timeout=self.connection_confirmation_timeout, frames=connected_frame_or_collected_frames
114+
)
115+
connected_frame = connected_frame_or_collected_frames
110116

111117
if unsupported_protocol_version_error := check_stomp_protocol_version(
112118
connected_frame=connected_frame, supported_version=self.protocol_version
@@ -130,10 +136,12 @@ async def enter(self) -> StompProtocolConnectionIssue | None:
130136

131137
async def exit(self) -> None:
132138
await unsubscribe_from_all_active_subscriptions(active_subscriptions=self.active_subscriptions)
133-
await self.connection.write_frame(DisconnectFrame(headers={"receipt": _make_receipt_id()}))
134-
await wait_for_receipt_frame(
139+
await self.connection.write_frame(DisconnectFrame(headers={"receipt": self._generate_receipt_id()}))
140+
await take_frame_of_type(
141+
frame_type=ReceiptFrame,
135142
frames_iter=self.connection.read_frames(),
136-
disconnect_confirmation_timeout=self.disconnect_confirmation_timeout,
143+
timeout=self.disconnect_confirmation_timeout,
144+
wait_for_or_none=wait_for_or_none,
137145
)
138146

139147

stompman/connection_manager.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import asyncio
22
from collections.abc import AsyncGenerator, Awaitable, Callable
33
from dataclasses import dataclass, field
4-
from types import TracebackType
5-
from typing import TYPE_CHECKING, Self
4+
from typing import TYPE_CHECKING
65

76
from stompman.config import ConnectionParameters
87
from stompman.connection import AbstractConnection
@@ -26,11 +25,15 @@ class ActiveConnectionState:
2625
lifespan: "AbstractConnectionLifespan"
2726

2827

28+
Sleep = Callable[[float], Awaitable[None]]
29+
30+
2931
async def attempt_to_connect(
3032
*,
3133
connect: Callable[[], Awaitable[ActiveConnectionState | AnyConnectionIssue]],
3234
connect_retry_interval: int,
3335
connect_retry_attempts: int,
36+
sleep: Sleep,
3437
) -> ActiveConnectionState:
3538
connection_issues = []
3639

@@ -40,7 +43,7 @@ async def attempt_to_connect(
4043
return connection_result
4144

4245
connection_issues.append(connection_result)
43-
await asyncio.sleep(connect_retry_interval * (attempt + 1))
46+
await sleep(connect_retry_interval * (attempt + 1))
4447

4548
raise FailedAllConnectAttemptsError(retry_attempts=connect_retry_attempts, issues=connection_issues)
4649

@@ -83,13 +86,10 @@ class ConnectionManager:
8386
_active_connection_state: ActiveConnectionState | None = field(default=None, init=False)
8487
_reconnect_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
8588

86-
async def __aenter__(self) -> Self:
89+
async def enter(self) -> None:
8790
self._active_connection_state = await self._get_active_connection_state()
88-
return self
8991

90-
async def __aexit__(
91-
self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None
92-
) -> None:
92+
async def exit(self) -> None:
9393
if not self._active_connection_state:
9494
return
9595
try:
@@ -131,6 +131,7 @@ async def _get_active_connection_state(self) -> ActiveConnectionState:
131131
connect=self._connect_to_any_server,
132132
connect_retry_interval=self.connect_retry_interval,
133133
connect_retry_attempts=self.connect_retry_attempts,
134+
sleep=asyncio.sleep,
134135
)
135136
return self._active_connection_state
136137

tests/conftest.py

Lines changed: 10 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import asyncio
2-
from collections.abc import AsyncGenerator
2+
from collections.abc import AsyncGenerator, AsyncIterable, Iterable
33
from dataclasses import dataclass, field
44
from typing import Any, Self, TypeVar
55

@@ -8,21 +8,13 @@
88

99
import stompman
1010
from stompman.connection import AbstractConnection
11-
from stompman.connection_lifespan import AbstractConnectionLifespan
12-
from stompman.connection_manager import ConnectionManager
1311

1412

1513
@pytest.fixture
1614
def anyio_backend() -> object:
1715
return "asyncio"
1816

1917

20-
@pytest.fixture
21-
def mock_sleep(monkeypatch: pytest.MonkeyPatch) -> None: # noqa: PT004
22-
original_sleep = asyncio.sleep
23-
monkeypatch.setattr("asyncio.sleep", lambda _: original_sleep(0))
24-
25-
2618
async def noop_message_handler(frame: stompman.MessageFrame) -> None: ...
2719

2820

@@ -52,29 +44,6 @@ class EnrichedClient(stompman.Client):
5244
)
5345

5446

55-
@dataclass(frozen=True, kw_only=True, slots=True)
56-
class NoopLifespan(AbstractConnectionLifespan):
57-
connection: AbstractConnection
58-
connection_parameters: stompman.ConnectionParameters
59-
60-
async def enter(self) -> stompman.StompProtocolConnectionIssue | None: ...
61-
async def exit(self) -> None: ...
62-
63-
64-
@dataclass(kw_only=True, slots=True)
65-
class EnrichedConnectionManager(ConnectionManager):
66-
servers: list[stompman.ConnectionParameters] = field(
67-
default_factory=lambda: [stompman.ConnectionParameters("localhost", 12345, "login", "passcode")]
68-
)
69-
lifespan_factory: stompman.connection_lifespan.ConnectionLifespanFactory = field(default=NoopLifespan)
70-
connect_retry_attempts: int = 3
71-
connect_retry_interval: int = 1
72-
connect_timeout: int = 3
73-
read_timeout: int = 4
74-
read_max_chunk_size: int = 5
75-
write_retry_attempts: int = 3
76-
77-
7847
DataclassType = TypeVar("DataclassType")
7948

8049

@@ -140,3 +109,12 @@ def enrich_expected_frames(
140109
stompman.DisconnectFrame(headers={"receipt": "receipt-id-1"}),
141110
stompman.ReceiptFrame(headers={"receipt-id": "receipt-id-1"}),
142111
]
112+
113+
114+
IterableItemT = TypeVar("IterableItemT")
115+
116+
117+
async def make_async_iter(iterable: Iterable[IterableItemT]) -> AsyncIterable[IterableItemT]:
118+
for item in iterable:
119+
yield item
120+
await asyncio.sleep(0)

tests/test_client.py

Lines changed: 63 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,83 @@
11
import asyncio
2+
from collections.abc import AsyncIterable, Iterable
23
from unittest import mock
34

45
import pytest
56

6-
from tests.conftest import EnrichedClient, create_spying_connection, get_read_frames_with_lifespan
7+
from stompman import (
8+
AnyServerFrame,
9+
Client,
10+
ConnectedFrame,
11+
ConnectFrame,
12+
ConnectionParameters,
13+
HeartbeatFrame,
14+
ReceiptFrame,
15+
)
16+
from tests.conftest import EnrichedClient, build_dataclass, make_async_iter
717

818
pytestmark = pytest.mark.anyio
919

1020

11-
async def test_client_heartbeats_ok(monkeypatch: pytest.MonkeyPatch) -> None:
21+
async def test_client_heartbeats_and_lifespan_ok(monkeypatch: pytest.MonkeyPatch) -> None:
22+
sleep_calls = []
23+
real_sleep = asyncio.sleep
24+
1225
async def mock_sleep(delay: float) -> None:
1326
await real_sleep(0)
1427
sleep_calls.append(delay)
1528

16-
sleep_calls: list[float] = []
17-
real_sleep = asyncio.sleep
1829
monkeypatch.setattr("asyncio.sleep", mock_sleep)
1930

20-
connection_class, _ = create_spying_connection(*get_read_frames_with_lifespan([]))
21-
connection_class.write_heartbeat = (write_heartbeat_mock := mock.Mock()) # type: ignore[method-assign]
31+
written_and_read_frames = []
2232

23-
async with EnrichedClient(connection_class=connection_class):
33+
async def mock_read_frames(iterable: Iterable[AnyServerFrame]) -> AsyncIterable[AnyServerFrame]:
34+
async for frame in make_async_iter(iterable):
35+
written_and_read_frames.append(frame)
36+
yield frame
37+
await asyncio.Future()
38+
39+
connected_frame = build_dataclass(
40+
ConnectedFrame, headers={"version": EnrichedClient.PROTOCOL_VERSION, "heart-beat": "1000,1000"}
41+
)
42+
receipt_frame = build_dataclass(ReceiptFrame)
43+
44+
on_heartbeat = mock.Mock()
45+
write_heartbeat = mock.Mock()
46+
connection = mock.AsyncMock(
47+
read_frames=mock.Mock(
48+
side_effect=[
49+
mock_read_frames([connected_frame]),
50+
mock_read_frames([HeartbeatFrame(), HeartbeatFrame()]),
51+
mock_read_frames([receipt_frame]),
52+
]
53+
),
54+
write_frame=mock.AsyncMock(side_effect=written_and_read_frames.append),
55+
write_heartbeat=write_heartbeat,
56+
)
57+
connection_class = mock.Mock(connect=mock.AsyncMock(return_value=connection))
58+
connection_parameters = build_dataclass(ConnectionParameters)
59+
async with Client(servers=[connection_parameters], connection_class=connection_class, on_heartbeat=on_heartbeat):
2460
await real_sleep(0)
2561
await real_sleep(0)
2662
await real_sleep(0)
2763

28-
assert sleep_calls == [0, 1, 1]
29-
assert write_heartbeat_mock.mock_calls == [mock.call(), mock.call(), mock.call()]
64+
assert written_and_read_frames == [
65+
ConnectFrame(
66+
headers={
67+
"accept-version": "1.2",
68+
"heart-beat": "1000,1000",
69+
"host": connection_parameters.host,
70+
"login": connection_parameters.login,
71+
"passcode": connection_parameters.unescaped_passcode,
72+
}
73+
),
74+
connected_frame,
75+
HeartbeatFrame(),
76+
HeartbeatFrame(),
77+
written_and_read_frames[-2],
78+
receipt_frame,
79+
]
80+
81+
assert sleep_calls == [0, 1, 0, 1]
82+
assert write_heartbeat.mock_calls == [mock.call(), mock.call(), mock.call()]
83+
assert on_heartbeat.mock_calls == [mock.call(), mock.call()]

0 commit comments

Comments
 (0)