7
7
8
8
import pytest
9
9
10
- from stompman import (
11
- AnyServerFrame ,
12
- ConnectedFrame ,
13
- Connection ,
14
- ConnectionLostError ,
15
- HeartbeatFrame ,
16
- )
10
+ from stompman import AnyServerFrame , ConnectedFrame , Connection , ConnectionLostError , HeartbeatFrame
17
11
from stompman .frames import BeginFrame , CommitFrame
18
12
13
+ pytestmark = pytest .mark .anyio
14
+
19
15
20
16
async def make_connection () -> Connection | None :
21
17
return await Connection .connect (host = "localhost" , port = 12345 , timeout = 2 )
22
18
23
19
20
+ async def make_mocked_connection (
21
+ monkeypatch : pytest .MonkeyPatch ,
22
+ reader : Any , # noqa: ANN401
23
+ writer : Any , # noqa: ANN401
24
+ ) -> Connection :
25
+ monkeypatch .setattr ("asyncio.open_connection" , mock .AsyncMock (return_value = (reader , writer )))
26
+ connection = await make_connection ()
27
+ assert connection
28
+ return connection
29
+
30
+
24
31
def mock_wait_for (monkeypatch : pytest .MonkeyPatch ) -> None :
25
32
async def mock_impl (future : Awaitable [Any ], timeout : int ) -> Any : # noqa: ANN401, ARG001
26
33
return await original_wait_for (future , timeout = 0 )
@@ -57,19 +64,21 @@ class MockWriter:
57
64
b"som" ,
58
65
b"e server\n version:1.2\n \n \x00 " ,
59
66
]
67
+ expected_frames = [
68
+ HeartbeatFrame (),
69
+ HeartbeatFrame (),
70
+ HeartbeatFrame (),
71
+ ConnectedFrame (headers = {"heart-beat" : "0,0" , "version" : "1.2" , "server" : "some server" }),
72
+ ]
73
+ max_chunk_size = 1024
60
74
61
75
class MockReader :
62
76
read = mock .AsyncMock (side_effect = read_bytes )
63
77
64
- monkeypatch .setattr ("asyncio.open_connection" , mock .AsyncMock (return_value = (MockReader (), MockWriter ())))
65
- connection = await make_connection ()
66
- assert connection
67
-
78
+ connection = await make_mocked_connection (monkeypatch , MockReader (), MockWriter ())
68
79
connection .write_heartbeat ()
69
80
await connection .write_frame (CommitFrame (headers = {"transaction" : "transaction" }))
70
81
71
- max_chunk_size = 1024
72
-
73
82
async def take_frames (count : int ) -> list [AnyServerFrame ]:
74
83
frames = []
75
84
async for frame in connection .read_frames (max_chunk_size = max_chunk_size , timeout = 1 ):
@@ -79,12 +88,6 @@ async def take_frames(count: int) -> list[AnyServerFrame]:
79
88
80
89
return frames
81
90
82
- expected_frames = [
83
- HeartbeatFrame (),
84
- HeartbeatFrame (),
85
- HeartbeatFrame (),
86
- ConnectedFrame (headers = {"heart-beat" : "0,0" , "version" : "1.2" , "server" : "some server" }),
87
- ]
88
91
assert await take_frames (len (expected_frames )) == expected_frames
89
92
await connection .close ()
90
93
@@ -100,10 +103,7 @@ class MockWriter:
100
103
close = mock .Mock ()
101
104
wait_closed = mock .AsyncMock (side_effect = ConnectionError )
102
105
103
- monkeypatch .setattr ("asyncio.open_connection" , mock .AsyncMock (return_value = (mock .Mock (), MockWriter ())))
104
- connection = await make_connection ()
105
- assert connection
106
-
106
+ connection = await make_mocked_connection (monkeypatch , mock .Mock (), MockWriter ())
107
107
with pytest .raises (ConnectionLostError ):
108
108
await connection .close ()
109
109
@@ -113,10 +113,17 @@ class MockWriter:
113
113
write = mock .Mock ()
114
114
drain = mock .AsyncMock (side_effect = ConnectionError )
115
115
116
- monkeypatch . setattr ( "asyncio.open_connection" , mock .AsyncMock ( return_value = ( mock . Mock (), MockWriter ()) ))
117
- connection = await make_connection ()
118
- assert connection
116
+ connection = await make_mocked_connection ( monkeypatch , mock .Mock (), MockWriter ())
117
+ with pytest . raises ( ConnectionLostError ):
118
+ await connection . write_frame ( BeginFrame ( headers = { "transaction" : "" }))
119
119
120
+
121
+ async def test_connection_write_frame_runtime_error (monkeypatch : pytest .MonkeyPatch ) -> None :
122
+ class MockWriter :
123
+ write = mock .Mock (side_effect = RuntimeError )
124
+ drain = mock .AsyncMock ()
125
+
126
+ connection = await make_mocked_connection (monkeypatch , mock .Mock (), MockWriter ())
120
127
with pytest .raises (ConnectionLostError ):
121
128
await connection .write_frame (BeginFrame (headers = {"transaction" : "" }))
122
129
@@ -133,27 +140,17 @@ async def test_connection_connect_connection_error(monkeypatch: pytest.MonkeyPat
133
140
134
141
135
142
async def test_read_frames_timeout_error (monkeypatch : pytest .MonkeyPatch ) -> None :
136
- monkeypatch .setattr (
137
- "asyncio.open_connection" ,
138
- mock .AsyncMock (return_value = (mock .AsyncMock (read = partial (asyncio .sleep , 5 )), mock .AsyncMock ())),
143
+ connection = await make_mocked_connection (
144
+ monkeypatch , mock .AsyncMock (read = partial (asyncio .sleep , 5 )), mock .AsyncMock ()
139
145
)
140
- connection = await make_connection ()
141
- assert connection
142
-
143
146
mock_wait_for (monkeypatch )
144
147
with pytest .raises (ConnectionLostError ):
145
148
[frame async for frame in connection .read_frames (1024 , 1 )]
146
149
147
150
148
151
async def test_read_frames_connection_error (monkeypatch : pytest .MonkeyPatch ) -> None :
149
- monkeypatch .setattr (
150
- "asyncio.open_connection" ,
151
- mock .AsyncMock (
152
- return_value = (mock .AsyncMock (read = mock .AsyncMock (side_effect = BrokenPipeError )), mock .AsyncMock ())
153
- ),
152
+ connection = await make_mocked_connection (
153
+ monkeypatch , mock .AsyncMock (read = mock .AsyncMock (side_effect = BrokenPipeError )), mock .AsyncMock ()
154
154
)
155
- connection = await make_connection ()
156
- assert connection
157
-
158
155
with pytest .raises (ConnectionLostError ):
159
156
[frame async for frame in connection .read_frames (1024 , 1 )]
0 commit comments