diff --git a/packages/stompman/stompman/serde.py b/packages/stompman/stompman/serde.py index fbfaa43..4491a05 100644 --- a/packages/stompman/stompman/serde.py +++ b/packages/stompman/stompman/serde.py @@ -1,8 +1,7 @@ import struct -from collections import deque from collections.abc import Iterator from contextlib import suppress -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Any, Final, cast from stompman.frames import ( @@ -141,53 +140,83 @@ def make_frame_from_parts(*, command: bytes, headers: dict[str, str], body: byte return frame_type(headers=headers_, body=body) if frame_type in FRAMES_WITH_BODY else frame_type(headers=headers_) # type: ignore[call-arg] -def parse_lines_into_frame(lines: deque[bytearray]) -> AnyClientFrame | AnyServerFrame: - command = bytes(lines.popleft()) - headers = {} - - while line := lines.popleft(): - header = parse_header(line) - if header and header[0] not in headers: - headers[header[0]] = header[1] - body = bytes(lines.popleft()) if lines else b"" - return make_frame_from_parts(command=command, headers=headers, body=body) - - -@dataclass(kw_only=True, slots=True) +@dataclass(kw_only=True, slots=True, init=False) class FrameParser: - _lines: deque[bytearray] = field(default_factory=deque, init=False) - _current_line: bytearray = field(default_factory=bytearray, init=False) - _previous_byte: bytes = field(default=b"", init=False) - _headers_processed: bool = field(default=False, init=False) + _current_buf: bytearray + _previous_byte: bytes | None + _headers_processed: bool + _command: bytes | None + _headers: dict[str, str] + _content_length: int | None + + def __init__(self) -> None: + self._previous_byte = None + self._reset() def _reset(self) -> None: + self._current_buf = bytearray() self._headers_processed = False - self._lines.clear() - self._current_line = bytearray() + self._command = None + self._headers = {} + self._content_length = None + + def _handle_null_byte(self) -> Iterator[AnyClientFrame | AnyServerFrame]: + if not self._command or not self._headers_processed: + self._reset() + return + if self._content_length is not None and self._content_length != len(self._current_buf): + self._current_buf += NULL + return + yield make_frame_from_parts(command=self._command, headers=self._headers, body=bytes(self._current_buf)) + self._reset() + + def _handle_newline_byte(self) -> Iterator[HeartbeatFrame]: + if not self._current_buf and not self._command: + yield HeartbeatFrame() + return + if self._previous_byte == CARRIAGE: + self._current_buf.pop() + self._headers_processed = not self._current_buf # extra empty line after headers + + if self._command: + self._process_header() + else: + self._process_command() + + def _process_command(self) -> None: + current_buf_bytes = bytes(self._current_buf) + if current_buf_bytes not in COMMANDS_TO_FRAMES: + self._reset() + else: + self._command = current_buf_bytes + self._current_buf = bytearray() + + def _process_header(self) -> None: + header = parse_header(self._current_buf) + if not header: + self._current_buf = bytearray() + return + header_key, header_value = header + if header_key not in self._headers: + self._headers[header_key] = header_value + if header_key.lower() == "content-length": + with suppress(ValueError): + self._content_length = int(header_value) + self._current_buf = bytearray() + + def _handle_body_byte(self, byte: bytes) -> None: + if self._content_length is None or self._content_length != len(self._current_buf): + self._current_buf += byte def parse_frames_from_chunk(self, chunk: bytes) -> Iterator[AnyClientFrame | AnyServerFrame]: for byte in iter_bytes(chunk): if byte == NULL: - if self._headers_processed: - self._lines.append(self._current_line) - yield parse_lines_into_frame(self._lines) - self._reset() - - elif not self._headers_processed and byte == NEWLINE: - if self._current_line or self._lines: - if self._previous_byte == CARRIAGE: - self._current_line.pop() - self._headers_processed = not self._current_line # extra empty line after headers - - if not self._lines and bytes(self._current_line) not in COMMANDS_TO_FRAMES: - self._reset() - else: - self._lines.append(self._current_line) - self._current_line = bytearray() - else: - yield HeartbeatFrame() - + yield from self._handle_null_byte() + elif self._headers_processed: + self._handle_body_byte(byte) + elif byte == NEWLINE: + yield from self._handle_newline_byte() else: - self._current_line += byte + self._current_buf += byte self._previous_byte = byte diff --git a/packages/stompman/test_stompman/test_frame_serde.py b/packages/stompman/test_stompman/test_frame_serde.py index b9a8652..392d993 100644 --- a/packages/stompman/test_stompman/test_frame_serde.py +++ b/packages/stompman/test_stompman/test_frame_serde.py @@ -222,6 +222,26 @@ def test_dump_frame(frame: AnyClientFrame, dumped_frame: bytes) -> None: ConnectedFrame(headers={"header": "1.2"}), ], ), + # Correct content-length with body containing NULL byte + ( + b"MESSAGE\ncontent-length:5\n\nBod\x00y\x00", + [MessageFrame(headers={"content-length": "5"}, body=b"Bod\x00y")], + ), + # Content-length shorter than actual body (should only read up to content-length) + ( + b"MESSAGE\ncontent-length:4\n\nBody\x00 with extra\x00\n", + [MessageFrame(headers={"content-length": "4"}, body=b"Body"), HeartbeatFrame()], + ), + # Content-length longer than actual body (should wait for more data) + ( + b"MESSAGE\ncontent-length:10\n\nShort", + [], + ), + # Content-length longer than actual body, then more data comes with NULL terminator + ( + b"MESSAGE\ncontent-length:10\n\nShortMOREDATA\x00", + [MessageFrame(headers={"content-length": "10"}, body=b"ShortMORED")], + ), ], ) def test_load_frames(raw_frames: bytes, loaded_frames: list[AnyServerFrame]) -> None: