From c557bd9ae88eacac764349bbb3c16c08552d2bf9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Fri, 1 Sep 2023 10:40:35 +0000 Subject: [PATCH 01/14] Add resp encoder class to tests --- tests/resp.py | 93 ++++++++++++++++++++++++++++++++++++++++++++ tests/test_resp.py | 97 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 190 insertions(+) create mode 100644 tests/resp.py create mode 100644 tests/test_resp.py diff --git a/tests/resp.py b/tests/resp.py new file mode 100644 index 0000000000..d28cd7a6fe --- /dev/null +++ b/tests/resp.py @@ -0,0 +1,93 @@ +import itertools +from types import NoneType +from typing import Any, Optional + + +class RespEncoder: + """ + A class for simple RESP protocol encodign for unit tests + """ + + def __init__(self, protocol: int = 2, encoding: str = "utf-8") -> None: + self.protocol = protocol + self.encoding = encoding + + def encode(self, data: Any, hint: Optional[str] = None) -> bytes: + if isinstance(data, dict): + if self.protocol > 2: + result = f"%{len(data)}\r\n".encode() + for key, val in data.items(): + result += self.encode(key) + self.encode(val) + return result + else: + # Automatically encode dicts as flattened key, value arrays + mylist = list( + itertools.chain(*((key, val) for (key, val) in data.items())) + ) + return self.encode(mylist) + + elif isinstance(data, list): + result = f"*{len(data)}\r\n".encode() + for val in data: + result += self.encode(val) + return result + + elif isinstance(data, set): + if self.protocol > 2: + result = f"~{len(data)}\r\n".encode() + for val in data: + result += self.encode(val) + return result + else: + return self.encode(list(data)) + + elif isinstance(data, str): + enc = data.encode(self.encoding) + # long strings or strings with control characters must be encoded as bulk + # strings + if hint or len(enc) > 20 or b"\r" in enc or b"\n" in enc: + return self.encode_bulkstr(enc, hint) + return b"+" + enc + b"\r\n" + + elif isinstance(data, bytes): + return self.encode_bulkstr(data, hint) + + elif isinstance(data, bool): + if self.protocol == 2: + return b":1\r\n" if data else b":0\r\n" + else: + return b"t\r\n" if data else b"f\r\n" + + elif isinstance(data, int): + if (data > 2**63 - 1) or (data < -(2**63)): + if self.protocol > 2: + return f"({data}\r\n".encode() # resp3 big int + else: + return f"+{data}\r\n".encode() # force to simple string + return f":{data}\r\n".encode() + elif isinstance(data, float): + if self.protocol > 2: + return f",{data}\r\n".encode() # resp3 double + else: + return f"+{data}\r\n".encode() # simple string + + elif isinstance(data, NoneType): + if self.protocol > 2: + return b"_\r\n" # resp3 null + else: + return b"$-1\r\n" # Null bulk string + # some commands return null array: b"*-1\r\n" + + else: + raise NotImplementedError + + def encode_bulkstr(self, bstr: bytes, hint: Optional[str]) -> bytes: + if self.protocol > 2 and hint is not None: + # a resp3 verbatim string + return f"={len(bstr)}\r\n{hint}:".encode() + bstr + b"\r\n" + else: + return f"${len(bstr)}\r\n".encode() + bstr + b"\r\n" + + +def encode(value: Any, protocol: int = 2, hint: Optional[str] = None) -> bytes: + return RespEncoder(protocol).encode(value, hint) diff --git a/tests/test_resp.py b/tests/test_resp.py new file mode 100644 index 0000000000..3aa5028c53 --- /dev/null +++ b/tests/test_resp.py @@ -0,0 +1,97 @@ +from .resp import encode + +import pytest + + +@pytest.fixture(params=[2, 3]) +def resp_version(request): + return request.param + + +class TestEncoder: + def test_simple_str(self): + assert encode("foo") == b"+foo\r\n" + + def test_long_str(self): + text = "fooling around with the sword in the mud" + assert len(text) == 40 + assert encode(text) == b"$40\r\n" + text.encode() + b"\r\n" + + # test strings with control characters + def test_str_with_ctrl_chars(self): + text = "foo\r\nbar" + assert encode(text) == b"$8\r\nfoo\r\nbar\r\n" + + def test_bytes(self): + assert encode(b"foo") == b"$3\r\nfoo\r\n" + + def test_int(self): + assert encode(123) == b":123\r\n" + + def test_float(self, resp_version): + data = encode(1.23, protocol=resp_version) + if resp_version == 2: + assert data == b"+1.23\r\n" + else: + assert data == b",1.23\r\n" + + def test_large_int(self, resp_version): + data = encode(2**63, protocol=resp_version) + if resp_version == 2: + assert data == b"+9223372036854775808\r\n" + else: + assert data == b"(9223372036854775808\r\n" + + def test_array(self): + assert encode([1, 2, 3]) == b"*3\r\n:1\r\n:2\r\n:3\r\n" + + def test_set(self, resp_version): + data = encode({1, 2, 3}, protocol=resp_version) + if resp_version == 2: + assert data == b"*3\r\n:1\r\n:2\r\n:3\r\n" + else: + assert data == b"~3\r\n:1\r\n:2\r\n:3\r\n" + + def test_map(self, resp_version): + data = encode({1: 2, 3: 4}, protocol=resp_version) + if resp_version == 2: + assert data == b"*4\r\n:1\r\n:2\r\n:3\r\n:4\r\n" + else: + assert data == b"%2\r\n:1\r\n:2\r\n:3\r\n:4\r\n" + + def test_nested_array(self): + assert encode([1, [2, 3]]) == b"*2\r\n:1\r\n*2\r\n:2\r\n:3\r\n" + + def test_nested_map(self, resp_version): + data = encode({1: {2: 3}}, protocol=resp_version) + if resp_version == 2: + assert data == b"*2\r\n:1\r\n*2\r\n:2\r\n:3\r\n" + else: + assert data == b"%1\r\n:1\r\n%1\r\n:2\r\n:3\r\n" + + def test_null(self, resp_version): + data = encode(None, protocol=resp_version) + if resp_version == 2: + assert data == b"$-1\r\n" + else: + assert data == b"_\r\n" + + def test_mixed_array(self, resp_version): + data = encode([1, "foo", 2.3, None, True], protocol=resp_version) + if resp_version == 2: + assert data == b"*5\r\n:1\r\n+foo\r\n+2.3\r\n$-1\r\n:1\r\n" + else: + assert data == b"*5\r\n:1\r\n+foo\r\n,2.3\r\n_\r\nt\r\n" + + def test_bool(self, resp_version): + data = encode(True, protocol=resp_version) + if resp_version == 2: + assert data == b":1\r\n" + else: + assert data == b"t\r\n" + + data = encode(False, resp_version) + if resp_version == 2: + assert data == b":0\r\n" + else: + assert data == b"f\r\n" From 9f4e05f02ff73bf5136b0e1bce98e62fbb3707b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Fri, 1 Sep 2023 17:00:27 +0000 Subject: [PATCH 02/14] Add resp parser, tests --- tests/resp.py | 264 ++++++++++++++++++++++++++++++++++++++++++--- tests/test_resp.py | 103 +++++++++++++++++- 2 files changed, 350 insertions(+), 17 deletions(-) diff --git a/tests/resp.py b/tests/resp.py index d28cd7a6fe..57b792ea62 100644 --- a/tests/resp.py +++ b/tests/resp.py @@ -1,11 +1,38 @@ import itertools +from contextlib import closing from types import NoneType -from typing import Any, Optional +from typing import Any, Generator, List, Optional, Tuple, Union + +CRNL = b"\r\n" + + +class VerbatimString(bytes): + """ + A string that is encoded as a resp3 verbatim string + """ + + def __new__(cls, value: bytes, hint: str) -> "VerbatimString": + return bytes.__new__(cls, value) + + def __init__(self, value: bytes, hint: str) -> None: + self.hint = hint + + def __repr__(self) -> str: + return f"VerbatimString({super().__repr__()}, {self.hint!r})" + + +class PushData(list): + """ + A special type of list indicating data from a push response + """ + + def __repr__(self) -> str: + return f"PushData({super().__repr__()})" class RespEncoder: """ - A class for simple RESP protocol encodign for unit tests + A class for simple RESP protocol encoding for unit tests """ def __init__(self, protocol: int = 2, encoding: str = "utf-8") -> None: @@ -27,7 +54,10 @@ def encode(self, data: Any, hint: Optional[str] = None) -> bytes: return self.encode(mylist) elif isinstance(data, list): - result = f"*{len(data)}\r\n".encode() + if isinstance(data, PushData) and self.protocol > 2: + result = f">{len(data)}\r\n".encode() + else: + result = f"*{len(data)}\r\n".encode() for val in data: result += self.encode(val) return result @@ -55,39 +85,243 @@ def encode(self, data: Any, hint: Optional[str] = None) -> bytes: elif isinstance(data, bool): if self.protocol == 2: return b":1\r\n" if data else b":0\r\n" - else: - return b"t\r\n" if data else b"f\r\n" + return b"t\r\n" if data else b"f\r\n" elif isinstance(data, int): if (data > 2**63 - 1) or (data < -(2**63)): if self.protocol > 2: return f"({data}\r\n".encode() # resp3 big int - else: - return f"+{data}\r\n".encode() # force to simple string + return f"+{data}\r\n".encode() # force to simple string return f":{data}\r\n".encode() elif isinstance(data, float): if self.protocol > 2: return f",{data}\r\n".encode() # resp3 double - else: - return f"+{data}\r\n".encode() # simple string + return f"+{data}\r\n".encode() # simple string elif isinstance(data, NoneType): if self.protocol > 2: return b"_\r\n" # resp3 null - else: - return b"$-1\r\n" # Null bulk string - # some commands return null array: b"*-1\r\n" + return b"$-1\r\n" # Null bulk string + # some commands return null array: b"*-1\r\n" else: - raise NotImplementedError + raise NotImplementedError(f"encode not implemented for {type(data)}") def encode_bulkstr(self, bstr: bytes, hint: Optional[str]) -> bytes: if self.protocol > 2 and hint is not None: # a resp3 verbatim string return f"={len(bstr)}\r\n{hint}:".encode() + bstr + b"\r\n" - else: - return f"${len(bstr)}\r\n".encode() + bstr + b"\r\n" + # regular bulk string + return f"${len(bstr)}\r\n".encode() + bstr + b"\r\n" def encode(value: Any, protocol: int = 2, hint: Optional[str] = None) -> bytes: + """ + Encode a value using the RESP protocol + """ return RespEncoder(protocol).encode(value, hint) + + +# a stateful RESP parser implemented via a generator +def resp_parse( + buffer: bytes, +) -> Generator[Optional[Tuple[Any, bytes]], Union[None, bytes], None]: + """ + A stateful, generator based, RESP parser. + Returns a generator producing at most a single top-level primitive. + Yields tuple of (data_item, unparsed), or None if more data is needed. + It is fed more data with generator.send() + """ + # Read the first line of resp or yield to get more data + while CRNL not in buffer: + incoming = yield None + assert incoming is not None + buffer += incoming + cmd, rest = buffer.split(CRNL, 1) + + code, arg = cmd[:1], cmd[1:] + + if code == b":" or code == b"(": # integer, resp3 large int + yield int(arg), rest + + elif code == b"t": # resp3 true + yield True, rest + + elif code == b"f": # resp3 false + yield False, rest + + elif code == b"_": # resp3 null + yield None, rest + + elif code == b",": # resp3 double + yield float(arg), rest + + elif code == b"+": # simple string + # we decode them automatically + yield arg.decode(), rest + + elif code == b"$": # bulk string + count = int(arg) + expect = count + 2 # +2 for the trailing CRNL + while len(rest) < expect: + incoming = yield (None) + assert incoming is not None + rest += incoming + bulkstr = rest[:count] + # bulk strings are not decoded, could contain binary data + yield bulkstr, rest[expect:] + + elif code == b"=": # verbatim strings + count = int(arg) + expect = count + 4 + 2 # 4 type and colon +2 for the trailing CRNL + while len(rest) < expect: + incoming = yield (None) + assert incoming is not None + rest += incoming + hint = rest[:3] + result = rest[4 : (count + 4)] + # verbatim strings are not decoded, could contain binary data + yield VerbatimString(result, hint.decode()), rest[expect:] + + elif code in b"*>": # array or push data + count = int(arg) + result_array = [] + for _ in range(count): + # recursively parse the next array item + with closing(resp_parse(rest)) as parser: + parsed = parser.send(None) + while parsed is None: + incoming = yield None + parsed = parser.send(incoming) + value, rest = parsed + result_array.append(value) + if code == b">": + yield PushData(result_array), rest + else: + yield result_array, rest + + elif code == b"~": # set + count = int(arg) + result_set = set() + for _ in range(count): + # recursively parse the next set item + with closing(resp_parse(rest)) as parser: + parsed = parser.send(None) + while parsed is None: + incoming = yield None + parsed = parser.send(incoming) + value, rest = parsed + result_set.add(value) + yield result_set, rest + + elif code == b"%": # map + count = int(arg) + result_map = {} + for _ in range(count): + # recursively parse the next key, and value + with closing(resp_parse(rest)) as parser: + parsed = parser.send(None) + while parsed is None: + incoming = yield None + parsed = parser.send(incoming) + key, rest = parsed + with closing(resp_parse(rest)) as parser: + parsed = parser.send(None) + while parsed is None: + incoming = yield None + parsed = parser.send(incoming) + value, rest = parsed + result_map[key] = value + yield result_map, rest + else: + if code in b"-!": + raise NotImplementedError(f"resp opcode '{code.decode()}' not implemented") + raise ValueError(f"Unknown opcode '{code.decode()}'") + + +class NeedMoreData(RuntimeError): + """ + Raised when more data is needed to complete a parse + """ + + +class RespParser: + """ + A class for simple RESP protocol decoding for unit tests + """ + + def __init__(self) -> None: + self.parser: Optional[ + Generator[Optional[Tuple[Any, bytes]], Union[None, bytes], None] + ] = None + # which has not resulted in a parsed value + self.consumed: List[bytes] = [] + + def parse(self, buffer: bytes) -> Optional[Any]: + """ + Parse a buffer of data, return a tuple of a single top-level primitive and the + remaining buffer or raise NeedMoreData if more data is needed + """ + if self.parser is None: + # create a new parser generator, initializing it with + # any unparsed data from previous calls + buffer = b"".join(self.consumed) + buffer + del self.consumed[:] + self.parser = resp_parse(buffer) + parsed = self.parser.send(None) + else: + # sen more data to the parser + parsed = self.parser.send(buffer) + + if parsed is None: + self.consumed.append(buffer) + raise NeedMoreData() + + # got a value, close the parser, store the remaining buffer + self.parser.close() + self.parser = None + value, remaining = parsed + self.consumed = [remaining] + return value + + def get_unparsed(self) -> bytes: + return b"".join(self.consumed) + + def close(self) -> None: + if self.parser is not None: + self.parser.close() + self.parser = None + del self.consumed[:] + + +def parse_all(buffer: bytes) -> Tuple[List[Any], bytes]: + """ + Parse all the data in the buffer, returning the list of top-level objects and the + remaining buffer + """ + with closing(RespParser()) as parser: + result: List[Any] = [] + while True: + try: + result.append(parser.parse(buffer)) + buffer = b"" + except NeedMoreData: + return result, parser.get_unparsed() + + +def parse_chunks(buffers: List[bytes]) -> Tuple[List[Any], bytes]: + """ + Parse all the data in the buffers, returning the list of top-level objects and the + remaining buffer. + Used primarily for testing, since it will parse the data in chunks + """ + result: List[Any] = [] + with closing(RespParser()) as parser: + for buffer in buffers: + while True: + try: + result.append(parser.parse(buffer)) + buffer = b"" + except NeedMoreData: + break + return result, parser.get_unparsed() diff --git a/tests/test_resp.py b/tests/test_resp.py index 3aa5028c53..5b28dae961 100644 --- a/tests/test_resp.py +++ b/tests/test_resp.py @@ -1,7 +1,7 @@ -from .resp import encode - import pytest +from .resp import PushData, VerbatimString, encode, parse_all, parse_chunks + @pytest.fixture(params=[2, 3]) def resp_version(request): @@ -45,6 +45,13 @@ def test_large_int(self, resp_version): def test_array(self): assert encode([1, 2, 3]) == b"*3\r\n:1\r\n:2\r\n:3\r\n" + def test_push_data(self, resp_version): + data = encode(PushData([1, 2, 3]), protocol=resp_version) + if resp_version == 2: + assert data == b"*3\r\n:1\r\n:2\r\n:3\r\n" + else: + assert data == b">3\r\n:1\r\n:2\r\n:3\r\n" + def test_set(self, resp_version): data = encode({1, 2, 3}, protocol=resp_version) if resp_version == 2: @@ -95,3 +102,95 @@ def test_bool(self, resp_version): assert data == b":0\r\n" else: assert data == b"f\r\n" + + +@pytest.mark.parametrize("chunk_size", [0, 1, 2, -2]) +class TestParser: + def breakup_bytes(self, data, chunk_size=2): + insert_empty = False + if chunk_size < 0: + insert_empty = True + chunk_size = -chunk_size + chunks = [data[i : i + chunk_size] for i in range(0, len(data), chunk_size)] + if insert_empty: + empty = len(chunks) * [b""] + chunks = [item for pair in zip(chunks, empty) for item in pair] + return chunks + + def parse_data(self, chunk_size, data): + """helper to parse either a single blob, or a list of chunks""" + if chunk_size == 0: + return parse_all(data) + else: + return parse_chunks(self.breakup_bytes(data, chunk_size)) + + def test_int(self, chunk_size): + parsed = self.parse_data(chunk_size, b":123\r\n") + assert parsed == ([123], b"") + + parsed = self.parse_data(chunk_size, b":123\r\nfoo") + assert parsed == ([123], b"foo") + + def test_double(self, chunk_size): + parsed = self.parse_data(chunk_size, b",1.23\r\njunk") + assert parsed == ([1.23], b"junk") + + def test_array(self, chunk_size): + parsed = self.parse_data(chunk_size, b"*3\r\n:1\r\n:2\r\n:3\r\n") + assert parsed == ([[1, 2, 3]], b"") + + parsed = self.parse_data(chunk_size, b"*3\r\n:1\r\n:2\r\n:3\r\nfoo") + assert parsed == ([[1, 2, 3]], b"foo") + + def test_push_data(self, chunk_size): + parsed = self.parse_data(chunk_size, b">3\r\n:1\r\n:2\r\n:3\r\n") + assert isinstance(parsed[0][0], PushData) + assert parsed == ([[1, 2, 3]], b"") + + def test_incomplete_list(self, chunk_size): + parsed = self.parse_data(chunk_size, b"*3\r\n:1\r\n:2\r\n") + assert parsed == ([], b"*3\r\n:1\r\n:2\r\n") + + def test_invalid_token(self, chunk_size): + with pytest.raises(ValueError): + self.parse_data(chunk_size, b")foo\r\n") + with pytest.raises(NotImplementedError): + self.parse_data(chunk_size, b"!foo\r\n") + + def test_multiple_ints(self, chunk_size): + parsed = self.parse_data(chunk_size, b":1\r\n:2\r\n:3\r\n") + assert parsed == ([1, 2, 3], b"") + + def test_multiple_ints_and_junk(self, chunk_size): + parsed = self.parse_data(chunk_size, b":1\r\n:2\r\n:3\r\n*3\r\n:1\r\n:2\r\n") + assert parsed == ([1, 2, 3], b"*3\r\n:1\r\n:2\r\n") + + def test_set(self, chunk_size): + parsed = self.parse_data(chunk_size, b"~3\r\n:1\r\n:2\r\n:3\r\n") + assert parsed == ([{1, 2, 3}], b"") + + def test_list_of_sets(self, chunk_size): + parsed = self.parse_data( + chunk_size, b"*2\r\n~3\r\n:1\r\n:2\r\n:3\r\n~2\r\n:4\r\n:5\r\n" + ) + assert parsed == ([[{1, 2, 3}, {4, 5}]], b"") + + def test_map(self, chunk_size): + parsed = self.parse_data(chunk_size, b"%2\r\n:1\r\n:2\r\n:3\r\n:4\r\n") + assert parsed == ([{1: 2, 3: 4}], b"") + + def test_simple_string(self, chunk_size): + parsed = self.parse_data(chunk_size, b"+foo\r\n") + assert parsed == (["foo"], b"") + + def test_bulk_string(self, chunk_size): + parsed = parse_all(b"$3\r\nfoo\r\nbar") + assert parsed == ([b"foo"], b"bar") + + def test_bulk_string_with_ctrl_chars(self, chunk_size): + parsed = self.parse_data(chunk_size, b"$8\r\nfoo\r\nbar\r\n") + assert parsed == ([b"foo\r\nbar"], b"") + + def test_verbatim_string(self, chunk_size): + parsed = self.parse_data(chunk_size, b"=3\r\ntxt:foo\r\nbar") + assert parsed == ([VerbatimString(b"foo", "txt")], b"bar") From ee762576f1657f60d4b5f92038a077568f0ca95d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Sun, 10 Sep 2023 11:51:38 +0000 Subject: [PATCH 03/14] Add errors, use strings --- tests/resp.py | 113 ++++++++++++++++++++++++++++++++++++--------- tests/test_resp.py | 59 +++++++++++++++++++---- 2 files changed, 140 insertions(+), 32 deletions(-) diff --git a/tests/resp.py b/tests/resp.py index 57b792ea62..65330dcaee 100644 --- a/tests/resp.py +++ b/tests/resp.py @@ -6,19 +6,37 @@ CRNL = b"\r\n" -class VerbatimString(bytes): +class VerbatimStr(str): """ A string that is encoded as a resp3 verbatim string """ - def __new__(cls, value: bytes, hint: str) -> "VerbatimString": - return bytes.__new__(cls, value) + def __new__(cls, value: str, hint: str) -> "VerbatimStr": + return str.__new__(cls, value) - def __init__(self, value: bytes, hint: str) -> None: + def __init__(self, value: str, hint: str) -> None: self.hint = hint def __repr__(self) -> str: - return f"VerbatimString({super().__repr__()}, {self.hint!r})" + return f"VerbatimStr({super().__repr__()}, {self.hint!r})" + + +class ErrorStr(str): + """ + A string to be encoded as a resp3 error + """ + + def __new__(cls, code: str, value: str) -> "ErrorStr": + return str.__new__(cls, value) + + def __init__(self, code: str, value: str) -> None: + self.code = code.upper() + + def __repr__(self) -> str: + return f"ErrorString({self.code!r}, {super().__repr__()})" + + def __str__(self): + return f"{self.code} {super().__str__()}" class PushData(list): @@ -30,19 +48,43 @@ def __repr__(self) -> str: return f"PushData({super().__repr__()})" +class Attribute(dict): + """ + A special type of map indicating data from a attribute response + """ + + def __repr__(self) -> str: + return f"Attribute({super().__repr__()})" + + class RespEncoder: """ - A class for simple RESP protocol encoding for unit tests + A class for simple RESP protocol encoder for unit tests """ - def __init__(self, protocol: int = 2, encoding: str = "utf-8") -> None: + def __init__( + self, protocol: int = 2, encoding: str = "utf-8", errorhander="strict" + ) -> None: self.protocol = protocol self.encoding = encoding + self.errorhandler = errorhander + + def apply_encoding(self, value: str) -> bytes: + return value.encode(self.encoding, errors=self.errorhandler) + + def has_crnl(self, value: bytes) -> bool: + """check if either cr or nl is in the value""" + return b"\r" in value or b"\n" in value + + def escape_crln(self, value: bytes) -> bytes: + """remove any cr or nl from the value""" + return value.replace(b"\r", b"\\r").replace(b"\n", b"\\n") def encode(self, data: Any, hint: Optional[str] = None) -> bytes: if isinstance(data, dict): if self.protocol > 2: - result = f"%{len(data)}\r\n".encode() + code = "|" if isinstance(data, Attribute) else "%" + result = f"{code}{len(data)}\r\n".encode() for key, val in data.items(): result += self.encode(key) + self.encode(val) return result @@ -54,10 +96,8 @@ def encode(self, data: Any, hint: Optional[str] = None) -> bytes: return self.encode(mylist) elif isinstance(data, list): - if isinstance(data, PushData) and self.protocol > 2: - result = f">{len(data)}\r\n".encode() - else: - result = f"*{len(data)}\r\n".encode() + code = ">" if isinstance(data, PushData) and self.protocol > 2 else "*" + result = f"{code}{len(data)}\r\n".encode() for val in data: result += self.encode(val) return result @@ -71,11 +111,18 @@ def encode(self, data: Any, hint: Optional[str] = None) -> bytes: else: return self.encode(list(data)) + elif isinstance(data, ErrorStr): + enc = self.apply_encoding(str(data)) + if self.protocol > 2: + if len(enc) > 80 or self.has_crnl(enc): + return f"!{len(enc)}\r\n".encode() + enc + b"\r\n" + return b"-" + self.escape_crln(enc) + b"\r\n" + elif isinstance(data, str): - enc = data.encode(self.encoding) + enc = self.apply_encoding(data) # long strings or strings with control characters must be encoded as bulk # strings - if hint or len(enc) > 20 or b"\r" in enc or b"\n" in enc: + if hint or len(enc) > 80 or self.has_crnl(enc): return self.encode_bulkstr(enc, hint) return b"+" + enc + b"\r\n" @@ -158,7 +205,7 @@ def resp_parse( elif code == b"+": # simple string # we decode them automatically - yield arg.decode(), rest + yield arg.decode(errors="surrogateescape"), rest elif code == b"$": # bulk string count = int(arg) @@ -168,8 +215,9 @@ def resp_parse( assert incoming is not None rest += incoming bulkstr = rest[:count] - # bulk strings are not decoded, could contain binary data - yield bulkstr, rest[expect:] + # we decode them automatically. Can be encoded + # back to binary if necessary with "surrogatescape" + yield bulkstr.decode(errors="surrogateescape"), rest[expect:] elif code == b"=": # verbatim strings count = int(arg) @@ -179,9 +227,9 @@ def resp_parse( assert incoming is not None rest += incoming hint = rest[:3] - result = rest[4 : (count + 4)] - # verbatim strings are not decoded, could contain binary data - yield VerbatimString(result, hint.decode()), rest[expect:] + result = rest[4: (count + 4)] + yield VerbatimStr(result.decode(errors="surrogateescape"), + hint.decode()), rest[expect:] elif code in b"*>": # array or push data count = int(arg) @@ -214,7 +262,7 @@ def resp_parse( result_set.add(value) yield result_set, rest - elif code == b"%": # map + elif code in b"%|": # map or attribute count = int(arg) result_map = {} for _ in range(count): @@ -232,10 +280,29 @@ def resp_parse( parsed = parser.send(incoming) value, rest = parsed result_map[key] = value + if code == b"|": + yield Attribute(result_map), rest yield result_map, rest + + elif code == b"-": # error + # we decode them automatically + decoded = arg.decode(errors="surrogateescape") + code, value = decoded.split(" ", 1) + yield ErrorStr(code, value), rest + + elif code == b"!": # resp3 error + count = int(arg) + expect = count + 2 # +2 for the trailing CRNL + while len(rest) < expect: + incoming = yield (None) + assert incoming is not None + rest += incoming + bulkstr = rest[:count] + decoded = bulkstr.decode(errors="surrogateescape") + code, value = decoded.split(" ", 1) + yield ErrorStr(code, value), rest[expect:] + else: - if code in b"-!": - raise NotImplementedError(f"resp opcode '{code.decode()}' not implemented") raise ValueError(f"Unknown opcode '{code.decode()}'") diff --git a/tests/test_resp.py b/tests/test_resp.py index 5b28dae961..4706699a4f 100644 --- a/tests/test_resp.py +++ b/tests/test_resp.py @@ -1,6 +1,14 @@ import pytest -from .resp import PushData, VerbatimString, encode, parse_all, parse_chunks +from .resp import ( + Attribute, + ErrorStr, + PushData, + VerbatimStr, + encode, + parse_all, + parse_chunks, +) @pytest.fixture(params=[2, 3]) @@ -13,9 +21,9 @@ def test_simple_str(self): assert encode("foo") == b"+foo\r\n" def test_long_str(self): - text = "fooling around with the sword in the mud" - assert len(text) == 40 - assert encode(text) == b"$40\r\n" + text.encode() + b"\r\n" + text = 3 * "fooling around with the sword in the mud" + assert len(text) == 120 + assert encode(text) == b"$120\r\n" + text.encode() + b"\r\n" # test strings with control characters def test_str_with_ctrl_chars(self): @@ -66,6 +74,13 @@ def test_map(self, resp_version): else: assert data == b"%2\r\n:1\r\n:2\r\n:3\r\n:4\r\n" + def test_attribute(self, resp_version): + data = encode(Attribute({1: 2, 3: 4}), protocol=resp_version) + if resp_version == 2: + assert data == b"*4\r\n:1\r\n:2\r\n:3\r\n:4\r\n" + else: + assert data == b"|2\r\n:1\r\n:2\r\n:3\r\n:4\r\n" + def test_nested_array(self): assert encode([1, [2, 3]]) == b"*2\r\n:1\r\n*2\r\n:2\r\n:3\r\n" @@ -103,6 +118,14 @@ def test_bool(self, resp_version): else: assert data == b"f\r\n" + def test_errorstr(self, resp_version): + err = ErrorStr("foo", "bar\r\nbaz") + data = encode(err, protocol=resp_version) + if resp_version == 2: + assert data == b"-FOO bar\\r\\nbaz\r\n" + else: + assert data == b"!12\r\nFOO bar\r\nbaz\r\n" + @pytest.mark.parametrize("chunk_size", [0, 1, 2, -2]) class TestParser: @@ -154,7 +177,7 @@ def test_incomplete_list(self, chunk_size): def test_invalid_token(self, chunk_size): with pytest.raises(ValueError): self.parse_data(chunk_size, b")foo\r\n") - with pytest.raises(NotImplementedError): + with pytest.raises(ValueError): self.parse_data(chunk_size, b"!foo\r\n") def test_multiple_ints(self, chunk_size): @@ -185,12 +208,30 @@ def test_simple_string(self, chunk_size): def test_bulk_string(self, chunk_size): parsed = parse_all(b"$3\r\nfoo\r\nbar") - assert parsed == ([b"foo"], b"bar") + assert parsed == (["foo"], b"bar") def test_bulk_string_with_ctrl_chars(self, chunk_size): parsed = self.parse_data(chunk_size, b"$8\r\nfoo\r\nbar\r\n") - assert parsed == ([b"foo\r\nbar"], b"") + assert parsed == (["foo\r\nbar"], b"") - def test_verbatim_string(self, chunk_size): + def test_verbatimstr(self, chunk_size): parsed = self.parse_data(chunk_size, b"=3\r\ntxt:foo\r\nbar") - assert parsed == ([VerbatimString(b"foo", "txt")], b"bar") + assert parsed == ([VerbatimStr("foo", "txt")], b"bar") + + def test_errorstr(self, chunk_size): + parsed = self.parse_data(chunk_size, b"-FOO bar\r\nbaz") + assert parsed == ([ErrorStr("foo", "bar")], b"baz") + + def test_errorstr_resp3(self, chunk_size): + parsed = self.parse_data(chunk_size, b"!12\r\nFOO bar\r\nbaz\r\n") + assert parsed == ([ErrorStr("foo", "bar\r\nbaz")], b"") + + def test_attribute_map(self, chunk_size): + parsed = self.parse_data(chunk_size, b"|2\r\n:1\r\n:2\r\n:3\r\n:4\r\n") + assert parsed == ([Attribute({1: 2, 3: 4})], b"") + + def test_surrogateescape(self, chunk_size): + data = b"foo\xff" + parsed = self.parse_data(chunk_size, b"$4\r\n" + data + b"\r\nbar") + assert parsed == ([data.decode(errors="surrogateescape")], b"bar") + assert parsed[0][0].encode("utf-8", "surrogateescape") == data From e2af798312999853f7aef512d4c6e7bc886b2cd8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Sat, 16 Sep 2023 13:56:00 +0000 Subject: [PATCH 04/14] Add a class around the parser, to hold parsing rules. --- tests/resp.py | 304 +++++++++++++++++++++++++++----------------------- 1 file changed, 166 insertions(+), 138 deletions(-) diff --git a/tests/resp.py b/tests/resp.py index 65330dcaee..1aefbdddfb 100644 --- a/tests/resp.py +++ b/tests/resp.py @@ -169,141 +169,167 @@ def encode(value: Any, protocol: int = 2, hint: Optional[str] = None) -> bytes: return RespEncoder(protocol).encode(value, hint) -# a stateful RESP parser implemented via a generator -def resp_parse( - buffer: bytes, -) -> Generator[Optional[Tuple[Any, bytes]], Union[None, bytes], None]: +class RespGeneratorParser: """ - A stateful, generator based, RESP parser. - Returns a generator producing at most a single top-level primitive. - Yields tuple of (data_item, unparsed), or None if more data is needed. - It is fed more data with generator.send() + A wrapper class around a stateful RESP parsing generator, + allowing custom string decoding rules. """ - # Read the first line of resp or yield to get more data - while CRNL not in buffer: - incoming = yield None - assert incoming is not None - buffer += incoming - cmd, rest = buffer.split(CRNL, 1) - code, arg = cmd[:1], cmd[1:] - - if code == b":" or code == b"(": # integer, resp3 large int - yield int(arg), rest - - elif code == b"t": # resp3 true - yield True, rest - - elif code == b"f": # resp3 false - yield False, rest - - elif code == b"_": # resp3 null - yield None, rest - - elif code == b",": # resp3 double - yield float(arg), rest + def __init__(self, encoding: str = "utf-8", errorhandler: str = "surrogateescape"): + """ + Create a new parser, optionally specifying the encoding and errorhandler. + If `encoding` is None, bytes will be returned as-is. + The default settings are utf-8 encoding and surrogateescape errorhandler, + which can decode all possible byte sequences, + allowing decoded data to be re-encoded back to bytes. + """ + self.encoding = encoding + self.errorhandler = errorhandler - elif code == b"+": # simple string - # we decode them automatically - yield arg.decode(errors="surrogateescape"), rest + def decode_bytes(self, data: bytes) -> str: + """ + decode the data as a string, + """ + return data.decode(self.encoding, errors=self.errorhandler) - elif code == b"$": # bulk string - count = int(arg) - expect = count + 2 # +2 for the trailing CRNL - while len(rest) < expect: - incoming = yield (None) - assert incoming is not None - rest += incoming - bulkstr = rest[:count] - # we decode them automatically. Can be encoded - # back to binary if necessary with "surrogatescape" - yield bulkstr.decode(errors="surrogateescape"), rest[expect:] - - elif code == b"=": # verbatim strings - count = int(arg) - expect = count + 4 + 2 # 4 type and colon +2 for the trailing CRNL - while len(rest) < expect: - incoming = yield (None) - assert incoming is not None - rest += incoming - hint = rest[:3] - result = rest[4: (count + 4)] - yield VerbatimStr(result.decode(errors="surrogateescape"), - hint.decode()), rest[expect:] - - elif code in b"*>": # array or push data - count = int(arg) - result_array = [] - for _ in range(count): - # recursively parse the next array item - with closing(resp_parse(rest)) as parser: - parsed = parser.send(None) - while parsed is None: - incoming = yield None - parsed = parser.send(incoming) - value, rest = parsed - result_array.append(value) - if code == b">": - yield PushData(result_array), rest - else: - yield result_array, rest - - elif code == b"~": # set - count = int(arg) - result_set = set() - for _ in range(count): - # recursively parse the next set item - with closing(resp_parse(rest)) as parser: - parsed = parser.send(None) - while parsed is None: - incoming = yield None - parsed = parser.send(incoming) - value, rest = parsed - result_set.add(value) - yield result_set, rest - - elif code in b"%|": # map or attribute - count = int(arg) - result_map = {} - for _ in range(count): - # recursively parse the next key, and value - with closing(resp_parse(rest)) as parser: - parsed = parser.send(None) - while parsed is None: - incoming = yield None - parsed = parser.send(incoming) - key, rest = parsed - with closing(resp_parse(rest)) as parser: - parsed = parser.send(None) - while parsed is None: - incoming = yield None - parsed = parser.send(incoming) - value, rest = parsed - result_map[key] = value - if code == b"|": - yield Attribute(result_map), rest - yield result_map, rest - - elif code == b"-": # error - # we decode them automatically - decoded = arg.decode(errors="surrogateescape") - code, value = decoded.split(" ", 1) - yield ErrorStr(code, value), rest - - elif code == b"!": # resp3 error - count = int(arg) - expect = count + 2 # +2 for the trailing CRNL - while len(rest) < expect: - incoming = yield (None) + # a stateful RESP parser implemented via a generator + def parse( + self, + buffer: bytes, + ) -> Generator[Optional[Tuple[Any, bytes]], Union[None, bytes], None]: + """ + A stateful, generator based, RESP parser. + Returns a generator producing at most a single top-level primitive. + Yields tuple of (data_item, unparsed), or None if more data is needed. + It is fed more data with generator.send() + """ + # Read the first line of resp or yield to get more data + while CRNL not in buffer: + incoming = yield None assert incoming is not None - rest += incoming - bulkstr = rest[:count] - decoded = bulkstr.decode(errors="surrogateescape") - code, value = decoded.split(" ", 1) - yield ErrorStr(code, value), rest[expect:] + buffer += incoming + cmd, rest = buffer.split(CRNL, 1) + + code, arg = cmd[:1], cmd[1:] + + if code == b":" or code == b"(": # integer, resp3 large int + yield int(arg), rest + + elif code == b"t": # resp3 true + yield True, rest + + elif code == b"f": # resp3 false + yield False, rest + + elif code == b"_": # resp3 null + yield None, rest + + elif code == b",": # resp3 double + yield float(arg), rest + + elif code == b"+": # simple string + # we decode them automatically + yield self.decode_bytes(arg), rest + + elif code == b"$": # bulk string + count = int(arg) + expect = count + 2 # +2 for the trailing CRNL + while len(rest) < expect: + incoming = yield (None) + assert incoming is not None + rest += incoming + bulkstr = rest[:count] + yield self.decode_bytes(bulkstr), rest[expect:] + + elif code == b"=": # verbatim strings + count = int(arg) + expect = count + 4 + 2 # 4 type and colon +2 for the trailing CRNL + while len(rest) < expect: + incoming = yield (None) + assert incoming is not None + rest += incoming + string = self.decode_bytes(rest[: (count + 4)]) + if string[3] != ":": + raise ValueError(f"Expected colon after hint, got {bulkstr[3]}") + hint = string[:3] + string = string[4 : (count + 4)] + yield VerbatimStr(string, hint), rest[expect:] + + elif code in b"*>": # array or push data + count = int(arg) + result_array = [] + for _ in range(count): + # recursively parse the next array item + with closing(self.parse(rest)) as parser: + parsed = parser.send(None) + while parsed is None: + incoming = yield None + parsed = parser.send(incoming) + value, rest = parsed + result_array.append(value) + if code == b">": + yield PushData(result_array), rest + else: + yield result_array, rest + + elif code == b"~": # set + count = int(arg) + result_set = set() + for _ in range(count): + # recursively parse the next set item + with closing(self.parse(rest)) as parser: + parsed = parser.send(None) + while parsed is None: + incoming = yield None + parsed = parser.send(incoming) + value, rest = parsed + result_set.add(value) + yield result_set, rest + + elif code in b"%|": # map or attribute + count = int(arg) + result_map = {} + for _ in range(count): + # recursively parse the next key, and value + with closing(self.parse(rest)) as parser: + parsed = parser.send(None) + while parsed is None: + incoming = yield None + parsed = parser.send(incoming) + key, rest = parsed + with closing(self.parse(rest)) as parser: + parsed = parser.send(None) + while parsed is None: + incoming = yield None + parsed = parser.send(incoming) + value, rest = parsed + result_map[key] = value + if code == b"|": + yield Attribute(result_map), rest + yield result_map, rest + + elif code == b"-": # error + # we decode them automatically + decoded = self.decode_bytes(arg) + assert isinstance(decoded, str) + code, value = decoded.split(" ", 1) + yield ErrorStr(code, value), rest + + elif code == b"!": # resp3 error + count = int(arg) + expect = count + 2 # +2 for the trailing CRNL + while len(rest) < expect: + incoming = yield (None) + assert incoming is not None + rest += incoming + bulkstr = rest[:count] + decoded = self.decode_bytes(bulkstr) + assert isinstance(decoded, str) + code, value = decoded.split(" ", 1) + yield ErrorStr(code, value), rest[expect:] - else: - raise ValueError(f"Unknown opcode '{code.decode()}'") + else: + raise ValueError(f"Unknown opcode '{code.decode()}'") class NeedMoreData(RuntimeError): @@ -315,10 +341,12 @@ class NeedMoreData(RuntimeError): class RespParser: """ A class for simple RESP protocol decoding for unit tests + Uses a RespGeneratorParser to produce data. """ def __init__(self) -> None: - self.parser: Optional[ + self.parser = RespGeneratorParser() + self.generator: Optional[ Generator[Optional[Tuple[Any, bytes]], Union[None, bytes], None] ] = None # which has not resulted in a parsed value @@ -329,24 +357,24 @@ def parse(self, buffer: bytes) -> Optional[Any]: Parse a buffer of data, return a tuple of a single top-level primitive and the remaining buffer or raise NeedMoreData if more data is needed """ - if self.parser is None: + if self.generator is None: # create a new parser generator, initializing it with # any unparsed data from previous calls buffer = b"".join(self.consumed) + buffer del self.consumed[:] - self.parser = resp_parse(buffer) - parsed = self.parser.send(None) + self.generator = self.parser.parse(buffer) + parsed = self.generator.send(None) else: # sen more data to the parser - parsed = self.parser.send(buffer) + parsed = self.generator.send(buffer) if parsed is None: self.consumed.append(buffer) raise NeedMoreData() # got a value, close the parser, store the remaining buffer - self.parser.close() - self.parser = None + self.generator.close() + self.generator = None value, remaining = parsed self.consumed = [remaining] return value @@ -355,9 +383,9 @@ def get_unparsed(self) -> bytes: return b"".join(self.consumed) def close(self) -> None: - if self.parser is not None: - self.parser.close() - self.parser = None + if self.generator is not None: + self.generator.close() + self.generator = None del self.consumed[:] From 314f9cdf3e059325ea3c9514d55bacbb3e66898b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Sat, 9 Sep 2023 12:03:52 +0000 Subject: [PATCH 05/14] Use common parser and server for non-async test_connect.py --- tests/resp.py | 22 ++++++++++++++ tests/test_connect.py | 69 ++++++++++++++++--------------------------- 2 files changed, 48 insertions(+), 43 deletions(-) diff --git a/tests/resp.py b/tests/resp.py index 1aefbdddfb..3af6e52e5c 100644 --- a/tests/resp.py +++ b/tests/resp.py @@ -420,3 +420,25 @@ def parse_chunks(buffers: List[bytes]) -> Tuple[List[Any], bytes]: except NeedMoreData: break return result, parser.get_unparsed() + + +class RespServer: + """A simple, dummy, REDIS server for unit tests. + Accepts RESP commands and returns RESP responses. + """ + + _CLIENT_NAME = "test-suite-client" + _SUCCESS_RESP = b"+OK" + CRNL + _ERROR_RESP = b"-ERR" + CRNL + _SUPPORTED_CMDS = {f"CLIENT SETNAME {_CLIENT_NAME}": _SUCCESS_RESP} + + def command(self, cmd: Any) -> bytes: + """Process a single command and return the response""" + if not isinstance(cmd, list): + return f"-ERR unknown command {cmd!r}\r\n".encode() + + # currently supports only a single command + command = " ".join(cmd) + if command in self._SUPPORTED_CMDS: + return self._SUPPORTED_CMDS[command] + return self._ERROR_RESP diff --git a/tests/test_connect.py b/tests/test_connect.py index 696e69ceea..574693ee5f 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -1,5 +1,4 @@ import logging -import re import socket import socketserver import ssl @@ -8,16 +7,13 @@ import pytest from redis.connection import Connection, SSLConnection, UnixDomainSocketConnection +from . import resp from .ssl_utils import get_ssl_filename _logger = logging.getLogger(__name__) _CLIENT_NAME = "test-suite-client" -_CMD_SEP = b"\r\n" -_SUCCESS_RESP = b"+OK" + _CMD_SEP -_ERROR_RESP = b"-ERR" + _CMD_SEP -_SUPPORTED_CMDS = {f"CLIENT SETNAME {_CLIENT_NAME}": _SUCCESS_RESP} @pytest.fixture @@ -148,44 +144,31 @@ def finish(self): _logger.info("%s disconnected", self.client_address) def handle(self): + parser = resp.RespParser() + server = resp.RespServer() buffer = b"" - command = None - command_ptr = None - fragment_length = None - while self.server.is_serving() or buffer: - try: - buffer += self.request.recv(1024) - except socket.timeout: - continue - if not buffer: - continue - parts = re.split(_CMD_SEP, buffer) - buffer = parts[-1] - for fragment in parts[:-1]: - fragment = fragment.decode() - _logger.info("Command fragment: %s", fragment) - - if fragment.startswith("*") and command is None: - command = [None for _ in range(int(fragment[1:]))] - command_ptr = 0 - fragment_length = None - continue - - if fragment.startswith("$") and command[command_ptr] is None: - fragment_length = int(fragment[1:]) - continue - - assert len(fragment) == fragment_length - command[command_ptr] = fragment - command_ptr += 1 - - if command_ptr < len(command): + try: + # if client performs pipelining, we may need + # to adjust this code to not block when sending + # responses. + while self.server.is_serving(): + try: + command = parser.parse(buffer) + buffer = b"" + except resp.NeedMoreData: + try: + buffer = self.request.recv(1024) + except socket.timeout: + buffer = b"" + continue + if not buffer: + break # EOF continue - - command = " ".join(command) _logger.info("Command %s", command) - resp = _SUPPORTED_CMDS.get(command, _ERROR_RESP) - _logger.info("Response %s", resp) - self.request.sendall(resp) - command = None - _logger.info("Exit handler") + response = server.command(command) + _logger.info("Response %s", response) + self.request.sendall(response) + except Exception: + _logger.exception("Exception in handler") + finally: + _logger.info("Exit handler") From f61ad28cf27ed715137525a0b9350669fce2f0ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Sat, 9 Sep 2023 12:19:09 +0000 Subject: [PATCH 06/14] use common parser and server for test_asyncio/test_connect.py --- tests/test_asyncio/test_connect.py | 70 ++++++++++++------------------ 1 file changed, 27 insertions(+), 43 deletions(-) diff --git a/tests/test_asyncio/test_connect.py b/tests/test_asyncio/test_connect.py index 5e6b120fb3..a827186d24 100644 --- a/tests/test_asyncio/test_connect.py +++ b/tests/test_asyncio/test_connect.py @@ -1,6 +1,5 @@ import asyncio import logging -import re import socket import ssl @@ -11,16 +10,13 @@ UnixDomainSocketConnection, ) +from .. import resp from ..ssl_utils import get_ssl_filename _logger = logging.getLogger(__name__) _CLIENT_NAME = "test-suite-client" -_CMD_SEP = b"\r\n" -_SUCCESS_RESP = b"+OK" + _CMD_SEP -_ERROR_RESP = b"-ERR" + _CMD_SEP -_SUPPORTED_CMDS = {f"CLIENT SETNAME {_CLIENT_NAME}": _SUCCESS_RESP} @pytest.fixture @@ -102,46 +98,34 @@ async def _handler(reader, writer): async def _redis_request_handler(reader, writer, stop_event): + parser = resp.RespParser() + server = resp.RespServer() buffer = b"" - command = None - command_ptr = None - fragment_length = None - while not stop_event.is_set() or buffer: - _logger.info(str(stop_event.is_set())) - try: - buffer += await asyncio.wait_for(reader.read(1024), timeout=0.5) - except TimeoutError: - continue - if not buffer: - continue - parts = re.split(_CMD_SEP, buffer) - buffer = parts[-1] - for fragment in parts[:-1]: - fragment = fragment.decode() - _logger.info("Command fragment: %s", fragment) - - if fragment.startswith("*") and command is None: - command = [None for _ in range(int(fragment[1:]))] - command_ptr = 0 - fragment_length = None - continue - - if fragment.startswith("$") and command[command_ptr] is None: - fragment_length = int(fragment[1:]) - continue - - assert len(fragment) == fragment_length - command[command_ptr] = fragment - command_ptr += 1 - - if command_ptr < len(command): + try: + # if client performs pipelining, we may need + # to adjust this code to not block when sending + # responses. + while not stop_event.is_set() or buffer: + _logger.info(str(stop_event.is_set())) + try: + command = parser.parse(buffer) + buffer = b"" + except resp.NeedMoreData: + try: + buffer = await asyncio.wait_for(reader.read(1024), timeout=0.5) + except TimeoutError: + buffer = b"" + continue + if not buffer: + break # EOF continue - command = " ".join(command) _logger.info("Command %s", command) - resp = _SUPPORTED_CMDS.get(command, _ERROR_RESP) - _logger.info("Response from %s", resp) - writer.write(resp) + response = server.command(command) + _logger.info("Response %s", response) + writer.write(response) await writer.drain() - command = None - _logger.info("Exit handler") + except Exception: + _logger.exception("Error in handler") + finally: + _logger.info("Exit handler") From da8c5f3974e010f43498fdcc7b228aa000a90ac5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Sun, 10 Sep 2023 17:03:17 +0000 Subject: [PATCH 07/14] expand the RespServer --- tests/resp.py | 118 +++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 103 insertions(+), 15 deletions(-) diff --git a/tests/resp.py b/tests/resp.py index 3af6e52e5c..2f21ebf62f 100644 --- a/tests/resp.py +++ b/tests/resp.py @@ -192,7 +192,6 @@ def decode_bytes(self, data: bytes) -> str: """ return data.decode(self.encoding, errors=self.errorhandler) - # a stateful RESP parser implemented via a generator def parse( self, buffer: bytes, @@ -340,8 +339,9 @@ class NeedMoreData(RuntimeError): class RespParser: """ - A class for simple RESP protocol decoding for unit tests - Uses a RespGeneratorParser to produce data. + A class for simple RESP protocol decoding for unit tests. + Uses a RespGeneratorParser to produce data, and can + produce top-level objects for as long as there is data available. """ def __init__(self) -> None: @@ -355,7 +355,8 @@ def __init__(self) -> None: def parse(self, buffer: bytes) -> Optional[Any]: """ Parse a buffer of data, return a tuple of a single top-level primitive and the - remaining buffer or raise NeedMoreData if more data is needed + remaining buffer or raise NeedMoreData if more data is needed to + produce a value. """ if self.generator is None: # create a new parser generator, initializing it with @@ -372,7 +373,7 @@ def parse(self, buffer: bytes) -> Optional[Any]: self.consumed.append(buffer) raise NeedMoreData() - # got a value, close the parser, store the remaining buffer + # got a value, close the generator, store the remaining buffer self.generator.close() self.generator = None value, remaining = parsed @@ -427,18 +428,105 @@ class RespServer: Accepts RESP commands and returns RESP responses. """ - _CLIENT_NAME = "test-suite-client" - _SUCCESS_RESP = b"+OK" + CRNL - _ERROR_RESP = b"-ERR" + CRNL - _SUPPORTED_CMDS = {f"CLIENT SETNAME {_CLIENT_NAME}": _SUCCESS_RESP} + handlers = {} + + def __init__(self): + self.protocol = 2 + self.server_ver = self.get_server_version() + self.auth = [] + self.client_name = "" + + # patchable methods for testing + + def get_server_version(self): + return 6 + + def on_auth(self, auth): + pass + + def on_setname(self, name): + pass + + def on_protocol(self, proto): + pass def command(self, cmd: Any) -> bytes: """Process a single command and return the response""" + result = self._command(cmd) + return RespEncoder(self.protocol).encode(result) + + def _command(self, cmd: Any) -> Any: if not isinstance(cmd, list): - return f"-ERR unknown command {cmd!r}\r\n".encode() + return ErrorStr("ERR", "unknown command {cmd!r}") + + # handle registered commands + command = cmd[0].upper() + args = cmd[1:] + if command in self.handlers: + return self.handlers[command](self, args) + + return ErrorStr("ERR", "unknown command {cmd!r}") + + def handle_auth(self, args): + self.auth = args[:] + self.on_auth(self.auth) + expect = 2 if self.server_ver >= 6 else 1 + if len(args) != expect: + return ErrorStr("ERR", "wrong number of arguments" " for 'AUTH' command") + return "OK" + + handlers["AUTH"] = handle_auth + + def handle_client(self, args): + if args[0] == "SETNAME": + return self.handle_setname(args[1:]) + return ErrorStr("ERR", "unknown subcommand or wrong number of arguments") + + handlers["CLIENT"] = handle_client + + def handle_setname(self, args): + if len(args) != 1: + return ErrorStr("ERR", "wrong number of arguments") + self.client_name = args[0] + self.on_setname(self.client_name) + return "OK" + + def handle_hello(self, args): + if self.server_ver < 6: + return ErrorStr("ERR", "unknown command 'HELLO'") + proto = self.protocol + if args: + proto = args.pop(0) + if str(proto) not in ["2", "3"]: + return ErrorStr( + "NOPROTO", "sorry this protocol version is not supported" + ) - # currently supports only a single command - command = " ".join(cmd) - if command in self._SUPPORTED_CMDS: - return self._SUPPORTED_CMDS[command] - return self._ERROR_RESP + while args: + cmd = args.pop(0).upper() + if cmd == "AUTH": + auth_args = args[:2] + args = args[2:] + res = self.handle_auth(auth_args) + if res != "OK": + return res + continue + if cmd == "SETNAME": + setname_args = args[:1] + args = args[1:] + res = self.handle_setname(setname_args) + if res != "OK": + return res + continue + return ErrorStr("ERR", "unknown subcommand or wrong number of arguments") + + self.protocol = int(proto) + self.on_protocol(self.protocol) + result = { + "server": "redistester", + "version": "0.0.1", + "proto": self.protocol, + } + return result + + handlers["HELLO"] = handle_hello From c4e656ecac36d725408d43117e296c1a616b06e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Sun, 10 Sep 2023 17:52:42 +0000 Subject: [PATCH 08/14] Add tests for varions connection handshakes, async --- tests/test_asyncio/test_connect.py | 86 ++++++++++++++++++++++++++++++ 1 file changed, 86 insertions(+) diff --git a/tests/test_asyncio/test_connect.py b/tests/test_asyncio/test_connect.py index a827186d24..b59310578a 100644 --- a/tests/test_asyncio/test_connect.py +++ b/tests/test_asyncio/test_connect.py @@ -2,10 +2,12 @@ import logging import socket import ssl +from unittest.mock import patch import pytest from redis.asyncio.connection import ( Connection, + ResponseError, SSLConnection, UnixDomainSocketConnection, ) @@ -61,6 +63,90 @@ async def test_tcp_ssl_connect(tcp_address): await conn.disconnect() +@pytest.mark.parametrize( + ("use_server_ver", "use_protocol", "use_auth", "use_client_name"), + [ + (5, 2, False, True), + (5, 2, True, True), + (5, 3, True, True), + (6, 2, False, True), + (6, 2, True, True), + (6, 3, False, False), + (6, 3, True, False), + (6, 3, False, True), + (6, 3, True, True), + ], +) +# @pytest.mark.parametrize("use_protocol", [2, 3]) +# @pytest.mark.parametrize("use_auth", [False, True]) +async def test_tcp_auth( + tcp_address, use_protocol, use_auth, use_server_ver, use_client_name +): + """ + Test that various initial handshake cases are handled correctly by the client + """ + got_auth = [] + got_protocol = None + got_name = None + + def on_auth(self, auth): + got_auth[:] = auth + + def on_protocol(self, proto): + nonlocal got_protocol + got_protocol = proto + + def on_setname(self, name): + nonlocal got_name + got_name = name + + def get_server_version(self): + return use_server_ver + + if use_auth: + auth_args = {"username": "myuser", "password": "mypassword"} + else: + auth_args = {} + got_protocol = None + host, port = tcp_address + conn = Connection( + host=host, + port=port, + client_name=_CLIENT_NAME if use_client_name else None, + socket_timeout=10, + protocol=use_protocol, + **auth_args, + ) + try: + with patch.multiple( + resp.RespServer, + on_auth=on_auth, + get_server_version=get_server_version, + on_protocol=on_protocol, + on_setname=on_setname, + ): + if use_server_ver < 6 and use_protocol > 2: + with pytest.raises(ResponseError): + await _assert_connect(conn, tcp_address) + return + + await _assert_connect(conn, tcp_address) + if use_protocol == 3: + assert got_protocol == use_protocol + if use_auth: + if use_server_ver < 6: + assert got_auth == ["mypassword"] + else: + assert got_auth == ["myuser", "mypassword"] + + if use_client_name: + assert got_name == _CLIENT_NAME + else: + assert got_name is None + finally: + await conn.disconnect() + + async def _assert_connect(conn, server_address, certfile=None, keyfile=None): stop_event = asyncio.Event() finished = asyncio.Event() From 6528dff94d3174ae32f1d353e02722f6208c69a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Sun, 10 Sep 2023 17:59:40 +0000 Subject: [PATCH 09/14] add connection handshake tests for non-async version --- tests/test_connect.py | 90 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 89 insertions(+), 1 deletion(-) diff --git a/tests/test_connect.py b/tests/test_connect.py index 574693ee5f..49c3abe506 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -3,9 +3,15 @@ import socketserver import ssl import threading +from unittest.mock import patch import pytest -from redis.connection import Connection, SSLConnection, UnixDomainSocketConnection +from redis.connection import ( + Connection, + ResponseError, + SSLConnection, + UnixDomainSocketConnection, +) from . import resp from .ssl_utils import get_ssl_filename @@ -55,6 +61,88 @@ def test_tcp_ssl_connect(tcp_address): _assert_connect(conn, tcp_address, certfile=certfile, keyfile=keyfile) +@pytest.mark.parametrize( + ("use_server_ver", "use_protocol", "use_auth", "use_client_name"), + [ + (5, 2, False, True), + (5, 2, True, True), + (5, 3, True, True), + (6, 2, False, True), + (6, 2, True, True), + (6, 3, False, False), + (6, 3, True, False), + (6, 3, False, True), + (6, 3, True, True), + ], +) +# @pytest.mark.parametrize("use_protocol", [2, 3]) +# @pytest.mark.parametrize("use_auth", [False, True]) +def test_tcp_auth(tcp_address, use_protocol, use_auth, use_server_ver, use_client_name): + """ + Test that various initial handshake cases are handled correctly by the client + """ + got_auth = [] + got_protocol = None + got_name = None + + def on_auth(self, auth): + got_auth[:] = auth + + def on_protocol(self, proto): + nonlocal got_protocol + got_protocol = proto + + def on_setname(self, name): + nonlocal got_name + got_name = name + + def get_server_version(self): + return use_server_ver + + if use_auth: + auth_args = {"username": "myuser", "password": "mypassword"} + else: + auth_args = {} + got_protocol = None + host, port = tcp_address + conn = Connection( + host=host, + port=port, + client_name=_CLIENT_NAME if use_client_name else None, + socket_timeout=10, + protocol=use_protocol, + **auth_args, + ) + try: + with patch.multiple( + resp.RespServer, + on_auth=on_auth, + get_server_version=get_server_version, + on_protocol=on_protocol, + on_setname=on_setname, + ): + if use_server_ver < 6 and use_protocol > 2: + with pytest.raises(ResponseError): + _assert_connect(conn, tcp_address) + return + + _assert_connect(conn, tcp_address) + if use_protocol == 3: + assert got_protocol == use_protocol + if use_auth: + if use_server_ver < 6: + assert got_auth == ["mypassword"] + else: + assert got_auth == ["myuser", "mypassword"] + + if use_client_name: + assert got_name == _CLIENT_NAME + else: + assert got_name is None + finally: + conn.disconnect() + + def _assert_connect(conn, server_address, certfile=None, keyfile=None): if isinstance(server_address, str): if not _RedisUDSServer: From 7cafea7ff5a21564163eb17cef02b2337103eac3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Sun, 17 Sep 2023 09:53:35 +0000 Subject: [PATCH 10/14] Don't use NoneType, for 3.7 compatibility --- tests/resp.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/resp.py b/tests/resp.py index 2f21ebf62f..89e8661b1d 100644 --- a/tests/resp.py +++ b/tests/resp.py @@ -1,6 +1,5 @@ import itertools from contextlib import closing -from types import NoneType from typing import Any, Generator, List, Optional, Tuple, Union CRNL = b"\r\n" @@ -145,7 +144,7 @@ def encode(self, data: Any, hint: Optional[str] = None) -> bytes: return f",{data}\r\n".encode() # resp3 double return f"+{data}\r\n".encode() # simple string - elif isinstance(data, NoneType): + elif data is None: if self.protocol > 2: return b"_\r\n" # resp3 null return b"$-1\r\n" # Null bulk string From 2972989db9ca16b3a8019812dea94015ffbff15e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Sun, 17 Sep 2023 09:55:51 +0000 Subject: [PATCH 11/14] Fix tying for mypy --- tests/resp.py | 46 +++++++++++++++++++++++----------------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/tests/resp.py b/tests/resp.py index 89e8661b1d..abf4b69b05 100644 --- a/tests/resp.py +++ b/tests/resp.py @@ -1,6 +1,6 @@ import itertools from contextlib import closing -from typing import Any, Generator, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union CRNL = b"\r\n" @@ -34,11 +34,11 @@ def __init__(self, code: str, value: str) -> None: def __repr__(self) -> str: return f"ErrorString({self.code!r}, {super().__repr__()})" - def __str__(self): + def __str__(self) -> str: return f"{self.code} {super().__str__()}" -class PushData(list): +class PushData(List[Any]): """ A special type of list indicating data from a push response """ @@ -47,7 +47,7 @@ def __repr__(self) -> str: return f"PushData({super().__repr__()})" -class Attribute(dict): +class Attribute(Dict[Any, Any]): """ A special type of map indicating data from a attribute response """ @@ -62,7 +62,7 @@ class RespEncoder: """ def __init__( - self, protocol: int = 2, encoding: str = "utf-8", errorhander="strict" + self, protocol: int = 2, encoding: str = "utf-8", errorhander: str = "strict" ) -> None: self.protocol = protocol self.encoding = encoding @@ -248,7 +248,7 @@ def parse( rest += incoming string = self.decode_bytes(rest[: (count + 4)]) if string[3] != ":": - raise ValueError(f"Expected colon after hint, got {bulkstr[3]}") + raise ValueError(f"Expected colon after hint, got {string[3]}") hint = string[:3] string = string[4 : (count + 4)] yield VerbatimStr(string, hint), rest[expect:] @@ -310,8 +310,8 @@ def parse( # we decode them automatically decoded = self.decode_bytes(arg) assert isinstance(decoded, str) - code, value = decoded.split(" ", 1) - yield ErrorStr(code, value), rest + err, value = decoded.split(" ", 1) + yield ErrorStr(err, value), rest elif code == b"!": # resp3 error count = int(arg) @@ -323,8 +323,8 @@ def parse( bulkstr = rest[:count] decoded = self.decode_bytes(bulkstr) assert isinstance(decoded, str) - code, value = decoded.split(" ", 1) - yield ErrorStr(code, value), rest[expect:] + err, value = decoded.split(" ", 1) + yield ErrorStr(err, value), rest[expect:] else: raise ValueError(f"Unknown opcode '{code.decode()}'") @@ -427,26 +427,26 @@ class RespServer: Accepts RESP commands and returns RESP responses. """ - handlers = {} + handlers: Dict[str, Callable[..., Any]] = {} - def __init__(self): + def __init__(self) -> None: self.protocol = 2 self.server_ver = self.get_server_version() - self.auth = [] + self.auth: List[Any] = [] self.client_name = "" # patchable methods for testing - def get_server_version(self): + def get_server_version(self) -> int: return 6 - def on_auth(self, auth): + def on_auth(self, auth: List[Any]) -> None: pass - def on_setname(self, name): + def on_setname(self, name: str) -> None: pass - def on_protocol(self, proto): + def on_protocol(self, proto: int) -> None: pass def command(self, cmd: Any) -> bytes: @@ -466,7 +466,7 @@ def _command(self, cmd: Any) -> Any: return ErrorStr("ERR", "unknown command {cmd!r}") - def handle_auth(self, args): + def handle_auth(self, args: List[Any]) -> Union[str, ErrorStr]: self.auth = args[:] self.on_auth(self.auth) expect = 2 if self.server_ver >= 6 else 1 @@ -476,21 +476,21 @@ def handle_auth(self, args): handlers["AUTH"] = handle_auth - def handle_client(self, args): + def handle_client(self, args: List[Any]) -> Union[str, ErrorStr]: if args[0] == "SETNAME": return self.handle_setname(args[1:]) return ErrorStr("ERR", "unknown subcommand or wrong number of arguments") handlers["CLIENT"] = handle_client - def handle_setname(self, args): + def handle_setname(self, args: List[Any]) -> Union[str, ErrorStr]: if len(args) != 1: return ErrorStr("ERR", "wrong number of arguments") self.client_name = args[0] self.on_setname(self.client_name) return "OK" - def handle_hello(self, args): + def handle_hello(self, args: List[Any]) -> Union[ErrorStr, Dict[str, Any]]: if self.server_ver < 6: return ErrorStr("ERR", "unknown command 'HELLO'") proto = self.protocol @@ -507,14 +507,14 @@ def handle_hello(self, args): auth_args = args[:2] args = args[2:] res = self.handle_auth(auth_args) - if res != "OK": + if isinstance(res, ErrorStr): return res continue if cmd == "SETNAME": setname_args = args[:1] args = args[1:] res = self.handle_setname(setname_args) - if res != "OK": + if isinstance(res, ErrorStr): return res continue return ErrorStr("ERR", "unknown subcommand or wrong number of arguments") From 57cd79e225b32c900c8a0cb075097ec384bd6ade Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Sun, 17 Sep 2023 10:18:26 +0000 Subject: [PATCH 12/14] use list.clear() --- tests/resp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/resp.py b/tests/resp.py index abf4b69b05..c0b8895527 100644 --- a/tests/resp.py +++ b/tests/resp.py @@ -361,7 +361,7 @@ def parse(self, buffer: bytes) -> Optional[Any]: # create a new parser generator, initializing it with # any unparsed data from previous calls buffer = b"".join(self.consumed) + buffer - del self.consumed[:] + self.consumed.clear() self.generator = self.parser.parse(buffer) parsed = self.generator.send(None) else: @@ -386,7 +386,7 @@ def close(self) -> None: if self.generator is not None: self.generator.close() self.generator = None - del self.consumed[:] + self.consumed.clear() def parse_all(buffer: bytes) -> Tuple[List[Any], bytes]: From 09b51d47349c77854e1d3dae8f03c3f481ae261a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Tue, 17 Oct 2023 11:24:56 +0000 Subject: [PATCH 13/14] Skip testing resp3 parsing for async/py3.7 --- tests/test_asyncio/test_connect.py | 7 +++++-- tests/test_connect.py | 2 -- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/test_asyncio/test_connect.py b/tests/test_asyncio/test_connect.py index b59310578a..7c402bbdc1 100644 --- a/tests/test_asyncio/test_connect.py +++ b/tests/test_asyncio/test_connect.py @@ -2,6 +2,7 @@ import logging import socket import ssl +import sys from unittest.mock import patch import pytest @@ -19,6 +20,7 @@ _CLIENT_NAME = "test-suite-client" +PY37 = sys.version_info[:2] == (3, 7) @pytest.fixture @@ -77,14 +79,15 @@ async def test_tcp_ssl_connect(tcp_address): (6, 3, True, True), ], ) -# @pytest.mark.parametrize("use_protocol", [2, 3]) -# @pytest.mark.parametrize("use_auth", [False, True]) async def test_tcp_auth( tcp_address, use_protocol, use_auth, use_server_ver, use_client_name ): """ Test that various initial handshake cases are handled correctly by the client """ + if use_protocol == 3 and PY37: + pytest.skip("Python 3.7 does not support protocol 3 for asyncio") + got_auth = [] got_protocol = None got_name = None diff --git a/tests/test_connect.py b/tests/test_connect.py index 49c3abe506..090b5af953 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -75,8 +75,6 @@ def test_tcp_ssl_connect(tcp_address): (6, 3, True, True), ], ) -# @pytest.mark.parametrize("use_protocol", [2, 3]) -# @pytest.mark.parametrize("use_auth", [False, True]) def test_tcp_auth(tcp_address, use_protocol, use_auth, use_server_ver, use_client_name): """ Test that various initial handshake cases are handled correctly by the client From 5a79252766fc174b19f9e93d8cc75ba2ce22edcf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Mon, 13 Nov 2023 11:49:39 +0000 Subject: [PATCH 14/14] Revert "Skip testing resp3 parsing for async/py3.7", now that parsing is fixed. This reverts commit 09b51d47349c77854e1d3dae8f03c3f481ae261a. --- tests/test_asyncio/test_connect.py | 7 ++----- tests/test_connect.py | 2 ++ 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/test_asyncio/test_connect.py b/tests/test_asyncio/test_connect.py index 7c402bbdc1..b59310578a 100644 --- a/tests/test_asyncio/test_connect.py +++ b/tests/test_asyncio/test_connect.py @@ -2,7 +2,6 @@ import logging import socket import ssl -import sys from unittest.mock import patch import pytest @@ -20,7 +19,6 @@ _CLIENT_NAME = "test-suite-client" -PY37 = sys.version_info[:2] == (3, 7) @pytest.fixture @@ -79,15 +77,14 @@ async def test_tcp_ssl_connect(tcp_address): (6, 3, True, True), ], ) +# @pytest.mark.parametrize("use_protocol", [2, 3]) +# @pytest.mark.parametrize("use_auth", [False, True]) async def test_tcp_auth( tcp_address, use_protocol, use_auth, use_server_ver, use_client_name ): """ Test that various initial handshake cases are handled correctly by the client """ - if use_protocol == 3 and PY37: - pytest.skip("Python 3.7 does not support protocol 3 for asyncio") - got_auth = [] got_protocol = None got_name = None diff --git a/tests/test_connect.py b/tests/test_connect.py index 090b5af953..49c3abe506 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -75,6 +75,8 @@ def test_tcp_ssl_connect(tcp_address): (6, 3, True, True), ], ) +# @pytest.mark.parametrize("use_protocol", [2, 3]) +# @pytest.mark.parametrize("use_auth", [False, True]) def test_tcp_auth(tcp_address, use_protocol, use_auth, use_server_ver, use_client_name): """ Test that various initial handshake cases are handled correctly by the client