|
1 | 1 | import asyncio
|
2 |
| -from collections.abc import AsyncGenerator, AsyncIterator, Callable, Coroutine |
| 2 | +from collections.abc import AsyncGenerator, Callable, Coroutine |
3 | 3 | from contextlib import AsyncExitStack, asynccontextmanager, suppress
|
4 | 4 | from dataclasses import dataclass, field
|
| 5 | +from functools import partial |
5 | 6 | from types import TracebackType
|
6 | 7 | from typing import ClassVar, Self
|
7 | 8 | from uuid import uuid4
|
8 | 9 |
|
9 | 10 | from stompman.config import ConnectionParameters, Heartbeat
|
10 | 11 | from stompman.connection import AbstractConnection, Connection
|
11 |
| -from stompman.connection_manager import ConnectionManager |
12 |
| -from stompman.errors import ConnectionConfirmationTimeoutError, UnsupportedProtocolVersionError |
| 12 | +from stompman.connection_manager import AbstractConnectionLifespan, ConnectionManager |
| 13 | +from stompman.errors import ConnectionConfirmationTimeout, StompProtocolConnectionIssue, UnsupportedProtocolVersion |
13 | 14 | from stompman.frames import (
|
14 | 15 | AbortFrame,
|
15 | 16 | AckFrame,
|
|
30 | 31 | )
|
31 | 32 |
|
32 | 33 |
|
33 |
| -@asynccontextmanager |
34 |
| -async def connection_lifespan( |
35 |
| - *, |
36 |
| - connection: AbstractConnection, |
37 |
| - connection_parameters: ConnectionParameters, |
38 |
| - protocol_version: str, |
39 |
| - client_heartbeat: Heartbeat, |
40 |
| - connection_confirmation_timeout: int, |
41 |
| - disconnect_confirmation_timeout: int, |
42 |
| -) -> AsyncIterator[float]: |
43 |
| - await connection.write_frame( |
44 |
| - ConnectFrame( |
45 |
| - headers={ |
46 |
| - "accept-version": protocol_version, |
47 |
| - "heart-beat": client_heartbeat.to_header(), |
48 |
| - "host": connection_parameters.host, |
49 |
| - "login": connection_parameters.login, |
50 |
| - "passcode": connection_parameters.unescaped_passcode, |
51 |
| - }, |
52 |
| - ) |
53 |
| - ) |
54 |
| - collected_frames = [] |
55 |
| - |
56 |
| - async def take_connected_frame_and_collect_other_frames() -> ConnectedFrame: |
57 |
| - async for frame in connection.read_frames(): |
58 |
| - if isinstance(frame, ConnectedFrame): |
59 |
| - return frame |
60 |
| - collected_frames.append(frame) |
61 |
| - msg = "unreachable" # pragma: no cover |
62 |
| - raise AssertionError(msg) # pragma: no cover |
63 |
| - |
64 |
| - try: |
65 |
| - connected_frame = await asyncio.wait_for( |
66 |
| - take_connected_frame_and_collect_other_frames(), timeout=connection_confirmation_timeout |
| 34 | +@dataclass(kw_only=True, slots=True) |
| 35 | +class ConnectionLifespan(AbstractConnectionLifespan): |
| 36 | + connection: AbstractConnection |
| 37 | + connection_parameters: ConnectionParameters |
| 38 | + protocol_version: str |
| 39 | + client_heartbeat: Heartbeat |
| 40 | + connection_confirmation_timeout: int |
| 41 | + disconnect_confirmation_timeout: int |
| 42 | + active_subscriptions: dict[str, "Subscription"] |
| 43 | + active_transactions: set["Transaction"] |
| 44 | + set_heartbeat_interval: Callable[[float], None] |
| 45 | + |
| 46 | + async def _establish_connection(self) -> StompProtocolConnectionIssue | None: |
| 47 | + await self.connection.write_frame( |
| 48 | + ConnectFrame( |
| 49 | + headers={ |
| 50 | + "accept-version": self.protocol_version, |
| 51 | + "heart-beat": self.client_heartbeat.to_header(), |
| 52 | + "host": self.connection_parameters.host, |
| 53 | + "login": self.connection_parameters.login, |
| 54 | + "passcode": self.connection_parameters.unescaped_passcode, |
| 55 | + }, |
| 56 | + ) |
67 | 57 | )
|
68 |
| - except TimeoutError as exception: |
69 |
| - raise ConnectionConfirmationTimeoutError( |
70 |
| - timeout=connection_confirmation_timeout, frames=collected_frames |
71 |
| - ) from exception |
72 |
| - |
73 |
| - if connected_frame.headers["version"] != protocol_version: |
74 |
| - raise UnsupportedProtocolVersionError( |
75 |
| - given_version=connected_frame.headers["version"], supported_version=protocol_version |
| 58 | + collected_frames = [] |
| 59 | + |
| 60 | + async def take_connected_frame_and_collect_other_frames() -> ConnectedFrame: |
| 61 | + async for frame in self.connection.read_frames(): |
| 62 | + if isinstance(frame, ConnectedFrame): |
| 63 | + return frame |
| 64 | + collected_frames.append(frame) |
| 65 | + msg = "unreachable" # pragma: no cover |
| 66 | + raise AssertionError(msg) # pragma: no cover |
| 67 | + |
| 68 | + try: |
| 69 | + connected_frame = await asyncio.wait_for( |
| 70 | + take_connected_frame_and_collect_other_frames(), timeout=self.connection_confirmation_timeout |
| 71 | + ) |
| 72 | + except TimeoutError: |
| 73 | + return ConnectionConfirmationTimeout(timeout=self.connection_confirmation_timeout, frames=collected_frames) |
| 74 | + |
| 75 | + if connected_frame.headers["version"] != self.protocol_version: |
| 76 | + return UnsupportedProtocolVersion( |
| 77 | + given_version=connected_frame.headers["version"], supported_version=self.protocol_version |
| 78 | + ) |
| 79 | + |
| 80 | + server_heartbeat = Heartbeat.from_header(connected_frame.headers["heart-beat"]) |
| 81 | + self.set_heartbeat_interval( |
| 82 | + max(self.client_heartbeat.will_send_interval_ms, server_heartbeat.want_to_receive_interval_ms) / 1000 |
76 | 83 | )
|
| 84 | + return None |
77 | 85 |
|
78 |
| - server_heartbeat = Heartbeat.from_header(connected_frame.headers["heart-beat"]) |
79 |
| - heartbeat_interval = ( |
80 |
| - max(client_heartbeat.will_send_interval_ms, server_heartbeat.want_to_receive_interval_ms) / 1000 |
81 |
| - ) |
82 |
| - yield heartbeat_interval |
| 86 | + async def _resubscribe(self) -> None: |
| 87 | + for subscription in self.active_subscriptions.values(): |
| 88 | + await self.connection.write_frame( |
| 89 | + SubscribeFrame( |
| 90 | + headers={"id": subscription.id, "destination": subscription.destination, "ack": subscription.ack} |
| 91 | + ) |
| 92 | + ) |
83 | 93 |
|
84 |
| - await connection.write_frame(DisconnectFrame(headers={"receipt": _make_receipt_id()})) |
| 94 | + async def _commit_pending_transactions(self) -> None: |
| 95 | + for transaction in self.active_transactions: |
| 96 | + for frame in transaction.sent_frames: |
| 97 | + await self.connection.write_frame(frame) |
| 98 | + await self.connection.write_frame(CommitFrame(headers={"transaction": transaction.id})) |
| 99 | + self.active_transactions.clear() |
85 | 100 |
|
86 |
| - async def take_receipt_frame() -> None: |
87 |
| - async for frame in connection.read_frames(): |
88 |
| - if isinstance(frame, ReceiptFrame): |
89 |
| - break |
| 101 | + async def enter(self) -> StompProtocolConnectionIssue | None: |
| 102 | + if connection_issue := await self._establish_connection(): |
| 103 | + return connection_issue |
| 104 | + await self._resubscribe() |
| 105 | + await self._commit_pending_transactions() |
| 106 | + return None |
90 | 107 |
|
91 |
| - with suppress(TimeoutError): |
92 |
| - await asyncio.wait_for(take_receipt_frame(), timeout=disconnect_confirmation_timeout) |
| 108 | + async def exit(self) -> None: |
| 109 | + for subscription in self.active_subscriptions.copy().values(): |
| 110 | + await subscription.unsubscribe() |
| 111 | + |
| 112 | + await self.connection.write_frame(DisconnectFrame(headers={"receipt": _make_receipt_id()})) |
| 113 | + |
| 114 | + async def take_receipt_frame() -> None: |
| 115 | + async for frame in self.connection.read_frames(): |
| 116 | + if isinstance(frame, ReceiptFrame): |
| 117 | + break |
| 118 | + |
| 119 | + with suppress(TimeoutError): |
| 120 | + await asyncio.wait_for(take_receipt_frame(), timeout=self.disconnect_confirmation_timeout) |
93 | 121 |
|
94 | 122 |
|
95 | 123 | def _make_receipt_id() -> str:
|
@@ -146,21 +174,6 @@ def _make_subscription_id() -> str:
|
146 | 174 | return str(uuid4())
|
147 | 175 |
|
148 | 176 |
|
149 |
| -@asynccontextmanager |
150 |
| -async def subscriptions_lifespan( |
151 |
| - *, connection: AbstractConnection, active_subscriptions: dict[str, Subscription] |
152 |
| -) -> AsyncIterator[None]: |
153 |
| - for subscription in active_subscriptions.values(): |
154 |
| - await connection.write_frame( |
155 |
| - SubscribeFrame( |
156 |
| - headers={"id": subscription.id, "destination": subscription.destination, "ack": subscription.ack} |
157 |
| - ) |
158 |
| - ) |
159 |
| - yield |
160 |
| - for subscription in active_subscriptions.copy().values(): |
161 |
| - await subscription.unsubscribe() |
162 |
| - |
163 |
| - |
164 | 177 | @dataclass(kw_only=True, slots=True, unsafe_hash=True)
|
165 | 178 | class Transaction:
|
166 | 179 | id: str = field(default_factory=lambda: _make_transaction_id(), init=False) # noqa: PLW0108
|
@@ -198,14 +211,6 @@ def _make_transaction_id() -> str:
|
198 | 211 | return str(uuid4())
|
199 | 212 |
|
200 | 213 |
|
201 |
| -async def commit_pending_transactions(*, connection: AbstractConnection, active_transactions: set[Transaction]) -> None: |
202 |
| - for transaction in active_transactions: |
203 |
| - for frame in transaction.sent_frames: |
204 |
| - await connection.write_frame(frame) |
205 |
| - await connection.write_frame(CommitFrame(headers={"transaction": transaction.id})) |
206 |
| - active_transactions.clear() |
207 |
| - |
208 |
| - |
209 | 214 | @dataclass(kw_only=True, slots=True)
|
210 | 215 | class Client:
|
211 | 216 | PROTOCOL_VERSION: ClassVar = "1.2" # https://stomp.github.io/stomp-specification-1.2.html
|
@@ -236,7 +241,16 @@ class Client:
|
236 | 241 | def __post_init__(self) -> None:
|
237 | 242 | self._connection_manager = ConnectionManager(
|
238 | 243 | servers=self.servers,
|
239 |
| - lifespan=self._lifespan, |
| 244 | + lifespan_factory=partial( |
| 245 | + ConnectionLifespan, |
| 246 | + protocol_version=self.PROTOCOL_VERSION, |
| 247 | + client_heartbeat=self.heartbeat, |
| 248 | + connection_confirmation_timeout=self.connection_confirmation_timeout, |
| 249 | + disconnect_confirmation_timeout=self.disconnect_confirmation_timeout, |
| 250 | + active_subscriptions=self._active_subscriptions, |
| 251 | + active_transactions=self._active_transactions, |
| 252 | + set_heartbeat_interval=self._restart_heartbeat_task, |
| 253 | + ), |
240 | 254 | connection_class=self.connection_class,
|
241 | 255 | connect_retry_attempts=self.connect_retry_attempts,
|
242 | 256 | connect_retry_interval=self.connect_retry_interval,
|
@@ -264,23 +278,6 @@ async def __aexit__(
|
264 | 278 | await asyncio.wait([self._listen_task, self._heartbeat_task])
|
265 | 279 | await self._exit_stack.aclose()
|
266 | 280 |
|
267 |
| - @asynccontextmanager |
268 |
| - async def _lifespan( |
269 |
| - self, connection: AbstractConnection, connection_parameters: ConnectionParameters |
270 |
| - ) -> AsyncGenerator[None, None]: |
271 |
| - async with connection_lifespan( |
272 |
| - connection=connection, |
273 |
| - connection_parameters=connection_parameters, |
274 |
| - protocol_version=self.PROTOCOL_VERSION, |
275 |
| - client_heartbeat=self.heartbeat, |
276 |
| - connection_confirmation_timeout=self.connection_confirmation_timeout, |
277 |
| - disconnect_confirmation_timeout=self.disconnect_confirmation_timeout, |
278 |
| - ) as heartbeat_interval: |
279 |
| - self._restart_heartbeat_task(heartbeat_interval) |
280 |
| - async with subscriptions_lifespan(connection=connection, active_subscriptions=self._active_subscriptions): |
281 |
| - await commit_pending_transactions(connection=connection, active_transactions=self._active_transactions) |
282 |
| - yield |
283 |
| - |
284 | 281 | def _restart_heartbeat_task(self, interval: float) -> None:
|
285 | 282 | self._heartbeat_task.cancel()
|
286 | 283 | self._heartbeat_task = self._task_group.create_task(self._send_heartbeats_forever(interval))
|
|
0 commit comments