|
1 | 1 | import asyncio
|
2 |
| -import contextlib |
3 | 2 | import logging
|
4 |
| -import typing |
| 3 | +from typing import Awaitable, Callable, Hashable, Literal, TypeVar |
5 | 4 |
|
6 |
| -import typing_extensions |
| 5 | +from typing_extensions import ParamSpec |
7 | 6 |
|
8 | 7 | from pgcachewatch import strategies, utils
|
9 | 8 |
|
10 |
| -P = typing_extensions.ParamSpec("P") |
11 |
| -T = typing.TypeVar("T") |
| 9 | +P = ParamSpec("P") |
| 10 | +T = TypeVar("T") |
12 | 11 |
|
13 | 12 |
|
14 | 13 | def cache(
|
15 | 14 | strategy: strategies.Strategy,
|
16 |
| - statistics_callback: typing.Callable[[typing.Literal["hit", "miss"]], None] |
17 |
| - | None = None, |
18 |
| -) -> typing.Callable[ |
19 |
| - [typing.Callable[P, typing.Awaitable[T]]], |
20 |
| - typing.Callable[P, typing.Awaitable[T]], |
21 |
| -]: |
22 |
| - def outer( |
23 |
| - fn: typing.Callable[P, typing.Awaitable[T]], |
24 |
| - ) -> typing.Callable[P, typing.Awaitable[T]]: |
25 |
| - cached = dict[typing.Hashable, asyncio.Future[T]]() |
| 15 | + statistics_callback: Callable[[Literal["hit", "miss"]], None] | None = None, |
| 16 | +) -> Callable[[Callable[P, Awaitable[T]]], Callable[P, Awaitable[T]]]: |
| 17 | + def outer(fn: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]: |
| 18 | + cached = dict[Hashable, asyncio.Future[T]]() |
26 | 19 |
|
27 |
| - async def inner(*args: P.args, **kw: P.kwargs) -> T: |
| 20 | + async def inner(*args: P.args, **kwargs: P.kwargs) -> T: |
28 | 21 | # If db-conn is down, disable cache.
|
29 | 22 | if not strategy.pg_connection_healthy():
|
30 | 23 | logging.critical("Database connection is closed, caching disabled.")
|
31 |
| - return await fn(*args, **kw) |
| 24 | + return await fn(*args, **kwargs) |
32 | 25 |
|
33 | 26 | # Clear cache if we have a event from
|
34 | 27 | # the database the instructs us to clear.
|
35 | 28 | if strategy.clear():
|
36 | 29 | logging.debug("Cache clear")
|
37 | 30 | cached.clear()
|
38 | 31 |
|
39 |
| - # Check for cache hit |
40 |
| - key = utils.make_key(args, kw) |
41 |
| - with contextlib.suppress(KeyError): |
42 |
| - # OBS: Will only await if the cache key hits. |
43 |
| - result = await cached[key] |
| 32 | + key = utils.make_key(args, kwargs) |
| 33 | + |
| 34 | + try: |
| 35 | + waiter = cached[key] |
| 36 | + except KeyError: |
| 37 | + # Cache miss |
| 38 | + ... |
| 39 | + else: |
| 40 | + # Cache hit |
44 | 41 | logging.debug("Cache hit")
|
45 | 42 | if statistics_callback:
|
46 | 43 | statistics_callback("hit")
|
47 |
| - return result |
| 44 | + return await waiter |
48 | 45 |
|
49 |
| - # Below deals with a cache miss. |
50 | 46 | logging.debug("Cache miss")
|
51 | 47 | if statistics_callback:
|
52 | 48 | statistics_callback("miss")
|
53 | 49 |
|
54 |
| - # By using a future as placeholder we avoid |
55 |
| - # cache stampeded. Note that on the "miss" branch/path, controll |
56 |
| - # is never given to the eventloopscheduler before the future |
57 |
| - # is create. |
| 50 | + # Initialize Future to prevent cache stampedes. |
58 | 51 | cached[key] = waiter = asyncio.Future[T]()
|
| 52 | + |
59 | 53 | try:
|
60 |
| - result = await fn(*args, **kw) |
| 54 | + # # Attempt to compute result and set for waiter |
| 55 | + waiter.set_result(await fn(*args, **kwargs)) |
61 | 56 | except Exception as e:
|
62 |
| - cached.pop( |
63 |
| - key, None |
64 |
| - ) # Next try should not result in a repeating exception |
65 |
| - waiter.set_exception( |
66 |
| - e |
67 |
| - ) # Propegate exception to other callers who are waiting. |
68 |
| - raise e from None # Propegate exception to first caller. |
69 |
| - else: |
70 |
| - waiter.set_result(result) |
| 57 | + # Remove key from cache on failure. |
| 58 | + cached.pop(key, None) |
| 59 | + # Propagate exception to all awaiting the future. |
| 60 | + waiter.set_exception(e) |
71 | 61 |
|
72 |
| - return result |
| 62 | + return await waiter |
73 | 63 |
|
74 | 64 | return inner
|
75 | 65 |
|
|
0 commit comments