Skip to content

Commit 4a701eb

Browse files
authored
Avoid sending cleanup frames if ConnectionLostError was raised, also avoid raising it in ack/nack (#30)
1 parent 4120df9 commit 4a701eb

File tree

3 files changed

+109
-54
lines changed

3 files changed

+109
-54
lines changed

stompman/client.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
22
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable
3-
from contextlib import AsyncExitStack, asynccontextmanager, suppress
3+
from contextlib import AsyncExitStack, asynccontextmanager
44
from dataclasses import dataclass, field
55
from types import TracebackType
66
from typing import Any, ClassVar, NamedTuple, Self, TypedDict
@@ -145,8 +145,7 @@ async def _connect_to_any_server(self) -> None:
145145
for maybe_connection_future in asyncio.as_completed(
146146
[self._connect_to_one_server(server) for server in self.servers]
147147
):
148-
maybe_result = await maybe_connection_future
149-
if maybe_result:
148+
if maybe_result := await maybe_connection_future:
150149
self._connection, self._connection_parameters = maybe_result
151150
return
152151
raise FailedAllConnectAttemptsError(
@@ -199,10 +198,12 @@ async def _connection_lifespan(self) -> AsyncGenerator[None, None]:
199198
)
200199

201200
async def send_heartbeats_forever() -> None:
202-
while True:
201+
while self._connection.active:
203202
try:
204203
self._connection.write_heartbeat()
205204
except ConnectionLostError:
205+
# Avoid raising the error in an exception group.
206+
# ConnectionLostError should be raised in a way that user expects it.
206207
return
207208
await asyncio.sleep(heartbeat_interval)
208209

@@ -213,8 +214,9 @@ async def send_heartbeats_forever() -> None:
213214
finally:
214215
task.cancel()
215216

216-
with suppress(ConnectionLostError):
217+
if self._connection.active:
217218
await self._connection.write_frame(DisconnectFrame(headers={"receipt": str(uuid4())}))
219+
if self._connection.active:
218220
await self._connection.read_frame_of_type(
219221
ReceiptFrame, max_chunk_size=self.read_max_chunk_size, timeout=self.read_timeout
220222
)
@@ -227,10 +229,12 @@ async def enter_transaction(self) -> AsyncGenerator[str, None]:
227229
try:
228230
yield transaction_id
229231
except Exception:
230-
await self._connection.write_frame(AbortFrame(headers={"transaction": transaction_id}))
232+
if self._connection.active:
233+
await self._connection.write_frame(AbortFrame(headers={"transaction": transaction_id}))
231234
raise
232235
else:
233-
await self._connection.write_frame(CommitFrame(headers={"transaction": transaction_id}))
236+
if self._connection.active:
237+
await self._connection.write_frame(CommitFrame(headers={"transaction": transaction_id}))
234238

235239
async def send( # noqa: PLR0913
236240
self,
@@ -258,7 +262,8 @@ async def subscribe(self, destination: str) -> AsyncGenerator[None, None]:
258262
try:
259263
yield
260264
finally:
261-
await self._connection.write_frame(UnsubscribeFrame(headers={"id": subscription_id}))
265+
if self._connection.active:
266+
await self._connection.write_frame(UnsubscribeFrame(headers={"id": subscription_id}))
262267

263268
async def listen(self) -> AsyncIterator["AnyListeningEvent"]:
264269
async for frame in self._connection.read_frames(
@@ -285,18 +290,26 @@ def __post_init__(self) -> None:
285290
self.body = self._frame.body
286291

287292
async def ack(self) -> None:
288-
await self._client._connection.write_frame(
289-
AckFrame(
290-
headers={"id": self._frame.headers["message-id"], "subscription": self._frame.headers["subscription"]},
293+
if self._client._connection.active:
294+
await self._client._connection.write_frame(
295+
AckFrame(
296+
headers={
297+
"id": self._frame.headers["message-id"],
298+
"subscription": self._frame.headers["subscription"],
299+
},
300+
)
291301
)
292-
)
293302

294303
async def nack(self) -> None:
295-
await self._client._connection.write_frame(
296-
NackFrame(
297-
headers={"id": self._frame.headers["message-id"], "subscription": self._frame.headers["subscription"]}
304+
if self._client._connection.active:
305+
await self._client._connection.write_frame(
306+
NackFrame(
307+
headers={
308+
"id": self._frame.headers["message-id"],
309+
"subscription": self._frame.headers["subscription"],
310+
}
311+
)
298312
)
299-
)
300313

301314
async def with_auto_ack(
302315
self,

stompman/connection.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
@dataclass
1616
class AbstractConnection(Protocol):
17+
active: bool = True
18+
1719
@classmethod
1820
async def connect(cls, host: str, port: int, timeout: int) -> Self | None: ...
1921
async def close(self) -> None: ...
@@ -28,15 +30,7 @@ async def read_frame_of_type(self, type_: type[FrameType], max_chunk_size: int,
2830
return frame
2931

3032

31-
@contextmanager
32-
def _reraise_connection_lost(*causes: type[Exception]) -> Generator[None, None, None]:
33-
try:
34-
yield
35-
except causes as exception:
36-
raise ConnectionLostError from exception
37-
38-
39-
@dataclass
33+
@dataclass(kw_only=True)
4034
class Connection(AbstractConnection):
4135
reader: asyncio.StreamReader
4236
writer: asyncio.StreamWriter
@@ -54,15 +48,24 @@ async def close(self) -> None:
5448
self.writer.close()
5549
with suppress(ConnectionError):
5650
await self.writer.wait_closed()
51+
self.active = False
52+
53+
@contextmanager
54+
def _reraise_connection_lost(self, *causes: type[Exception]) -> Generator[None, None, None]:
55+
try:
56+
yield
57+
except causes as exception:
58+
self.active = False
59+
raise ConnectionLostError from exception
5760

5861
def write_heartbeat(self) -> None:
59-
with _reraise_connection_lost(RuntimeError):
62+
with self._reraise_connection_lost(RuntimeError):
6063
return self.writer.write(NEWLINE)
6164

6265
async def write_frame(self, frame: AnyClientFrame) -> None:
63-
with _reraise_connection_lost(RuntimeError):
66+
with self._reraise_connection_lost(RuntimeError):
6467
self.writer.write(dump_frame(frame))
65-
with _reraise_connection_lost(ConnectionError):
68+
with self._reraise_connection_lost(ConnectionError):
6669
await self.writer.drain()
6770

6871
async def _read_non_empty_bytes(self, max_chunk_size: int) -> bytes:
@@ -74,7 +77,7 @@ async def read_frames(self, max_chunk_size: int, timeout: int) -> AsyncGenerator
7477
parser = FrameParser()
7578

7679
while True:
77-
with _reraise_connection_lost(ConnectionError, TimeoutError):
80+
with self._reraise_connection_lost(ConnectionError, TimeoutError):
7881
raw_frames = await asyncio.wait_for(self._read_non_empty_bytes(max_chunk_size), timeout=timeout)
7982

8083
for frame in cast(Iterator[AnyServerFrame], parser.parse_frames_from_chunk(raw_frames)):

tests/integration.py

Lines changed: 64 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,27 @@
1212
pytestmark = pytest.mark.anyio
1313

1414

15+
@asynccontextmanager
16+
async def create_client() -> AsyncGenerator[stompman.Client, None]:
17+
server = stompman.ConnectionParameters(
18+
host=os.environ["ARTEMIS_HOST"], port=61616, login="admin", passcode="%3D123"
19+
)
20+
async with stompman.Client(servers=[server], read_timeout=10, connection_confirmation_timeout=10) as client:
21+
yield client
22+
23+
1524
@pytest.fixture()
16-
def server() -> stompman.ConnectionParameters:
17-
return stompman.ConnectionParameters(host=os.environ["ARTEMIS_HOST"], port=61616, login="admin", passcode="%3D123")
25+
async def client() -> AsyncGenerator[stompman.Client, None]:
26+
async with create_client() as client:
27+
yield client
1828

1929

20-
async def test_ok(server: stompman.ConnectionParameters) -> None:
21-
destination = "DLQ"
22-
messages = [str(uuid4()).encode() for _ in range(10000)]
30+
@pytest.fixture()
31+
def destination() -> str:
32+
return "DLQ"
2333

34+
35+
async def test_ok(destination: str) -> None:
2436
async def produce() -> None:
2537
async with producer.enter_transaction() as transaction:
2638
for message in messages:
@@ -42,35 +54,62 @@ async def consume() -> None:
4254

4355
assert sorted(received_messages) == sorted(messages)
4456

45-
async with (
46-
stompman.Client(servers=[server], read_timeout=10, connection_confirmation_timeout=10) as consumer,
47-
stompman.Client(servers=[server], read_timeout=10, connection_confirmation_timeout=10) as producer,
48-
asyncio.TaskGroup() as task_group,
49-
):
57+
messages = [str(uuid4()).encode() for _ in range(10000)]
58+
59+
async with create_client() as consumer, create_client() as producer, asyncio.TaskGroup() as task_group:
5060
task_group.create_task(consume())
5161
task_group.create_task(produce())
5262

5363

54-
@asynccontextmanager
55-
async def closed_client(server: stompman.ConnectionParameters) -> AsyncGenerator[stompman.Client, None]:
56-
async with stompman.Client(servers=[server], read_timeout=10, connection_confirmation_timeout=10) as client:
64+
async def test_not_raises_connection_lost_error_in_aexit(client: stompman.Client) -> None:
65+
await client._connection.close()
66+
67+
68+
async def test_not_raises_connection_lost_error_in_write_frame(client: stompman.Client) -> None:
69+
await client._connection.close()
70+
71+
with pytest.raises(ConnectionLostError):
72+
await client._connection.write_frame(stompman.ConnectFrame(headers={"accept-version": "", "host": ""}))
73+
74+
75+
@pytest.mark.parametrize("anyio_backend", [("asyncio", {"use_uvloop": True})])
76+
async def test_not_raises_connection_lost_error_in_write_heartbeat(client: stompman.Client) -> None:
77+
await client._connection.close()
78+
79+
with pytest.raises(ConnectionLostError):
80+
client._connection.write_heartbeat()
81+
82+
83+
async def test_not_raises_connection_lost_error_in_subscription(client: stompman.Client, destination: str) -> None:
84+
async with client.subscribe(destination):
5785
await client._connection.close()
58-
yield client
5986

6087

61-
async def test_not_raises_connection_lost_error_in_aexit(server: stompman.ConnectionParameters) -> None:
62-
async with closed_client(server):
63-
pass
88+
async def test_not_raises_connection_lost_error_in_transaction_without_send(client: stompman.Client) -> None:
89+
async with client.enter_transaction():
90+
await client._connection.close()
6491

6592

66-
async def test_not_raises_connection_lost_error_in_write_frame(server: stompman.ConnectionParameters) -> None:
67-
async with closed_client(server) as client:
93+
async def test_not_raises_connection_lost_error_in_transaction_with_send(
94+
client: stompman.Client, destination: str
95+
) -> None:
96+
async with client.enter_transaction() as transaction:
97+
await client.send(b"first", destination=destination, transaction=transaction)
98+
await client._connection.close()
99+
68100
with pytest.raises(ConnectionLostError):
69-
await client._connection.write_frame(stompman.ConnectFrame(headers={"accept-version": "", "host": ""}))
101+
await client.send(b"second", destination=destination, transaction=transaction)
70102

71103

72-
@pytest.mark.parametrize("anyio_backend", [("asyncio", {"use_uvloop": True})])
73-
async def test_not_raises_connection_lost_error_in_write_heartbeat(server: stompman.ConnectionParameters) -> None:
74-
async with closed_client(server) as client:
75-
with pytest.raises(ConnectionLostError):
76-
client._connection.write_heartbeat()
104+
async def test_raises_connection_lost_error_in_send(client: stompman.Client, destination: str) -> None:
105+
await client._connection.close()
106+
107+
with pytest.raises(ConnectionLostError):
108+
await client.send(b"first", destination=destination)
109+
110+
111+
async def test_raises_connection_lost_error_in_listen(client: stompman.Client) -> None:
112+
await client._connection.close()
113+
client.read_timeout = 0
114+
with pytest.raises(ConnectionLostError):
115+
[event async for event in client.listen()]

0 commit comments

Comments
 (0)