Skip to content

Commit db05208

Browse files
authored
Fix graceful shutdown when using uvloop (#25)
1 parent 9730d08 commit db05208

File tree

6 files changed

+61
-47
lines changed

6 files changed

+61
-47
lines changed

stompman/client.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ async def listen(self) -> AsyncIterator["AnyListeningEvent"]:
269269
class MessageEvent:
270270
body: bytes = field(init=False)
271271
_frame: MessageFrame
272-
_client: "Client" = field(repr=False)
272+
_client: Client = field(repr=False)
273273

274274
def __post_init__(self) -> None:
275275
self.body = self._frame.body
@@ -314,7 +314,7 @@ class ErrorEvent:
314314
body: bytes = field(init=False)
315315
"""Long description of the error."""
316316
_frame: ErrorFrame
317-
_client: "Client" = field(repr=False)
317+
_client: Client = field(repr=False)
318318

319319
def __post_init__(self) -> None:
320320
self.message_header = self._frame.headers["message"]
@@ -324,7 +324,7 @@ def __post_init__(self) -> None:
324324
@dataclass
325325
class HeartbeatEvent:
326326
_frame: HeartbeatFrame
327-
_client: "Client" = field(repr=False)
327+
_client: Client = field(repr=False)
328328

329329

330330
AnyListeningEvent = MessageEvent | ErrorEvent | HeartbeatEvent

stompman/connection.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ def write_heartbeat(self) -> None:
5959
return self.writer.write(NEWLINE)
6060

6161
async def write_frame(self, frame: AnyClientFrame) -> None:
62-
self.writer.write(dump_frame(frame))
62+
with _reraise_connection_lost(RuntimeError):
63+
self.writer.write(dump_frame(frame))
6364
with _reraise_connection_lost(ConnectionError):
6465
await self.writer.drain()
6566

tests/conftest.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
pytest.param(("asyncio", {"use_uvloop": True}), id="asyncio+uvloop"),
77
pytest.param(("asyncio", {"use_uvloop": False}), id="asyncio"),
88
],
9-
autouse=True,
109
)
1110
def anyio_backend(request: pytest.FixtureRequest) -> object:
1211
return request.param

tests/integration.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,20 @@
22
import os
33
from uuid import uuid4
44

5+
import pytest
6+
57
import stompman
8+
from stompman.errors import ConnectionLostError
9+
10+
pytestmark = pytest.mark.anyio
11+
612

13+
@pytest.fixture()
14+
def server() -> stompman.ConnectionParameters:
15+
return stompman.ConnectionParameters(host=os.environ["ARTEMIS_HOST"], port=61616, login="admin", passcode="admin")
716

8-
async def test_integration() -> None:
9-
server = stompman.ConnectionParameters(host=os.environ["ARTEMIS_HOST"], port=61616, login="admin", passcode="admin")
17+
18+
async def test_ok(server: stompman.ConnectionParameters) -> None:
1019
destination = "DLQ"
1120
messages = [str(uuid4()).encode() for _ in range(10000)]
1221

@@ -38,3 +47,9 @@ async def consume() -> None:
3847
):
3948
task_group.create_task(consume())
4049
task_group.create_task(produce())
50+
51+
52+
async def test_raises_connection_lost_error(server: stompman.ConnectionParameters) -> None:
53+
with pytest.raises(ConnectionLostError):
54+
async with stompman.Client(servers=[server], read_timeout=10, connection_confirmation_timeout=10) as consumer:
55+
await consumer._connection.close()

tests/test_client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
)
3535
from stompman.client import ConnectionParameters, ErrorEvent, HeartbeatEvent, MessageEvent
3636

37+
pytestmark = pytest.mark.anyio
38+
3739

3840
@dataclass
3941
class BaseMockConnection(AbstractConnection):

tests/test_connection.py

Lines changed: 37 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,27 @@
77

88
import pytest
99

10-
from stompman import (
11-
AnyServerFrame,
12-
ConnectedFrame,
13-
Connection,
14-
ConnectionLostError,
15-
HeartbeatFrame,
16-
)
10+
from stompman import AnyServerFrame, ConnectedFrame, Connection, ConnectionLostError, HeartbeatFrame
1711
from stompman.frames import BeginFrame, CommitFrame
1812

13+
pytestmark = pytest.mark.anyio
14+
1915

2016
async def make_connection() -> Connection | None:
2117
return await Connection.connect(host="localhost", port=12345, timeout=2)
2218

2319

20+
async def make_mocked_connection(
21+
monkeypatch: pytest.MonkeyPatch,
22+
reader: Any, # noqa: ANN401
23+
writer: Any, # noqa: ANN401
24+
) -> Connection:
25+
monkeypatch.setattr("asyncio.open_connection", mock.AsyncMock(return_value=(reader, writer)))
26+
connection = await make_connection()
27+
assert connection
28+
return connection
29+
30+
2431
def mock_wait_for(monkeypatch: pytest.MonkeyPatch) -> None:
2532
async def mock_impl(future: Awaitable[Any], timeout: int) -> Any: # noqa: ANN401, ARG001
2633
return await original_wait_for(future, timeout=0)
@@ -57,19 +64,21 @@ class MockWriter:
5764
b"som",
5865
b"e server\nversion:1.2\n\n\x00",
5966
]
67+
expected_frames = [
68+
HeartbeatFrame(),
69+
HeartbeatFrame(),
70+
HeartbeatFrame(),
71+
ConnectedFrame(headers={"heart-beat": "0,0", "version": "1.2", "server": "some server"}),
72+
]
73+
max_chunk_size = 1024
6074

