|
1 | 1 | import asyncio
|
2 | 2 | from collections.abc import AsyncGenerator, Callable, Coroutine
|
3 | 3 | from contextlib import suppress
|
| 4 | +from functools import partial |
4 | 5 | from typing import TYPE_CHECKING, Any, get_args
|
5 | 6 | from unittest import mock
|
6 | 7 |
|
|
12 | 13 | AbortFrame,
|
13 | 14 | AbstractConnection,
|
14 | 15 | AckFrame,
|
| 16 | + AckMode, |
15 | 17 | AnyClientFrame,
|
16 | 18 | AnyServerFrame,
|
17 | 19 | BeginFrame,
|
|
23 | 25 | ConnectionParameters,
|
24 | 26 | DisconnectFrame,
|
25 | 27 | ErrorFrame,
|
| 28 | + FailedAllConnectAttemptsError, |
26 | 29 | HeartbeatFrame,
|
27 | 30 | MessageFrame,
|
28 | 31 | NackFrame,
|
|
32 | 35 | UnsubscribeFrame,
|
33 | 36 | UnsupportedProtocolVersionError,
|
34 | 37 | )
|
35 |
| -from stompman.frames import AckMode |
36 | 38 | from tests.conftest import (
|
37 | 39 | BaseMockConnection,
|
38 | 40 | EnrichedClient,
|
@@ -62,6 +64,7 @@ async def read_frames() -> AsyncGenerator[AnyServerFrame, None]:
|
62 | 64 | for frame in next(read_frames_iterator):
|
63 | 65 | collected_frames.append(frame)
|
64 | 66 | yield frame
|
| 67 | + await asyncio.Future() |
65 | 68 |
|
66 | 69 | read_frames_iterator = iter(read_frames_yields)
|
67 | 70 | collected_frames: list[AnyClientFrame | AnyServerFrame] = []
|
@@ -443,6 +446,30 @@ async def test_client_listen_auto_ack_nack(monkeypatch: pytest.MonkeyPatch, *, o
|
443 | 446 | )
|
444 | 447 |
|
445 | 448 |
|
| 449 | +async def test_client_listen_raises_on_aexit(monkeypatch: pytest.MonkeyPatch) -> None: |
| 450 | + monkeypatch.setattr("asyncio.sleep", partial(asyncio.sleep, 0)) |
| 451 | + |
| 452 | + connection_class, _ = create_spying_connection(*get_read_frames_with_lifespan([])) |
| 453 | + connection_class.connect = mock.AsyncMock(side_effect=[connection_class(), None, None, None]) # type: ignore[method-assign] |
| 454 | + |
| 455 | + async def close_connection_soon(client: stompman.Client) -> None: |
| 456 | + await asyncio.sleep(0) |
| 457 | + client._connection_manager._clear_active_connection_state() |
| 458 | + |
| 459 | + with pytest.raises(ExceptionGroup) as exc_info: # noqa: PT012 |
| 460 | + async with asyncio.TaskGroup() as task_group, EnrichedClient(connection_class=connection_class) as client: |
| 461 | + await client.subscribe(FAKER.pystr(), noop_message_handler, on_suppressed_exception=noop_error_handler) |
| 462 | + task_group.create_task(close_connection_soon(client)) |
| 463 | + |
| 464 | + assert len(exc_info.value.exceptions) == 1 |
| 465 | + inner_group = exc_info.value.exceptions[0] |
| 466 | + |
| 467 | + assert isinstance(inner_group, ExceptionGroup) |
| 468 | + assert len(inner_group.exceptions) == 1 |
| 469 | + |
| 470 | + assert isinstance(inner_group.exceptions[0], FailedAllConnectAttemptsError) |
| 471 | + |
| 472 | + |
446 | 473 | async def test_send_message_and_enter_transaction_ok(monkeypatch: pytest.MonkeyPatch) -> None:
|
447 | 474 | body, destination, expires, content_type = FAKER.binary(), FAKER.pystr(), FAKER.pystr(), FAKER.pystr()
|
448 | 475 |
|
|
0 commit comments