1
1
import asyncio
2
2
from collections .abc import AsyncGenerator , AsyncIterator , Callable , Coroutine
3
- from contextlib import AsyncExitStack , asynccontextmanager
3
+ from contextlib import AsyncExitStack , asynccontextmanager , suppress
4
4
from dataclasses import dataclass , field
5
5
from types import TracebackType
6
6
from typing import ClassVar , Self
31
31
32
32
33
33
@asynccontextmanager
34
- async def connection_lifespan (
34
+ async def connection_lifespan ( # noqa: PLR0913
35
35
* ,
36
36
connection : AbstractConnection ,
37
37
connection_parameters : ConnectionParameters ,
38
38
protocol_version : str ,
39
39
client_heartbeat : Heartbeat ,
40
40
connection_confirmation_timeout : int ,
41
+ disconnect_confirmation_timeout : int ,
41
42
) -> AsyncIterator [float ]:
42
43
await connection .write_frame (
43
44
ConnectFrame (
@@ -52,7 +53,7 @@ async def connection_lifespan(
52
53
)
53
54
collected_frames = []
54
55
55
- async def take_connected_frame () -> ConnectedFrame :
56
+ async def take_connected_frame_and_collect_other_frames () -> ConnectedFrame :
56
57
async for frame in connection .read_frames ():
57
58
if isinstance (frame , ConnectedFrame ):
58
59
return frame
@@ -61,7 +62,9 @@ async def take_connected_frame() -> ConnectedFrame:
61
62
raise AssertionError (msg ) # pragma: no cover
62
63
63
64
try :
64
- connected_frame = await asyncio .wait_for (take_connected_frame (), timeout = connection_confirmation_timeout )
65
+ connected_frame = await asyncio .wait_for (
66
+ take_connected_frame_and_collect_other_frames (), timeout = connection_confirmation_timeout
67
+ )
65
68
except TimeoutError as exception :
66
69
raise ConnectionConfirmationTimeoutError (
67
70
timeout = connection_confirmation_timeout , frames = collected_frames
@@ -79,9 +82,14 @@ async def take_connected_frame() -> ConnectedFrame:
79
82
yield heartbeat_interval
80
83
81
84
await connection .write_frame (DisconnectFrame (headers = {"receipt" : _make_receipt_id ()}))
82
- async for frame in connection .read_frames ():
83
- if isinstance (frame , ReceiptFrame ):
84
- break
85
+
86
+ async def take_receipt_frame () -> None :
87
+ async for frame in connection .read_frames ():
88
+ if isinstance (frame , ReceiptFrame ):
89
+ break
90
+
91
+ with suppress (TimeoutError ):
92
+ await asyncio .wait_for (take_receipt_frame (), timeout = disconnect_confirmation_timeout )
85
93
86
94
87
95
def _make_receipt_id () -> str :
@@ -211,9 +219,11 @@ class Client:
211
219
connect_retry_attempts : int = 3
212
220
connect_retry_interval : int = 1
213
221
connect_timeout : int = 2
214
- connection_confirmation_timeout : int = 2
215
222
read_timeout : int = 2
216
223
read_max_chunk_size : int = 1024 * 1024
224
+ connection_confirmation_timeout : int = 2
225
+ disconnect_confirmation_timeout : int = 2
226
+
217
227
connection_class : type [AbstractConnection ] = Connection
218
228
219
229
_connection_manager : ConnectionManager = field (init = False )
@@ -263,6 +273,7 @@ async def _lifespan(
263
273
protocol_version = self .PROTOCOL_VERSION ,
264
274
client_heartbeat = self .heartbeat ,
265
275
connection_confirmation_timeout = self .connection_confirmation_timeout ,
276
+ disconnect_confirmation_timeout = self .disconnect_confirmation_timeout ,
266
277
) as heartbeat_interval :
267
278
self ._restart_heartbeat_task (heartbeat_interval )
268
279
async with subscriptions_lifespan (connection = connection , active_subscriptions = self ._active_subscriptions ):
0 commit comments