6175
class MockReader:
6276
read = mock.AsyncMock(side_effect=read_bytes)
6377

64-
monkeypatch.setattr("asyncio.open_connection", mock.AsyncMock(return_value=(MockReader(), MockWriter())))
65-
connection = await make_connection()
66-
assert connection
67-
78+
connection = await make_mocked_connection(monkeypatch, MockReader(), MockWriter())
6879
connection.write_heartbeat()
6980
await connection.write_frame(CommitFrame(headers={"transaction": "transaction"}))
7081

71-
max_chunk_size = 1024
72-
7382
async def take_frames(count: int) -> list[AnyServerFrame]:
7483
frames = []
7584
async for frame in connection.read_frames(max_chunk_size=max_chunk_size, timeout=1):
@@ -79,12 +88,6 @@ async def take_frames(count: int) -> list[AnyServerFrame]:
7988

8089
return frames
8190

82-
expected_frames = [
83-
HeartbeatFrame(),
84-
HeartbeatFrame(),
85-
HeartbeatFrame(),
86-
ConnectedFrame(headers={"heart-beat": "0,0", "version": "1.2", "server": "some server"}),
87-
]
8891
assert await take_frames(len(expected_frames)) == expected_frames
8992
await connection.close()
9093

@@ -100,10 +103,7 @@ class MockWriter:
100103
close = mock.Mock()
101104
wait_closed = mock.AsyncMock(side_effect=ConnectionError)
102105

103-
monkeypatch.setattr("asyncio.open_connection", mock.AsyncMock(return_value=(mock.Mock(), MockWriter())))
104-
connection = await make_connection()
105-
assert connection
106-
106+
connection = await make_mocked_connection(monkeypatch, mock.Mock(), MockWriter())
107107
with pytest.raises(ConnectionLostError):
108108
await connection.close()
109109

@@ -113,10 +113,17 @@ class MockWriter:
113113
write = mock.Mock()
114114
drain = mock.AsyncMock(side_effect=ConnectionError)
115115

116-
monkeypatch.setattr("asyncio.open_connection", mock.AsyncMock(return_value=(mock.Mock(), MockWriter())))
117-
connection = await make_connection()
118-
assert connection
116+
connection = await make_mocked_connection(monkeypatch, mock.Mock(), MockWriter())
117+
with pytest.raises(ConnectionLostError):
118+
await connection.write_frame(BeginFrame(headers={"transaction": ""}))
119119

120+
121+
async def test_connection_write_frame_runtime_error(monkeypatch: pytest.MonkeyPatch) -> None:
122+
class MockWriter:
123+
write = mock.Mock(side_effect=RuntimeError)
124+
drain = mock.AsyncMock()
125+
126+
connection = await make_mocked_connection(monkeypatch, mock.Mock(), MockWriter())
120127
with pytest.raises(ConnectionLostError):
121128
await connection.write_frame(BeginFrame(headers={"transaction": ""}))
122129

@@ -133,27 +140,17 @@ async def test_connection_connect_connection_error(monkeypatch: pytest.MonkeyPat
133140

134141

135142
async def test_read_frames_timeout_error(monkeypatch: pytest.MonkeyPatch) -> None:
136-
monkeypatch.setattr(
137-
"asyncio.open_connection",
138-
mock.AsyncMock(return_value=(mock.AsyncMock(read=partial(asyncio.sleep, 5)), mock.AsyncMock())),
143+
connection = await make_mocked_connection(
144+
monkeypatch, mock.AsyncMock(read=partial(asyncio.sleep, 5)), mock.AsyncMock()
139145
)
140-
connection = await make_connection()
141-
assert connection
142-
143146
mock_wait_for(monkeypatch)
144147
with pytest.raises(ConnectionLostError):
145148
[frame async for frame in connection.read_frames(1024, 1)]
146149

147150

148151
async def test_read_frames_connection_error(monkeypatch: pytest.MonkeyPatch) -> None:
149-
monkeypatch.setattr(
150-
"asyncio.open_connection",
151-
mock.AsyncMock(
152-
return_value=(mock.AsyncMock(read=mock.AsyncMock(side_effect=BrokenPipeError)), mock.AsyncMock())
153-
),
152+
connection = await make_mocked_connection(
153+
monkeypatch, mock.AsyncMock(read=mock.AsyncMock(side_effect=BrokenPipeError)), mock.AsyncMock()
154154
)
155-
connection = await make_connection()
156-
assert connection
157-
158155
with pytest.raises(ConnectionLostError):
159156
[frame async for frame in connection.read_frames(1024, 1)]

0 commit comments

Comments
 (0)