Skip to content

Commit 8c83496

Browse files
authored
Merge branch 'master' into create-new-stream-commands
2 parents 0608e5f + 4c9512b commit 8c83496

File tree

10 files changed

+144
-81
lines changed

10 files changed

+144
-81
lines changed

redis/asyncio/connection.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -295,13 +295,18 @@ async def connect(self):
295295
"""Connects to the Redis server if not already connected"""
296296
await self.connect_check_health(check_health=True)
297297

298-
async def connect_check_health(self, check_health: bool = True):
298+
async def connect_check_health(
299+
self, check_health: bool = True, retry_socket_connect: bool = True
300+
):
299301
if self.is_connected:
300302
return
301303
try:
302-
await self.retry.call_with_retry(
303-
lambda: self._connect(), lambda error: self.disconnect()
304-
)
304+
if retry_socket_connect:
305+
await self.retry.call_with_retry(
306+
lambda: self._connect(), lambda error: self.disconnect()
307+
)
308+
else:
309+
await self._connect()
305310
except asyncio.CancelledError:
306311
raise # in 3.7 and earlier, this is an Exception, not BaseException
307312
except (socket.timeout, asyncio.TimeoutError):
@@ -1037,6 +1042,7 @@ class ConnectionPool:
10371042
By default, TCP connections are created unless ``connection_class``
10381043
is specified. Use :py:class:`~redis.UnixDomainSocketConnection` for
10391044
unix sockets.
1045+
:py:class:`~redis.SSLConnection` can be used for SSL enabled connections.
10401046
10411047
Any additional keyword arguments are passed to the constructor of
10421048
``connection_class``.

redis/asyncio/retry.py

Lines changed: 14 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,16 @@
22
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Tuple, Type, TypeVar
33

44
from redis.exceptions import ConnectionError, RedisError, TimeoutError
5-
6-
if TYPE_CHECKING:
7-
from redis.backoff import AbstractBackoff
8-
5+
from redis.retry import AbstractRetry
96

107
T = TypeVar("T")
118

9+
if TYPE_CHECKING:
10+
from redis.backoff import AbstractBackoff
1211

13-
class Retry:
14-
"""Retry a specific number of times after a failure"""
1512

16-
__slots__ = "_backoff", "_retries", "_supported_errors"
13+
class Retry(AbstractRetry[RedisError]):
14+
__hash__ = AbstractRetry.__hash__
1715

1816
def __init__(
1917
self,
@@ -24,36 +22,17 @@ def __init__(
2422
TimeoutError,
2523
),
2624
):
27-
"""
28-
Initialize a `Retry` object with a `Backoff` object
29-
that retries a maximum of `retries` times.
30-
`retries` can be negative to retry forever.
31-
You can specify the types of supported errors which trigger
32-
a retry with the `supported_errors` parameter.
33-
"""
34-
self._backoff = backoff
35-
self._retries = retries
36-
self._supported_errors = supported_errors
25+
super().__init__(backoff, retries, supported_errors)
3726

38-
def update_supported_errors(self, specified_errors: list):
39-
"""
40-
Updates the supported errors with the specified error types
41-
"""
42-
self._supported_errors = tuple(
43-
set(self._supported_errors + tuple(specified_errors))
44-
)
45-
46-
def get_retries(self) -> int:
47-
"""
48-
Get the number of retries.
49-
"""
50-
return self._retries
27+
def __eq__(self, other: Any) -> bool:
28+
if not isinstance(other, Retry):
29+
return NotImplemented
5130

52-
def update_retries(self, value: int) -> None:
53-
"""
54-
Set the number of retries.
55-
"""
56-
self._retries = value
31+
return (
32+
self._backoff == other._backoff
33+
and self._retries == other._retries
34+
and set(self._supported_errors) == set(other._supported_errors)
35+
)
5736

5837
async def call_with_retry(
5938
self, do: Callable[[], Awaitable[T]], fail: Callable[[RedisError], Any]

redis/asyncio/sentinel.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,12 @@
1111
SSLConnection,
1212
)
1313
from redis.commands import AsyncSentinelCommands
14-
from redis.exceptions import ConnectionError, ReadOnlyError, ResponseError, TimeoutError
15-
from redis.utils import str_if_bytes
14+
from redis.exceptions import (
15+
ConnectionError,
16+
ReadOnlyError,
17+
ResponseError,
18+
TimeoutError,
19+
)
1620

1721

1822
class MasterNotFoundError(ConnectionError):
@@ -37,11 +41,10 @@ def __repr__(self):
3741

3842
async def connect_to(self, address):
3943
self.host, self.port = address
40-
await super().connect()
41-
if self.connection_pool.check_connection:
42-
await self.send_command("PING")
43-
if str_if_bytes(await self.read_response()) != "PONG":
44-
raise ConnectionError("PING failed")
44+
await self.connect_check_health(
45+
check_health=self.connection_pool.check_connection,
46+
retry_socket_connect=False,
47+
)
4548

