Skip to content

Commit 191e155

Browse files
authored
Fix hanging forever in Client.__aexit__() (#50)
1 parent 4a57e0c commit 191e155

File tree

3 files changed

+34
-5
lines changed

3 files changed

+34
-5
lines changed

Justfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ test-integration *args:
2020
docker compose run --build --rm app .venv/bin/pytest tests/integration.py --no-cov {{args}}
2121

2222
run-artemis:
23-
docker compose up
23+
docker compose run --service-ports artemis
2424

2525
run-consumer:
2626
ARTEMIS_HOST=0.0.0.0 uv run -q --frozen python testing/consumer.py

stompman/client.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ class Client:
231231
_exit_stack: AsyncExitStack = field(default_factory=AsyncExitStack, init=False)
232232
_heartbeat_task: asyncio.Task[None] = field(init=False)
233233
_listen_task: asyncio.Task[None] = field(init=False)
234+
_task_group: asyncio.TaskGroup = field(init=False)
234235

235236
def __post_init__(self) -> None:
236237
self._connection_manager = ConnectionManager(
@@ -245,9 +246,10 @@ def __post_init__(self) -> None:
245246
)
246247

247248
async def __aenter__(self) -> Self:
248-
self._heartbeat_task = asyncio.create_task(asyncio.sleep(0))
249+
self._task_group = await self._exit_stack.enter_async_context(asyncio.TaskGroup())
250+
self._heartbeat_task = self._task_group.create_task(asyncio.sleep(0))
249251
await self._exit_stack.enter_async_context(self._connection_manager)
250-
self._listen_task = asyncio.create_task(self._listen_to_frames())
252+
self._listen_task = self._task_group.create_task(self._listen_to_frames())
251253
return self
252254

253255
async def __aexit__(
@@ -281,7 +283,7 @@ async def _lifespan(
281283

282284
def _restart_heartbeat_task(self, interval: float) -> None:
283285
self._heartbeat_task.cancel()
284-
self._heartbeat_task = asyncio.create_task(self._send_heartbeats_forever(interval))
286+
self._heartbeat_task = self._task_group.create_task(self._send_heartbeats_forever(interval))
285287

286288
async def _send_heartbeats_forever(self, interval: float) -> None:
287289
while True:

tests/test_client.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
from collections.abc import AsyncGenerator, Callable, Coroutine
33
from contextlib import suppress
4+
from functools import partial
45
from typing import TYPE_CHECKING, Any, get_args
56
from unittest import mock
67

@@ -12,6 +13,7 @@
1213
AbortFrame,
1314
AbstractConnection,
1415
AckFrame,
16+
AckMode,
1517
AnyClientFrame,
1618
AnyServerFrame,
1719
BeginFrame,
@@ -23,6 +25,7 @@
2325
ConnectionParameters,
2426
DisconnectFrame,
2527
ErrorFrame,
28+
FailedAllConnectAttemptsError,
2629
HeartbeatFrame,
2730
MessageFrame,
2831
NackFrame,
@@ -32,7 +35,6 @@
3235
UnsubscribeFrame,
3336
UnsupportedProtocolVersionError,
3437
)
35-
from stompman.frames import AckMode
3638
from tests.conftest import (
3739
BaseMockConnection,
3840
EnrichedClient,
@@ -62,6 +64,7 @@ async def read_frames() -> AsyncGenerator[AnyServerFrame, None]:
6264
for frame in next(read_frames_iterator):
6365
collected_frames.append(frame)
6466
yield frame
67+
await asyncio.Future()
6568

6669
read_frames_iterator = iter(read_frames_yields)
6770
collected_frames: list[AnyClientFrame | AnyServerFrame] = []
@@ -443,6 +446,30 @@ async def test_client_listen_auto_ack_nack(monkeypatch: pytest.MonkeyPatch, *, o
443446
)
444447

445448

449+
async def test_client_listen_raises_on_aexit(monkeypatch: pytest.MonkeyPatch) -> None:
450+
monkeypatch.setattr("asyncio.sleep", partial(asyncio.sleep, 0))
451+
452+
connection_class, _ = create_spying_connection(*get_read_frames_with_lifespan([]))
453+
connection_class.connect = mock.AsyncMock(side_effect=[connection_class(), None, None, None]) # type: ignore[method-assign]
454+
455+
async def close_connection_soon(client: stompman.Client) -> None:
456+
await asyncio.sleep(0)
457+
client._connection_manager._clear_active_connection_state()
458+
459+
with pytest.raises(ExceptionGroup) as exc_info: # noqa: PT012
460+
async with asyncio.TaskGroup() as task_group, EnrichedClient(connection_class=connection_class) as client:
461+
await client.subscribe(FAKER.pystr(), noop_message_handler, on_suppressed_exception=noop_error_handler)
462+
task_group.create_task(close_connection_soon(client))
463+
464+
assert len(exc_info.value.exceptions) == 1
465+
inner_group = exc_info.value.exceptions[0]
466+
467+
assert isinstance(inner_group, ExceptionGroup)
468+
assert len(inner_group.exceptions) == 1
469+
470+
assert isinstance(inner_group.exceptions[0], FailedAllConnectAttemptsError)
471+
472+
446473
async def test_send_message_and_enter_transaction_ok(monkeypatch: pytest.MonkeyPatch) -> None:
447474
body, destination, expires, content_type = FAKER.binary(), FAKER.pystr(), FAKER.pystr(), FAKER.pystr()
448475

0 commit comments

Comments
 (0)