Skip to content

Commit a1c8ed9

Browse files
authored
Add property testing for ser-de (#27)
1 parent fc2d7e3 commit a1c8ed9

File tree

7 files changed

+124
-39
lines changed

7 files changed

+124
-39
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
.coverage
2+
.hypothesis
23
.venv
34
dist
45
uv.lock

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,9 @@ dev-dependencies = [
2525
"mypy~=1.10.0",
2626
"pytest-cov~=5.0.0",
2727
"pytest~=8.2.2",
28-
"ruff~=0.4.8",
28+
"ruff~=0.4.9",
2929
"uvloop~=0.19.0",
30+
"hypothesis~=6.103.2",
3031
]
3132

3233
[build-system]

stompman/serde.py

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import struct
22
from collections import deque
33
from collections.abc import Iterator
4+
from contextlib import suppress
45
from dataclasses import dataclass, field
56
from typing import Any, Final, cast
67

@@ -32,13 +33,12 @@
3233
"\n": "\\n",
3334
":": "\\c",
3435
"\\": "\\\\",
35-
"\r": "\\r",
36+
"\r": "", # [\r]\n is newline, therefore can't be used in header
3637
}
3738
HEADER_UNESCAPE_CHARS: Final = {
38-
b"n": b"\n",
39+
b"n": NEWLINE,
3940
b"c": b":",
4041
b"\\": b"\\",
41-
b"r": b"\r",
4242
}
4343

4444

@@ -89,51 +89,62 @@ def dump_frame(frame: AnyClientFrame | AnyServerFrame) -> bytes:
8989

9090

9191
def unescape_byte(byte: bytes, previous_byte: bytes | None) -> bytes | None:
92+
if previous_byte == b"\\":
93+
return HEADER_UNESCAPE_CHARS.get(byte)
9294
if byte == b"\\":
9395
return None
94-
95-
if previous_byte == b"\\":
96-
return HEADER_UNESCAPE_CHARS.get(byte, byte)
97-
9896
return byte
9997

10098

101-
def parse_headers(buffer: list[bytes]) -> tuple[str, str] | None:
99+
def parse_header(buffer: list[bytes]) -> tuple[str, str] | None:
102100
key_buffer: list[bytes] = []
103101
key_parsed = False
104102
value_buffer: list[bytes] = []
105103

106104
previous_byte = None
105+
just_escaped_line = False
106+
107107
for byte in buffer:
108108
if byte == b":":
109109
if key_parsed:
110110
return None
111111
key_parsed = True
112-
113-
elif (unescaped_byte := unescape_byte(byte, previous_byte)) is not None:
112+
elif just_escaped_line:
113+
just_escaped_line = False
114+
if byte != b"\\":
115+
(value_buffer if key_parsed else key_buffer).append(byte)
116+
elif unescaped_byte := unescape_byte(byte, previous_byte):
117+
just_escaped_line = True
114118
(value_buffer if key_parsed else key_buffer).append(unescaped_byte)
115119

116120
previous_byte = byte
117121

118-
return (b"".join(key_buffer).decode(), b"".join(value_buffer).decode()) if key_parsed else None
122+
if key_parsed:
123+
with suppress(UnicodeDecodeError):
124+
return (b"".join(key_buffer).decode(), b"".join(value_buffer).decode())
125+
126+
return None
127+
128+
129+
def make_frame_from_parts(command: bytes, headers: dict[str, str], body: bytes) -> AnyClientFrame | AnyServerFrame:
130+
frame_type = COMMANDS_TO_FRAMES[command]
131+
return (
132+
frame_type(headers=cast(Any, headers), body=body) # type: ignore[call-arg]
133+
if frame_type in FRAMES_WITH_BODY
134+
else frame_type(headers=cast(Any, headers)) # type: ignore[call-arg]
135+
)
119136

120137

121138
def parse_lines_into_frame(lines: deque[list[bytes]]) -> AnyClientFrame | AnyServerFrame:
122139
command = b"".join(lines.popleft())
123140
headers = {}
124141

125142
while line := lines.popleft():
126-
header = parse_headers(line)
143+
header = parse_header(line)
127144
if header and header[0] not in headers:
128145
headers[header[0]] = header[1]
129146
body = b"".join(lines.popleft()) if lines else b""
130-
131-
frame_type = COMMANDS_TO_FRAMES[command]
132-
return (
133-
frame_type(headers=cast(Any, headers), body=body) # type: ignore[call-arg]
134-
if frame_type in FRAMES_WITH_BODY
135-
else frame_type(headers=cast(Any, headers)) # type: ignore[call-arg]
136-
)
147+
return make_frame_from_parts(command=command, headers=headers, body=body)
137148

