|
| 1 | +import typing as t |
| 2 | +from functools import wraps |
| 3 | + |
| 4 | +from ellar.cache.interface import ICacheService |
| 5 | +from ellar.core import IExecutionContext |
| 6 | +from ellar.core.params import ExtraEndpointArg |
| 7 | +from ellar.helper import is_async_callable |
| 8 | + |
| 9 | +from .decorators.extra_args import extra_args |
| 10 | +from .routing.params import Context, Provide |
| 11 | + |
| 12 | + |
| 13 | +class CacheDecorator: |
| 14 | + __slots__ = ( |
| 15 | + "_is_async", |
| 16 | + "_key_prefix", |
| 17 | + "_version", |
| 18 | + "_backend", |
| 19 | + "_func", |
| 20 | + "_timeout", |
| 21 | + "_cache_service_arg", |
| 22 | + "_context_arg", |
| 23 | + "_make_key_callback", |
| 24 | + ) |
| 25 | + |
| 26 | + def __init__( |
| 27 | + self, |
| 28 | + func: t.Callable, |
| 29 | + timeout: int, |
| 30 | + *, |
| 31 | + key_prefix: str = "", |
| 32 | + version: str = None, |
| 33 | + backend: str = "default", |
| 34 | + make_key_callback: t.Callable[[IExecutionContext, str], str] = None, |
| 35 | + ) -> None: |
| 36 | + self._is_async = is_async_callable(func) |
| 37 | + self._key_prefix = key_prefix |
| 38 | + self._version = version |
| 39 | + self._backend = backend |
| 40 | + self._func = func |
| 41 | + self._timeout = timeout |
| 42 | + |
| 43 | + # create extra args |
| 44 | + self._cache_service_arg = ExtraEndpointArg( |
| 45 | + name="cache_service", annotation=ICacheService, default_value=Provide() # type: ignore |
| 46 | + ) |
| 47 | + self._context_arg = ExtraEndpointArg( |
| 48 | + name="route_context", annotation=IExecutionContext, default_value=Context() # type: ignore |
| 49 | + ) |
| 50 | + # apply extra_args to endpoint |
| 51 | + extra_args(self._cache_service_arg, self._context_arg)(func) |
| 52 | + self._make_key_callback: t.Callable[[IExecutionContext, str], str] = ( |
| 53 | + make_key_callback or self.route_cache_make_key |
| 54 | + ) |
| 55 | + |
| 56 | + def _get_key(self, **input_kwargs: t.Any) -> str: |
| 57 | + context: IExecutionContext = self._context_arg.resolve(input_kwargs) |
| 58 | + return self._make_key_callback(context, self._key_prefix or "") |
| 59 | + |
| 60 | + def get_decorator_wrapper(self) -> t.Callable: |
| 61 | + if self._is_async: |
| 62 | + return self.get_async_cache_wrapper() |
| 63 | + return self.get_cache_wrapper() |
| 64 | + |
| 65 | + def route_cache_make_key( |
| 66 | + self, context: IExecutionContext, key_prefix: str = "" |
| 67 | + ) -> str: |
| 68 | + """Defaults key generator for caching view""" |
| 69 | + connection = context.switch_to_http_connection() |
| 70 | + return f"{connection.get_client().url}:{key_prefix or 'view'}" |
| 71 | + |
| 72 | + def get_async_cache_wrapper(self) -> t.Callable: |
| 73 | + """Gets endpoint asynchronous wrapper function""" |
| 74 | + |
| 75 | + @wraps(self._func) |
| 76 | + async def _async_wrapper(*args: t.Any, **kwargs: t.Any) -> t.Any: |
| 77 | + cache_service: ICacheService = self._cache_service_arg.resolve(kwargs) |
| 78 | + |
| 79 | + key = self._get_key(**kwargs) |
| 80 | + |
| 81 | + cached_value = await cache_service.get_async( |
| 82 | + key, self._version, backend=self._backend |
| 83 | + ) |
| 84 | + if cached_value: |
| 85 | + return cached_value |
| 86 | + |
| 87 | + response = await self._func(*args, **kwargs) |
| 88 | + await cache_service.set_async( |
| 89 | + key, |
| 90 | + response, |
| 91 | + timeout=self._timeout, |
| 92 | + version=self._version, |
| 93 | + backend=self._backend, |
| 94 | + ) |
| 95 | + return response |
| 96 | + |
| 97 | + return _async_wrapper |
| 98 | + |
| 99 | + def get_cache_wrapper(self) -> t.Callable: |
| 100 | + """Gets endpoint synchronous wrapper function""" |
| 101 | + |
| 102 | + @wraps(self._func) |
| 103 | + def _wrapper(*args: t.Any, **kwargs: t.Any) -> t.Any: |
| 104 | + cache_service: ICacheService = self._cache_service_arg.resolve(kwargs) |
| 105 | + |
| 106 | + key = self._get_key(**kwargs) |
| 107 | + |
| 108 | + cached_value = cache_service.get(key, self._version, backend=self._backend) |
| 109 | + if cached_value: |
| 110 | + return cached_value |
| 111 | + |
| 112 | + response = self._func(*args, **kwargs) |
| 113 | + cache_service.set( |
| 114 | + key, |
| 115 | + response, |
| 116 | + timeout=self._timeout, |
| 117 | + version=self._version, |
| 118 | + backend=self._backend, |
| 119 | + ) |
| 120 | + return response |
| 121 | + |
| 122 | + return _wrapper |
| 123 | + |
| 124 | + |
| 125 | +def cache( |
| 126 | + timeout: int, |
| 127 | + *, |
| 128 | + key_prefix: str = "", |
| 129 | + version: str = None, |
| 130 | + backend: str = "default", |
| 131 | + make_key_callback: t.Callable[[IExecutionContext, str], str] = None, |
| 132 | +) -> t.Callable: |
| 133 | + def _wraps(func: t.Callable) -> t.Callable: |
| 134 | + cache_decorator = CacheDecorator( |
| 135 | + func, |
| 136 | + timeout, |
| 137 | + key_prefix=key_prefix, |
| 138 | + version=version, |
| 139 | + backend=backend, |
| 140 | + make_key_callback=make_key_callback, |
| 141 | + ) |
| 142 | + return cache_decorator.get_decorator_wrapper() |
| 143 | + |
| 144 | + return _wraps |
0 commit comments