diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 3f35fdd59e..ce4d69f8c1 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -224,6 +224,7 @@ def __init__( encoding: str = "utf-8", encoding_errors: str = "strict", decode_responses: bool = False, + check_ready: bool = False, retry_on_timeout: bool = False, retry_on_error: Optional[list] = None, ssl: bool = False, @@ -291,6 +292,7 @@ def __init__( "encoding": encoding, "encoding_errors": encoding_errors, "decode_responses": decode_responses, + "check_ready": check_ready, "retry_on_timeout": retry_on_timeout, "retry_on_error": retry_on_error, "retry": copy.deepcopy(retry), diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 28fcd3aa23..08c841b505 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -258,6 +258,7 @@ def __init__( encoding_errors: str = "strict", decode_responses: bool = False, # Connection related kwargs + check_ready: bool = False, health_check_interval: float = 0, socket_connect_timeout: Optional[float] = None, socket_keepalive: bool = False, @@ -313,6 +314,7 @@ def __init__( "encoding_errors": encoding_errors, "decode_responses": decode_responses, # Connection related kwargs + "check_ready": check_ready, "health_check_interval": health_check_interval, "socket_connect_timeout": socket_connect_timeout, "socket_keepalive": socket_keepalive, diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 7404f3d6f8..cb816ed7ff 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -148,6 +148,7 @@ def __init__( encoding_errors: str = "strict", decode_responses: bool = False, parser_class: Type[BaseParser] = DefaultParser, + check_ready: bool = False, socket_read_size: int = 65536, health_check_interval: float = 0, client_name: Optional[str] = None, @@ -204,6 +205,7 @@ def __init__( self.health_check_interval = health_check_interval self.next_health_check: float = -1 self.encoder = encoder_class(encoding, encoding_errors, decode_responses) + self.check_ready = check_ready self.redis_connect_func = redis_connect_func self._reader: Optional[asyncio.StreamReader] = None self._writer: Optional[asyncio.StreamWriter] = None @@ -295,14 +297,48 @@ async def connect(self): """Connects to the Redis server if not already connected""" await self.connect_check_health(check_health=True) + async def _connect_check_ready(self): + await self._connect() + + # Doing handshake since connect and send operations work even when Redis is not ready + if self.check_ready: + try: + ping_cmd = self.pack_command("PING") + if self.socket_timeout: + await asyncio.wait_for( + self._send_packed_command(ping_cmd), self.socket_timeout + ) + else: + await self._send_packed_command(ping_cmd) + + if self.socket_timeout is not None: + async with async_timeout(self.socket_timeout): + response = str_if_bytes(await self._reader.read(1024)) + else: + response = str_if_bytes(await self._reader.read(1024)) + + if not (response.startswith("+PONG") or response.startswith("-NOAUTH")): + raise ResponseError(f"Invalid PING response: {response}") + except ( + socket.timeout, + asyncio.TimeoutError, + ResponseError, + ConnectionResetError, + ) as e: + # `socket_keepalive_options` might contain invalid options + # causing an error. Do not leave the connection open. + self._close() + raise ConnectionError(self._error_message(e)) + async def connect_check_health(self, check_health: bool = True): if self.is_connected: return try: await self.retry.call_with_retry( - lambda: self._connect(), lambda error: self.disconnect() + lambda: self._connect_check_ready(), lambda error: self.disconnect() ) except asyncio.CancelledError: + self._close() raise # in 3.7 and earlier, this is an Exception, not BaseException except (socket.timeout, asyncio.TimeoutError): raise TimeoutError("Timeout connecting to server") @@ -526,8 +562,7 @@ async def send_packed_command( self._send_packed_command(command), self.socket_timeout ) else: - self._writer.writelines(command) - await self._writer.drain() + await self._send_packed_command(command) except asyncio.TimeoutError: await self.disconnect(nowait=True) raise TimeoutError("Timeout writing to socket") from None @@ -774,7 +809,7 @@ async def _connect(self): except (OSError, TypeError): # `socket_keepalive_options` might contain invalid options # causing an error. Do not leave the connection open. - writer.close() + self._close() raise def _host_error(self) -> str: @@ -933,7 +968,6 @@ async def _connect(self): reader, writer = await asyncio.open_unix_connection(path=self.path) self._reader = reader self._writer = writer - await self.on_connect() def _host_error(self) -> str: return self.path diff --git a/redis/client.py b/redis/client.py index fda927507a..52edbb634b 100755 --- a/redis/client.py +++ b/redis/client.py @@ -203,6 +203,7 @@ def __init__( encoding: str = "utf-8", encoding_errors: str = "strict", decode_responses: bool = False, + check_ready: bool = False, retry_on_timeout: bool = False, retry_on_error: Optional[List[Type[Exception]]] = None, ssl: bool = False, @@ -265,6 +266,7 @@ def __init__( "encoding": encoding, "encoding_errors": encoding_errors, "decode_responses": decode_responses, + "check_ready": check_ready, "retry_on_error": retry_on_error, "retry": copy.deepcopy(retry), "max_connections": max_connections, diff --git a/redis/connection.py b/redis/connection.py index 08e980e866..7ef7124097 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -236,6 +236,7 @@ def __init__( encoding: str = "utf-8", encoding_errors: str = "strict", decode_responses: bool = False, + check_ready: bool = False, parser_class=DefaultParser, socket_read_size: int = 65536, health_check_interval: int = 0, @@ -302,6 +303,7 @@ def __init__( self.redis_connect_func = redis_connect_func self.encoder = Encoder(encoding, encoding_errors, decode_responses) self.handshake_metadata = None + self.check_ready = check_ready self._sock = None self._socket_read_size = socket_read_size self.set_parser(parser_class) @@ -378,12 +380,35 @@ def connect(self): "Connects to the Redis server if not already connected" self.connect_check_health(check_health=True) + def _connect_check_ready(self): + sock = self._connect() + + # Doing handshake since connect and send operations work even when Redis is not ready + if self.check_ready: + try: + ping_parts = self._command_packer.pack("PING") + for part in ping_parts: + sock.sendall(part) + + response = str_if_bytes(sock.recv(1024)) + if not (response.startswith("+PONG") or response.startswith("-NOAUTH")): + raise ResponseError(f"Invalid PING response: {response}") + except (ConnectionResetError, ResponseError) as err: + try: + sock.shutdown(socket.SHUT_RDWR) # ensure a clean close + except OSError: + pass + sock.close() + raise ConnectionError(self._error_message(err)) + return sock + def connect_check_health(self, check_health: bool = True): if self._sock: return try: sock = self.retry.call_with_retry( - lambda: self._connect(), lambda error: self.disconnect(error) + lambda: self._connect_check_ready(), + lambda error: self.disconnect(error), ) except socket.timeout: raise TimeoutError("Timeout connecting to server") diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index a0429152ec..916a24e541 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -716,7 +716,7 @@ async def test_reading_with_load_balancing_strategies( Connection, send_command=mock.DEFAULT, read_response=mock.DEFAULT, - _connect=mock.DEFAULT, + _connect_check_ready=mock.DEFAULT, can_read_destructive=mock.DEFAULT, on_connect=mock.DEFAULT, ) as mocks: @@ -748,7 +748,7 @@ def execute_command_mock_third(self, *args, **options): execute_command.side_effect = execute_command_mock_first mocks["send_command"].return_value = True mocks["read_response"].return_value = "OK" - mocks["_connect"].return_value = True + mocks["_connect_check_ready"].return_value = True mocks["can_read_destructive"].return_value = False mocks["on_connect"].return_value = True @@ -3090,13 +3090,17 @@ async def execute_command(self, *args, **kwargs): return _create_client + @pytest.mark.parametrize("check_ready", [True, False]) async def test_ssl_connection_without_ssl( - self, create_client: Callable[..., Awaitable[RedisCluster]] + self, create_client: Callable[..., Awaitable[RedisCluster]], check_ready ) -> None: with pytest.raises(RedisClusterException) as e: - await create_client(mocked=False, ssl=False) + await create_client(mocked=False, ssl=False, check_ready=check_ready) e = e.value.__cause__ - assert "Connection closed by server" in str(e) + if check_ready: + assert "Invalid PING response" in str(e) + else: + assert "Connection closed by server" in str(e) async def test_ssl_with_invalid_cert( self, create_client: Callable[..., Awaitable[RedisCluster]] diff --git a/tests/test_asyncio/test_connect.py b/tests/test_asyncio/test_connect.py index 6c4b3c33d7..6705c8cb3c 100644 --- a/tests/test_asyncio/test_connect.py +++ b/tests/test_asyncio/test_connect.py @@ -2,6 +2,8 @@ import re import socket import ssl +import struct +import sys import pytest from redis.asyncio.connection import ( @@ -16,6 +18,7 @@ _CLIENT_NAME = "test-suite-client" _CMD_SEP = b"\r\n" _SUCCESS_RESP = b"+OK" + _CMD_SEP +_PONG_RESP = b"+PONG" + _CMD_SEP _ERROR_RESP = b"-ERR" + _CMD_SEP _SUPPORTED_CMDS = {f"CLIENT SETNAME {_CLIENT_NAME}": _SUCCESS_RESP} @@ -32,18 +35,30 @@ def uds_address(tmpdir): return tmpdir / "uds.sock" -async def test_tcp_connect(tcp_address): +@pytest.mark.parametrize( + "check_ready", [True, False], ids=["check_ready", "no_check_ready"] +) +async def test_tcp_connect(tcp_address, check_ready): host, port = tcp_address - conn = Connection(host=host, port=port, client_name=_CLIENT_NAME, socket_timeout=10) - await _assert_connect(conn, tcp_address) + conn = Connection( + host=host, + port=port, + client_name=_CLIENT_NAME, + socket_timeout=10, + check_ready=check_ready, + ) + await _assert_connect(conn, tcp_address, check_ready) -async def test_uds_connect(uds_address): +@pytest.mark.parametrize( + "check_ready", [True, False], ids=["check_ready", "no_check_ready"] +) +async def test_uds_connect(uds_address, check_ready): path = str(uds_address) conn = UnixDomainSocketConnection( - path=path, client_name=_CLIENT_NAME, socket_timeout=10 + path=path, client_name=_CLIENT_NAME, socket_timeout=10, check_ready=check_ready ) - await _assert_connect(conn, path) + await _assert_connect(conn, path, check_ready) @pytest.mark.ssl @@ -86,7 +101,10 @@ async def test_tcp_ssl_tls12_custom_ciphers(tcp_address, ssl_ciphers): ), ], ) -async def test_tcp_ssl_connect(tcp_address, ssl_min_version): +@pytest.mark.parametrize( + "check_ready", [True, False], ids=["check_ready", "no_check_ready"] +) +async def test_tcp_ssl_connect(tcp_address, ssl_min_version, check_ready): host, port = tcp_address server_certs = get_tls_certificates(cert_type=CertificateType.server) @@ -98,13 +116,140 @@ async def test_tcp_ssl_connect(tcp_address, ssl_min_version): ssl_ca_certs=server_certs.ca_certfile, socket_timeout=10, ssl_min_version=ssl_min_version, + check_ready=check_ready, ) await _assert_connect( - conn, tcp_address, certfile=server_certs.certfile, keyfile=server_certs.keyfile + conn, + tcp_address, + check_ready, + certfile=server_certs.certfile, + keyfile=server_certs.keyfile, ) await conn.disconnect() +@pytest.mark.asyncio +async def test_connect_check_ready_asyncio_timeout_error(tcp_address): + """ + Demonstrates a scenario where redis-py hits an `asyncio.TimeoutError` internally + (via `asyncio.wait_for(...)` or `async_timeout(...)`). Redis-py catches that + and re-raises `ConnectionError`. + """ + host, port = tcp_address + conn = Connection( + host=host, + port=port, + client_name="test-suite-client", + socket_timeout=0.5, + check_ready=True, + ) + + async def _no_response_handler(reader, writer): + # Accept the connection + buffer = await reader.read(1024) + assert "PING" in buffer.decode() + # do nothing (no response to PING). + # The client code will eventually hit asyncio.TimeoutError on read. + await asyncio.sleep(1) + writer.close() + await writer.wait_closed() + + server = await asyncio.start_server(_no_response_handler, host=host, port=port) + async with server: + await server.start_serving() + # We expect ConnectionError due to the underlying asyncio.TimeoutError + # from lack of a timely PONG. + with pytest.raises(ConnectionError): + await conn.connect() + + server.close() + await server.wait_closed() + + +@pytest.mark.asyncio +async def test_connect_check_ready_invalid_ping(tcp_address): + """ + Demonstrates a scenario where redis-py hits an `ResponseError` internally + due to an invalid response to the PING command. + Redis-py catches that and re-raises `ConnectionError`. + """ + host, port = tcp_address + conn = Connection( + host=host, + port=port, + client_name="test-suite-client", + socket_timeout=5, + check_ready=True, + ) + + async def _no_response_handler(reader, writer): + # Accept the connection + buffer = await reader.read(1024) + assert "PING" in buffer.decode() + # Send wrong answer back + writer.write(_ERROR_RESP) + await writer.drain() + writer.close() + await writer.wait_closed() + + server = await asyncio.start_server(_no_response_handler, host=host, port=port) + async with server: + await server.start_serving() + # We expect ConnectionError due to a wrong response to PING. + with pytest.raises(ConnectionError, match="Invalid PING response"): + await conn.connect() + + server.close() + await server.wait_closed() + + +@pytest.mark.asyncio +async def test_connect_check_ready_connection_reset(tcp_address): + """ + Demonstrates a scenario where the server accepts the connection and receives the PING command, + but abruptly resets the connection (sends a TCP RST). This causes the client to raise a + ConnectionResetError internally, which redis-py catches and re-raises as ConnectionError. + """ + host, port = tcp_address + conn = Connection( + host=host, + port=port, + client_name="test-suite-client", + socket_timeout=5, + check_ready=True, + ) + + async def _no_response_handler(reader, writer): + # Accept the connection + buffer = await reader.read(1024) + assert "PING" in buffer.decode() + + sock = writer.transport.get_extra_info("socket") + + # Configure socket for abrupt close (RST packet) + linger_struct = struct.pack("ii", 1, 0) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, linger_struct) + + # Close immediately to trigger ConnectionResetError on client side + sock.close() + + server = await asyncio.start_server(_no_response_handler, host=host, port=port) + async with server: + await server.start_serving() + + if sys.version_info < (3, 11, 3): + # Python 3.11.3+ handles ConnectionResetError differently + # and overloads with TimeoutError. + with pytest.raises(ConnectionError): + await conn.connect() + else: + with pytest.raises(ConnectionError, match="Connection reset by peer"): + await conn.connect() + + server.close() + await server.wait_closed() + + @pytest.mark.ssl @pytest.mark.skipif(not ssl.HAS_TLSv1_3, reason="requires TLSv1.3") async def test_tcp_ssl_version_mismatch(tcp_address): @@ -132,6 +277,7 @@ async def test_tcp_ssl_version_mismatch(tcp_address): async def _assert_connect( conn, server_address, + check_ready=False, certfile=None, keyfile=None, minimum_ssl_version=ssl.TLSVersion.TLSv1_2, @@ -142,7 +288,7 @@ async def _assert_connect( async def _handler(reader, writer): try: - return await _redis_request_handler(reader, writer, stop_event) + return await _redis_request_handler(reader, writer, stop_event, check_ready) finally: writer.close() await writer.wait_closed() @@ -176,7 +322,7 @@ async def _handler(reader, writer): await finished.wait() -async def _redis_request_handler(reader, writer, stop_event): +async def _redis_request_handler(reader, writer, stop_event, check_ready): command = None command_ptr = None fragment_length = None @@ -208,7 +354,10 @@ async def _redis_request_handler(reader, writer, stop_event): continue command = " ".join(command) - resp = _SUPPORTED_CMDS.get(command, _ERROR_RESP) + if check_ready and command == "PING": + resp = _PONG_RESP + else: + resp = _SUPPORTED_CMDS.get(command, _ERROR_RESP) writer.write(resp) await writer.drain() command = None diff --git a/tests/test_asyncio/test_cwe_404.py b/tests/test_asyncio/test_cwe_404.py index 677e165fc6..e1cfea2cde 100644 --- a/tests/test_asyncio/test_cwe_404.py +++ b/tests/test_asyncio/test_cwe_404.py @@ -173,7 +173,7 @@ async def op(pipe): with pytest.raises(asyncio.CancelledError): await t - # we have now cancelled the pieline in the middle of a request, + # we have now cancelled the pipeline in the middle of a request, # make sure that the connection is still usable pipe.get("bar") pipe.ping() diff --git a/tests/test_cluster.py b/tests/test_cluster.py index d96342f87a..07e30e3a28 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -625,7 +625,7 @@ def test_reading_with_load_balancing_strategies( Connection, send_command=DEFAULT, read_response=DEFAULT, - _connect=DEFAULT, + _connect_check_ready=DEFAULT, can_read=DEFAULT, on_connect=DEFAULT, ) as mocks: @@ -654,7 +654,7 @@ def parse_response_mock_third(connection, *args, **options): parse_response.side_effect = parse_response_mock_first mocks["send_command"].return_value = True mocks["read_response"].return_value = "OK" - mocks["_connect"].return_value = True + mocks["_connect_check_ready"].return_value = True mocks["can_read"].return_value = False mocks["on_connect"].return_value = True diff --git a/tests/test_connect.py b/tests/test_connect.py index f3c02b330f..cae1446d2b 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -2,17 +2,19 @@ import socket import socketserver import ssl +import struct import threading import pytest from redis.connection import Connection, SSLConnection, UnixDomainSocketConnection -from redis.exceptions import RedisError +from redis.exceptions import RedisError, ConnectionError from .ssl_utils import CertificateType, get_tls_certificates _CLIENT_NAME = "test-suite-client" _CMD_SEP = b"\r\n" _SUCCESS_RESP = b"+OK" + _CMD_SEP +_PONG_RESP = b"+PONG" + _CMD_SEP _ERROR_RESP = b"-ERR" + _CMD_SEP _SUPPORTED_CMDS = {f"CLIENT SETNAME {_CLIENT_NAME}": _SUCCESS_RESP} @@ -29,16 +31,30 @@ def uds_address(tmpdir): return tmpdir / "uds.sock" -def test_tcp_connect(tcp_address): +@pytest.mark.parametrize( + "check_ready", [True, False], ids=["check_ready", "no_check_ready"] +) +def test_tcp_connect(tcp_address, check_ready): host, port = tcp_address - conn = Connection(host=host, port=port, client_name=_CLIENT_NAME, socket_timeout=10) - _assert_connect(conn, tcp_address) + conn = Connection( + host=host, + port=port, + client_name=_CLIENT_NAME, + socket_timeout=10, + check_ready=check_ready, + ) + _assert_connect(conn, tcp_address, check_ready) -def test_uds_connect(uds_address): +@pytest.mark.parametrize( + "check_ready", [True, False], ids=["check_ready", "no_check_ready"] +) +def test_uds_connect(uds_address, check_ready): path = str(uds_address) - conn = UnixDomainSocketConnection(path, client_name=_CLIENT_NAME, socket_timeout=10) - _assert_connect(conn, path) + conn = UnixDomainSocketConnection( + path, client_name=_CLIENT_NAME, socket_timeout=10, check_ready=check_ready + ) + _assert_connect(conn, path, check_ready) @pytest.mark.ssl @@ -52,7 +68,10 @@ def test_uds_connect(uds_address): ), ], ) -def test_tcp_ssl_connect(tcp_address, ssl_min_version): +@pytest.mark.parametrize( + "check_ready", [True, False], ids=["check_ready", "no_check_ready"] +) +def test_tcp_ssl_connect(tcp_address, ssl_min_version, check_ready): host, port = tcp_address server_certs = get_tls_certificates(cert_type=CertificateType.server) conn = SSLConnection( @@ -62,9 +81,14 @@ def test_tcp_ssl_connect(tcp_address, ssl_min_version): ssl_ca_certs=server_certs.ca_certfile, socket_timeout=10, ssl_min_version=ssl_min_version, + check_ready=check_ready, ) _assert_connect( - conn, tcp_address, certfile=server_certs.certfile, keyfile=server_certs.keyfile + conn, + tcp_address, + check_ready, + certfile=server_certs.certfile, + keyfile=server_certs.keyfile, ) @@ -136,13 +160,83 @@ def test_tcp_ssl_version_mismatch(tcp_address): ) -def _assert_connect(conn, server_address, **tcp_kw): +def test_connect_check_ready_connection_reset(tcp_address): + host, port = tcp_address + conn = Connection( + host=host, + port=port, + client_name=_CLIENT_NAME, + socket_timeout=10, + check_ready=True, + ) + + class _CloseConnectionRequestHandler(socketserver.BaseRequestHandler): + def handle(self): + data = self.request.recv(1024) + assert "PING" in data.decode() + + # Configure socket for abrupt close (RST packet) + linger_struct = struct.pack("ii", 1, 0) + self.request.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, linger_struct) + + # Close immediately to trigger ConnectionResetError on client side + self.request.close() + + server = _RedisTCPServer( + (host, port), _CloseConnectionRequestHandler, check_ready=True + ) + with server as aserver: + t = threading.Thread(target=aserver.serve_forever) + t.start() + try: + aserver.wait_online() + with pytest.raises(ConnectionError, match="Connection reset by peer"): + conn.connect() + finally: + aserver.stop() + t.join(timeout=5) + + +def test_connect_check_ready_invalid_ping(tcp_address): + host, port = tcp_address + conn = Connection( + host=host, + port=port, + client_name=_CLIENT_NAME, + socket_timeout=10, + check_ready=True, + ) + + class _InvalidPingRequestHandler(socketserver.BaseRequestHandler): + def handle(self): + data = self.request.recv(1024) + assert "PING" in data.decode() + self.request.sendall(_ERROR_RESP) + + server = _RedisTCPServer((host, port), _InvalidPingRequestHandler, check_ready=True) + with server as aserver: + t = threading.Thread(target=aserver.serve_forever) + t.start() + try: + aserver.wait_online() + with pytest.raises(ConnectionError, match="Invalid PING response"): + conn.connect() + finally: + aserver.stop() + t.join(timeout=5) + + +def _assert_connect(conn, server_address, check_ready=False, **tcp_kw): if isinstance(server_address, str): if not _RedisUDSServer: pytest.skip("Unix domain sockets are not supported on this platform") - server = _RedisUDSServer(server_address, _RedisRequestHandler) + server = _RedisUDSServer( + server_address, _RedisRequestHandler, check_ready=check_ready + ) else: - server = _RedisTCPServer(server_address, _RedisRequestHandler, **tcp_kw) + server = _RedisTCPServer( + server_address, _RedisRequestHandler, check_ready=check_ready, **tcp_kw + ) with server as aserver: t = threading.Thread(target=aserver.serve_forever) t.start() @@ -163,6 +257,7 @@ def __init__( keyfile=None, minimum_ssl_version=ssl.TLSVersion.TLSv1_2, maximum_ssl_version=ssl.TLSVersion.TLSv1_3, + check_ready=False, **kw, ) -> None: self._ready_event = threading.Event() @@ -171,6 +266,7 @@ def __init__( self._keyfile = keyfile self._minimum_ssl_version = minimum_ssl_version self._maximum_ssl_version = maximum_ssl_version + self.check_ready = check_ready super().__init__(*args, **kw) def service_actions(self): @@ -201,9 +297,10 @@ def get_request(self): if hasattr(socketserver, "UnixStreamServer"): class _RedisUDSServer(socketserver.UnixStreamServer): - def __init__(self, *args, **kw) -> None: + def __init__(self, *args, check_ready=False, **kw) -> None: self._ready_event = threading.Event() self._stop_requested = False + self.check_ready = check_ready super().__init__(*args, **kw) def service_actions(self): @@ -223,13 +320,7 @@ def is_serving(self): _RedisUDSServer = None -class _RedisRequestHandler(socketserver.StreamRequestHandler): - def setup(self): - pass - - def finish(self): - pass - +class _RedisRequestHandler(socketserver.BaseRequestHandler): def handle(self): buffer = b"" command = None @@ -265,6 +356,9 @@ def handle(self): continue command = " ".join(command) - resp = _SUPPORTED_CMDS.get(command, _ERROR_RESP) + if self.server.check_ready and command == "PING": + resp = _PONG_RESP + else: + resp = _SUPPORTED_CMDS.get(command, _ERROR_RESP) self.request.sendall(resp) command = None diff --git a/tests/test_ssl.py b/tests/test_ssl.py index 2a945ac287..8befd7a3d6 100644 --- a/tests/test_ssl.py +++ b/tests/test_ssl.py @@ -41,14 +41,19 @@ def test_ssl_connection(self, request): assert r.ping() r.close() - def test_ssl_connection_without_ssl(self, request): + @pytest.mark.parametrize("check_ready", [True, False]) + def test_ssl_connection_without_ssl(self, request, check_ready): ssl_url = request.config.option.redis_ssl_url p = urlparse(ssl_url)[1].split(":") - r = redis.Redis(host=p[0], port=p[1], ssl=False) + r = redis.Redis(host=p[0], port=p[1], ssl=False, check_ready=check_ready) with pytest.raises(ConnectionError) as e: r.ping() - assert "Connection closed by server" in str(e) + + if check_ready: + assert "Invalid PING response" in str(e) + else: + assert "Connection closed by server" in str(e) r.close() def test_validating_self_signed_certificate(self, request):