138149

139150
@dataclass
@@ -149,10 +160,7 @@ def _reset(self) -> None:
149160
self._current_line = []
150161

151162
def parse_frames_from_chunk(self, chunk: bytes) -> Iterator[AnyClientFrame | AnyServerFrame | HeartbeatFrame]:
152-
buffer = deque(iter_bytes(chunk))
153-
while buffer:
154-
byte = buffer.popleft()
155-
163+
for byte in iter_bytes(chunk):
156164
if byte == NULL:
157165
if self._headers_processed:
158166
self._lines.append(self._current_line)

tests/integration.py

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,26 @@
11
import asyncio
22
import os
3-
from collections.abc import AsyncGenerator
3+
from collections.abc import AsyncGenerator, Callable
44
from contextlib import asynccontextmanager
5+
from itertools import starmap
56
from uuid import uuid4
67

78
import pytest
9+
from hypothesis import given, strategies
810

911
import stompman
10-
from stompman.errors import ConnectionLostError
12+
from stompman import AnyClientFrame, AnyServerFrame, ConnectionLostError, HeartbeatFrame
13+
from stompman.serde import (
14+
COMMANDS_TO_FRAMES,
15+
NEWLINE,
16+
NULL,
17+
FrameParser,
18+
dump_frame,
19+
dump_header,
20+
iter_bytes,
21+
make_frame_from_parts,
22+
parse_header,
23+
)
1124

1225
pytestmark = pytest.mark.anyio
1326

@@ -113,3 +126,55 @@ async def test_raises_connection_lost_error_in_listen(client: stompman.Client) -
113126
client.read_timeout = 0
114127
with pytest.raises(ConnectionLostError):
115128
[event async for event in client.listen()]
129+
130+
131+
def generate_frames(
132+
cases: list[tuple[bytes, list[AnyClientFrame | AnyServerFrame | HeartbeatFrame]]],
133+
) -> tuple[list[bytes], list[AnyClientFrame | AnyServerFrame | HeartbeatFrame]]:
134+
all_bytes, all_frames = [], []
135+
136+
for noise, frames in cases:
137+
current_all_bytes = []
138+
if noise:
139+
current_all_bytes.append(noise + NEWLINE)
140+
141+
for frame in frames:
142+
current_all_bytes.append(NEWLINE if isinstance(frame, HeartbeatFrame) else dump_frame(frame))
143+
all_frames.append(frame)
144+
145+
all_bytes.append(b"".join(current_all_bytes))
146+
147+
return all_bytes, all_frames
148+
149+
150+
def bytes_not_contains(*avoided: bytes) -> Callable[[bytes], bool]:
151+
return lambda checked: all(item not in checked for item in avoided)
152+
153+
154+
noise_bytes_strategy = strategies.binary().filter(bytes_not_contains(NEWLINE, NULL))
155+
header_value_strategy = strategies.text().filter(lambda text: "\x00" not in text)
156+
headers_strategy = strategies.dictionaries(header_value_strategy, header_value_strategy).map(
157+
lambda headers: dict(
158+
parsed_header
159+
for header in starmap(dump_header, headers.items())
160+
if (parsed_header := parse_header(list(iter_bytes(header))))
161+
)
162+
)
163+
frame_strategy = strategies.just(HeartbeatFrame()) | strategies.builds(
164+
make_frame_from_parts,
165+
command=strategies.sampled_from(tuple(COMMANDS_TO_FRAMES.keys())),
166+
headers=headers_strategy,
167+
body=strategies.binary().filter(bytes_not_contains(NULL)),
168+
)
169+
170+
171+
@given(
172+
strategies.builds(
173+
generate_frames,
174+
strategies.lists(strategies.tuples(noise_bytes_strategy, strategies.lists(frame_strategy))),
175+
),
176+
)
177+
def test_parsing(case: tuple[list[bytes], list[AnyClientFrame | AnyServerFrame | HeartbeatFrame]]) -> None:
178+
stream_chunks, expected_frames = case
179+
parser = FrameParser()
180+
assert [frame for chunk in stream_chunks for frame in parser.parse_frames_from_chunk(chunk)] == expected_frames

