Skip to content

Commit 369f100

Browse files
authored
Refactor client modules (#55)
1 parent b164548 commit 369f100

13 files changed

+617
-520
lines changed

stompman/__init__.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,5 @@
1-
from stompman.client import Client, ConnectionLifespan, Subscription, Transaction
1+
from stompman.client import Client
22
from stompman.config import ConnectionParameters, Heartbeat
3-
from stompman.connection import AbstractConnection, Connection
4-
from stompman.connection_manager import (
5-
AbstractConnectionLifespan,
6-
ActiveConnectionState,
7-
ConnectionLifespanFactory,
8-
ConnectionManager,
9-
)
103
from stompman.errors import (
114
ConnectionConfirmationTimeout,
125
ConnectionLostError,
@@ -38,14 +31,13 @@
3831
UnsubscribeFrame,
3932
)
4033
from stompman.serde import FrameParser, dump_frame
34+
from stompman.subscription import Subscription
35+
from stompman.transaction import Transaction
4136

4237
__all__ = [
4338
"AbortFrame",
44-
"AbstractConnection",
45-
"AbstractConnectionLifespan",
4639
"AckFrame",
4740
"AckMode",
48-
"ActiveConnectionState",
4941
"AnyClientFrame",
5042
"AnyRealServerFrame",
5143
"AnyServerFrame",
@@ -54,12 +46,8 @@
5446
"CommitFrame",
5547
"ConnectFrame",
5648
"ConnectedFrame",
57-
"Connection",
5849
"ConnectionConfirmationTimeout",
59-
"ConnectionLifespan",
60-
"ConnectionLifespanFactory",
6150
"ConnectionLostError",
62-
"ConnectionManager",
6351
"ConnectionParameters",
6452
"DisconnectFrame",
6553
"Error",

stompman/client.py

Lines changed: 5 additions & 193 deletions
Original file line numberDiff line numberDiff line change
@@ -1,214 +1,26 @@
11
import asyncio
22
from collections.abc import AsyncGenerator, Callable, Coroutine
3-
from contextlib import AsyncExitStack, asynccontextmanager, suppress
3+
from contextlib import AsyncExitStack, asynccontextmanager
44
from dataclasses import dataclass, field
55
from functools import partial
66
from types import TracebackType
77
from typing import ClassVar, Self
8-
from uuid import uuid4
98

109
from stompman.config import ConnectionParameters, Heartbeat
1110
from stompman.connection import AbstractConnection, Connection
12-
from stompman.connection_manager import AbstractConnectionLifespan, ConnectionManager
13-
from stompman.errors import ConnectionConfirmationTimeout, StompProtocolConnectionIssue, UnsupportedProtocolVersion
11+
from stompman.connection_lifespan import ConnectionLifespan
12+
from stompman.connection_manager import ConnectionManager
1413
from stompman.frames import (
15-
AbortFrame,
16-
AckFrame,
1714
AckMode,
18-
BeginFrame,
19-
CommitFrame,
2015
ConnectedFrame,
21-
ConnectFrame,
22-
DisconnectFrame,
2316
ErrorFrame,
2417
HeartbeatFrame,
2518
MessageFrame,
26-
NackFrame,
2719
ReceiptFrame,
2820
SendFrame,
29-
SubscribeFrame,
30-
UnsubscribeFrame,
3121
)
32-
33-
34-
@dataclass(kw_only=True, slots=True)
35-
class ConnectionLifespan(AbstractConnectionLifespan):
36-
connection: AbstractConnection
37-
connection_parameters: ConnectionParameters
38-
protocol_version: str
39-
client_heartbeat: Heartbeat
40-
connection_confirmation_timeout: int
41-
disconnect_confirmation_timeout: int
42-
active_subscriptions: dict[str, "Subscription"]
43-
active_transactions: set["Transaction"]
44-
set_heartbeat_interval: Callable[[float], None]
45-
46-
async def _establish_connection(self) -> StompProtocolConnectionIssue | None:
47-
await self.connection.write_frame(
48-
ConnectFrame(
49-
headers={
50-
"accept-version": self.protocol_version,
51-
"heart-beat": self.client_heartbeat.to_header(),
52-
"host": self.connection_parameters.host,
53-
"login": self.connection_parameters.login,
54-
"passcode": self.connection_parameters.unescaped_passcode,
55-
},
56-
)
57-
)
58-
collected_frames = []
59-
60-
async def take_connected_frame_and_collect_other_frames() -> ConnectedFrame:
61-
async for frame in self.connection.read_frames():
62-
if isinstance(frame, ConnectedFrame):
63-
return frame
64-
collected_frames.append(frame)
65-
msg = "unreachable" # pragma: no cover
66-
raise AssertionError(msg) # pragma: no cover
67-
68-
try:
69-
connected_frame = await asyncio.wait_for(
70-
take_connected_frame_and_collect_other_frames(), timeout=self.connection_confirmation_timeout
71-
)
72-
except TimeoutError:
73-
return ConnectionConfirmationTimeout(timeout=self.connection_confirmation_timeout, frames=collected_frames)
74-
75-
if connected_frame.headers["version"] != self.protocol_version:
76-
return UnsupportedProtocolVersion(
77-
given_version=connected_frame.headers["version"], supported_version=self.protocol_version
78-
)
79-
80-
server_heartbeat = Heartbeat.from_header(connected_frame.headers["heart-beat"])
81-
self.set_heartbeat_interval(
82-
max(self.client_heartbeat.will_send_interval_ms, server_heartbeat.want_to_receive_interval_ms) / 1000
83-
)
84-
return None
85-
86-
async def _resubscribe(self) -> None:
87-
for subscription in self.active_subscriptions.values():
88-
await self.connection.write_frame(
89-
SubscribeFrame(
90-
headers={"id": subscription.id, "destination": subscription.destination, "ack": subscription.ack}
91-
)
92-
)
93-
94-
async def _commit_pending_transactions(self) -> None:
95-
for transaction in self.active_transactions:
96-
for frame in transaction.sent_frames:
97-
await self.connection.write_frame(frame)
98-
await self.connection.write_frame(CommitFrame(headers={"transaction": transaction.id}))
99-
self.active_transactions.clear()
100-
101-
async def enter(self) -> StompProtocolConnectionIssue | None:
102-
if connection_issue := await self._establish_connection():
103-
return connection_issue
104-
await self._resubscribe()
105-
await self._commit_pending_transactions()
106-
return None
107-
108-
async def exit(self) -> None:
109-
for subscription in self.active_subscriptions.copy().values():
110-
await subscription.unsubscribe()
111-
112-
await self.connection.write_frame(DisconnectFrame(headers={"receipt": _make_receipt_id()}))
113-
114-
async def take_receipt_frame() -> None:
115-
async for frame in self.connection.read_frames():
116-
if isinstance(frame, ReceiptFrame):
117-
break
118-
119-
with suppress(TimeoutError):
120-
await asyncio.wait_for(take_receipt_frame(), timeout=self.disconnect_confirmation_timeout)
121-
122-
123-
def _make_receipt_id() -> str:
124-
return str(uuid4())
125-
126-
127-
@dataclass(kw_only=True, slots=True)
128-
class Subscription:
129-
id: str = field(default_factory=lambda: _make_subscription_id(), init=False) # noqa: PLW0108
130-
destination: str
131-
handler: Callable[[MessageFrame], Coroutine[None, None, None]]
132-
ack: AckMode
133-
on_suppressed_exception: Callable[[Exception, MessageFrame], None]
134-
supressed_exception_classes: tuple[type[Exception], ...]
135-
_connection_manager: ConnectionManager
136-
_active_subscriptions: dict[str, "Subscription"]
137-
138-
_should_handle_ack_nack: bool = field(init=False)
139-
140-
def __post_init__(self) -> None:
141-
self._should_handle_ack_nack = self.ack in {"client", "client-individual"}
142-
143-
async def _subscribe(self) -> None:
144-
await self._connection_manager.write_frame_reconnecting(
145-
SubscribeFrame(headers={"id": self.id, "destination": self.destination, "ack": self.ack})
146-
)
147-
self._active_subscriptions[self.id] = self
148-
149-
async def unsubscribe(self) -> None:
150-
del self._active_subscriptions[self.id]
151-
await self._connection_manager.maybe_write_frame(UnsubscribeFrame(headers={"id": self.id}))
152-
153-
async def _run_handler(self, *, frame: MessageFrame) -> None:
154-
try:
155-
await self.handler(frame)
156-
except self.supressed_exception_classes as exception:
157-
if self._should_handle_ack_nack and self.id in self._active_subscriptions:
158-
await self._connection_manager.maybe_write_frame(
159-
NackFrame(
160-
headers={"id": frame.headers["message-id"], "subscription": frame.headers["subscription"]}
161-
)
162-
)
163-
self.on_suppressed_exception(exception, frame)
164-
else:
165-
if self._should_handle_ack_nack and self.id in self._active_subscriptions:
166-
await self._connection_manager.maybe_write_frame(
167-
AckFrame(
168-
headers={"id": frame.headers["message-id"], "subscription": frame.headers["subscription"]},
169-
)
170-
)
171-
172-
173-
def _make_subscription_id() -> str:
174-
return str(uuid4())
175-
176-
177-
@dataclass(kw_only=True, slots=True, unsafe_hash=True)
178-
class Transaction:
179-
id: str = field(default_factory=lambda: _make_transaction_id(), init=False) # noqa: PLW0108
180-
_connection_manager: ConnectionManager = field(hash=False)
181-
_active_transactions: set["Transaction"] = field(hash=False)
182-
sent_frames: list[SendFrame] = field(default_factory=list, init=False, hash=False)
183-
184-
async def __aenter__(self) -> Self:
185-
await self._connection_manager.write_frame_reconnecting(BeginFrame(headers={"transaction": self.id}))
186-
self._active_transactions.add(self)
187-
return self
188-
189-
async def __aexit__(
190-
self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None
191-
) -> None:
192-
if exc_value:
193-
await self._connection_manager.maybe_write_frame(AbortFrame(headers={"transaction": self.id}))
194-
self._active_transactions.remove(self)
195-
else:
196-
commited = await self._connection_manager.maybe_write_frame(CommitFrame(headers={"transaction": self.id}))
197-
if commited:
198-
self._active_transactions.remove(self)
199-
200-
async def send(
201-
self, body: bytes, destination: str, *, content_type: str | None = None, headers: dict[str, str] | None = None
202-
) -> None:
203-
frame = SendFrame.build(
204-
body=body, destination=destination, transaction=self.id, content_type=content_type, headers=headers
205-
)
206-
self.sent_frames.append(frame)
207-
await self._connection_manager.write_frame_reconnecting(frame)
208-
209-
210-
def _make_transaction_id() -> str:
211-
return str(uuid4())
22+
from stompman.subscription import Subscription
23+
from stompman.transaction import Transaction
21224

21325

21426
@dataclass(kw_only=True, slots=True)

stompman/connection_lifespan.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import asyncio
2+
from collections.abc import Callable
3+
from contextlib import suppress
4+
from dataclasses import dataclass
5+
from typing import Protocol
6+
from uuid import uuid4
7+
8+
from stompman.config import ConnectionParameters, Heartbeat
9+
from stompman.connection import AbstractConnection
10+
from stompman.errors import ConnectionConfirmationTimeout, StompProtocolConnectionIssue, UnsupportedProtocolVersion
11+
from stompman.frames import (
12+
ConnectedFrame,
13+
ConnectFrame,
14+
DisconnectFrame,
15+
ReceiptFrame,
16+
)
17+
from stompman.subscription import (
18+
ActiveSubscriptions,
19+
resubscribe_to_active_subscriptions,
20+
unsubscribe_from_all_active_subscriptions,
21+
)
22+
from stompman.transaction import ActiveTransactions, commit_pending_transactions
23+
24+
25+
class AbstractConnectionLifespan(Protocol):
26+
async def enter(self) -> StompProtocolConnectionIssue | None: ...
27+
async def exit(self) -> None: ...
28+
29+
30+
@dataclass(kw_only=True, slots=True)
31+
class ConnectionLifespan(AbstractConnectionLifespan):
32+
connection: AbstractConnection
33+
connection_parameters: ConnectionParameters
34+
protocol_version: str
35+
client_heartbeat: Heartbeat
36+
connection_confirmation_timeout: int
37+
disconnect_confirmation_timeout: int
38+
active_subscriptions: ActiveSubscriptions
39+
active_transactions: ActiveTransactions
40+
set_heartbeat_interval: Callable[[float], None]
41+
42+
async def _establish_connection(self) -> StompProtocolConnectionIssue | None:
43+
await self.connection.write_frame(
44+
ConnectFrame(
45+
headers={
46+
"accept-version": self.protocol_version,
47+
"heart-beat": self.client_heartbeat.to_header(),
48+
"host": self.connection_parameters.host,
49+
"login": self.connection_parameters.login,
50+
"passcode": self.connection_parameters.unescaped_passcode,
51+
},
52+
)
53+
)
54+
collected_frames = []
55+
56+
async def take_connected_frame_and_collect_other_frames() -> ConnectedFrame:
57+
async for frame in self.connection.read_frames():
58+
if isinstance(frame, ConnectedFrame):
59+
return frame
60+
collected_frames.append(frame)
61+
msg = "unreachable" # pragma: no cover
62+
raise AssertionError(msg) # pragma: no cover
63+
64+
try:
65+
connected_frame = await asyncio.wait_for(
66+
take_connected_frame_and_collect_other_frames(), timeout=self.connection_confirmation_timeout
67+
)
68+
except TimeoutError:
69+
return ConnectionConfirmationTimeout(timeout=self.connection_confirmation_timeout, frames=collected_frames)
70+
71+
if connected_frame.headers["version"] != self.protocol_version:
72+
return UnsupportedProtocolVersion(
73+
given_version=connected_frame.headers["version"], supported_version=self.protocol_version
74+
)
75+
76+
server_heartbeat = Heartbeat.from_header(connected_frame.headers["heart-beat"])
77+
self.set_heartbeat_interval(
78+
max(self.client_heartbeat.will_send_interval_ms, server_heartbeat.want_to_receive_interval_ms) / 1000
79+
)
80+
return None
81+
82+
async def enter(self) -> StompProtocolConnectionIssue | None:
83+
if connection_issue := await self._establish_connection():
84+
return connection_issue
85+
await resubscribe_to_active_subscriptions(
86+
connection=self.connection, active_subscriptions=self.active_subscriptions
87+
)
88+
await commit_pending_transactions(connection=self.connection, active_transactions=self.active_transactions)
89+
return None
90+
91+
async def _take_receipt_frame(self) -> None:
92+
async for frame in self.connection.read_frames():
93+
if isinstance(frame, ReceiptFrame):
94+
break
95+
96+
async def exit(self) -> None:
97+
await unsubscribe_from_all_active_subscriptions(active_subscriptions=self.active_subscriptions)
98+
await self.connection.write_frame(DisconnectFrame(headers={"receipt": _make_receipt_id()}))
99+
100+
with suppress(TimeoutError):
101+
await asyncio.wait_for(self._take_receipt_frame(), timeout=self.disconnect_confirmation_timeout)
102+
103+
104+
def _make_receipt_id() -> str:
105+
return str(uuid4())
106+
107+
108+
class ConnectionLifespanFactory(Protocol):
109+
def __call__(
110+
self, *, connection: AbstractConnection, connection_parameters: ConnectionParameters
111+
) -> AbstractConnectionLifespan: ...

0 commit comments

Comments
 (0)