1
- import asyncio
2
- import pickle
1
+ import random
3
2
import typing as t
4
3
from abc import ABC
5
4
5
+ from ellar .helper .event_loop import get_or_create_eventloop
6
+
6
7
try :
7
8
from redis .asyncio import Redis # type: ignore
8
9
from redis .asyncio .connection import ConnectionPool # type: ignore
9
10
except ImportError as e : # pragma: no cover
10
11
raise RuntimeError (
11
12
"To use `RedisCacheBackend`, you have to install 'redis' package e.g. `pip install redis`"
12
13
) from e
13
- from ..interface import IBaseCacheBackendAsync
14
- from ..make_key_decorator import make_key_decorator
15
- from ..model import BaseCacheBackend
14
+ from ...interface import IBaseCacheBackendAsync
15
+ from ...make_key_decorator import make_key_decorator
16
+ from ...model import BaseCacheBackend
17
+ from .serializer import IRedisSerializer , RedisSerializer
16
18
17
19
18
20
class RedisCacheBackendSync (IBaseCacheBackendAsync , ABC ):
19
21
def _async_executor (self , func : t .Awaitable ) -> t .Any :
20
- return asyncio . get_event_loop ().run_until_complete (func )
22
+ return get_or_create_eventloop ().run_until_complete (func )
21
23
22
24
def get (self , key : str , version : str = None ) -> t .Any :
23
25
return self ._async_executor (self .get_async (key , version = version ))
@@ -48,36 +50,58 @@ def touch(
48
50
49
51
50
52
class RedisCacheBackend (RedisCacheBackendSync , BaseCacheBackend ):
51
- """Redis-based cache backend."""
53
+ """Redis-based cache backend.
54
+
55
+ Redis Server Construct example::
56
+ backend = RedisCacheBackend(servers=['redis://[[username]:[password]]@localhost:6379/0'])
57
+ OR
58
+ backend = RedisCacheBackend(servers=['redis://[[username]:[password]]@localhost:6379/0'])
59
+ OR
60
+ backend = RedisCacheBackend(servers=['rediss://[[username]:[password]]@localhost:6379/0'])
61
+ OR
62
+ backend = RedisCacheBackend(servers=['unix://[username@]/path/to/socket.sock?db=0[&password=password]'])
52
63
53
- pickle_protocol : t . Any = pickle . HIGHEST_PROTOCOL
64
+ """
54
65
55
66
def __init__ (
56
67
self ,
57
- url : str = "localhost" ,
58
- db : int = None ,
59
- port : int = None ,
60
- username : str = None ,
61
- password : str = None ,
68
+ servers : t .List [str ],
62
69
options : t .Dict = None ,
63
- serializer : t .Callable = pickle .dumps ,
64
- deserializer : t .Callable = pickle .loads ,
70
+ serializer : IRedisSerializer = None ,
65
71
** kwargs : t .Any
66
72
) -> None :
67
73
super ().__init__ (** kwargs )
68
74
69
- self ._cache_client_init : Redis = None
75
+ self ._pools : t .Dict [int , ConnectionPool ] = {}
76
+ self ._servers = servers
70
77
_default_options = options or {}
71
78
self ._options = {
72
- "url" : url ,
73
- "db" : db ,
74
- "port" : port ,
75
- "username" : username ,
76
- "password" : password ,
77
79
** _default_options ,
78
80
}
79
- self ._serializer = serializer
80
- self ._deserializer = deserializer
81
+ self ._serializer = serializer or RedisSerializer ()
82
+
83
+ def _get_connection_pool_index (self , write : bool ) -> int :
84
+ # Write to the first server. Read from other servers if there are more,
85
+ # otherwise read from the first server.
86
+ if write or len (self ._servers ) == 1 :
87
+ return 0
88
+ return random .randint (1 , len (self ._servers ) - 1 )
89
+
90
+ def _get_connection_pool (self , write : bool ) -> ConnectionPool :
91
+ index = self ._get_connection_pool_index (write )
92
+ if index not in self ._pools :
93
+ self ._pools [index ] = ConnectionPool .from_url (
94
+ self ._servers [index ],
95
+ ** self ._options ,
96
+ )
97
+ return self ._pools [index ]
98
+
99
+ def _get_client (self , * , write : bool = False ) -> Redis :
100
+ # key is used so that the method signature remains the same and custom
101
+ # cache client can be implemented which might require the key to select
102
+ # the server, e.g. sharding.
103
+ pool = self ._get_connection_pool (write )
104
+ return Redis (connection_pool = pool )
81
105
82
106
def get_backend_timeout (
83
107
self , timeout : t .Union [float , int ] = None
@@ -88,21 +112,12 @@ def get_backend_timeout(
88
112
# Non-positive values will cause the key to be deleted.
89
113
return None if timeout is None else max (0 , int (timeout ))
90
114
91
- @property
92
- def _cache_client (self ) -> Redis :
93
- """
94
- Implement transparent thread-safe access to a memcached client.
95
- """
96
- if self ._cache_client_init is None :
97
- pool = ConnectionPool .from_url (** self ._options )
98
- self ._redis_int = Redis (connection_pool = pool )
99
- return self ._cache_client_init
100
-
101
115
@make_key_decorator
102
116
async def get_async (self , key : str , version : str = None ) -> t .Any :
103
- value = await self ._cache_client .get (key )
117
+ client = self ._get_client ()
118
+ value = await client .get (key )
104
119
if value :
105
- return self ._deserializer (value )
120
+ return self ._serializer . load (value )
106
121
return None
107
122
108
123
@make_key_decorator
@@ -113,27 +128,26 @@ async def set_async(
113
128
timeout : t .Union [float , int ] = None ,
114
129
version : str = None ,
115
130
) -> bool :
116
- value = self ._serializer (value , self .pickle_protocol )
131
+ client = self ._get_client ()
132
+ value = self ._serializer .dumps (value )
117
133
if timeout == 0 :
118
- await self . _cache_client .delete (key )
134
+ await client .delete (key )
119
135
120
- return bool (
121
- await self ._cache_client .set (
122
- key , value , ex = self .get_backend_timeout (timeout )
123
- )
124
- )
136
+ return bool (await client .set (key , value , ex = self .get_backend_timeout (timeout )))
125
137
126
138
@make_key_decorator
127
139
async def delete_async (self , key : str , version : str = None ) -> bool :
128
- result = await self ._cache_client .delete (key )
140
+ client = self ._get_client ()
141
+ result = await client .delete (key )
129
142
return bool (result )
130
143
131
144
@make_key_decorator
132
145
async def touch_async (
133
146
self , key : str , timeout : t .Union [float , int ] = None , version : str = None
134
147
) -> bool :
148
+ client = self ._get_client ()
135
149
if timeout is None :
136
- res = await self . _cache_client .persist (key )
150
+ res = await client .persist (key )
137
151
return bool (res )
138
- res = await self . _cache_client .expire (key , self .get_backend_timeout (timeout ))
152
+ res = await client .expire (key , self .get_backend_timeout (timeout ))
139
153
return bool (res )
0 commit comments