Skip to content

Commit 65735e1

Browse files
authored
Fix parsing stream chunks that do not contain frames (#23)
1 parent a1d778f commit 65735e1

File tree

3 files changed

+47
-31
lines changed

3 files changed

+47
-31
lines changed

Justfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ install:
55
uv -q sync
66

77
test *args:
8-
uv -q run pytest {{args}}
8+
uv -q run pytest -- {{args}}
99

1010
lint:
1111
uv -q run ruff check .

stompman/protocol.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,13 @@
3131
CARRIAGE_NEWLINE_CARRIAGE_NEWLINE = (CARRIAGE, NEWLINE, CARRIAGE, NEWLINE)
3232

3333

34+
def iter_bytes(bytes_: bytes) -> tuple[bytes, ...]:
35+
return struct.unpack(f"{len(bytes_)!s}c", bytes_)
36+
37+
38+
VALID_COMMANDS = [list(iter_bytes(command)) for command in COMMANDS_TO_FRAMES]
39+
40+
3441
def dump_header(key: str, value: str) -> bytes:
3542
escaped_key = "".join(ESCAPE_CHARS.get(char, char) for char in key)
3643
escaped_value = "".join(ESCAPE_CHARS.get(char, char) for char in value)
@@ -79,7 +86,7 @@ def parse_headers(buffer: list[bytes]) -> tuple[str, str] | None:
7986
return (b"".join(key_buffer).decode(), b"".join(value_buffer).decode()) if key_parsed else None
8087

8188

82-
def parse_lines_into_frame(lines: deque[list[bytes]]) -> AnyClientFrame | AnyServerFrame | None:
89+
def parse_lines_into_frame(lines: deque[list[bytes]]) -> AnyClientFrame | AnyServerFrame:
8390
command = b"".join(lines.popleft())
8491
headers = {}
8592

@@ -89,9 +96,7 @@ def parse_lines_into_frame(lines: deque[list[bytes]]) -> AnyClientFrame | AnySer
8996
headers[header[0]] = header[1]
9097
body = b"".join(lines.popleft()) if lines else b""
9198

92-
if known_frame_type := COMMANDS_TO_FRAMES.get(command):
93-
return known_frame_type(headers=cast(Any, headers), body=body)
94-
return None
99+
return COMMANDS_TO_FRAMES[command](headers=cast(Any, headers), body=body)
95100

96101

97102
@dataclass
@@ -101,29 +106,33 @@ class Parser:
101106
_previous_byte: bytes = field(default=b"", init=False)
102107
_headers_processed: bool = field(default=False, init=False)
103108

109+
def _reset(self) -> None:
110+
self._headers_processed = False
111+
self._lines.clear()
112+
self._current_line = []
113+
104114
def load_frames(self, raw_frames: bytes) -> Iterator[AnyClientFrame | AnyServerFrame | HeartbeatFrame]:
105-
buffer = deque(struct.unpack(f"{len(raw_frames)!s}c", raw_frames))
115+
buffer = deque(iter_bytes(raw_frames))
106116
while buffer:
107117
byte = buffer.popleft()
108118

109-
if self._headers_processed and byte == NULL:
110-
self._lines.append(self._current_line)
111-
if parsed_frame := parse_lines_into_frame(self._lines):
112-
yield parsed_frame
113-
self._headers_processed = False
114-
self._lines.clear()
115-
self._current_line = []
119+
if byte == NULL:
120+
if self._headers_processed:
121+
self._lines.append(self._current_line)
122+
yield parse_lines_into_frame(self._lines)
123+
self._reset()
116124

117125
elif not self._headers_processed and byte == NEWLINE:
118126
if self._current_line or self._lines:
119-
if not self._current_line: # extra empty line after headers
120-
self._headers_processed = True
121-
122127
if self._previous_byte == b"\r":
123128
self._current_line.pop()
129+
self._headers_processed = not self._current_line # extra empty line after headers
124130

125-
self._lines.append(self._current_line)
126-
self._current_line = []
131+
if not self._lines and self._current_line not in VALID_COMMANDS:
132+
self._reset()
133+
else:
134+
self._lines.append(self._current_line)
135+
self._current_line = []
127136
else:
128137
yield HeartbeatFrame()
129138

tests/test_protocol.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -187,19 +187,16 @@ def test_dump_frame(frame: AnyClientFrame, dumped_frame: bytes) -> None:
187187
HeartbeatFrame(),
188188
],
189189
),
190-
(
191-
b"\n",
192-
[HeartbeatFrame()],
193-
),
190+
(b"\n", [HeartbeatFrame()]),
194191
# Two headers: only first should be accepted
195192
(
196193
b"CONNECTED\naccept-version:1.0\naccept-version:1.1\n\n\x00",
197194
[ConnectedFrame(headers={"accept-version": "1.0"})],
198195
),
199196
# no end of line after command
200-
(b"SOME_COMMAND", []),
201-
(b"SOME_COMMAND\n", []),
202-
(b"SOME_COMMAND\x00", []),
197+
(b"CONNECTED", []),
198+
(b"CONNECTED\n", []),
199+
(b"CONNECTED\x00", []),
203200
# \r\n after command
204201
(b"CONNECTED\r\n\n\n\x00", [ConnectedFrame(headers={}, body=b"\n")]),
205202
(b"CONNECTED\r\nheader:1.0\n\n\x00", [ConnectedFrame(headers={"header": "1.0"})]),
@@ -213,14 +210,24 @@ def test_dump_frame(frame: AnyClientFrame, dumped_frame: bytes) -> None:
213210
# header value with :
214211
(b"CONNECTED\nheader:what:?\n\n\x00", [ConnectedFrame(headers={})]),
215212
# no NULL
216-
(b"SOME_COMMAND\nheader:what:?\n\nhello", []),
213+
(b"CONNECTED\nheader:what:?\n\nhello", []),
217214
# header never end
218-
(b"SOME_COMMAND\nheader:hello", []),
219-
(b"SOME_COMMAND\nheader:hello\n", []),
220-
(b"SOME_COMMAND\nheader:hello\n\x00", []),
221-
(b"SOME_COMMAND\nn", []),
215+
(b"CONNECTED\nheader:hello", []),
216+
(b"CONNECTED\nheader:hello\n", []),
217+
(b"CONNECTED\nheader:hello\n\x00", []),
218+
(b"CONNECTED\nn", []),
219+
# unknown command
220+
(b"SOME_COMMAND\nhead:\nheader:1.1\n\n\x00", [HeartbeatFrame()]),
222221
# unknown command
223-
(b"SOME_COMMAND\nhead:\nheader:1.1\n\n\x00", []),
222+
(
223+
b"whatever\nWHATEVER\nheader:1.1\n\n\x00CONNECTED\nheader:1.1\n\n\x00\nwhatever\nCONNECTED\nheader:1.2\n\n\x00",
224+
[
225+
HeartbeatFrame(),
226+
ConnectedFrame(headers={"header": "1.1"}, body=b""),
227+
HeartbeatFrame(),
228+
ConnectedFrame(headers={"header": "1.2"}, body=b""),
229+
],
230+
),
224231
],
225232
)
226233
def test_load_frames(raw_frames: bytes, loaded_frames: list[AnyServerFrame]) -> None:

0 commit comments

Comments
 (0)