Skip to content

Commit 4120df9

Browse files
authored
Suppress ConnectionLostError in shutdown to avoid dirty tracebacks that sometime include ExceptionGroups (#29)
1 parent 42839fc commit 4120df9

File tree

8 files changed

+72
-57
lines changed

8 files changed

+72
-57
lines changed

pyproject.toml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,16 @@ ignore = [
6666
"S101",
6767
"SLF001",
6868
"CPY001",
69+
"S104",
6970
]
71+
extend-per-file-ignores = { "tests/*" = [
72+
"ARG002",
73+
"ARG003",
74+
"PLR6301",
75+
"ANN401",
76+
"PLC2801",
77+
] }
78+
7079

7180
[tool.pytest.ini_options]
7281
addopts = "--cov -s"

stompman/client.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
22
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable
3-
from contextlib import AsyncExitStack, asynccontextmanager
3+
from contextlib import AsyncExitStack, asynccontextmanager, suppress
44
from dataclasses import dataclass, field
55
from types import TracebackType
66
from typing import Any, ClassVar, NamedTuple, Self, TypedDict
@@ -10,6 +10,7 @@
1010
from stompman.connection import AbstractConnection, Connection
1111
from stompman.errors import (
1212
ConnectionConfirmationTimeoutError,
13+
ConnectionLostError,
1314
FailedAllConnectAttemptsError,
1415
UnsupportedProtocolVersionError,
1516
)
@@ -199,7 +200,10 @@ async def _connection_lifespan(self) -> AsyncGenerator[None, None]:
199200

200201
async def send_heartbeats_forever() -> None:
201202
while True:
202-
self._connection.write_heartbeat()
203+
try:
204+
self._connection.write_heartbeat()
205+
except ConnectionLostError:
206+
return
203207
await asyncio.sleep(heartbeat_interval)
204208

205209
async with asyncio.TaskGroup() as task_group:
@@ -209,10 +213,11 @@ async def send_heartbeats_forever() -> None:
209213
finally:
210214
task.cancel()
211215

212-
await self._connection.write_frame(DisconnectFrame(headers={"receipt": str(uuid4())}))
213-
await self._connection.read_frame_of_type(
214-
ReceiptFrame, max_chunk_size=self.read_max_chunk_size, timeout=self.read_timeout
215-
)
216+
with suppress(ConnectionLostError):
217+
await self._connection.write_frame(DisconnectFrame(headers={"receipt": str(uuid4())}))
218+
await self._connection.read_frame_of_type(
219+
ReceiptFrame, max_chunk_size=self.read_max_chunk_size, timeout=self.read_timeout
220+
)
216221

217222
@asynccontextmanager
218223
async def enter_transaction(self) -> AsyncGenerator[str, None]:

stompman/connection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import asyncio
22
import socket
33
from collections.abc import AsyncGenerator, Generator, Iterator
4-
from contextlib import contextmanager
4+
from contextlib import contextmanager, suppress
55
from dataclasses import dataclass
66
from typing import Protocol, Self, TypeVar, cast
77

@@ -52,7 +52,7 @@ async def connect(cls, host: str, port: int, timeout: int) -> Self | None:
5252

5353
async def close(self) -> None:
5454
self.writer.close()
55-
with _reraise_connection_lost(ConnectionError):
55+
with suppress(ConnectionError):
5656
await self.writer.wait_closed()
5757

5858
def write_heartbeat(self) -> None:

testing/consumer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
async def main() -> None:
77
async with (
8-
stompman.Client(servers=[stompman.ConnectionParameters("0.0.0.0", 61616, "admin", "admin")]) as client, # noqa: S104
8+
stompman.Client(servers=[stompman.ConnectionParameters("0.0.0.0", 61616, "admin", "=123")]) as client,
99
client.subscribe("DLQ"),
1010
):
1111
async for event in client.listen():

testing/producer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
async def main() -> None:
77
async with (
8-
stompman.Client(servers=[stompman.ConnectionParameters("0.0.0.0", 61616, "admin", "admin")]) as client, # noqa: S104
8+
stompman.Client(servers=[stompman.ConnectionParameters("0.0.0.0", 61616, "admin", "admin")]) as client,
99
client.enter_transaction() as transaction,
1010
):
1111
for _ in range(10):

tests/integration.py

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -58,28 +58,19 @@ async def closed_client(server: stompman.ConnectionParameters) -> AsyncGenerator
5858
yield client
5959

6060

61-
async def test_raises_connection_lost_error_in_aexit(server: stompman.ConnectionParameters) -> None:
62-
with pytest.raises(ConnectionLostError):
63-
async with closed_client(server):
64-
pass
61+
async def test_not_raises_connection_lost_error_in_aexit(server: stompman.ConnectionParameters) -> None:
62+
async with closed_client(server):
63+
pass
6564

6665

67-
async def test_raises_connection_lost_error_in_write_frame(server: stompman.ConnectionParameters) -> None:
68-
client = await closed_client(server).__aenter__() # noqa: PLC2801
69-
70-
with pytest.raises(ConnectionLostError):
71-
await client._connection.write_frame(stompman.ConnectFrame(headers={"accept-version": "", "host": ""}))
72-
73-
with pytest.raises(ConnectionLostError):
74-
await client.__aexit__(None, None, None)
66+
async def test_not_raises_connection_lost_error_in_write_frame(server: stompman.ConnectionParameters) -> None:
67+
async with closed_client(server) as client:
68+
with pytest.raises(ConnectionLostError):
69+
await client._connection.write_frame(stompman.ConnectFrame(headers={"accept-version": "", "host": ""}))
7570

7671

7772
@pytest.mark.parametrize("anyio_backend", [("asyncio", {"use_uvloop": True})])
78-
async def test_raises_connection_lost_error_in_write_heartbeat(server: stompman.ConnectionParameters) -> None:
79-
client = await closed_client(server).__aenter__() # noqa: PLC2801
80-
81-
with pytest.raises(ConnectionLostError):
82-
client._connection.write_heartbeat()
83-
84-
with pytest.raises(ConnectionLostError):
85-
await client.__aexit__(None, None, None)
73+
async def test_not_raises_connection_lost_error_in_write_heartbeat(server: stompman.ConnectionParameters) -> None:
74+
async with closed_client(server) as client:
75+
with pytest.raises(ConnectionLostError):
76+
client._connection.write_heartbeat()

tests/test_client.py

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
UnsupportedProtocolVersionError,
3434
)
3535
from stompman.client import ConnectionParameters, ErrorEvent, HeartbeatEvent, MessageEvent
36+
from stompman.errors import ConnectionLostError
3637