4649
async def _connect_retry(self):
4750
if self._reader:

redis/backoff.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def __hash__(self) -> int:
170170
return hash((self._base, self._cap))
171171

172172
def __eq__(self, other) -> bool:
173-
if not isinstance(other, EqualJitterBackoff):
173+
if not isinstance(other, ExponentialWithJitterBackoff):
174174
return NotImplemented
175175

176176
return self._base == other._base and self._cap == other._cap

redis/connection.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -379,13 +379,18 @@ def connect(self):
379379
"Connects to the Redis server if not already connected"
380380
self.connect_check_health(check_health=True)
381381

382-
def connect_check_health(self, check_health: bool = True):
382+
def connect_check_health(
383+
self, check_health: bool = True, retry_socket_connect: bool = True
384+
):
383385
if self._sock:
384386
return
385387
try:
386-
sock = self.retry.call_with_retry(
387-
lambda: self._connect(), lambda error: self.disconnect(error)
388-
)
388+
if retry_socket_connect:
389+
sock = self.retry.call_with_retry(
390+
lambda: self._connect(), lambda error: self.disconnect(error)
391+
)
392+
else:
393+
sock = self._connect()
389394
except socket.timeout:
390395
raise TimeoutError("Timeout connecting to server")
391396
except OSError as e:
@@ -1316,6 +1321,7 @@ class ConnectionPool:
13161321
By default, TCP connections are created unless ``connection_class``
13171322
is specified. Use class:`.UnixDomainSocketConnection` for
13181323
unix sockets.
1324+
:py:class:`~redis.SSLConnection` can be used for SSL enabled connections.
13191325
13201326
Any additional keyword arguments are passed to the constructor of
13211327
``connection_class``.

redis/retry.py

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,27 @@
1+
import abc
12
import socket
23
from time import sleep
3-
from typing import TYPE_CHECKING, Any, Callable, Iterable, Tuple, Type, TypeVar
4+
from typing import TYPE_CHECKING, Any, Callable, Generic, Iterable, Tuple, Type, TypeVar
45

56
from redis.exceptions import ConnectionError, TimeoutError
67

78
T = TypeVar("T")
9+
E = TypeVar("E", bound=Exception, covariant=True)
810

911
if TYPE_CHECKING:
1012
from redis.backoff import AbstractBackoff
1113

1214

13-
class Retry:
15+
class AbstractRetry(Generic[E], abc.ABC):
1416
"""Retry a specific number of times after a failure"""
1517

18+
_supported_errors: Tuple[Type[E], ...]
19+
1620
def __init__(
1721
self,
1822
backoff: "AbstractBackoff",
1923
retries: int,
20-
supported_errors: Tuple[Type[Exception], ...] = (
21-
ConnectionError,
22-
TimeoutError,
23-
socket.timeout,
24-
),
24+
supported_errors: Tuple[Type[E], ...],
2525
):
2626
"""
2727
Initialize a `Retry` object with a `Backoff` object
@@ -34,22 +34,14 @@ def __init__(
3434
self._retries = retries
3535
self._supported_errors = supported_errors
3636

37+
@abc.abstractmethod
3738
def __eq__(self, other: Any) -> bool:
38-
if not isinstance(other, Retry):
39-
return NotImplemented
40-
41-
return (
42-
self._backoff == other._backoff
43-
and self._retries == other._retries
44-
and set(self._supported_errors) == set(other._supported_errors)
45-
)
39+
return NotImplemented
4640

4741
def __hash__(self) -> int:
4842
return hash((self._backoff, self._retries, frozenset(self._supported_errors)))
4943

50-
def update_supported_errors(
51-
self, specified_errors: Iterable[Type[Exception]]
52-
) -> None:
44+
def update_supported_errors(self, specified_errors: Iterable[Type[E]]) -> None:
5345
"""
5446
Updates the supported errors with the specified error types
5547
"""
@@ -69,6 +61,32 @@ def update_retries(self, value: int) -> None:
6961
"""
7062
self._retries = value
7163

