1
1
import asyncio
2
- from collections .abc import AsyncGenerator , AsyncIterator
2
+ from collections .abc import AsyncGenerator , AsyncIterator , Awaitable , Callable
3
3
from contextlib import AsyncExitStack , asynccontextmanager
4
4
from dataclasses import dataclass , field
5
5
from types import TracebackType
6
- from typing import NamedTuple , Self
6
+ from typing import Any , ClassVar , NamedTuple , Self , TypedDict
7
7
from uuid import uuid4
8
8
9
- from stompman .connection import AbstractConnection , Connection , ConnectionParameters
9
+ from stompman .connection import AbstractConnection , Connection
10
10
from stompman .errors import (
11
11
ConnectionConfirmationTimeoutError ,
12
12
FailedAllConnectAttemptsError ,
13
13
UnsupportedProtocolVersionError ,
14
14
)
15
15
from stompman .frames import (
16
16
AbortFrame ,
17
+ AckFrame ,
17
18
BeginFrame ,
18
19
CommitFrame ,
19
20
ConnectedFrame ,
22
23
ErrorFrame ,
23
24
HeartbeatFrame ,
24
25
MessageFrame ,
26
+ NackFrame ,
25
27
ReceiptFrame ,
26
28
SendFrame ,
27
29
SendHeaders ,
28
30
SubscribeFrame ,
29
31
UnsubscribeFrame ,
30
32
)
31
- from stompman .listening_events import AnyListeningEvent , ErrorEvent , HeartbeatEvent , MessageEvent
32
- from stompman .protocol import PROTOCOL_VERSION
33
33
34
34
35
35
class Heartbeat (NamedTuple ):
@@ -45,6 +45,56 @@ def from_header(cls, header: str) -> Self:
45
45
return cls (int (first ), int (second ))
46
46
47
47
48
+ class MultiHostHostLike (TypedDict ):
49
+ username : str | None
50
+ password : str | None
51
+ host : str | None
52
+ port : int | None
53
+
54
+
55
+ @dataclass
56
+ class ConnectionParameters :
57
+ host : str
58
+ port : int
59
+ login : str
60
+ passcode : str = field (repr = False )
61
+
62
+ @classmethod
63
+ def from_pydantic_multihost_hosts (cls , hosts : list [MultiHostHostLike ]) -> list [Self ]:
64
+ """Create connection parameters from a list of `MultiHostUrl` objects.
65
+
66
+ .. code-block:: python
67
+ import stompman.
68
+
69
+ ArtemisDsn = typing.Annotated[
70
+ pydantic_core.MultiHostUrl,
71
+ pydantic.UrlConstraints(
72
+ host_required=True,
73
+ allowed_schemes=["tcp"],
74
+ ),
75
+ ]
76
+
77
+ async with stompman.Client(
78
+ servers=stompman.ConnectionParameters.from_pydantic_multihost_hosts(
79
+ ArtemisDsn("tcp://lev:pass@host1:61616,lev:pass@host1:61617,lev:pass@host2:61616").hosts()
80
+ ),
81
+ ):
82
+ ...
83
+ """
84
+ servers : list [Self ] = []
85
+ for host in hosts :
86
+ if host ["host" ] is None :
87
+ raise ValueError ("host must be set" )
88
+ if host ["port" ] is None :
89
+ raise ValueError ("port must be set" )
90
+ if host ["username" ] is None :
91
+ raise ValueError ("username must be set" )
92
+ if host ["password" ] is None :
93
+ raise ValueError ("password must be set" )
94
+ servers .append (cls (host = host ["host" ], port = host ["port" ], login = host ["username" ], passcode = host ["password" ]))
95
+ return servers
96
+
97
+
48
98
@dataclass
49
99
class Client :
50
100
servers : list [ConnectionParameters ]
@@ -57,11 +107,14 @@ class Client:
57
107
read_max_chunk_size : int = 1024 * 1024
58
108
connection_class : type [AbstractConnection ] = Connection
59
109
110
+ PROTOCOL_VERSION : ClassVar = "1.2" # https://stomp.github.io/stomp-specification-1.2.html
111
+
60
112
_connection : AbstractConnection = field (init = False )
113
+ _connection_parameters : ConnectionParameters = field (init = False )
61
114
_exit_stack : AsyncExitStack = field (default_factory = AsyncExitStack , init = False )
62
115
63
116
async def __aenter__ (self ) -> Self :
64
- self . _connection = await self ._connect_to_any_server ()
117
+ await self ._connect_to_any_server ()
65
118
await self ._exit_stack .enter_async_context (self ._connection_lifespan ())
66
119
return self
67
120
@@ -71,26 +124,25 @@ async def __aexit__(
71
124
await self ._exit_stack .aclose ()
72
125
await self ._connection .close ()
73
126
74
- async def _connect_to_one_server (self , server : ConnectionParameters ) -> AbstractConnection | None :
127
+ async def _connect_to_one_server (
128
+ self , server : ConnectionParameters
129
+ ) -> tuple [AbstractConnection , ConnectionParameters ] | None :
75
130
for attempt in range (self .connect_retry_attempts ):
76
- connection = self .connection_class (
77
- connection_parameters = server ,
78
- connect_timeout = self .connect_timeout ,
79
- read_timeout = self .read_timeout ,
80
- read_max_chunk_size = self .read_max_chunk_size ,
81
- )
82
- if await connection .connect ():
83
- return connection
131
+ if connection := await self .connection_class .connect (
132
+ host = server .host , port = server .port , timeout = self .connect_timeout
133
+ ):
134
+ return connection , server
84
135
await asyncio .sleep (self .connect_retry_interval * (attempt + 1 ))
85
136
return None
86
137
87
- async def _connect_to_any_server (self ) -> AbstractConnection :
138
+ async def _connect_to_any_server (self ) -> None :
88
139
for maybe_connection_future in asyncio .as_completed (
89
140
[self ._connect_to_one_server (server ) for server in self .servers ]
90
141
):
91
- maybe_connection = await maybe_connection_future
92
- if maybe_connection :
93
- return maybe_connection
142
+ maybe_result = await maybe_connection_future
143
+ if maybe_result :
144
+ self ._connection , self ._connection_parameters = maybe_result
145
+ return
94
146
raise FailedAllConnectAttemptsError (
95
147
servers = self .servers ,
96
148
retry_attempts = self .connect_retry_attempts ,
@@ -112,24 +164,27 @@ async def _connection_lifespan(self) -> AsyncGenerator[None, None]:
112
164
await self ._connection .write_frame (
113
165
ConnectFrame (
114
166
headers = {
115
- "accept-version" : PROTOCOL_VERSION ,
167
+ "accept-version" : self . PROTOCOL_VERSION ,
116
168
"heart-beat" : self .heartbeat .to_header (),
117
- "host" : self ._connection . connection_parameters .host ,
118
- "login" : self ._connection . connection_parameters .login ,
119
- "passcode" : self ._connection . connection_parameters .passcode ,
169
+ "host" : self ._connection_parameters .host ,
170
+ "login" : self ._connection_parameters .login ,
171
+ "passcode" : self ._connection_parameters .passcode ,
120
172
},
121
173
)
122
174
)
123
175
try :
124
176
connected_frame = await asyncio .wait_for (
125
- self ._connection .read_frame_of_type (ConnectedFrame ), timeout = self .connection_confirmation_timeout
177
+ self ._connection .read_frame_of_type (
178
+ ConnectedFrame , max_chunk_size = self .read_max_chunk_size , timeout = self .read_timeout
179
+ ),
180
+ timeout = self .connection_confirmation_timeout ,
126
181
)
127
182
except TimeoutError as exception :
128
183
raise ConnectionConfirmationTimeoutError (self .connection_confirmation_timeout ) from exception
129
184
130
- if connected_frame .headers ["version" ] != PROTOCOL_VERSION :
185
+ if connected_frame .headers ["version" ] != self . PROTOCOL_VERSION :
131
186
raise UnsupportedProtocolVersionError (
132
- given_version = connected_frame .headers ["version" ], supported_version = PROTOCOL_VERSION
187
+ given_version = connected_frame .headers ["version" ], supported_version = self . PROTOCOL_VERSION
133
188
)
134
189
135
190
server_heartbeat = Heartbeat .from_header (connected_frame .headers ["heart-beat" ])
@@ -150,30 +205,9 @@ async def send_heartbeats_forever() -> None:
150
205
task .cancel ()
151
206
152
207
await self ._connection .write_frame (DisconnectFrame (headers = {"receipt" : str (uuid4 ())}))
153
- await self ._connection .read_frame_of_type (ReceiptFrame )
154
-
155
- @asynccontextmanager
156
- async def subscribe (self , destination : str ) -> AsyncGenerator [None , None ]:
157
- subscription_id = str (uuid4 ())
158
- await self ._connection .write_frame (
159
- SubscribeFrame (headers = {"id" : subscription_id , "destination" : destination , "ack" : "client-individual" })
208
+ await self ._connection .read_frame_of_type (
209
+ ReceiptFrame , max_chunk_size = self .read_max_chunk_size , timeout = self .read_timeout
160
210
)
161
- try :
162
- yield
163
- finally :
164
- await self ._connection .write_frame (UnsubscribeFrame (headers = {"id" : subscription_id }))
165
-
166
- async def listen (self ) -> AsyncIterator [AnyListeningEvent ]:
167
- async for frame in self ._connection .read_frames ():
168
- match frame :
169
- case MessageFrame ():
170
- yield MessageEvent (_client = self , _frame = frame )
171
- case ErrorFrame ():
172
- yield ErrorEvent (_client = self , _frame = frame )
173
- case HeartbeatFrame ():
174
- yield HeartbeatEvent (_client = self , _frame = frame )
175
- case ConnectedFrame () | ReceiptFrame ():
176
- raise AssertionError ("Should be unreachable! Report the issue." , frame )
177
211
178
212
@asynccontextmanager
179
213
async def enter_transaction (self ) -> AsyncGenerator [str , None ]:
@@ -204,3 +238,93 @@ async def send( # noqa: PLR0913
204
238
if transaction is not None :
205
239
full_headers ["transaction" ] = transaction
206
240
await self ._connection .write_frame (SendFrame (headers = full_headers , body = body ))
241
+
242
+ @asynccontextmanager
243
+ async def subscribe (self , destination : str ) -> AsyncGenerator [None , None ]:
244
+ subscription_id = str (uuid4 ())
245
+ await self ._connection .write_frame (
246
+ SubscribeFrame (headers = {"id" : subscription_id , "destination" : destination , "ack" : "client-individual" })
247
+ )
248
+ try :
249
+ yield
250
+ finally :
251
+ await self ._connection .write_frame (UnsubscribeFrame (headers = {"id" : subscription_id }))
252
+
253
+ async def listen (self ) -> AsyncIterator ["AnyListeningEvent" ]:
254
+ async for frame in self ._connection .read_frames (
255
+ max_chunk_size = self .read_max_chunk_size , timeout = self .read_timeout
256
+ ):
257
+ match frame :
258
+ case MessageFrame ():
259
+ yield MessageEvent (_client = self , _frame = frame )
260
+ case ErrorFrame ():
261
+ yield ErrorEvent (_client = self , _frame = frame )
262
+ case HeartbeatFrame ():
263
+ yield HeartbeatEvent (_client = self , _frame = frame )
264
+ case ConnectedFrame () | ReceiptFrame ():
265
+ raise AssertionError ("Should be unreachable! Report the issue." , frame )
266
+
267
+
268
+ @dataclass
269
+ class MessageEvent :
270
+ body : bytes = field (init = False )
271
+ _frame : MessageFrame
272
+ _client : "Client" = field (repr = False )
273
+
274
+ def __post_init__ (self ) -> None :
275
+ self .body = self ._frame .body
276
+
277
+ async def ack (self ) -> None :
278
+ await self ._client ._connection .write_frame (
279
+ AckFrame (
280
+ headers = {"id" : self ._frame .headers ["message-id" ], "subscription" : self ._frame .headers ["subscription" ]},
281
+ )
282
+ )
283
+
284
+ async def nack (self ) -> None :
285
+ await self ._client ._connection .write_frame (
286
+ NackFrame (
287
+ headers = {"id" : self ._frame .headers ["message-id" ], "subscription" : self ._frame .headers ["subscription" ]}
288
+ )
289
+ )
290
+
291
+ async def with_auto_ack (
292
+ self ,
293
+ awaitable : Awaitable [None ],
294
+ * ,
295
+ on_suppressed_exception : Callable [[Exception , Self ], Any ],
296
+ supressed_exception_classes : tuple [type [Exception ], ...] = (Exception ,),
297
+ ) -> None :
298
+ called_nack = False
299
+ try :
300
+ await awaitable
301
+ except supressed_exception_classes as exception :
302
+ await self .nack ()
303
+ called_nack = True
304
+ on_suppressed_exception (exception , self )
305
+ finally :
306
+ if not called_nack :
307
+ await self .ack ()
308
+
309
+
310
+ @dataclass
311
+ class ErrorEvent :
312
+ message_header : str = field (init = False )
313
+ """Short description of the error."""
314
+ body : bytes = field (init = False )
315
+ """Long description of the error."""
316
+ _frame : ErrorFrame
317
+ _client : "Client" = field (repr = False )
318
+
319
+ def __post_init__ (self ) -> None :
320
+ self .message_header = self ._frame .headers ["message" ]
321
+ self .body = self ._frame .body
322
+
323
+
324
+ @dataclass
325
+ class HeartbeatEvent :
326
+ _frame : HeartbeatFrame
327
+ _client : "Client" = field (repr = False )
328
+
329
+
330
+ AnyListeningEvent = MessageEvent | ErrorEvent | HeartbeatEvent
0 commit comments