1
1
import asyncio
2
- from collections .abc import AsyncIterable , Awaitable , Callable
3
- from dataclasses import dataclass , field
4
- from typing import Any , Protocol , TypeVar
2
+ from collections .abc import Callable
3
+ from contextlib import suppress
4
+ from dataclasses import dataclass
5
+ from typing import Protocol
5
6
from uuid import uuid4
6
7
7
8
from stompman .config import ConnectionParameters , Heartbeat
8
9
from stompman .connection import AbstractConnection
9
10
from stompman .errors import ConnectionConfirmationTimeout , StompProtocolConnectionIssue , UnsupportedProtocolVersion
10
11
from stompman .frames import (
11
- AnyServerFrame ,
12
12
ConnectedFrame ,
13
13
ConnectFrame ,
14
14
DisconnectFrame ,
21
21
)
22
22
from stompman .transaction import ActiveTransactions , commit_pending_transactions
23
23
24
- FrameType = TypeVar ("FrameType" , bound = AnyServerFrame )
25
- WaitForFutureReturnType = TypeVar ("WaitForFutureReturnType" )
26
-
27
-
28
- async def wait_for_or_none (
29
- awaitable : Awaitable [WaitForFutureReturnType ], timeout : float
30
- ) -> WaitForFutureReturnType | None :
31
- try :
32
- return await asyncio .wait_for (awaitable , timeout = timeout )
33
- except TimeoutError :
34
- return None
35
-
36
-
37
- WaitForOrNone = Callable [[Awaitable [WaitForFutureReturnType ], float ], Awaitable [WaitForFutureReturnType | None ]]
38
-
39
-
40
- async def take_frame_of_type (
41
- * ,
42
- frame_type : type [FrameType ],
43
- frames_iter : AsyncIterable [AnyServerFrame ],
44
- timeout : int ,
45
- wait_for_or_none : WaitForOrNone [FrameType ],
46
- ) -> FrameType | list [Any ]:
47
- collected_frames = []
48
-
49
- async def inner () -> FrameType :
50
- async for frame in frames_iter :
51
- if isinstance (frame , frame_type ):
52
- return frame
53
- collected_frames .append (frame )
54
- msg = "unreachable"
55
- raise AssertionError (msg )
56
-
57
- return await wait_for_or_none (inner (), timeout ) or collected_frames
58
-
59
-
60
- def check_stomp_protocol_version (
61
- * , connected_frame : ConnectedFrame , supported_version : str
62
- ) -> UnsupportedProtocolVersion | None :
63
- if connected_frame .headers ["version" ] == supported_version :
64
- return None
65
- return UnsupportedProtocolVersion (
66
- given_version = connected_frame .headers ["version" ], supported_version = supported_version
67
- )
68
-
69
-
70
- def calculate_heartbeat_interval (* , connected_frame : ConnectedFrame , client_heartbeat : Heartbeat ) -> float :
71
- server_heartbeat = Heartbeat .from_header (connected_frame .headers ["heart-beat" ])
72
- return max (client_heartbeat .will_send_interval_ms , server_heartbeat .want_to_receive_interval_ms ) / 1000
73
-
74
24
75
25
class AbstractConnectionLifespan (Protocol ):
76
26
async def enter (self ) -> StompProtocolConnectionIssue | None : ...
@@ -88,7 +38,6 @@ class ConnectionLifespan(AbstractConnectionLifespan):
88
38
active_subscriptions : ActiveSubscriptions
89
39
active_transactions : ActiveTransactions
90
40
set_heartbeat_interval : Callable [[float ], None ]
91
- _generate_receipt_id : Callable [[], str ] = field (default = lambda : _make_receipt_id ()) # noqa: PLW0108
92
41
93
42
async def _establish_connection (self ) -> StompProtocolConnectionIssue | None :
94
43
await self .connection .write_frame (
@@ -102,54 +51,61 @@ async def _establish_connection(self) -> StompProtocolConnectionIssue | None:
102
51
},
103
52
)
104
53
)
105
- connected_frame_or_collected_frames = await take_frame_of_type (
106
- frame_type = ConnectedFrame ,
107
- frames_iter = self .connection .read_frames (),
108
- timeout = self .connection_confirmation_timeout ,
109
- wait_for_or_none = wait_for_or_none ,
110
- )
111
- if not isinstance (connected_frame_or_collected_frames , ConnectedFrame ):
112
- return ConnectionConfirmationTimeout (
113
- timeout = self .connection_confirmation_timeout , frames = connected_frame_or_collected_frames
54
+ collected_frames = []
55
+
56
+ async def take_connected_frame_and_collect_other_frames () -> ConnectedFrame :
57
+ async for frame in self .connection .read_frames ():
58
+ if isinstance (frame , ConnectedFrame ):
59
+ return frame
60
+ collected_frames .append (frame )
61
+ msg = "unreachable" # pragma: no cover
62
+ raise AssertionError (msg ) # pragma: no cover
63
+
64
+ try :
65
+ connected_frame = await asyncio .wait_for (
66
+ take_connected_frame_and_collect_other_frames (), timeout = self .connection_confirmation_timeout
114
67
)
115
- connected_frame = connected_frame_or_collected_frames
68
+ except TimeoutError :
69
+ return ConnectionConfirmationTimeout (timeout = self .connection_confirmation_timeout , frames = collected_frames )
116
70
117
- if unsupported_protocol_version_error := check_stomp_protocol_version (
118
- connected_frame = connected_frame , supported_version = self . protocol_version
119
- ):
120
- return unsupported_protocol_version_error
71
+ if connected_frame . headers [ "version" ] != self . protocol_version :
72
+ return UnsupportedProtocolVersion (
73
+ given_version = connected_frame . headers [ "version" ], supported_version = self . protocol_version
74
+ )
121
75
76
+ server_heartbeat = Heartbeat .from_header (connected_frame .headers ["heart-beat" ])
122
77
self .set_heartbeat_interval (
123
- calculate_heartbeat_interval ( connected_frame = connected_frame , client_heartbeat = self .client_heartbeat )
78
+ max ( self .client_heartbeat . will_send_interval_ms , server_heartbeat . want_to_receive_interval_ms ) / 1000
124
79
)
125
80
return None
126
81
127
82
async def enter (self ) -> StompProtocolConnectionIssue | None :
128
- if protocol_connection_issue := await self ._establish_connection ():
129
- return protocol_connection_issue
130
-
83
+ if connection_issue := await self ._establish_connection ():
84
+ return connection_issue
131
85
await resubscribe_to_active_subscriptions (
132
86
connection = self .connection , active_subscriptions = self .active_subscriptions
133
87
)
134
88
await commit_pending_transactions (connection = self .connection , active_transactions = self .active_transactions )
135
89
return None
136
90
91
+ async def _take_receipt_frame (self ) -> None :
92
+ async for frame in self .connection .read_frames ():
93
+ if isinstance (frame , ReceiptFrame ):
94
+ break
95
+
137
96
async def exit (self ) -> None :
138
97
await unsubscribe_from_all_active_subscriptions (active_subscriptions = self .active_subscriptions )
139
- await self .connection .write_frame (DisconnectFrame (headers = {"receipt" : self ._generate_receipt_id ()}))
140
- await take_frame_of_type (
141
- frame_type = ReceiptFrame ,
142
- frames_iter = self .connection .read_frames (),
143
- timeout = self .disconnect_confirmation_timeout ,
144
- wait_for_or_none = wait_for_or_none ,
145
- )
98
+ await self .connection .write_frame (DisconnectFrame (headers = {"receipt" : _make_receipt_id ()}))
99
+
100
+ with suppress (TimeoutError ):
101
+ await asyncio .wait_for (self ._take_receipt_frame (), timeout = self .disconnect_confirmation_timeout )
102
+
103
+
104
+ def _make_receipt_id () -> str :
105
+ return str (uuid4 ())
146
106
147
107
148
108
class ConnectionLifespanFactory (Protocol ):
149
109
def __call__ (
150
110
self , * , connection : AbstractConnection , connection_parameters : ConnectionParameters
151
111
) -> AbstractConnectionLifespan : ...
152
-
153
-
154
- def _make_receipt_id () -> str :
155
- return str (uuid4 ())
0 commit comments