3738
pytestmark = pytest.mark.anyio
3839

@@ -42,19 +43,19 @@ class BaseMockConnection(AbstractConnection):
4243
@classmethod
4344
async def connect(
4445
cls,
45-
host: str, # noqa: ARG003
46-
port: int, # noqa: ARG003
47-
timeout: int, # noqa: ARG003
46+
host: str,
47+
port: int,
48+
timeout: int,
4849
) -> Self | None:
4950
return cls()
5051

5152
async def close(self) -> None: ...
5253
def write_heartbeat(self) -> None: ...
5354
async def write_frame(self, frame: AnyClientFrame) -> None: ...
54-
async def read_frames( # noqa: PLR6301
55+
async def read_frames(
5556
self,
56-
max_chunk_size: int, # noqa: ARG002
57-
timeout: int, # noqa: ARG002
57+
max_chunk_size: int,
58+
timeout: int,
5859
) -> AsyncGenerator[AnyServerFrame, None]: # pragma: no cover
5960
await asyncio.Future()
6061
yield # type: ignore[misc]
@@ -65,13 +66,13 @@ def create_spying_connection(
6566
) -> tuple[type[AbstractConnection], list[AnyClientFrame | AnyServerFrame | HeartbeatFrame]]:
6667
@dataclass
6768
class BaseCollectingConnection(BaseMockConnection):
68-
async def write_frame(self, frame: AnyClientFrame) -> None: # noqa: PLR6301
69+
async def write_frame(self, frame: AnyClientFrame) -> None:
6970
collected_frames.append(frame)
7071

71-
async def read_frames( # noqa: PLR6301
72+
async def read_frames(
7273
self,
73-
max_chunk_size: int, # noqa: ARG002
74-
timeout: int, # noqa: ARG002
74+
max_chunk_size: int,
75+
timeout: int,
7576
) -> AsyncGenerator[AnyServerFrame, None]:
7677
for frame in next(read_frames_iterator):
7778
collected_frames.append(frame)
@@ -147,9 +148,9 @@ class MockConnection(BaseMockConnection):
147148
@classmethod
148149
async def connect(
149150
cls,
150-
host: str, # noqa: ARG003
151-
port: int, # noqa: ARG003
152-
timeout: int, # noqa: ARG003
151+
host: str,
152+
port: int,
153+
timeout: int,
153154
) -> Self | None:
154155
return None
155156

@@ -185,9 +186,9 @@ class MockConnection(BaseMockConnection):
185186
@classmethod
186187
async def connect(
187188
cls,
188-
host: str, # noqa: ARG003
189-
port: int, # noqa: ARG003
190-
timeout: int, # noqa: ARG003
189+
host: str,
190+
port: int,
191+
timeout: int,
191192
) -> Self | None:
192193
return None
193194

@@ -241,7 +242,7 @@ class MockConnection(connection_class): # type: ignore[valid-type, misc]
241242

242243

243244
async def test_client_lifespan_connection_not_confirmed(monkeypatch: pytest.MonkeyPatch) -> None:
244-
async def timeout(future: Awaitable[Any], timeout: float) -> Any: # noqa: ANN401
245+
async def timeout(future: Awaitable[Any], timeout: float) -> Any:
245246
assert timeout == client.connection_confirmation_timeout
246247
return await original_wait_for(future, 0)
247248

@@ -250,7 +251,7 @@ async def timeout(future: Awaitable[Any], timeout: float) -> Any: # noqa: ANN40
250251

251252
client = EnrichedClient(connection_class=BaseMockConnection)
252253
with pytest.raises(ConnectionConfirmationTimeoutError) as exc_info:
253-
await client.__aenter__() # noqa: PLC2801
254+
await client.__aenter__()
254255

255256
assert exc_info.value == ConnectionConfirmationTimeoutError(client.connection_confirmation_timeout)
256257

@@ -263,7 +264,7 @@ async def test_client_lifespan_unsupported_protocol_version() -> None:
263264

264265
client = EnrichedClient(connection_class=connection_class)
265266
with pytest.raises(UnsupportedProtocolVersionError) as exc_info:
266-
await client.__aenter__() # noqa: PLC2801
267+
await client.__aenter__()
267268

268269
assert exc_info.value == UnsupportedProtocolVersionError(
269270
given_version=given_version, supported_version=client.PROTOCOL_VERSION
@@ -319,6 +320,16 @@ class MockConnection(connection_class): # type: ignore[valid-type, misc]
319320
assert write_heartbeat_mock.mock_calls == [mock.call(), mock.call(), mock.call()]
320321

321322

323+
async def test_client_heartbeat_not_raises_connection_lost() -> None:
324+
connection_class, _ = create_spying_connection(get_read_frames_with_lifespan([]))
325+
326+
class MockConnection(connection_class): # type: ignore[valid-type, misc]
327+
write_heartbeat = mock.Mock(side_effect=ConnectionLostError)
328+
329+
async with EnrichedClient(connection_class=MockConnection):
330+
await asyncio.sleep(0)
331+
332+
322333
async def test_client_listen_to_events_ok() -> None:
323334
message_frame = MessageFrame(headers={"destination": "", "message-id": "", "subscription": ""}, body=b"hello")
324335
error_frame = ErrorFrame(headers={"message": "short description"})
@@ -423,11 +434,11 @@ async def test_message_event_with_auto_ack_ack_raises() -> None:
423434
event, ack, nack, on_suppressed_exception = get_mocked_message_event()
424435

425436
async def func() -> None: # noqa: RUF029
426-
raise Exception # noqa: TRY002
437+
raise ImportError
427438

428-
with suppress(Exception):
439+
with suppress(ImportError):
429440
await event.with_auto_ack(
430-
func(), supressed_exception_classes=(RuntimeError,), on_suppressed_exception=on_suppressed_exception
441+
func(), supressed_exception_classes=(ModuleNotFoundError,), on_suppressed_exception=on_suppressed_exception
431442
)
432443

433444
ack.assert_called_once_with()

tests/test_connection.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ async def make_connection() -> Connection | None:
1919

2020
async def make_mocked_connection(
2121
monkeypatch: pytest.MonkeyPatch,
22-
reader: Any, # noqa: ANN401
23-
writer: Any, # noqa: ANN401
22+
reader: Any,
23+
writer: Any,
2424
) -> Connection:
2525
monkeypatch.setattr("asyncio.open_connection", mock.AsyncMock(return_value=(reader, writer)))
2626
connection = await make_connection()
@@ -29,7 +29,7 @@ async def make_mocked_connection(
2929

3030

3131
def mock_wait_for(monkeypatch: pytest.MonkeyPatch) -> None:
32-
async def mock_impl(future: Awaitable[Any], timeout: int) -> Any: # noqa: ANN401, ARG001
32+
async def mock_impl(future: Awaitable[Any], timeout: int) -> Any: # noqa: ARG001
3333
return await original_wait_for(future, timeout=0)
3434

3535
original_wait_for = asyncio.wait_for
@@ -104,8 +104,7 @@ class MockWriter:
104104
wait_closed = mock.AsyncMock(side_effect=ConnectionError)
105105

106106
connection = await make_mocked_connection(monkeypatch, mock.Mock(), MockWriter())
107-
with pytest.raises(ConnectionLostError):
108-
await connection.close()
107+
await connection.close()
109108

110109

111110
async def test_connection_write_heartbeat_runtime_error(monkeypatch: pytest.MonkeyPatch) -> None:

0 commit comments

Comments
 (0)