1
1
import asyncio
2
+ import time
2
3
from collections .abc import AsyncGenerator
3
4
from dataclasses import dataclass , field
4
5
from ssl import SSLContext
5
6
from types import TracebackType
6
7
from typing import TYPE_CHECKING , Literal , Self
7
8
8
- from stompman .config import ConnectionParameters
9
+ from stompman .config import ConnectionParameters , Heartbeat
9
10
from stompman .connection import AbstractConnection
10
11
from stompman .errors import (
11
12
AllServersUnavailable ,
25
26
class ActiveConnectionState :
26
27
connection : AbstractConnection
27
28
lifespan : "AbstractConnectionLifespan"
29
+ server_heartbeat : Heartbeat
30
+
31
+ def is_alive (self ) -> bool :
32
+ if not (last_read_time := self .connection .last_read_time ):
33
+ return True
34
+ return (self .server_heartbeat .will_send_interval_ms / 1000 ) > (time .time () - last_read_time )
28
35
29
36
30
37
@dataclass (kw_only = True , slots = True )
@@ -39,17 +46,29 @@ class ConnectionManager:
39
46
read_timeout : int
40
47
read_max_chunk_size : int
41
48
write_retry_attempts : int
49
+ check_server_alive_interval_factor : int
42
50
43
51
_active_connection_state : ActiveConnectionState | None = field (default = None , init = False )
44
- _reconnect_lock : asyncio .Lock = field (default_factory = asyncio .Lock )
52
+ _reconnect_lock : asyncio .Lock = field (init = False , default_factory = asyncio .Lock )
53
+ _task_group : asyncio .TaskGroup = field (init = False , default_factory = asyncio .TaskGroup )
54
+ _send_heartbeat_task : asyncio .Task [None ] = field (init = False , repr = False )
55
+ _check_server_heartbeat_task : asyncio .Task [None ] = field (init = False , repr = False )
45
56
46
57
async def __aenter__ (self ) -> Self :
58
+ await self ._task_group .__aenter__ ()
59
+ self ._send_heartbeat_task = self ._task_group .create_task (asyncio .sleep (0 ))
60
+ self ._check_server_heartbeat_task = self ._task_group .create_task (asyncio .sleep (0 ))
47
61
self ._active_connection_state = await self ._get_active_connection_state ()
48
62
return self
49
63
50
64
async def __aexit__ (
51
65
self , exc_type : type [BaseException ] | None , exc_value : BaseException | None , traceback : TracebackType | None
52
66
) -> None :
67
+ self ._send_heartbeat_task .cancel ()
68
+ self ._check_server_heartbeat_task .cancel ()
69
+ await asyncio .wait ([self ._send_heartbeat_task , self ._check_server_heartbeat_task ])
70
+ await self ._task_group .__aexit__ (exc_type , exc_value , traceback )
71
+
53
72
if not self ._active_connection_state :
54
73
return
55
74
try :
@@ -58,7 +77,34 @@ async def __aexit__(
58
77
return
59
78
await self ._active_connection_state .connection .close ()
60
79
61
- async def _create_connection_to_one_server (self , server : ConnectionParameters ) -> ActiveConnectionState | None :
80
+ def _restart_heartbeat_tasks (self , server_heartbeat : Heartbeat ) -> None :
81
+ self ._send_heartbeat_task .cancel ()
82
+ self ._check_server_heartbeat_task .cancel ()
83
+ self ._send_heartbeat_task = self ._task_group .create_task (
84
+ self ._send_heartbeats_forever (server_heartbeat .want_to_receive_interval_ms )
85
+ )
86
+ self ._check_server_heartbeat_task = self ._task_group .create_task (
87
+ self ._check_server_heartbeat_forever (server_heartbeat .will_send_interval_ms )
88
+ )
89
+
90
+ async def _send_heartbeats_forever (self , send_heartbeat_interval_ms : int ) -> None :
91
+ send_heartbeat_interval_seconds = send_heartbeat_interval_ms / 1000
92
+ while True :
93
+ await self .write_heartbeat_reconnecting ()
94
+ await asyncio .sleep (send_heartbeat_interval_seconds )
95
+
96
+ async def _check_server_heartbeat_forever (self , receive_heartbeat_interval_ms : int ) -> None :
97
+ receive_heartbeat_interval_seconds = receive_heartbeat_interval_ms / 1000
98
+ while True :
99
+ await asyncio .sleep (receive_heartbeat_interval_seconds * self .check_server_alive_interval_factor )
100
+ if not self ._active_connection_state :
101
+ continue
102
+ if not self ._active_connection_state .is_alive ():
103
+ self ._active_connection_state = None
104
+
105
+ async def _create_connection_to_one_server (
106
+ self , server : ConnectionParameters
107
+ ) -> tuple [AbstractConnection , ConnectionParameters ] | None :
62
108
if connection := await self .connection_class .connect (
63
109
host = server .host ,
64
110
port = server .port ,
@@ -67,31 +113,41 @@ async def _create_connection_to_one_server(self, server: ConnectionParameters) -
67
113
read_timeout = self .read_timeout ,
68
114
ssl = self .ssl ,
69
115
):
70
- return ActiveConnectionState (
71
- connection = connection ,
72
- lifespan = self .lifespan_factory (connection = connection , connection_parameters = server ),
73
- )
116
+ return (connection , server )
74
117
return None
75
118
76
- async def _create_connection_to_any_server (self ) -> ActiveConnectionState | None :
119
+ async def _create_connection_to_any_server (self ) -> tuple [ AbstractConnection , ConnectionParameters ] | None :
77
120
for maybe_connection_future in asyncio .as_completed (
78
121
[self ._create_connection_to_one_server (server ) for server in self .servers ]
79
122
):
80
- if connection_state := await maybe_connection_future :
81
- return connection_state
123
+ if connection_and_server := await maybe_connection_future :
124
+ return connection_and_server
82
125
return None
83
126
84
127
async def _connect_to_any_server (self ) -> ActiveConnectionState | AnyConnectionIssue :
85
- if not (active_connection_state := await self ._create_connection_to_any_server ()):
128
+ from stompman .connection_lifespan import EstablishedConnectionResult # noqa: PLC0415
129
+
130
+ if not (connection_and_server := await self ._create_connection_to_any_server ()):
86
131
return AllServersUnavailable (servers = self .servers , timeout = self .connect_timeout )
132
+ connection , connection_parameters = connection_and_server
133
+ lifespan = self .lifespan_factory (
134
+ connection = connection ,
135
+ connection_parameters = connection_parameters ,
136
+ set_heartbeat_interval = self ._restart_heartbeat_tasks ,
137
+ )
87
138
88
139
try :
89
- if connection_issue := await active_connection_state .lifespan .enter ():
90
- return connection_issue
140
+ connection_result = await lifespan .enter ()
91
141
except ConnectionLostError :
92
142
return ConnectionLost ()
93
143
94
- return active_connection_state
144
+ return (
145
+ ActiveConnectionState (
146
+ connection = connection , lifespan = lifespan , server_heartbeat = connection_result .server_heartbeat
147
+ )
148
+ if isinstance (connection_result , EstablishedConnectionResult )
149
+ else connection_result
150
+ )
95
151
96
152
async def _get_active_connection_state (self ) -> ActiveConnectionState :
97
153
if self ._active_connection_state :
0 commit comments