Skip to content

Commit 59e3843

Browse files
committed
Enhance connection handling by introducing EstablishedConnectionResult for better heartbeat management and connection lifecycle tracking
1 parent c1518e0 commit 59e3843

File tree

2 files changed

+39
-25
lines changed

2 files changed

+39
-25
lines changed

packages/stompman/stompman/connection_lifespan.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,13 @@
2222
from stompman.transaction import ActiveTransactions, commit_pending_transactions
2323

2424

25+
@dataclass(frozen=True, kw_only=True, slots=True)
26+
class EstablishedConnectionResult:
27+
server_heartbeat: Heartbeat
28+
29+
2530
class AbstractConnectionLifespan(Protocol):
26-
async def enter(self) -> StompProtocolConnectionIssue | None: ...
31+
async def enter(self) -> EstablishedConnectionResult | StompProtocolConnectionIssue: ...
2732
async def exit(self) -> None: ...
2833

2934

@@ -39,7 +44,7 @@ class ConnectionLifespan(AbstractConnectionLifespan):
3944
active_transactions: ActiveTransactions
4045
set_heartbeat_interval: Callable[[Heartbeat], Any]
4146

42-
async def _establish_connection(self) -> StompProtocolConnectionIssue | None:
47+
async def _establish_connection(self) -> EstablishedConnectionResult | StompProtocolConnectionIssue:
4348
await self.connection.write_frame(
4449
ConnectFrame(
4550
headers={
@@ -73,17 +78,18 @@ async def take_connected_frame_and_collect_other_frames() -> ConnectedFrame:
7378
given_version=connected_frame.headers["version"], supported_version=self.protocol_version
7479
)
7580

76-
self.set_heartbeat_interval(Heartbeat.from_header(connected_frame.headers["heart-beat"]))
77-
return None
81+
server_heartbeat = Heartbeat.from_header(connected_frame.headers["heart-beat"])
82+
self.set_heartbeat_interval(server_heartbeat)
83+
return EstablishedConnectionResult(server_heartbeat=server_heartbeat)
7884

79-
async def enter(self) -> StompProtocolConnectionIssue | None:
80-
if connection_issue := await self._establish_connection():
81-
return connection_issue
82-
await resubscribe_to_active_subscriptions(
83-
connection=self.connection, active_subscriptions=self.active_subscriptions
84-
)
85-
await commit_pending_transactions(connection=self.connection, active_transactions=self.active_transactions)
86-
return None
85+
async def enter(self) -> EstablishedConnectionResult | StompProtocolConnectionIssue:
86+
connection_result = await self._establish_connection()
87+
if isinstance(connection_result, EstablishedConnectionResult):
88+
await resubscribe_to_active_subscriptions(
89+
connection=self.connection, active_subscriptions=self.active_subscriptions
90+
)
91+
await commit_pending_transactions(connection=self.connection, active_transactions=self.active_transactions)
92+
return connection_result
8793

8894
async def _take_receipt_frame(self) -> None:
8995
async for frame in self.connection.read_frames():

packages/stompman/stompman/connection_manager.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
from types import TracebackType
77
from typing import TYPE_CHECKING, Literal, Self
88

9-
from stompman.config import ConnectionParameters
9+
from stompman.config import ConnectionParameters, Heartbeat
1010
from stompman.connection import AbstractConnection
11+
from stompman.connection_lifespan import EstablishedConnectionResult
1112
from stompman.errors import (
1213
AllServersUnavailable,
1314
AnyConnectionIssue,
@@ -26,6 +27,7 @@
2627
class ActiveConnectionState:
2728
connection: AbstractConnection
2829
lifespan: "AbstractConnectionLifespan"
30+
server_heartbeat: Heartbeat
2931

3032

3133
@dataclass(kw_only=True, slots=True)
@@ -60,7 +62,9 @@ async def __aexit__(
6062
return
6163
await self._active_connection_state.connection.close()
6264

63-
async def _create_connection_to_one_server(self, server: ConnectionParameters) -> ActiveConnectionState | None:
65+
async def _create_connection_to_one_server(
66+
self, server: ConnectionParameters
67+
) -> tuple[AbstractConnection, ConnectionParameters] | None:
6468
if connection := await self.connection_class.connect(
6569
host=server.host,
6670
port=server.port,
@@ -69,31 +73,35 @@ async def _create_connection_to_one_server(self, server: ConnectionParameters) -
6973
read_timeout=self.read_timeout,
7074
ssl=self.ssl,
7175
):
72-
return ActiveConnectionState(
73-
connection=connection,
74-
lifespan=self.lifespan_factory(connection=connection, connection_parameters=server),
75-
)
76+
return (connection, server)
7677
return None
7778

78-
async def _create_connection_to_any_server(self) -> ActiveConnectionState | None:
79+
async def _create_connection_to_any_server(self) -> tuple[AbstractConnection, ConnectionParameters] | None:
7980
for maybe_connection_future in asyncio.as_completed(
8081
[self._create_connection_to_one_server(server) for server in self.servers]
8182
):
82-
if connection_state := await maybe_connection_future:
83-
return connection_state
83+
if connection_and_server := await maybe_connection_future:
84+
return connection_and_server
8485
return None
8586

8687
async def _connect_to_any_server(self) -> ActiveConnectionState | AnyConnectionIssue:
87-
if not (active_connection_state := await self._create_connection_to_any_server()):
88+
if not (connection_and_server := await self._create_connection_to_any_server()):
8889
return AllServersUnavailable(servers=self.servers, timeout=self.connect_timeout)
90+
connection, connection_parameters = connection_and_server
91+
lifespan = self.lifespan_factory(connection=connection, connection_parameters=connection_parameters)
8992

9093
try:
91-
if connection_issue := await active_connection_state.lifespan.enter():
92-
return connection_issue
94+
connection_result = await lifespan.enter()
9395
except ConnectionLostError:
9496
return ConnectionLost()
9597

96-
return active_connection_state
98+
return (
99+
ActiveConnectionState(
100+
connection=connection, lifespan=lifespan, server_heartbeat=connection_result.server_heartbeat
101+
)
102+
if isinstance(connection_result, EstablishedConnectionResult)
103+
else connection_result
104+
)
97105

98106
async def _get_active_connection_state(self) -> ActiveConnectionState:
99107
if self._active_connection_state:

0 commit comments

Comments
 (0)