|
1 | 1 | import asyncio
|
2 |
| -from collections.abc import Callable |
| 2 | +from collections.abc import AsyncIterable, Callable |
3 | 3 | from contextlib import suppress
|
4 | 4 | from dataclasses import dataclass
|
5 | 5 | from typing import Protocol
|
|
9 | 9 | from stompman.connection import AbstractConnection
|
10 | 10 | from stompman.errors import ConnectionConfirmationTimeout, StompProtocolConnectionIssue, UnsupportedProtocolVersion
|
11 | 11 | from stompman.frames import (
|
| 12 | + AnyServerFrame, |
12 | 13 | ConnectedFrame,
|
13 | 14 | ConnectFrame,
|
14 | 15 | DisconnectFrame,
|
|
22 | 23 | from stompman.transaction import ActiveTransactions, commit_pending_transactions
|
23 | 24 |
|
24 | 25 |
|
| 26 | +async def take_connected_frame( |
| 27 | + *, frames_iter: AsyncIterable[AnyServerFrame], connection_confirmation_timeout: int |
| 28 | +) -> ConnectedFrame | ConnectionConfirmationTimeout: |
| 29 | + collected_frames = [] |
| 30 | + |
| 31 | + async def take_connected_frame_and_collect_other_frames() -> ConnectedFrame: |
| 32 | + async for frame in frames_iter: |
| 33 | + if isinstance(frame, ConnectedFrame): |
| 34 | + return frame |
| 35 | + collected_frames.append(frame) |
| 36 | + msg = "unreachable" |
| 37 | + raise AssertionError(msg) |
| 38 | + |
| 39 | + try: |
| 40 | + return await asyncio.wait_for( |
| 41 | + take_connected_frame_and_collect_other_frames(), timeout=connection_confirmation_timeout |
| 42 | + ) |
| 43 | + except TimeoutError: |
| 44 | + return ConnectionConfirmationTimeout(timeout=connection_confirmation_timeout, frames=collected_frames) |
| 45 | + |
| 46 | + |
| 47 | +def check_stomp_protocol_version( |
| 48 | + *, connected_frame: ConnectedFrame, supported_version: str |
| 49 | +) -> UnsupportedProtocolVersion | None: |
| 50 | + if connected_frame.headers["version"] == supported_version: |
| 51 | + return None |
| 52 | + return UnsupportedProtocolVersion( |
| 53 | + given_version=connected_frame.headers["version"], supported_version=supported_version |
| 54 | + ) |
| 55 | + |
| 56 | + |
| 57 | +def calculate_heartbeat_interval(*, connected_frame: ConnectedFrame, client_heartbeat: Heartbeat) -> float: |
| 58 | + server_heartbeat = Heartbeat.from_header(connected_frame.headers["heart-beat"]) |
| 59 | + return max(client_heartbeat.will_send_interval_ms, server_heartbeat.want_to_receive_interval_ms) / 1000 |
| 60 | + |
| 61 | + |
| 62 | +async def wait_for_receipt_frame( |
| 63 | + *, frames_iter: AsyncIterable[AnyServerFrame], disconnect_confirmation_timeout: int |
| 64 | +) -> None: |
| 65 | + async def inner() -> None: |
| 66 | + async for frame in frames_iter: |
| 67 | + if isinstance(frame, ReceiptFrame): |
| 68 | + break |
| 69 | + |
| 70 | + with suppress(TimeoutError): |
| 71 | + await asyncio.wait_for(inner(), timeout=disconnect_confirmation_timeout) |
| 72 | + |
| 73 | + |
25 | 74 | class AbstractConnectionLifespan(Protocol):
|
26 | 75 | async def enter(self) -> StompProtocolConnectionIssue | None: ...
|
27 | 76 | async def exit(self) -> None: ...
|
@@ -51,61 +100,48 @@ async def _establish_connection(self) -> StompProtocolConnectionIssue | None:
|
51 | 100 | },
|
52 | 101 | )
|
53 | 102 | )
|
54 |
| - collected_frames = [] |
55 |
| - |
56 |
| - async def take_connected_frame_and_collect_other_frames() -> ConnectedFrame: |
57 |
| - async for frame in self.connection.read_frames(): |
58 |
| - if isinstance(frame, ConnectedFrame): |
59 |
| - return frame |
60 |
| - collected_frames.append(frame) |
61 |
| - msg = "unreachable" # pragma: no cover |
62 |
| - raise AssertionError(msg) # pragma: no cover |
63 |
| - |
64 |
| - try: |
65 |
| - connected_frame = await asyncio.wait_for( |
66 |
| - take_connected_frame_and_collect_other_frames(), timeout=self.connection_confirmation_timeout |
67 |
| - ) |
68 |
| - except TimeoutError: |
69 |
| - return ConnectionConfirmationTimeout(timeout=self.connection_confirmation_timeout, frames=collected_frames) |
| 103 | + connected_frame_or_error = await take_connected_frame( |
| 104 | + frames_iter=self.connection.read_frames(), |
| 105 | + connection_confirmation_timeout=self.connection_confirmation_timeout, |
| 106 | + ) |
| 107 | + if isinstance(connected_frame_or_error, ConnectionConfirmationTimeout): |
| 108 | + return connected_frame_or_error |
| 109 | + connected_frame = connected_frame_or_error |
70 | 110 |
|
71 |
| - if connected_frame.headers["version"] != self.protocol_version: |
72 |
| - return UnsupportedProtocolVersion( |
73 |
| - given_version=connected_frame.headers["version"], supported_version=self.protocol_version |
74 |
| - ) |
| 111 | + if unsupported_protocol_version_error := check_stomp_protocol_version( |
| 112 | + connected_frame=connected_frame, supported_version=self.protocol_version |
| 113 | + ): |
| 114 | + return unsupported_protocol_version_error |
75 | 115 |
|
76 |
| - server_heartbeat = Heartbeat.from_header(connected_frame.headers["heart-beat"]) |
77 | 116 | self.set_heartbeat_interval(
|
78 |
| - max(self.client_heartbeat.will_send_interval_ms, server_heartbeat.want_to_receive_interval_ms) / 1000 |
| 117 | + calculate_heartbeat_interval(connected_frame=connected_frame, client_heartbeat=self.client_heartbeat) |
79 | 118 | )
|
80 | 119 | return None
|
81 | 120 |
|
82 | 121 | async def enter(self) -> StompProtocolConnectionIssue | None:
|
83 |
| - if connection_issue := await self._establish_connection(): |
84 |
| - return connection_issue |
| 122 | + if protocol_connection_issue := await self._establish_connection(): |
| 123 | + return protocol_connection_issue |
| 124 | + |
85 | 125 | await resubscribe_to_active_subscriptions(
|
86 | 126 | connection=self.connection, active_subscriptions=self.active_subscriptions
|
87 | 127 | )
|
88 | 128 | await commit_pending_transactions(connection=self.connection, active_transactions=self.active_transactions)
|
89 | 129 | return None
|
90 | 130 |
|
91 |
| - async def _take_receipt_frame(self) -> None: |
92 |
| - async for frame in self.connection.read_frames(): |
93 |
| - if isinstance(frame, ReceiptFrame): |
94 |
| - break |
95 |
| - |
96 | 131 | async def exit(self) -> None:
|
97 | 132 | await unsubscribe_from_all_active_subscriptions(active_subscriptions=self.active_subscriptions)
|
98 | 133 | await self.connection.write_frame(DisconnectFrame(headers={"receipt": _make_receipt_id()}))
|
99 |
| - |
100 |
| - with suppress(TimeoutError): |
101 |
| - await asyncio.wait_for(self._take_receipt_frame(), timeout=self.disconnect_confirmation_timeout) |
102 |
| - |
103 |
| - |
104 |
| -def _make_receipt_id() -> str: |
105 |
| - return str(uuid4()) |
| 134 | + await wait_for_receipt_frame( |
| 135 | + frames_iter=self.connection.read_frames(), |
| 136 | + disconnect_confirmation_timeout=self.disconnect_confirmation_timeout, |
| 137 | + ) |
106 | 138 |
|
107 | 139 |
|
108 | 140 | class ConnectionLifespanFactory(Protocol):
|
109 | 141 | def __call__(
|
110 | 142 | self, *, connection: AbstractConnection, connection_parameters: ConnectionParameters
|
111 | 143 | ) -> AbstractConnectionLifespan: ...
|
| 144 | + |
| 145 | + |
| 146 | +def _make_receipt_id() -> str: |
| 147 | + return str(uuid4()) |
0 commit comments