Skip to content

Commit 9730d08

Browse files
authored
Refactor internals (#24)
1 parent 65735e1 commit 9730d08

File tree

11 files changed

+467
-466
lines changed

11 files changed

+467
-466
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ Or, to send messages in a transaction:
5050
```python
5151
async with client.enter_transaction() as transaction:
5252
for _ in range(10):
53-
await client.send(body=b"hi there!", destination="DLQ", transaction=transaction)
53+
await client.send(body=b"hi there!", destination="DLQ", transaction=transaction, headers={"persistent": "true"})
5454
await asyncio.sleep(0.1)
5555
```
5656

stompman/__init__.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
1-
from stompman.client import Client, Heartbeat
2-
from stompman.connection import AbstractConnection, Connection, ConnectionParameters
1+
from stompman.client import (
2+
AnyListeningEvent,
3+
Client,
4+
ConnectionParameters,
5+
ErrorEvent,
6+
Heartbeat,
7+
HeartbeatEvent,
8+
MessageEvent,
9+
)
10+
from stompman.connection import AbstractConnection, Connection
311
from stompman.errors import (
412
ConnectionConfirmationTimeoutError,
513
ConnectionLostError,
@@ -26,7 +34,6 @@
2634
SubscribeFrame,
2735
UnsubscribeFrame,
2836
)
29-
from stompman.listening_events import AnyListeningEvent, ErrorEvent, HeartbeatEvent, MessageEvent
3037

3138
__all__ = [
3239
"AbortFrame",

stompman/client.py

Lines changed: 173 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
import asyncio
2-
from collections.abc import AsyncGenerator, AsyncIterator
2+
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable
33
from contextlib import AsyncExitStack, asynccontextmanager
44
from dataclasses import dataclass, field
55
from types import TracebackType
6-
from typing import NamedTuple, Self
6+
from typing import Any, ClassVar, NamedTuple, Self, TypedDict
77
from uuid import uuid4
88

9-
from stompman.connection import AbstractConnection, Connection, ConnectionParameters
9+
from stompman.connection import AbstractConnection, Connection
1010
from stompman.errors import (
1111
ConnectionConfirmationTimeoutError,
1212
FailedAllConnectAttemptsError,
1313
UnsupportedProtocolVersionError,
1414
)
1515
from stompman.frames import (
1616
AbortFrame,
17+
AckFrame,
1718
BeginFrame,
1819
CommitFrame,
1920
ConnectedFrame,
@@ -22,14 +23,13 @@
2223
ErrorFrame,
2324
HeartbeatFrame,
2425
MessageFrame,
26+
NackFrame,
2527
ReceiptFrame,
2628
SendFrame,
2729
SendHeaders,
2830
SubscribeFrame,
2931
UnsubscribeFrame,
3032
)
31-
from stompman.listening_events import AnyListeningEvent, ErrorEvent, HeartbeatEvent, MessageEvent
32-
from stompman.protocol import PROTOCOL_VERSION
3333

3434

3535
class Heartbeat(NamedTuple):
@@ -45,6 +45,56 @@ def from_header(cls, header: str) -> Self:
4545
return cls(int(first), int(second))
4646

4747

48+
class MultiHostHostLike(TypedDict):
49+
username: str | None
50+
password: str | None
51+
host: str | None
52+
port: int | None
53+
54+
55+
@dataclass
56+
class ConnectionParameters:
57+
host: str
58+
port: int
59+
login: str
60+
passcode: str = field(repr=False)
61+
62+
@classmethod
63+
def from_pydantic_multihost_hosts(cls, hosts: list[MultiHostHostLike]) -> list[Self]:
64+
"""Create connection parameters from a list of `MultiHostUrl` objects.
65+
66+
.. code-block:: python
67+
import stompman.
68+
69+
ArtemisDsn = typing.Annotated[
70+
pydantic_core.MultiHostUrl,
71+
pydantic.UrlConstraints(
72+
host_required=True,
73+
allowed_schemes=["tcp"],
74+
),
75+
]
76+
77+
async with stompman.Client(
78+
servers=stompman.ConnectionParameters.from_pydantic_multihost_hosts(
79+
ArtemisDsn("tcp://lev:pass@host1:61616,lev:pass@host1:61617,lev:pass@host2:61616").hosts()
80+
),
81+
):
82+
...
83+
"""
84+
servers: list[Self] = []
85+
for host in hosts:
86+
if host["host"] is None:
87+
raise ValueError("host must be set")
88+
if host["port"] is None:
89+
raise ValueError("port must be set")
90+
if host["username"] is None:
91+
raise ValueError("username must be set")
92+
if host["password"] is None:
93+
raise ValueError("password must be set")
94+
servers.append(cls(host=host["host"], port=host["port"], login=host["username"], passcode=host["password"]))
95+
return servers
96+
97+
4898
@dataclass
4999
class Client:
50100
servers: list[ConnectionParameters]
@@ -57,11 +107,14 @@ class Client:
57107
read_max_chunk_size: int = 1024 * 1024
58108
connection_class: type[AbstractConnection] = Connection
59109

110+
PROTOCOL_VERSION: ClassVar = "1.2" # https://stomp.github.io/stomp-specification-1.2.html
111+
60112
_connection: AbstractConnection = field(init=False)
113+
_connection_parameters: ConnectionParameters = field(init=False)
61114
_exit_stack: AsyncExitStack = field(default_factory=AsyncExitStack, init=False)
62115

63116
async def __aenter__(self) -> Self:
64-
self._connection = await self._connect_to_any_server()
117+
await self._connect_to_any_server()
65118
await self._exit_stack.enter_async_context(self._connection_lifespan())
66119
return self
67120

@@ -71,26 +124,25 @@ async def __aexit__(
71124
await self._exit_stack.aclose()
72125
await self._connection.close()
73126

74-
async def _connect_to_one_server(self, server: ConnectionParameters) -> AbstractConnection | None:
127+
async def _connect_to_one_server(
128+
self, server: ConnectionParameters
129+
) -> tuple[AbstractConnection, ConnectionParameters] | None:
75130
for attempt in range(self.connect_retry_attempts):
76-
connection = self.connection_class(
77-
connection_parameters=server,
78-
connect_timeout=self.connect_timeout,
79-
read_timeout=self.read_timeout,
80-
read_max_chunk_size=self.read_max_chunk_size,
81-
)
82-
if await connection.connect():
83-
return connection
131+
if connection := await self.connection_class.connect(
132+
host=server.host, port=server.port, timeout=self.connect_timeout
133+
):
134+
return connection, server
84135
await asyncio.sleep(self.connect_retry_interval * (attempt + 1))
85136
return None
86137

87-
async def _connect_to_any_server(self) -> AbstractConnection:
138+
async def _connect_to_any_server(self) -> None:
88139
for maybe_connection_future in asyncio.as_completed(
89140
[self._connect_to_one_server(server) for server in self.servers]
90141
):
91-
maybe_connection = await maybe_connection_future
92-
if maybe_connection:
93-
return maybe_connection
142+
maybe_result = await maybe_connection_future
143+
if maybe_result:
144+
self._connection, self._connection_parameters = maybe_result
145+
return
94146
raise FailedAllConnectAttemptsError(
95147
servers=self.servers,
96148
retry_attempts=self.connect_retry_attempts,
@@ -112,24 +164,27 @@ async def _connection_lifespan(self) -> AsyncGenerator[None, None]:
112164
await self._connection.write_frame(
113165
ConnectFrame(
114166
headers={
115-
"accept-version": PROTOCOL_VERSION,
167+
"accept-version": self.PROTOCOL_VERSION,
116168
"heart-beat": self.heartbeat.to_header(),
117-
"host": self._connection.connection_parameters.host,
118-
"login": self._connection.connection_parameters.login,
119-
"passcode": self._connection.connection_parameters.passcode,
169+
"host": self._connection_parameters.host,
170+
"login": self._connection_parameters.login,
171+
"passcode": self._connection_parameters.passcode,
120172
},
121173
)
122174
)
123175
try:
124176
connected_frame = await asyncio.wait_for(
125-
self._connection.read_frame_of_type(ConnectedFrame), timeout=self.connection_confirmation_timeout
177+
self._connection.read_frame_of_type(
178+
ConnectedFrame, max_chunk_size=self.read_max_chunk_size, timeout=self.read_timeout
179+
),
180+
timeout=self.connection_confirmation_timeout,
126181
)
127182
except TimeoutError as exception:
128183
raise ConnectionConfirmationTimeoutError(self.connection_confirmation_timeout) from exception
129184

130-
if connected_frame.headers["version"] != PROTOCOL_VERSION:
185+
if connected_frame.headers["version"] != self.PROTOCOL_VERSION:
131186
raise UnsupportedProtocolVersionError(
132-
given_version=connected_frame.headers["version"], supported_version=PROTOCOL_VERSION
187+
given_version=connected_frame.headers["version"], supported_version=self.PROTOCOL_VERSION
133188
)
134189

135190
server_heartbeat = Heartbeat.from_header(connected_frame.headers["heart-beat"])
@@ -150,30 +205,9 @@ async def send_heartbeats_forever() -> None:
150205
task.cancel()
151206

152207
await self._connection.write_frame(DisconnectFrame(headers={"receipt": str(uuid4())}))
153-
await self._connection.read_frame_of_type(ReceiptFrame)
154-
155-
@asynccontextmanager
156-
async def subscribe(self, destination: str) -> AsyncGenerator[None, None]:
157-
subscription_id = str(uuid4())
158-
await self._connection.write_frame(
159-
SubscribeFrame(headers={"id": subscription_id, "destination": destination, "ack": "client-individual"})
208+
await self._connection.read_frame_of_type(
209+
ReceiptFrame, max_chunk_size=self.read_max_chunk_size, timeout=self.read_timeout
160210
)
161-
try:
162-
yield
163-
finally:
164-
await self._connection.write_frame(UnsubscribeFrame(headers={"id": subscription_id}))
165-
166-
async def listen(self) -> AsyncIterator[AnyListeningEvent]:
167-
async for frame in self._connection.read_frames():
168-
match frame:
169-
case MessageFrame():
170-
yield MessageEvent(_client=self, _frame=frame)
171-
case ErrorFrame():
172-
yield ErrorEvent(_client=self, _frame=frame)
173-
case HeartbeatFrame():
174-
yield HeartbeatEvent(_client=self, _frame=frame)
175-
case ConnectedFrame() | ReceiptFrame():
176-
raise AssertionError("Should be unreachable! Report the issue.", frame)
177211

178212
@asynccontextmanager
179213
async def enter_transaction(self) -> AsyncGenerator[str, None]:
@@ -204,3 +238,93 @@ async def send( # noqa: PLR0913
204238
if transaction is not None:
205239
full_headers["transaction"] = transaction
206240
await self._connection.write_frame(SendFrame(headers=full_headers, body=body))
241+
242+
@asynccontextmanager
243+
async def subscribe(self, destination: str) -> AsyncGenerator[None, None]:
244+
subscription_id = str(uuid4())
245+
await self._connection.write_frame(
246+
SubscribeFrame(headers={"id": subscription_id, "destination": destination, "ack": "client-individual"})
247+
)
248+
try:
249+
yield
250+
finally:
251+
await self._connection.write_frame(UnsubscribeFrame(headers={"id": subscription_id}))
252+
253+
async def listen(self) -> AsyncIterator["AnyListeningEvent"]:
254+
async for frame in self._connection.read_frames(
255+
max_chunk_size=self.read_max_chunk_size, timeout=self.read_timeout
256+
):
257+
match frame:
258+
case MessageFrame():
259+
yield MessageEvent(_client=self, _frame=frame)
260+
case ErrorFrame():
261+
yield ErrorEvent(_client=self, _frame=frame)
262+
case HeartbeatFrame():
263+
yield HeartbeatEvent(_client=self, _frame=frame)
264+
case ConnectedFrame() | ReceiptFrame():
265+
raise AssertionError("Should be unreachable! Report the issue.", frame)
266+
267+
268+
@dataclass
269+
class MessageEvent:
270+
body: bytes = field(init=False)
271+
_frame: MessageFrame
272+
_client: "Client" = field(repr=False)
273+
274+
def __post_init__(self) -> None:
275+
self.body = self._frame.body
276+
277+
async def ack(self) -> None:
278+
await self._client._connection.write_frame(
279+
AckFrame(
280+
headers={"id": self._frame.headers["message-id"], "subscription": self._frame.headers["subscription"]},
281+
)
282+
)
283+
284+
async def nack(self) -> None:
285+
await self._client._connection.write_frame(
286+
NackFrame(
287+
headers={"id": self._frame.headers["message-id"], "subscription": self._frame.headers["subscription"]}
288+
)
289+
)
290+
291+
async def with_auto_ack(
292+
self,
293+
awaitable: Awaitable[None],
294+
*,
295+
on_suppressed_exception: Callable[[Exception, Self], Any],
296+
supressed_exception_classes: tuple[type[Exception], ...] = (Exception,),
297+
) -> None:
298+
called_nack = False
299+
try:
300+
await awaitable
301+
except supressed_exception_classes as exception:
302+
await self.nack()
303+
called_nack = True
304+
on_suppressed_exception(exception, self)
305+
finally:
306+
if not called_nack:
307+
await self.ack()
308+
309+
310+
@dataclass
311+
class ErrorEvent:
312+
message_header: str = field(init=False)
313+
"""Short description of the error."""
314+
body: bytes = field(init=False)
315+
"""Long description of the error."""
316+
_frame: ErrorFrame
317+
_client: "Client" = field(repr=False)
318+
319+
def __post_init__(self) -> None:
320+
self.message_header = self._frame.headers["message"]
321+
self.body = self._frame.body
322+
323+
324+
@dataclass
325+
class HeartbeatEvent:
326+
_frame: HeartbeatFrame
327+
_client: "Client" = field(repr=False)
328+
329+
330+
AnyListeningEvent = MessageEvent | ErrorEvent | HeartbeatEvent

0 commit comments

Comments
 (0)