1
1
import asyncio
2
- from collections .abc import AsyncIterable , Callable
3
- from contextlib import suppress
4
- from dataclasses import dataclass
5
- from typing import Protocol
2
+ from collections .abc import AsyncIterable , Awaitable , Callable
3
+ from dataclasses import dataclass , field
4
+ from typing import Any , Protocol , TypeVar
6
5
from uuid import uuid4
7
6
8
7
from stompman .config import ConnectionParameters , Heartbeat
22
21
)
23
22
from stompman .transaction import ActiveTransactions , commit_pending_transactions
24
23
24
+ FrameType = TypeVar ("FrameType" , bound = AnyServerFrame )
25
+ WaitForFutureReturnType = TypeVar ("WaitForFutureReturnType" )
25
26
26
- async def take_connected_frame (
27
- * , frames_iter : AsyncIterable [AnyServerFrame ], connection_confirmation_timeout : int
28
- ) -> ConnectedFrame | ConnectionConfirmationTimeout :
27
+
28
+ async def wait_for_or_none (
29
+ awaitable : Awaitable [WaitForFutureReturnType ], timeout : float
30
+ ) -> WaitForFutureReturnType | None :
31
+ try :
32
+ return await asyncio .wait_for (awaitable , timeout = timeout )
33
+ except TimeoutError :
34
+ return None
35
+
36
+
37
+ WaitForOrNone = Callable [[Awaitable [WaitForFutureReturnType ], float ], Awaitable [WaitForFutureReturnType | None ]]
38
+
39
+
40
+ async def take_frame_of_type (
41
+ * ,
42
+ frame_type : type [FrameType ],
43
+ frames_iter : AsyncIterable [AnyServerFrame ],
44
+ timeout : int ,
45
+ wait_for_or_none : WaitForOrNone [FrameType ],
46
+ ) -> FrameType | list [Any ]:
29
47
collected_frames = []
30
48
31
- async def take_connected_frame_and_collect_other_frames () -> ConnectedFrame :
49
+ async def inner () -> FrameType :
32
50
async for frame in frames_iter :
33
- if isinstance (frame , ConnectedFrame ):
51
+ if isinstance (frame , frame_type ):
34
52
return frame
35
53
collected_frames .append (frame )
36
54
msg = "unreachable"
37
55
raise AssertionError (msg )
38
56
39
- try :
40
- return await asyncio .wait_for (
41
- take_connected_frame_and_collect_other_frames (), timeout = connection_confirmation_timeout
42
- )
43
- except TimeoutError :
44
- return ConnectionConfirmationTimeout (timeout = connection_confirmation_timeout , frames = collected_frames )
57
+ return await wait_for_or_none (inner (), timeout ) or collected_frames
45
58
46
59
47
60
def check_stomp_protocol_version (
@@ -59,18 +72,6 @@ def calculate_heartbeat_interval(*, connected_frame: ConnectedFrame, client_hear
59
72
return max (client_heartbeat .will_send_interval_ms , server_heartbeat .want_to_receive_interval_ms ) / 1000
60
73
61
74
62
- async def wait_for_receipt_frame (
63
- * , frames_iter : AsyncIterable [AnyServerFrame ], disconnect_confirmation_timeout : int
64
- ) -> None :
65
- async def inner () -> None :
66
- async for frame in frames_iter :
67
- if isinstance (frame , ReceiptFrame ):
68
- break
69
-
70
- with suppress (TimeoutError ):
71
- await asyncio .wait_for (inner (), timeout = disconnect_confirmation_timeout )
72
-
73
-
74
75
class AbstractConnectionLifespan (Protocol ):
75
76
async def enter (self ) -> StompProtocolConnectionIssue | None : ...
76
77
async def exit (self ) -> None : ...
@@ -87,6 +88,7 @@ class ConnectionLifespan(AbstractConnectionLifespan):
87
88
active_subscriptions : ActiveSubscriptions
88
89
active_transactions : ActiveTransactions
89
90
set_heartbeat_interval : Callable [[float ], None ]
91
+ _generate_receipt_id : Callable [[], str ] = field (default = lambda : _make_receipt_id ()) # noqa: PLW0108
90
92
91
93
async def _establish_connection (self ) -> StompProtocolConnectionIssue | None :
92
94
await self .connection .write_frame (
@@ -100,13 +102,17 @@ async def _establish_connection(self) -> StompProtocolConnectionIssue | None:
100
102
},
101
103
)
102
104
)
103
- connected_frame_or_error = await take_connected_frame (
105
+ connected_frame_or_collected_frames = await take_frame_of_type (
106
+ frame_type = ConnectedFrame ,
104
107
frames_iter = self .connection .read_frames (),
105
- connection_confirmation_timeout = self .connection_confirmation_timeout ,
108
+ timeout = self .connection_confirmation_timeout ,
109
+ wait_for_or_none = wait_for_or_none ,
106
110
)
107
- if isinstance (connected_frame_or_error , ConnectionConfirmationTimeout ):
108
- return connected_frame_or_error
109
- connected_frame = connected_frame_or_error
111
+ if not isinstance (connected_frame_or_collected_frames , ConnectedFrame ):
112
+ return ConnectionConfirmationTimeout (
113
+ timeout = self .connection_confirmation_timeout , frames = connected_frame_or_collected_frames
114
+ )
115
+ connected_frame = connected_frame_or_collected_frames
110
116
111
117
if unsupported_protocol_version_error := check_stomp_protocol_version (
112
118
connected_frame = connected_frame , supported_version = self .protocol_version
@@ -130,10 +136,12 @@ async def enter(self) -> StompProtocolConnectionIssue | None:
130
136
131
137
async def exit (self ) -> None :
132
138
await unsubscribe_from_all_active_subscriptions (active_subscriptions = self .active_subscriptions )
133
- await self .connection .write_frame (DisconnectFrame (headers = {"receipt" : _make_receipt_id ()}))
134
- await wait_for_receipt_frame (
139
+ await self .connection .write_frame (DisconnectFrame (headers = {"receipt" : self ._generate_receipt_id ()}))
140
+ await take_frame_of_type (
141
+ frame_type = ReceiptFrame ,
135
142
frames_iter = self .connection .read_frames (),
136
- disconnect_confirmation_timeout = self .disconnect_confirmation_timeout ,
143
+ timeout = self .disconnect_confirmation_timeout ,
144
+ wait_for_or_none = wait_for_or_none ,
137
145
)
138
146
139
147
0 commit comments