64+
65+
class Retry(AbstractRetry[Exception]):
66+
__hash__ = AbstractRetry.__hash__
67+
68+
def __init__(
69+
self,
70+
backoff: "AbstractBackoff",
71+
retries: int,
72+
supported_errors: Tuple[Type[Exception], ...] = (
73+
ConnectionError,
74+
TimeoutError,
75+
socket.timeout,
76+
),
77+
):
78+
super().__init__(backoff, retries, supported_errors)
79+
80+
def __eq__(self, other: Any) -> bool:
81+
if not isinstance(other, Retry):
82+
return NotImplemented
83+
84+
return (
85+
self._backoff == other._backoff
86+
and self._retries == other._retries
87+
and set(self._supported_errors) == set(other._supported_errors)
88+
)
89+
7290
def call_with_retry(
7391
self,
7492
do: Callable[[], T],

redis/sentinel.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,12 @@
55
from redis.client import Redis
66
from redis.commands import SentinelCommands
77
from redis.connection import Connection, ConnectionPool, SSLConnection
8-
from redis.exceptions import ConnectionError, ReadOnlyError, ResponseError, TimeoutError
9-
from redis.utils import str_if_bytes
8+
from redis.exceptions import (
9+
ConnectionError,
10+
ReadOnlyError,
11+
ResponseError,
12+
TimeoutError,
13+
)
1014

1115

1216
class MasterNotFoundError(ConnectionError):
@@ -35,11 +39,11 @@ def __repr__(self):
3539

3640
def connect_to(self, address):
3741
self.host, self.port = address
38-
super().connect()
39-
if self.connection_pool.check_connection:
40-
self.send_command("PING")
41-
if str_if_bytes(self.read_response()) != "PONG":
42-
raise ConnectionError("PING failed")
42+
43+
self.connect_check_health(
44+
check_health=self.connection_pool.check_connection,
45+
retry_socket_connect=False,
46+
)
4347

4448
def _connect_retry(self):
4549
if self._sock:

tests/test_asyncio/test_sentinel_managed_connection.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,5 @@ async def mock_connect():
3333
conn._connect.side_effect = mock_connect
3434
await conn.connect()
3535
assert conn._connect.call_count == 3
36+
assert connection_pool.get_master_address.call_count == 3
3637
await conn.disconnect()

tests/test_retry.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from unittest.mock import patch
22

33
import pytest
4+
from redis.asyncio.retry import Retry as AsyncRetry
45
from redis.backoff import (
56
AbstractBackoff,
67
ConstantBackoff,
@@ -89,6 +90,7 @@ def test_retry_on_error_retry(self, Class, retries):
8990
assert c.retry._retries == retries
9091

9192

93+
@pytest.mark.parametrize("retry_class", [Retry, AsyncRetry])
9294
@pytest.mark.parametrize(
9395
"args",
9496
[
@@ -108,8 +110,8 @@ def test_retry_on_error_retry(self, Class, retries):
108110
for backoff in ((Backoff(), 2), (Backoff(25), 5), (Backoff(25, 5), 5))
109111
],
110112
)
111-
def test_retry_eq_and_hashable(args):
112-
assert Retry(*args) == Retry(*args)
113+
def test_retry_eq_and_hashable(retry_class, args):
114+
assert retry_class(*args) == retry_class(*args)
113115

114116
# create another retry object with different parameters
115117
copy = list(args)
@@ -118,9 +120,19 @@ def test_retry_eq_and_hashable(args):
118120
else:
119121
copy[0] = ConstantBackoff(9000)
120122

121-
assert Retry(*args) != Retry(*copy)
122-
assert Retry(*copy) != Retry(*args)
123-
assert len({Retry(*args), Retry(*args), Retry(*copy), Retry(*copy)}) == 2
123+
assert retry_class(*args) != retry_class(*copy)
124+
assert retry_class(*copy) != retry_class(*args)
125+
assert (
126+
len(
127+
{
128+
retry_class(*args),
129+
retry_class(*args),
130+
retry_class(*copy),
131+
retry_class(*copy),
132+
}
133+
)
134+
== 2
135+
)
124136

125137

126138
class TestRetry:
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import socket
2+
3+
from redis.retry import Retry
4+
from redis.sentinel import SentinelManagedConnection
5+
from redis.backoff import NoBackoff
6+
from unittest import mock
7+
8+
9+
def test_connect_retry_on_timeout_error(master_host):
10+
"""Test that the _connect function is retried in case of a timeout"""
11+
connection_pool = mock.Mock()
12+
connection_pool.get_master_address = mock.Mock(
13+
return_value=(master_host[0], master_host[1])
14+
)
15+
conn = SentinelManagedConnection(
16+
retry_on_timeout=True,
17+
retry=Retry(NoBackoff(), 3),
18+
connection_pool=connection_pool,
19+
)
20+
origin_connect = conn._connect
21+
conn._connect = mock.Mock()
22+
23+
def mock_connect():
24+
# connect only on the last retry
25+
if conn._connect.call_count <= 2:
26+
raise socket.timeout
27+
else:
28+
return origin_connect()
29+
30+
conn._connect.side_effect = mock_connect
31+
conn.connect()
32+
assert conn._connect.call_count == 3
33+
assert connection_pool.get_master_address.call_count == 3
34+
conn.disconnect()

0 commit comments

Comments
 (0)