Skip to content

Commit fc2d7e3

Browse files
authored
Avoid raising ConnectionLostError in ack/nack (#31)
1 parent 4a701eb commit fc2d7e3

File tree

2 files changed

+36
-16
lines changed

2 files changed

+36
-16
lines changed

stompman/client.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
22
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable
3-
from contextlib import AsyncExitStack, asynccontextmanager
3+
from contextlib import AsyncExitStack, asynccontextmanager, suppress
44
from dataclasses import dataclass, field
55
from types import TracebackType
66
from typing import Any, ClassVar, NamedTuple, Self, TypedDict
@@ -291,25 +291,27 @@ def __post_init__(self) -> None:
291291

292292
async def ack(self) -> None:
293293
if self._client._connection.active:
294-
await self._client._connection.write_frame(
295-
AckFrame(
296-
headers={
297-
"id": self._frame.headers["message-id"],
298-
"subscription": self._frame.headers["subscription"],
299-
},
294+
with suppress(ConnectionLostError):
295+
await self._client._connection.write_frame(
296+
AckFrame(
297+
headers={
298+
"id": self._frame.headers["message-id"],
299+
"subscription": self._frame.headers["subscription"],
300+
},
301+
)
300302
)
301-
)
302303

303304
async def nack(self) -> None:
304305
if self._client._connection.active:
305-
await self._client._connection.write_frame(
306-
NackFrame(
307-
headers={
308-
"id": self._frame.headers["message-id"],
309-
"subscription": self._frame.headers["subscription"],
310-
}
306+
with suppress(ConnectionLostError):
307+
await self._client._connection.write_frame(
308+
NackFrame(
309+
headers={
310+
"id": self._frame.headers["message-id"],
311+
"subscription": self._frame.headers["subscription"],
312+
}
313+
)
311314
)
312-
)
313315

314316
async def with_auto_ack(
315317
self,

tests/test_client.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ async def test_client_listen_to_events_unreachable(frame: ConnectedFrame | Recei
368368
[event async for event in client.listen()]
369369

370370

371-
async def test_ack_nack() -> None:
371+
async def test_ack_nack_ok() -> None:
372372
subscription = "subscription-id"
373373
message_id = "message-id"
374374

@@ -391,6 +391,24 @@ async def test_ack_nack() -> None:
391391
assert_frames_between_lifespan_match(collected_frames, [message_frame, nack_frame, ack_frame])
392392

393393

394+
async def test_ack_nack_connection_lost_error() -> None:
395+
message_frame = MessageFrame(headers={"subscription": "", "message-id": "", "destination": ""}, body=b"")
396+
connection_class, _ = create_spying_connection(get_read_frames_with_lifespan([[message_frame]]))
397+
398+
class MockConnection(connection_class): # type: ignore[valid-type, misc]
399+
async def write_frame(self, frame: AnyClientFrame) -> None:
400+
if isinstance(frame, AckFrame | NackFrame):
401+
raise ConnectionLostError
402+
403+
async with EnrichedClient(connection_class=MockConnection) as client:
404+
events = [event async for event in client.listen()]
405+
event = events[0]
406+
assert isinstance(event, MessageEvent)
407+
408+
await event.nack()
409+
await event.ack()
410+
411+
394412
def get_mocked_message_event() -> tuple[MessageEvent, mock.AsyncMock, mock.AsyncMock, mock.Mock]:
395413
ack_mock, nack_mock, on_suppressed_exception_mock = mock.AsyncMock(), mock.AsyncMock(), mock.Mock()
396414

0 commit comments

Comments
 (0)