Skip to content

Commit 5b49846

Browse files
authored
Improve decorater error handling (#6)
1 parent 99d1f58 commit 5b49846

File tree

2 files changed

+81
-41
lines changed

2 files changed

+81
-41
lines changed

src/pgcachewatch/decorators.py

Lines changed: 29 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,75 +1,65 @@
11
import asyncio
2-
import contextlib
32
import logging
4-
import typing
3+
from typing import Awaitable, Callable, Hashable, Literal, TypeVar
54

6-
import typing_extensions
5+
from typing_extensions import ParamSpec
76

87
from pgcachewatch import strategies, utils
98

10-
P = typing_extensions.ParamSpec("P")
11-
T = typing.TypeVar("T")
9+
P = ParamSpec("P")
10+
T = TypeVar("T")
1211

1312

1413
def cache(
1514
strategy: strategies.Strategy,
16-
statistics_callback: typing.Callable[[typing.Literal["hit", "miss"]], None]
17-
| None = None,
18-
) -> typing.Callable[
19-
[typing.Callable[P, typing.Awaitable[T]]],
20-
typing.Callable[P, typing.Awaitable[T]],
21-
]:
22-
def outer(
23-
fn: typing.Callable[P, typing.Awaitable[T]],
24-
) -> typing.Callable[P, typing.Awaitable[T]]:
25-
cached = dict[typing.Hashable, asyncio.Future[T]]()
15+
statistics_callback: Callable[[Literal["hit", "miss"]], None] | None = None,
16+
) -> Callable[[Callable[P, Awaitable[T]]], Callable[P, Awaitable[T]]]:
17+
def outer(fn: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]:
18+
cached = dict[Hashable, asyncio.Future[T]]()
2619

27-
async def inner(*args: P.args, **kw: P.kwargs) -> T:
20+
async def inner(*args: P.args, **kwargs: P.kwargs) -> T:
2821
# If db-conn is down, disable cache.
2922
if not strategy.pg_connection_healthy():
3023
logging.critical("Database connection is closed, caching disabled.")
31-
return await fn(*args, **kw)
24+
return await fn(*args, **kwargs)
3225

3326
# Clear cache if we have a event from
3427
# the database the instructs us to clear.
3528
if strategy.clear():
3629
logging.debug("Cache clear")
3730
cached.clear()
3831

39-
# Check for cache hit
40-
key = utils.make_key(args, kw)
41-
with contextlib.suppress(KeyError):
42-
# OBS: Will only await if the cache key hits.
43-
result = await cached[key]
32+
key = utils.make_key(args, kwargs)
33+
34+
try:
35+
waiter = cached[key]
36+
except KeyError:
37+
# Cache miss
38+
...
39+
else:
40+
# Cache hit
4441
logging.debug("Cache hit")
4542
if statistics_callback:
4643
statistics_callback("hit")
47-
return result
44+
return await waiter
4845

49-
# Below deals with a cache miss.
5046
logging.debug("Cache miss")
5147
if statistics_callback:
5248
statistics_callback("miss")
5349

54-
# By using a future as placeholder we avoid
55-
# cache stampeded. Note that on the "miss" branch/path, controll
56-
# is never given to the eventloopscheduler before the future
57-
# is create.
50+
# Initialize Future to prevent cache stampedes.
5851
cached[key] = waiter = asyncio.Future[T]()
52+
5953
try:
60-
result = await fn(*args, **kw)
54+
# # Attempt to compute result and set for waiter
55+
waiter.set_result(await fn(*args, **kwargs))
6156
except Exception as e:
62-
cached.pop(
63-
key, None
64-
) # Next try should not result in a repeating exception
65-
waiter.set_exception(
66-
e
67-
) # Propegate exception to other callers who are waiting.
68-
raise e from None # Propegate exception to first caller.
69-
else:
70-
waiter.set_result(result)
57+
# Remove key from cache on failure.
58+
cached.pop(key, None)
59+
# Propagate exception to all awaiting the future.
60+
waiter.set_exception(e)
7161

72-
return result
62+
return await waiter
7363

7464
return inner
7565

tests/test_decoraters.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import asyncio
22
import collections
33
import datetime
4+
from typing import NoReturn
45

56
import asyncpg
67
import pytest
78
from pgcachewatch import decorators, listeners, models, strategies
89

910

10-
@pytest.mark.parametrize("N", (4, 16, 64, 512))
11+
@pytest.mark.parametrize("N", (1, 2, 4, 16, 64))
1112
async def test_gready_cache_decorator(N: int, pgconn: asyncpg.Connection) -> None:
1213
statistics = collections.Counter[str]()
1314
listener = listeners.PGEventQueue()
@@ -20,6 +21,55 @@ async def test_gready_cache_decorator(N: int, pgconn: asyncpg.Connection) -> Non
2021
async def now() -> datetime.datetime:
2122
return datetime.datetime.now()
2223

23-
await asyncio.gather(*[now() for _ in range(N)])
24+
nows = set(await asyncio.gather(*[now() for _ in range(N)]))
25+
assert len(nows) == 1
26+
2427
assert statistics["hit"] == N - 1
2528
assert statistics["miss"] == 1
29+
30+
31+
@pytest.mark.parametrize("N", (1, 2, 4, 16, 64))
32+
async def test_gready_cache_decorator_connection_closed(
33+
N: int,
34+
pgconn: asyncpg.Connection,
35+
) -> None:
36+
listener = listeners.PGEventQueue()
37+
await listener.connect(
38+
pgconn,
39+
models.PGChannel("test_gready_cache_decorator_connection_closed"),
40+
)
41+
await pgconn.close()
42+
43+
@decorators.cache(strategy=strategies.Gready(listener=listener))
44+
async def now() -> datetime.datetime:
45+
return datetime.datetime.now()
46+
47+
nows = await asyncio.gather(*[now() for _ in range(N)])
48+
assert len(set(nows)) == N
49+
50+
51+
@pytest.mark.parametrize("N", (1, 2, 4, 16, 64))
52+
async def test_gready_cache_decorator_exceptions(
53+
N: int,
54+
pgconn: asyncpg.Connection,
55+
) -> None:
56+
listener = listeners.PGEventQueue()
57+
await listener.connect(
58+
pgconn,
59+
models.PGChannel("test_gready_cache_decorator_exceptions"),
60+
)
61+
62+
@decorators.cache(strategy=strategies.Gready(listener=listener))
63+
async def raise_runtime_error() -> NoReturn:
64+
raise RuntimeError
65+
66+
for _ in range(N):
67+
with pytest.raises(RuntimeError):
68+
await raise_runtime_error()
69+
70+
exceptions = await asyncio.gather(
71+
*[raise_runtime_error() for _ in range(N)],
72+
return_exceptions=True,
73+
)
74+
assert len(exceptions) == N
75+
assert all(isinstance(exc, RuntimeError) for exc in exceptions)

0 commit comments

Comments
 (0)