tests/test_client.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,14 @@
2020
ConnectedFrame,
2121
ConnectFrame,
2222
ConnectionConfirmationTimeoutError,
23+
ConnectionParameters,
2324
DisconnectFrame,
25+
ErrorEvent,
2426
ErrorFrame,
2527
FailedAllConnectAttemptsError,
28+
HeartbeatEvent,
2629
HeartbeatFrame,
30+
MessageEvent,
2731
MessageFrame,
2832
NackFrame,
2933
ReceiptFrame,
@@ -32,7 +36,6 @@
3236
UnsubscribeFrame,
3337
UnsupportedProtocolVersionError,
3438
)
35-
from stompman.client import ConnectionParameters, ErrorEvent, HeartbeatEvent, MessageEvent
3639
from stompman.errors import ConnectionLostError
3740

3841
pytestmark = pytest.mark.anyio

tests/test_connection.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,16 @@
77

88
import pytest
99

10-
from stompman import AnyServerFrame, ConnectedFrame, Connection, ConnectionLostError, HeartbeatFrame
11-
from stompman.frames import BeginFrame, CommitFrame
10+
from stompman import (
11+
AnyServerFrame,
12+
BeginFrame,
13+
CommitFrame,
14+
ConnectedFrame,
15+
Connection,
16+
ConnectionLostError,
17+
HeartbeatFrame,
18+
)
19+
from stompman.serde import NEWLINE
1220

1321
pytestmark = pytest.mark.anyio
1422

@@ -95,7 +103,7 @@ async def take_frames(count: int) -> list[AnyServerFrame]:
95103
MockWriter.wait_closed.assert_called_once_with()
96104
MockWriter.drain.assert_called_once_with()
97105
MockReader.read.mock_calls = [mock.call(max_chunk_size)] * len(read_bytes) # type: ignore[assignment]
98-
assert MockWriter.write.mock_calls == [mock.call(b"\n"), mock.call(b"COMMIT\ntransaction:transaction\n\n\x00")]
106+
assert MockWriter.write.mock_calls == [mock.call(NEWLINE), mock.call(b"COMMIT\ntransaction:transaction\n\n\x00")]
99107

100108

101109
async def test_connection_close_connection_error(monkeypatch: pytest.MonkeyPatch) -> None:

tests/test_frame_serde.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
import pytest
22

33
from stompman import (
4+
AckFrame,
5+
AnyClientFrame,
6+
AnyServerFrame,
47
ConnectedFrame,
58
ConnectFrame,
69
ErrorFrame,
710
HeartbeatFrame,
811
MessageFrame,
912
)
10-
from stompman.frames import AckFrame, AnyClientFrame, AnyServerFrame
11-
from stompman.serde import FrameParser, dump_frame
13+
from stompman.serde import NEWLINE, FrameParser, dump_frame
1214

1315

1416
@pytest.mark.parametrize(
@@ -18,15 +20,12 @@
1820
(ConnectedFrame(headers={"version": "1.1"}), (b"CONNECTED\nversion:1.1\n\n\x00")),
1921
(
2022
MessageFrame(
21-
headers={"destination": "me:123", "message-id": "you\nmore\rextra\\here", "subscription": "hi"},
23+
headers={"destination": "me:123", "message-id": "you\nmoreextra\\here", "subscription": "hi"},
2224
body=b"I Am The Walrus",
2325
),
2426
(
25-
b"MESSAGE\n"
26-
b"destination:me\\c123\n"
27-
b"message-id:you\\nmore\\rextra\\\\here\nsubscription:hi\n\n"
28-
b"I Am The Walrus"
29-
b"\x00"
27+
b"MESSAGE\ndestination:me\\c123\nmessage-id:you\\nmoreextra\\\\here\nsubscription:hi\n\n"
28+
b"I Am The Walrus\x00"
3029
),
3130
),
3231
],
@@ -181,7 +180,7 @@ def test_dump_frame(frame: AnyClientFrame, dumped_frame: bytes) -> None:
181180
HeartbeatFrame(),
182181
],
183182
),
184-
(b"\n", [HeartbeatFrame()]),
183+
(NEWLINE, [HeartbeatFrame()]),
185184
# Two headers: only first should be accepted
186185
(
187186
b"CONNECTED\naccept-version:1.0\naccept-version:1.1\n\n\x00",

0 commit comments

Comments
 (0)