Skip to content

Commit 6df3de3

Browse files
authored
Refactor ser-de (#32)
1 parent a1c8ed9 commit 6df3de3

File tree

2 files changed

+30
-28
lines changed

2 files changed

+30
-28
lines changed

stompman/serde.py

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,19 @@
2929
NEWLINE: Final = b"\n"
3030
CARRIAGE: Final = b"\r"
3131
NULL: Final = b"\x00"
32+
BACKSLASH = b"\\"
33+
COLON_ = b":"
34+
3235
HEADER_ESCAPE_CHARS: Final = {
33-
"\n": "\\n",
34-
":": "\\c",
35-
"\\": "\\\\",
36-
"\r": "", # [\r]\n is newline, therefore can't be used in header
36+
NEWLINE.decode(): "\\n",
37+
COLON_.decode(): "\\c",
38+
BACKSLASH.decode(): "\\\\",
39+
CARRIAGE.decode(): "", # [\r]\n is newline, therefore can't be used in header
3740
}
3841
HEADER_UNESCAPE_CHARS: Final = {
3942
b"n": NEWLINE,
40-
b"c": b":",
41-
b"\\": b"\\",
43+
b"c": COLON_,
44+
BACKSLASH: BACKSLASH,
4245
}
4346

4447

@@ -89,39 +92,39 @@ def dump_frame(frame: AnyClientFrame | AnyServerFrame) -> bytes:
8992

9093

9194
def unescape_byte(byte: bytes, previous_byte: bytes | None) -> bytes | None:
92-
if previous_byte == b"\\":
95+
if previous_byte == BACKSLASH:
9396
return HEADER_UNESCAPE_CHARS.get(byte)
94-
if byte == b"\\":
97+
if byte == BACKSLASH:
9598
return None
9699
return byte
97100

98101

99-
def parse_header(buffer: list[bytes]) -> tuple[str, str] | None:
100-
key_buffer: list[bytes] = []
102+
def parse_header(buffer: bytearray) -> tuple[str, str] | None:
103+
key_buffer = bytearray()
104+
value_buffer = bytearray()
101105
key_parsed = False
102-
value_buffer: list[bytes] = []
103106

104107
previous_byte = None
105108
just_escaped_line = False
106109

107-
for byte in buffer:
108-
if byte == b":":
110+
for byte in iter_bytes(buffer):
111+
if byte == COLON_:
109112
if key_parsed:
110113
return None
111114
key_parsed = True
112115
elif just_escaped_line:
113116
just_escaped_line = False
114-
if byte != b"\\":
115-
(value_buffer if key_parsed else key_buffer).append(byte)
117+
if byte != BACKSLASH:
118+
(value_buffer if key_parsed else key_buffer).extend(byte)
116119
elif unescaped_byte := unescape_byte(byte, previous_byte):
117120
just_escaped_line = True
118-
(value_buffer if key_parsed else key_buffer).append(unescaped_byte)
121+
(value_buffer if key_parsed else key_buffer).extend(unescaped_byte)
119122

120123
previous_byte = byte
121124

122125
if key_parsed:
123126
with suppress(UnicodeDecodeError):
124-
return (b"".join(key_buffer).decode(), b"".join(value_buffer).decode())
127+
return key_buffer.decode(), value_buffer.decode()
125128

126129
return None
127130

@@ -135,29 +138,29 @@ def make_frame_from_parts(command: bytes, headers: dict[str, str], body: bytes)
135138
)
136139

137140

138-
def parse_lines_into_frame(lines: deque[list[bytes]]) -> AnyClientFrame | AnyServerFrame:
139-
command = b"".join(lines.popleft())
141+
def parse_lines_into_frame(lines: deque[bytearray]) -> AnyClientFrame | AnyServerFrame:
142+
command = bytes(lines.popleft())
140143
headers = {}
141144

142145
while line := lines.popleft():
143146
header = parse_header(line)
144147
if header and header[0] not in headers:
145148
headers[header[0]] = header[1]
146-
body = b"".join(lines.popleft()) if lines else b""
149+
body = lines.popleft() if lines else b""
147150
return make_frame_from_parts(command=command, headers=headers, body=body)
148151

149152

150153
@dataclass
151154
class FrameParser:
152-
_lines: deque[list[bytes]] = field(default_factory=deque, init=False)
153-
_current_line: list[bytes] = field(default_factory=list, init=False)
155+
_lines: deque[bytearray] = field(default_factory=deque, init=False)
156+
_current_line: bytearray = field(default_factory=bytearray, init=False)
154157
_previous_byte: bytes = field(default=b"", init=False)
155158
_headers_processed: bool = field(default=False, init=False)
156159

157160
def _reset(self) -> None:
158161
self._headers_processed = False
159162
self._lines.clear()
160-
self._current_line = []
163+
self._current_line = bytearray()
161164

162165
def parse_frames_from_chunk(self, chunk: bytes) -> Iterator[AnyClientFrame | AnyServerFrame | HeartbeatFrame]:
163166
for byte in iter_bytes(chunk):
@@ -173,15 +176,15 @@ def parse_frames_from_chunk(self, chunk: bytes) -> Iterator[AnyClientFrame | Any
173176
self._current_line.pop()
174177
self._headers_processed = not self._current_line # extra empty line after headers
175178

176-
if not self._lines and self._current_line not in COMMANDS_BYTES_LISTS:
179+
if not self._lines and bytes(self._current_line) not in COMMANDS_TO_FRAMES:
177180
self._reset()
178181
else:
179182
self._lines.append(self._current_line)
180-
self._current_line = []
183+
self._current_line = bytearray()
181184
else:
182185
yield HeartbeatFrame()
183186

184187
else:
185-
self._current_line.append(byte)
188+
self._current_line += byte
186189

187190
self._previous_byte = byte

tests/integration.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
FrameParser,
1818
dump_frame,
1919
dump_header,
20-
iter_bytes,
2120
make_frame_from_parts,
2221
parse_header,
2322
)
@@ -157,7 +156,7 @@ def bytes_not_contains(*avoided: bytes) -> Callable[[bytes], bool]:
157156
lambda headers: dict(
158157
parsed_header
159158
for header in starmap(dump_header, headers.items())
160-
if (parsed_header := parse_header(list(iter_bytes(header))))
159+
if (parsed_header := parse_header(bytearray(header)))
161160
)
162161
)
163162
frame_strategy = strategies.just(HeartbeatFrame()) | strategies.builds(

0 commit comments

Comments
 (0)