@@ -49,15 +49,26 @@ class ConnectionManager:
49
49
check_server_alive_interval_factor : int
50
50
51
51
_active_connection_state : ActiveConnectionState | None = field (default = None , init = False )
52
- _reconnect_lock : asyncio .Lock = field (default_factory = asyncio .Lock )
52
+ _reconnect_lock : asyncio .Lock = field (init = False , default_factory = asyncio .Lock )
53
+ _task_group : asyncio .TaskGroup = field (init = False , default_factory = asyncio .TaskGroup )
54
+ _send_heartbeat_task : asyncio .Task [None ] = field (init = False , repr = False )
55
+ _check_server_heartbeat_task : asyncio .Task [None ] = field (init = False , repr = False )
53
56
54
57
async def __aenter__ (self ) -> Self :
58
+ await self ._task_group .__aenter__ ()
59
+ self ._send_heartbeat_task = self ._task_group .create_task (asyncio .sleep (0 ))
60
+ self ._check_server_heartbeat_task = self ._task_group .create_task (asyncio .sleep (0 ))
55
61
self ._active_connection_state = await self ._get_active_connection_state ()
56
62
return self
57
63
58
64
async def __aexit__ (
59
65
self , exc_type : type [BaseException ] | None , exc_value : BaseException | None , traceback : TracebackType | None
60
66
) -> None :
67
+ self ._send_heartbeat_task .cancel ()
68
+ self ._check_server_heartbeat_task .cancel ()
69
+ await asyncio .wait ([self ._send_heartbeat_task , self ._check_server_heartbeat_task ])
70
+ await self ._task_group .__aexit__ (exc_type , exc_value , traceback )
71
+
61
72
if not self ._active_connection_state :
62
73
return
63
74
try :
@@ -66,6 +77,31 @@ async def __aexit__(
66
77
return
67
78
await self ._active_connection_state .connection .close ()
68
79
80
+ def _restart_heartbeat_tasks (self , server_heartbeat : Heartbeat ) -> None :
81
+ self ._send_heartbeat_task .cancel ()
82
+ self ._check_server_heartbeat_task .cancel ()
83
+ self ._send_heartbeat_task = self ._task_group .create_task (
84
+ self ._send_heartbeats_forever (server_heartbeat .want_to_receive_interval_ms )
85
+ )
86
+ self ._check_server_heartbeat_task = self ._task_group .create_task (
87
+ self ._check_server_heartbeat_forever (server_heartbeat .will_send_interval_ms )
88
+ )
89
+
90
+ async def _send_heartbeats_forever (self , send_heartbeat_interval_ms : int ) -> None :
91
+ send_heartbeat_interval_seconds = send_heartbeat_interval_ms / 1000
92
+ while True :
93
+ await self .write_heartbeat_reconnecting ()
94
+ await asyncio .sleep (send_heartbeat_interval_seconds )
95
+
96
+ async def _check_server_heartbeat_forever (self , receive_heartbeat_interval_ms : int ) -> None :
97
+ receive_heartbeat_interval_seconds = receive_heartbeat_interval_ms / 1000
98
+ while True :
99
+ await asyncio .sleep (receive_heartbeat_interval_seconds * self .check_server_alive_interval_factor )
100
+ if not self ._active_connection_state :
101
+ continue
102
+ if not self ._active_connection_state .is_alive ():
103
+ self ._active_connection_state = None
104
+
69
105
async def _create_connection_to_one_server (
70
106
self , server : ConnectionParameters
71
107
) -> tuple [AbstractConnection , ConnectionParameters ] | None :
@@ -94,7 +130,11 @@ async def _connect_to_any_server(self) -> ActiveConnectionState | AnyConnectionI
94
130
if not (connection_and_server := await self ._create_connection_to_any_server ()):
95
131
return AllServersUnavailable (servers = self .servers , timeout = self .connect_timeout )
96
132
connection , connection_parameters = connection_and_server
97
- lifespan = self .lifespan_factory (connection = connection , connection_parameters = connection_parameters )
133
+ lifespan = self .lifespan_factory (
134
+ connection = connection ,
135
+ connection_parameters = connection_parameters ,
136
+ set_heartbeat_interval = self ._restart_heartbeat_tasks ,
137
+ )
98
138
99
139
try :
100
140
connection_result = await lifespan .enter ()
0 commit comments