Skip to content

Commit a5837b8

Browse files
Add Redis readiness verification (redis#3555)
1 parent 04589d4 commit a5837b8

File tree

11 files changed

+407
-64
lines changed

11 files changed

+407
-64
lines changed

redis/asyncio/client.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ def __init__(
224224
encoding: str = "utf-8",
225225
encoding_errors: str = "strict",
226226
decode_responses: bool = False,
227+
check_server_ready: bool = False,
227228
retry_on_timeout: bool = False,
228229
retry_on_error: Optional[list] = None,
229230
ssl: bool = False,
@@ -255,6 +256,10 @@ def __init__(
255256
`retry_on_error` to a list of the error/s to retry on, then set
256257
`retry` to a valid `Retry` object.
257258
To retry on TimeoutError, `retry_on_timeout` can also be set to `True`.
259+
260+
Args:
261+
check_server_ready: if `True`, an extra handshake is performed by sending a PING command, since
262+
connect and send operations work even when Redis server is not ready.
258263
"""
259264
kwargs: Dict[str, Any]
260265
if event_dispatcher is None:
@@ -291,6 +296,7 @@ def __init__(
291296
"encoding": encoding,
292297
"encoding_errors": encoding_errors,
293298
"decode_responses": decode_responses,
299+
"check_server_ready": check_server_ready,
294300
"retry_on_timeout": retry_on_timeout,
295301
"retry_on_error": retry_on_error,
296302
"retry": copy.deepcopy(retry),

redis/asyncio/cluster.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,7 @@ def __init__(
258258
encoding_errors: str = "strict",
259259
decode_responses: bool = False,
260260
# Connection related kwargs
261+
check_server_ready: bool = False,
261262
health_check_interval: float = 0,
262263
socket_connect_timeout: Optional[float] = None,
263264
socket_keepalive: bool = False,
@@ -313,6 +314,7 @@ def __init__(
313314
"encoding_errors": encoding_errors,
314315
"decode_responses": decode_responses,
315316
# Connection related kwargs
317+
"check_server_ready": check_server_ready,
316318
"health_check_interval": health_check_interval,
317319
"socket_connect_timeout": socket_connect_timeout,
318320
"socket_keepalive": socket_keepalive,

redis/asyncio/connection.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ def __init__(
148148
encoding_errors: str = "strict",
149149
decode_responses: bool = False,
150150
parser_class: Type[BaseParser] = DefaultParser,
151+
check_server_ready: bool = False,
151152
socket_read_size: int = 65536,
152153
health_check_interval: float = 0,
153154
client_name: Optional[str] = None,
@@ -204,6 +205,7 @@ def __init__(
204205
self.health_check_interval = health_check_interval
205206
self.next_health_check: float = -1
206207
self.encoder = encoder_class(encoding, encoding_errors, decode_responses)
208+
self.check_server_ready = check_server_ready
207209
self.redis_connect_func = redis_connect_func
208210
self._reader: Optional[asyncio.StreamReader] = None
209211
self._writer: Optional[asyncio.StreamWriter] = None
@@ -300,9 +302,11 @@ async def connect_check_health(self, check_health: bool = True):
300302
return
301303
try:
302304
await self.retry.call_with_retry(
303-
lambda: self._connect(), lambda error: self.disconnect()
305+
lambda: self._connect_check_server_ready(),
306+
lambda error: self.disconnect(),
304307
)
305308
except asyncio.CancelledError:
309+
self._close()
306310
raise # in 3.7 and earlier, this is an Exception, not BaseException
307311
except (socket.timeout, asyncio.TimeoutError):
308312
raise TimeoutError("Timeout connecting to server")
@@ -337,6 +341,33 @@ async def connect_check_health(self, check_health: bool = True):
337341
if task and inspect.isawaitable(task):
338342
await task
339343

344+
async def _connect_check_server_ready(self):
345+
await self._connect()
346+
347+
# Doing handshake since connect and send operations work even when Redis is not ready
348+
if self.check_server_ready:
349+
try:
350+
await self.send_command("PING", check_health=False)
351+
352+
if self.socket_timeout is not None:
353+
async with async_timeout(self.socket_timeout):
354+
response = str_if_bytes(await self._reader.read(1024))
355+
else:
356+
response = str_if_bytes(await self._reader.read(1024))
357+
358+
if not (response.startswith("+PONG") or response.startswith("-NOAUTH")):
359+
raise ResponseError(f"Invalid PING response: {response}")
360+
except (
361+
socket.timeout,
362+
asyncio.TimeoutError,
363+
ResponseError,
364+
ConnectionResetError,
365+
) as e:
366+
# `socket_keepalive_options` might contain invalid options
367+
# causing an error. Do not leave the connection open.
368+
self._close()
369+
raise ConnectionError(self._error_message(e))
370+
340371
@abstractmethod
341372
async def _connect(self):
342373
pass
@@ -526,8 +557,7 @@ async def send_packed_command(
526557
self._send_packed_command(command), self.socket_timeout
527558
)
528559
else:
529-
self._writer.writelines(command)
530-
await self._writer.drain()
560+
await self._send_packed_command(command)
531561
except asyncio.TimeoutError:
532562
await self.disconnect(nowait=True)
533563
raise TimeoutError("Timeout writing to socket") from None
@@ -774,7 +804,7 @@ async def _connect(self):
774804
except (OSError, TypeError):
775805
# `socket_keepalive_options` might contain invalid options
776806
# causing an error. Do not leave the connection open.
777-
writer.close()
807+
self._close()
778808
raise
779809

780810
def _host_error(self) -> str:
@@ -933,7 +963,6 @@ async def _connect(self):
933963
reader, writer = await asyncio.open_unix_connection(path=self.path)
934964
self._reader = reader
935965
self._writer = writer
936-
await self.on_connect()
937966

938967
def _host_error(self) -> str:
939968
return self.path

redis/client.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ def __init__(
206206
charset: Optional[str] = None,
207207
errors: Optional[str] = None,
208208
decode_responses: bool = False,
209+
check_server_ready: bool = False,
209210
retry_on_timeout: bool = False,
210211
retry_on_error: Optional[List[Type[Exception]]] = None,
211212
ssl: bool = False,
@@ -246,10 +247,11 @@ def __init__(
246247
To retry on TimeoutError, `retry_on_timeout` can also be set to `True`.
247248
248249
Args:
249-
250-
single_connection_client:
251-
if `True`, connection pool is not used. In that case `Redis`
252-
instance use is not thread safe.
250+
check_server_ready: if `True`, an extra handshake is performed by sending a PING command, since
251+
connect and send operations work even when Redis server is not ready.
252+
single_connection_client:
253+
if `True`, connection pool is not used. In that case `Redis`
254+
instance use is not thread safe.
253255
"""
254256
if event_dispatcher is None:
255257
self._event_dispatcher = EventDispatcher()
@@ -282,6 +284,7 @@ def __init__(
282284
"encoding": encoding,
283285
"encoding_errors": encoding_errors,
284286
"decode_responses": decode_responses,
287+
"check_server_ready": check_server_ready,
285288
"retry_on_error": retry_on_error,
286289
"retry": copy.deepcopy(retry),
287290
"max_connections": max_connections,

redis/connection.py

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ def __init__(
236236
encoding: str = "utf-8",
237237
encoding_errors: str = "strict",
238238
decode_responses: bool = False,
239+
check_server_ready: bool = False,
239240
parser_class=DefaultParser,
240241
socket_read_size: int = 65536,
241242
health_check_interval: int = 0,
@@ -302,6 +303,7 @@ def __init__(
302303
self.redis_connect_func = redis_connect_func
303304
self.encoder = Encoder(encoding, encoding_errors, decode_responses)
304305
self.handshake_metadata = None
306+
self.check_server_ready = check_server_ready
305307
self._sock = None
306308
self._socket_read_size = socket_read_size
307309
self.set_parser(parser_class)
@@ -382,15 +384,15 @@ def connect_check_health(self, check_health: bool = True):
382384
if self._sock:
383385
return
384386
try:
385-
sock = self.retry.call_with_retry(
386-
lambda: self._connect(), lambda error: self.disconnect(error)
387+
self.retry.call_with_retry(
388+
lambda: self._connect_check_server_ready(),
389+
lambda error: self.disconnect(error),
387390
)
388391
except socket.timeout:
389392
raise TimeoutError("Timeout connecting to server")
390393
except OSError as e:
391394
raise ConnectionError(self._error_message(e))
392395

393-
self._sock = sock
394396
try:
395397
if self.redis_connect_func is None:
396398
# Use the default on_connect function
@@ -412,8 +414,27 @@ def connect_check_health(self, check_health: bool = True):
412414
if callback:
413415
callback(self)
414416

417+
def _connect_check_server_ready(self):
418+
self._connect()
419+
420+
# Doing handshake since connect and send operations work even when Redis is not ready
421+
if self.check_server_ready:
422+
try:
423+
self.send_command("PING", check_health=False)
424+
425+
response = str_if_bytes(self._sock.recv(1024))
426+
if not (response.startswith("+PONG") or response.startswith("-NOAUTH")):
427+
raise ResponseError(f"Invalid PING response: {response}")
428+
except (ConnectionResetError, ResponseError) as err:
429+
try:
430+
self._sock.shutdown(socket.SHUT_RDWR) # ensure a clean close
431+
except OSError:
432+
pass
433+
self._sock.close()
434+
raise ConnectionError(self._error_message(err))
435+
415436
@abstractmethod
416-
def _connect(self):
437+
def _connect(self) -> None:
417438
pass
418439

419440
@abstractmethod
@@ -752,7 +773,7 @@ def repr_pieces(self):
752773
pieces.append(("client_name", self.client_name))
753774
return pieces
754775

755-
def _connect(self):
776+
def _connect(self) -> None:
756777
"Create a TCP socket connection"
757778
# we want to mimic what socket.create_connection does to support
758779
# ipv4/ipv6, but we want to set options prior to calling
@@ -782,7 +803,8 @@ def _connect(self):
782803

783804
# set the socket_timeout now that we're connected
784805
sock.settimeout(self.socket_timeout)
785-
return sock
806+
self._sock = sock
807+
return
786808

787809
except OSError as _:
788810
err = _
@@ -1093,15 +1115,15 @@ def __init__(
10931115
self.ssl_ciphers = ssl_ciphers
10941116
super().__init__(**kwargs)
10951117

1096-
def _connect(self):
1118+
def _connect(self) -> None:
10971119
"""
10981120
Wrap the socket with SSL support, handling potential errors.
10991121
"""
1100-
sock = super()._connect()
1122+
super()._connect()
11011123
try:
1102-
return self._wrap_socket_with_ssl(sock)
1124+
self._sock = self._wrap_socket_with_ssl(self._sock)
11031125
except (OSError, RedisError):
1104-
sock.close()
1126+
self._sock.close()
11051127
raise
11061128

11071129
def _wrap_socket_with_ssl(self, sock):
@@ -1198,7 +1220,7 @@ def repr_pieces(self):
11981220
pieces.append(("client_name", self.client_name))
11991221
return pieces
12001222

1201-
def _connect(self):
1223+
def _connect(self) -> None:
12021224
"Create a Unix domain socket connection"
12031225
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
12041226
sock.settimeout(self.socket_connect_timeout)
@@ -1213,7 +1235,7 @@ def _connect(self):
12131235
sock.close()
12141236
raise
12151237
sock.settimeout(self.socket_timeout)
1216-
return sock
1238+
self._sock = sock
12171239

12181240
def _host_error(self):
12191241
return self.path

tests/test_asyncio/test_cluster.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -716,7 +716,7 @@ async def test_reading_with_load_balancing_strategies(
716716
Connection,
717717
send_command=mock.DEFAULT,
718718
read_response=mock.DEFAULT,
719-
_connect=mock.DEFAULT,
719+
_connect_check_server_ready=mock.DEFAULT,
720720
can_read_destructive=mock.DEFAULT,
721721
on_connect=mock.DEFAULT,
722722
) as mocks:
@@ -748,7 +748,7 @@ def execute_command_mock_third(self, *args, **options):
748748
execute_command.side_effect = execute_command_mock_first
749749
mocks["send_command"].return_value = True
750750
mocks["read_response"].return_value = "OK"
751-
mocks["_connect"].return_value = True
751+
mocks["_connect_check_server_ready"].return_value = True
752752
mocks["can_read_destructive"].return_value = False
753753
mocks["on_connect"].return_value = True
754754

@@ -3090,13 +3090,19 @@ async def execute_command(self, *args, **kwargs):
30903090

30913091
return _create_client
30923092

3093+
@pytest.mark.parametrize("check_server_ready", [True, False])
30933094
async def test_ssl_connection_without_ssl(
3094-
self, create_client: Callable[..., Awaitable[RedisCluster]]
3095+
self, create_client: Callable[..., Awaitable[RedisCluster]], check_server_ready
30953096
) -> None:
30963097
with pytest.raises(RedisClusterException) as e:
3097-
await create_client(mocked=False, ssl=False)
3098+
await create_client(
3099+
mocked=False, ssl=False, check_server_ready=check_server_ready
3100+
)
30983101
e = e.value.__cause__
3099-
assert "Connection closed by server" in str(e)
3102+
if check_server_ready:
3103+
assert "Invalid PING response" in str(e)
3104+
else:
3105+
assert "Connection closed by server" in str(e)
31003106

31013107
async def test_ssl_with_invalid_cert(
31023108
self, create_client: Callable[..., Awaitable[RedisCluster]]

0 commit comments

Comments
 (0)