Skip to content

Commit 3db311a

Browse files
authored
Add disconnect confirmation timeout (#46)
1 parent c8d4286 commit 3db311a

File tree

2 files changed

+51
-16
lines changed

2 files changed

+51
-16
lines changed

stompman/client.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
22
from collections.abc import AsyncGenerator, AsyncIterator, Callable, Coroutine
3-
from contextlib import AsyncExitStack, asynccontextmanager
3+
from contextlib import AsyncExitStack, asynccontextmanager, suppress
44
from dataclasses import dataclass, field
55
from types import TracebackType
66
from typing import ClassVar, Self
@@ -31,13 +31,14 @@
3131

3232

3333
@asynccontextmanager
34-
async def connection_lifespan(
34+
async def connection_lifespan( # noqa: PLR0913
3535
*,
3636
connection: AbstractConnection,
3737
connection_parameters: ConnectionParameters,
3838
protocol_version: str,
3939
client_heartbeat: Heartbeat,
4040
connection_confirmation_timeout: int,
41+
disconnect_confirmation_timeout: int,
4142
) -> AsyncIterator[float]:
4243
await connection.write_frame(
4344
ConnectFrame(
@@ -52,7 +53,7 @@ async def connection_lifespan(
5253
)
5354
collected_frames = []
5455

55-
async def take_connected_frame() -> ConnectedFrame:
56+
async def take_connected_frame_and_collect_other_frames() -> ConnectedFrame:
5657
async for frame in connection.read_frames():
5758
if isinstance(frame, ConnectedFrame):
5859
return frame
@@ -61,7 +62,9 @@ async def take_connected_frame() -> ConnectedFrame:
6162
raise AssertionError(msg) # pragma: no cover
6263

6364
try:
64-
connected_frame = await asyncio.wait_for(take_connected_frame(), timeout=connection_confirmation_timeout)
65+
connected_frame = await asyncio.wait_for(
66+
take_connected_frame_and_collect_other_frames(), timeout=connection_confirmation_timeout
67+
)
6568
except TimeoutError as exception:
6669
raise ConnectionConfirmationTimeoutError(
6770
timeout=connection_confirmation_timeout, frames=collected_frames
@@ -79,9 +82,14 @@ async def take_connected_frame() -> ConnectedFrame:
7982
yield heartbeat_interval
8083

8184
await connection.write_frame(DisconnectFrame(headers={"receipt": _make_receipt_id()}))
82-
async for frame in connection.read_frames():
83-
if isinstance(frame, ReceiptFrame):
84-
break
85+
86+
async def take_receipt_frame() -> None:
87+
async for frame in connection.read_frames():
88+
if isinstance(frame, ReceiptFrame):
89+
break
90+
91+
with suppress(TimeoutError):
92+
await asyncio.wait_for(take_receipt_frame(), timeout=disconnect_confirmation_timeout)
8593

8694

8795
def _make_receipt_id() -> str:
@@ -211,9 +219,11 @@ class Client:
211219
connect_retry_attempts: int = 3
212220
connect_retry_interval: int = 1
213221
connect_timeout: int = 2
214-
connection_confirmation_timeout: int = 2
215222
read_timeout: int = 2
216223
read_max_chunk_size: int = 1024 * 1024
224+
connection_confirmation_timeout: int = 2
225+
disconnect_confirmation_timeout: int = 2
226+
217227
connection_class: type[AbstractConnection] = Connection
218228

219229
_connection_manager: ConnectionManager = field(init=False)
@@ -263,6 +273,7 @@ async def _lifespan(
263273
protocol_version=self.PROTOCOL_VERSION,
264274
client_heartbeat=self.heartbeat,
265275
connection_confirmation_timeout=self.connection_confirmation_timeout,
276+
disconnect_confirmation_timeout=self.disconnect_confirmation_timeout,
266277
) as heartbeat_interval:
267278
self._restart_heartbeat_task(heartbeat_interval)
268279
async with subscriptions_lifespan(connection=connection, active_subscriptions=self._active_subscriptions):

tests/test_client.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,6 @@ async def read_frames() -> AsyncGenerator[AnyServerFrame, None]:
6868
return BaseCollectingConnection, collected_frames
6969

7070

71-
def get_read_frames_with_lifespan(*read_frames: list[AnyServerFrame]) -> list[list[AnyServerFrame]]:
72-
return [
73-
[ConnectedFrame(headers={"version": Client.PROTOCOL_VERSION, "heart-beat": "1,1"})],
74-
*read_frames,
75-
[ReceiptFrame(headers={"receipt-id": "receipt-id-1"})],
76-
]
77-
78-
7971
CONNECT_FRAME = ConnectFrame(
8072
headers={
8173
"accept-version": Client.PROTOCOL_VERSION,
@@ -88,6 +80,14 @@ def get_read_frames_with_lifespan(*read_frames: list[AnyServerFrame]) -> list[li
8880
CONNECTED_FRAME = ConnectedFrame(headers={"version": Client.PROTOCOL_VERSION, "heart-beat": "1,1"})
8981

9082

83+
def get_read_frames_with_lifespan(*read_frames: list[AnyServerFrame]) -> list[list[AnyServerFrame]]:
84+
return [
85+
[CONNECTED_FRAME],
86+
*read_frames,
87+
[ReceiptFrame(headers={"receipt-id": "receipt-id-1"})],
88+
]
89+
90+
9191
def enrich_expected_frames(*expected_frames: AnyClientFrame | AnyServerFrame) -> list[AnyClientFrame | AnyServerFrame]:
9292
return [
9393
CONNECT_FRAME,
@@ -172,6 +172,30 @@ async def test_client_connection_lifespan_unsupported_protocol_version() -> None
172172
)
173173

174174

175+
async def test_client_connection_lifespan_disconnect_not_confirmed(monkeypatch: pytest.MonkeyPatch) -> None:
176+
wait_for_calls = []
177+
178+
async def mock_wait_for(future: Coroutine[Any, Any, Any], timeout: float) -> object:
179+
wait_for_calls.append(timeout)
180+
task = asyncio.create_task(future)
181+
await asyncio.sleep(0)
182+
return await original_wait_for(task, 0)
183+
184+
original_wait_for = asyncio.wait_for
185+
monkeypatch.setattr("asyncio.wait_for", mock_wait_for)
186+
disconnect_confirmation_timeout = FAKER.pyint()
187+
read_frames_yields = get_read_frames_with_lifespan([])
188+
read_frames_yields[-1].clear()
189+
connection_class, _ = create_spying_connection(*read_frames_yields)
190+
191+
async with EnrichedClient(
192+
connection_class=connection_class, disconnect_confirmation_timeout=disconnect_confirmation_timeout
193+
):
194+
pass
195+
196+
assert wait_for_calls[-1] == disconnect_confirmation_timeout
197+
198+
175199
async def test_client_heartbeats_ok(monkeypatch: pytest.MonkeyPatch) -> None:
176200
async def mock_sleep(delay: float) -> None:
177201
await real_sleep(0)

0 commit comments

Comments
 (0)