Skip to content

Commit c5023db

Browse files
authored
Refactor lifespan and errors (#53)
1 parent 5578cae commit c5023db

File tree

10 files changed

+243
-187
lines changed

10 files changed

+243
-187
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ stompman takes care of cleaning up resources automatically. When you leave the c
107107

108108
- If multiple servers were provided, stompman will attempt to connect to each one simultaneously and will use the first that succeeds. If all servers fail to connect, an `stompman.FailedAllConnectAttemptsError` will be raised. In normal situation it doesn't need to be handled: tune retry and timeout parameters in `stompman.Client()` to your needs.
109109

110-
- When connection is lost, stompman will handle it automatically. `stompman.FailedAllConnectAttemptsError` will be raised if all connection attempts fail. `stompman.RepeatedConnectionLostError` will be raised if connection succeeds but operation (like sending a frame) leads to connection getting lost.
110+
- When connection is lost, stompman will handle it automatically. `stompman.FailedAllConnectAttemptsError` will be raised if all connection attempts fail. `stompman.RepeatedConnectionLostError` or `stompman.ConnectionLostDuringOperationError` will be raised if connection succeeds but operation (like sending a frame) leads to connection getting lost.
111111

112112
### ...and caveats
113113

stompman/__init__.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,21 @@
1-
from stompman.client import Client, Subscription, Transaction
1+
from stompman.client import Client, ConnectionLifespan, Subscription, Transaction
22
from stompman.config import ConnectionParameters, Heartbeat
33
from stompman.connection import AbstractConnection, Connection
4-
from stompman.connection_manager import ActiveConnectionState, ConnectionManager
4+
from stompman.connection_manager import (
5+
AbstractConnectionLifespan,
6+
ActiveConnectionState,
7+
ConnectionLifespanFactory,
8+
ConnectionManager,
9+
)
510
from stompman.errors import (
6-
ConnectionConfirmationTimeoutError,
11+
ConnectionAttemptsFailedError,
12+
ConnectionConfirmationTimeout,
13+
ConnectionLostDuringOperationError,
714
ConnectionLostError,
815
Error,
916
FailedAllConnectAttemptsError,
10-
RepeatedConnectionLostError,
11-
UnsupportedProtocolVersionError,
17+
StompProtocolConnectionIssue,
18+
UnsupportedProtocolVersion,
1219
)
1320
from stompman.frames import (
1421
AbortFrame,
@@ -31,10 +38,12 @@
3138
SubscribeFrame,
3239
UnsubscribeFrame,
3340
)
41+
from stompman.serde import FrameParser, dump_frame
3442

3543
__all__ = [
3644
"AbortFrame",
3745
"AbstractConnection",
46+
"AbstractConnectionLifespan",
3847
"AckFrame",
3948
"AckMode",
4049
"ActiveConnectionState",
@@ -47,24 +56,30 @@
4756
"ConnectFrame",
4857
"ConnectedFrame",
4958
"Connection",
50-
"ConnectionConfirmationTimeoutError",
59+
"ConnectionAttemptsFailedError",
60+
"ConnectionConfirmationTimeout",
61+
"ConnectionLifespan",
62+
"ConnectionLifespanFactory",
63+
"ConnectionLostDuringOperationError",
5164
"ConnectionLostError",
5265
"ConnectionManager",
5366
"ConnectionParameters",
5467
"DisconnectFrame",
5568
"Error",
5669
"ErrorFrame",
5770
"FailedAllConnectAttemptsError",
71+
"FrameParser",
5872
"Heartbeat",
5973
"HeartbeatFrame",
6074
"MessageFrame",
6175
"NackFrame",
6276
"ReceiptFrame",
63-
"RepeatedConnectionLostError",
6477
"SendFrame",
78+
"StompProtocolConnectionIssue",
6579
"SubscribeFrame",
6680
"Subscription",
6781
"Transaction",
6882
"UnsubscribeFrame",
69-
"UnsupportedProtocolVersionError",
83+
"UnsupportedProtocolVersion",
84+
"dump_frame",
7085
]

stompman/client.py

Lines changed: 95 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
import asyncio
2-
from collections.abc import AsyncGenerator, AsyncIterator, Callable, Coroutine
2+
from collections.abc import AsyncGenerator, Callable, Coroutine
33
from contextlib import AsyncExitStack, asynccontextmanager, suppress
44
from dataclasses import dataclass, field
5+
from functools import partial
56
from types import TracebackType
67
from typing import ClassVar, Self
78
from uuid import uuid4
89

910
from stompman.config import ConnectionParameters, Heartbeat
1011
from stompman.connection import AbstractConnection, Connection
11-
from stompman.connection_manager import ConnectionManager
12-
from stompman.errors import ConnectionConfirmationTimeoutError, UnsupportedProtocolVersionError
12+
from stompman.connection_manager import AbstractConnectionLifespan, ConnectionManager
13+
from stompman.errors import ConnectionConfirmationTimeout, StompProtocolConnectionIssue, UnsupportedProtocolVersion
1314
from stompman.frames import (
1415
AbortFrame,
1516
AckFrame,
@@ -30,66 +31,93 @@
3031
)
3132

3233

33-
@asynccontextmanager
34-
async def connection_lifespan(
35-
*,
36-
connection: AbstractConnection,
37-
connection_parameters: ConnectionParameters,
38-
protocol_version: str,
39-
client_heartbeat: Heartbeat,
40-
connection_confirmation_timeout: int,
41-
disconnect_confirmation_timeout: int,
42-
) -> AsyncIterator[float]:
43-
await connection.write_frame(
44-
ConnectFrame(
45-
headers={
46-
"accept-version": protocol_version,
47-
"heart-beat": client_heartbeat.to_header(),
48-
"host": connection_parameters.host,
49-
"login": connection_parameters.login,
50-
"passcode": connection_parameters.unescaped_passcode,
51-
},
52-
)
53-
)
54-
collected_frames = []
55-
56-
async def take_connected_frame_and_collect_other_frames() -> ConnectedFrame:
57-
async for frame in connection.read_frames():
58-
if isinstance(frame, ConnectedFrame):
59-
return frame
60-
collected_frames.append(frame)
61-
msg = "unreachable" # pragma: no cover
62-
raise AssertionError(msg) # pragma: no cover
63-
64-
try:
65-
connected_frame = await asyncio.wait_for(
66-
take_connected_frame_and_collect_other_frames(), timeout=connection_confirmation_timeout
34+
@dataclass(kw_only=True, slots=True)
35+
class ConnectionLifespan(AbstractConnectionLifespan):
36+
connection: AbstractConnection
37+
connection_parameters: ConnectionParameters
38+
protocol_version: str
39+
client_heartbeat: Heartbeat
40+
connection_confirmation_timeout: int
41+
disconnect_confirmation_timeout: int
42+
active_subscriptions: dict[str, "Subscription"]
43+
active_transactions: set["Transaction"]
44+
set_heartbeat_interval: Callable[[float], None]
45+
46+
async def _establish_connection(self) -> StompProtocolConnectionIssue | None:
47+
await self.connection.write_frame(
48+
ConnectFrame(
49+
headers={
50+
"accept-version": self.protocol_version,
51+
"heart-beat": self.client_heartbeat.to_header(),
52+
"host": self.connection_parameters.host,
53+
"login": self.connection_parameters.login,
54+
"passcode": self.connection_parameters.unescaped_passcode,
55+
},
56+
)
6757
)
68-
except TimeoutError as exception:
69-
raise ConnectionConfirmationTimeoutError(
70-
timeout=connection_confirmation_timeout, frames=collected_frames
71-
) from exception
72-
73-
if connected_frame.headers["version"] != protocol_version:
74-
raise UnsupportedProtocolVersionError(
75-
given_version=connected_frame.headers["version"], supported_version=protocol_version
58+
collected_frames = []
59+
60+
async def take_connected_frame_and_collect_other_frames() -> ConnectedFrame:
61+
async for frame in self.connection.read_frames():
62+
if isinstance(frame, ConnectedFrame):
63+
return frame
64+
collected_frames.append(frame)
65+
msg = "unreachable" # pragma: no cover
66+
raise AssertionError(msg) # pragma: no cover
67+
68+
try:
69+
connected_frame = await asyncio.wait_for(
70+
take_connected_frame_and_collect_other_frames(), timeout=self.connection_confirmation_timeout
71+
)
72+
except TimeoutError:
73+
return ConnectionConfirmationTimeout(timeout=self.connection_confirmation_timeout, frames=collected_frames)
74+
75+
if connected_frame.headers["version"] != self.protocol_version:
76+
return UnsupportedProtocolVersion(
77+
given_version=connected_frame.headers["version"], supported_version=self.protocol_version
78+
)
79+
80+
server_heartbeat = Heartbeat.from_header(connected_frame.headers["heart-beat"])
81+
self.set_heartbeat_interval(
82+
max(self.client_heartbeat.will_send_interval_ms, server_heartbeat.want_to_receive_interval_ms) / 1000
7683
)
84+
return None
7785

78-
server_heartbeat = Heartbeat.from_header(connected_frame.headers["heart-beat"])
79-
heartbeat_interval = (
80-
max(client_heartbeat.will_send_interval_ms, server_heartbeat.want_to_receive_interval_ms) / 1000
81-
)
82-
yield heartbeat_interval
86+
async def _resubscribe(self) -> None:
87+
for subscription in self.active_subscriptions.values():
88+
await self.connection.write_frame(
89+
SubscribeFrame(
90+
headers={"id": subscription.id, "destination": subscription.destination, "ack": subscription.ack}
91+
)
92+
)
8393

84-
await connection.write_frame(DisconnectFrame(headers={"receipt": _make_receipt_id()}))
94+
async def _commit_pending_transactions(self) -> None:
95+
for transaction in self.active_transactions:
96+
for frame in transaction.sent_frames:
97+
await self.connection.write_frame(frame)
98+
await self.connection.write_frame(CommitFrame(headers={"transaction": transaction.id}))
99+
self.active_transactions.clear()
85100

86-
async def take_receipt_frame() -> None:
87-
async for frame in connection.read_frames():
88-
if isinstance(frame, ReceiptFrame):
89-
break
101+
async def enter(self) -> StompProtocolConnectionIssue | None:
102+
if connection_issue := await self._establish_connection():
103+
return connection_issue
104+
await self._resubscribe()
105+
await self._commit_pending_transactions()
106+
return None
90107

91-
with suppress(TimeoutError):
92-
await asyncio.wait_for(take_receipt_frame(), timeout=disconnect_confirmation_timeout)
108+
async def exit(self) -> None:
109+
for subscription in self.active_subscriptions.copy().values():
110+
await subscription.unsubscribe()
111+
112+
await self.connection.write_frame(DisconnectFrame(headers={"receipt": _make_receipt_id()}))
113+
114+
async def take_receipt_frame() -> None:
115+
async for frame in self.connection.read_frames():
116+
if isinstance(frame, ReceiptFrame):
117+
break
118+
119+
with suppress(TimeoutError):
120+
await asyncio.wait_for(take_receipt_frame(), timeout=self.disconnect_confirmation_timeout)
93121

94122

95123
def _make_receipt_id() -> str:
@@ -146,21 +174,6 @@ def _make_subscription_id() -> str:
146174
return str(uuid4())
147175

148176

149-
@asynccontextmanager
150-
async def subscriptions_lifespan(
151-
*, connection: AbstractConnection, active_subscriptions: dict[str, Subscription]
152-
) -> AsyncIterator[None]:
153-
for subscription in active_subscriptions.values():
154-
await connection.write_frame(
155-
SubscribeFrame(
156-
headers={"id": subscription.id, "destination": subscription.destination, "ack": subscription.ack}
157-
)
158-
)
159-
yield
160-
for subscription in active_subscriptions.copy().values():
161-
await subscription.unsubscribe()
162-
163-
164177
@dataclass(kw_only=True, slots=True, unsafe_hash=True)
165178
class Transaction:
166179
id: str = field(default_factory=lambda: _make_transaction_id(), init=False) # noqa: PLW0108
@@ -198,14 +211,6 @@ def _make_transaction_id() -> str:
198211
return str(uuid4())
199212

200213

201-
async def commit_pending_transactions(*, connection: AbstractConnection, active_transactions: set[Transaction]) -> None:
202-
for transaction in active_transactions:
203-
for frame in transaction.sent_frames:
204-
await connection.write_frame(frame)
205-
await connection.write_frame(CommitFrame(headers={"transaction": transaction.id}))
206-
active_transactions.clear()
207-
208-
209214
@dataclass(kw_only=True, slots=True)
210215
class Client:
211216
PROTOCOL_VERSION: ClassVar = "1.2" # https://stomp.github.io/stomp-specification-1.2.html
@@ -236,7 +241,16 @@ class Client:
236241
def __post_init__(self) -> None:
237242
self._connection_manager = ConnectionManager(
238243
servers=self.servers,
239-
lifespan=self._lifespan,
244+
lifespan_factory=partial(
245+
ConnectionLifespan,
246+
protocol_version=self.PROTOCOL_VERSION,
247+
client_heartbeat=self.heartbeat,
248+
connection_confirmation_timeout=self.connection_confirmation_timeout,
249+
disconnect_confirmation_timeout=self.disconnect_confirmation_timeout,
250+
active_subscriptions=self._active_subscriptions,
251+
active_transactions=self._active_transactions,
252+
set_heartbeat_interval=self._restart_heartbeat_task,
253+
),
240254
connection_class=self.connection_class,
241255
connect_retry_attempts=self.connect_retry_attempts,
242256
connect_retry_interval=self.connect_retry_interval,
@@ -264,23 +278,6 @@ async def __aexit__(
264278
await asyncio.wait([self._listen_task, self._heartbeat_task])
265279
await self._exit_stack.aclose()
266280

267-
@asynccontextmanager
268-
async def _lifespan(
269-
self, connection: AbstractConnection, connection_parameters: ConnectionParameters
270-
) -> AsyncGenerator[None, None]:
271-
async with connection_lifespan(
272-
connection=connection,
273-
connection_parameters=connection_parameters,
274-
protocol_version=self.PROTOCOL_VERSION,
275-
client_heartbeat=self.heartbeat,
276-
connection_confirmation_timeout=self.connection_confirmation_timeout,
277-
disconnect_confirmation_timeout=self.disconnect_confirmation_timeout,
278-
) as heartbeat_interval:
279-
self._restart_heartbeat_task(heartbeat_interval)
280-
async with subscriptions_lifespan(connection=connection, active_subscriptions=self._active_subscriptions):
281-
await commit_pending_transactions(connection=connection, active_transactions=self._active_transactions)
282-
yield
283-
284281
def _restart_heartbeat_task(self, interval: float) -> None:
285282
self._heartbeat_task.cancel()
286283
self._heartbeat_task = self._task_group.create_task(self._send_heartbeats_forever(interval))

0 commit comments

Comments
 (0)