Skip to content

Commit e635682

Browse files
authored
Deferred Listener Connection and README Update (#5)
1 parent 1ac83c8 commit e635682

File tree

9 files changed

+115
-77
lines changed

9 files changed

+115
-77
lines changed

.github/workflows/ci.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ on:
99
jobs:
1010
ci:
1111
strategy:
12+
fail-fast: false
1213
matrix:
1314
python-version: ["3.10", "3.11", "3.12"]
1415
postgres-version: ["14", "15", "16"]
@@ -54,3 +55,11 @@ jobs:
5455

5556
- name: Full test
5657
run: pytest -v
58+
59+
check:
60+
name: Check test matrix passed.
61+
needs: ci
62+
runs-on: ubuntu-latest
63+
steps:
64+
- name: Check status
65+
run: echo "All tests passed; ready to merge."

README.md

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,24 +29,40 @@ pgcachewatch install <tables-to-cache>
2929
Example showing how to use PGCacheWatch for cache invalidation in a FastAPI app
3030

3131
```python
32+
import contextlib
33+
import typing
34+
3235
import asyncpg
3336
from fastapi import FastAPI
3437
from pgcachewatch import decorators, listeners, models, strategies
3538

36-
app = FastAPI()
39+
listener = listeners.PGEventQueue()
40+
3741

38-
async def setup_app(channel: models.PGChannel) -> FastAPI:
42+
@contextlib.asynccontextmanager
43+
async def app_setup_teardown(_: FastAPI) -> typing.AsyncGenerator[None, None]:
3944
conn = await asyncpg.connect()
40-
listener = await listeners.PGEventQueue.create(channel, conn)
45+
await listener.connect(conn, models.PGChannel("ch_pgcachewatch_table_change"))
46+
yield
47+
await conn.close()
48+
4149

42-
@decorators.cache(strategy=strategies.Greedy(listener=listener))
43-
async def cached_query():
44-
# Simulate a database query
45-
return {"data": "query result"}
50+
APP = FastAPI(lifespan=app_setup_teardown)
4651

47-
@app.get("/data")
48-
async def get_data():
49-
return await cached_query()
5052

51-
return app
52-
```
53+
# Only allow for cache refresh after an update
54+
@decorators.cache(
55+
strategy=strategies.Gready(
56+
listener=listener,
57+
predicate=lambda x: x.operation == "update",
58+
)
59+
)
60+
async def cached_query() -> dict[str, str]:
61+
# Simulate a database query
62+
return {"data": "query result"}
63+
64+
65+
@APP.get("/data")
66+
async def get_data() -> dict:
67+
return await cached_query()
68+
```

src/pgcachewatch/listeners.py

Lines changed: 46 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def _critical_termination_listener(*_: object, **__: object) -> None:
1212
# Must be defined in the global namespace, as ayncpg keeps
1313
# a set of functions to call. This this will now happen once as
1414
# all instance will point to the same function.
15-
logging.critical("Connection is closed / terminated!")
15+
logging.critical("Connection is closed / terminated.")
1616

1717

1818
class PGEventQueue(asyncio.Queue[models.Event]):
@@ -23,48 +23,59 @@ class PGEventQueue(asyncio.Queue[models.Event]):
2323

2424
def __init__(
2525
self,
26-
pgchannel: models.PGChannel,
27-
pgconn: asyncpg.Connection,
2826
max_size: int = 0,
2927
max_latency: datetime.timedelta = datetime.timedelta(milliseconds=500),
30-
_called_by_create: bool = False,
3128
) -> None:
32-
"""
33-
Initializes the PGEventQueue instance. Use the create() classmethod to
34-
instantiate.
35-
"""
36-
if not _called_by_create:
37-
raise RuntimeError(
38-
"Use classmethod create(...) to instantiate PGEventQueue."
39-
)
4029
super().__init__(maxsize=max_size)
41-
self._pg_channel = pgchannel
42-
self._pg_connection = pgconn
30+
self._pg_channel: None | models.PGChannel = None
31+
self._pg_connection: None | asyncpg.Connection = None
4332
self._max_latency = max_latency
4433

45-
@classmethod
46-
async def create(
47-
cls,
48-
pgchannel: models.PGChannel,
49-
pgconn: asyncpg.Connection,
50-
maxsize: int = 0,
51-
max_latency: datetime.timedelta = datetime.timedelta(milliseconds=500),
52-
) -> "PGEventQueue":
53-
"""
54-
Creates and initializes a new PGEventQueue instance, connecting to the specified
55-
PostgreSQL channel. Returns the initialized PGEventQueue instance.
34+
async def connect(
35+
self,
36+
connection: asyncpg.Connection,
37+
channel: models.PGChannel,
38+
) -> None:
5639
"""
57-
me = cls(
58-
pgchannel=pgchannel,
59-
pgconn=pgconn,
60-
max_size=maxsize,
61-
max_latency=max_latency,
62-
_called_by_create=True,
40+
Asynchronously connects the PGEventQueue to a specified
41+
PostgreSQL channel and connection.
42+
43+
This method establishes a listener on a PostgreSQL channel
44+
using the provided connection. It is designed to be called
45+
once per PGEventQueue instance to ensure a one-to-one relationship
46+
between the event queue and a database channel. If an attempt is
47+
made to connect a PGEventQueue instance to more than one channel
48+
or connection, a RuntimeError is raised to enforce this constraint.
49+
50+
Parameters:
51+
- connection: asyncpg.Connection
52+
The asyncpg connection object to be used for listening to database events.
53+
- channel: models.PGChannel
54+
The database channel to listen on for events.
55+
56+
Raises:
57+
- RuntimeError: If the PGEventQueue is already connected to a
58+
channel or connection.
59+
60+
Usage:
61+
```python
62+
await pg_event_queue.connect(
63+
connection=your_asyncpg_connection,
64+
channel=your_pg_channel,
6365
)
64-
me._pg_connection.add_termination_listener(_critical_termination_listener)
65-
await me._pg_connection.add_listener(me._pg_channel, me.parse_and_put) # type: ignore[arg-type]
66+
```
67+
"""
68+
if self._pg_channel or self._pg_connection:
69+
raise RuntimeError(
70+
"PGEventQueue instance is already connected to a channel and/or "
71+
"connection. Only supports one channel and connection per "
72+
"PGEventQueue instance."
73+
)
6674

67-
return me
75+
self._pg_channel = channel
76+
self._pg_connection = connection
77+
self._pg_connection.add_termination_listener(_critical_termination_listener)
78+
await self._pg_connection.add_listener(self._pg_channel, self.parse_and_put) # type: ignore[arg-type]
6879

6980
def parse_and_put(
7081
self,
@@ -87,6 +98,7 @@ def parse_and_put(
8798
except Exception:
8899
logging.exception("Unable to parse `%s`.", payload)
89100
else:
101+
logging.info("Received event: %s on %s", parsed, channel)
90102
try:
91103
self.put_nowait(parsed)
92104
except Exception:

tests/test_decoraters.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,8 @@
1010
@pytest.mark.parametrize("N", (4, 16, 64, 512))
1111
async def test_gready_cache_decorator(N: int, pgconn: asyncpg.Connection) -> None:
1212
statistics = collections.Counter[str]()
13-
listener = await listeners.PGEventQueue.create(
14-
models.PGChannel("test_cache_decorator"),
15-
pgconn=pgconn,
16-
)
13+
listener = listeners.PGEventQueue()
14+
await listener.connect(pgconn, models.PGChannel("test_cache_decorator"))
1715

1816
@decorators.cache(
1917
strategy=strategies.Gready(listener=listener),

tests/test_fastapi.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ async def fastapitestapp(
1414
) -> fastapi.FastAPI:
1515
app = fastapi.FastAPI()
1616

17-
listener = await listeners.PGEventQueue.create(channel, pgconn)
17+
listener = listeners.PGEventQueue()
18+
await listener.connect(pgconn, channel)
1819

1920
@decorators.cache(strategy=strategies.Gready(listener=listener))
2021
async def slow_db_read() -> dict:

tests/test_integration.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,8 @@ async def test_2_caching(
3737
pgpool: asyncpg.Pool,
3838
) -> None:
3939
statistics = collections.Counter[str]()
40-
listener = await listeners.PGEventQueue.create(
41-
models.PGChannel("test_2_caching"),
42-
pgconn=pgconn,
43-
)
40+
listener = listeners.PGEventQueue()
41+
await listener.connect(pgconn, models.PGChannel("test_2_caching"))
4442

4543
cnt = 0
4644

@@ -64,9 +62,10 @@ async def test_3_cache_invalidation_update(
6462
pgpool: asyncpg.Pool,
6563
) -> None:
6664
statistics = collections.Counter[str]()
67-
listener = await listeners.PGEventQueue.create(
65+
listener = listeners.PGEventQueue()
66+
await listener.connect(
67+
pgconn,
6868
models.PGChannel("ch_pgcachewatch_table_change"),
69-
pgconn=pgconn,
7069
)
7170

7271
@decorators.cache(
@@ -97,9 +96,10 @@ async def test_3_cache_invalidation_insert(
9796
pgpool: asyncpg.Pool,
9897
) -> None:
9998
statistics = collections.Counter[str]()
100-
listener = await listeners.PGEventQueue.create(
99+
listener = listeners.PGEventQueue()
100+
await listener.connect(
101+
pgconn,
101102
models.PGChannel("ch_pgcachewatch_table_change"),
102-
pgconn=pgconn,
103103
)
104104

105105
@decorators.cache(
@@ -131,9 +131,10 @@ async def test_3_cache_invalidation_delete(
131131
pgpool: asyncpg.Pool,
132132
) -> None:
133133
statistics = collections.Counter[str]()
134-
listener = await listeners.PGEventQueue.create(
134+
listener = listeners.PGEventQueue()
135+
await listener.connect(
136+
pgconn,
135137
models.PGChannel("ch_pgcachewatch_table_change"),
136-
pgconn=pgconn,
137138
)
138139

139140
@decorators.cache(

tests/test_listeners.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ async def test_eventqueue_and_pglistner(
1616
pgpool: asyncpg.Pool,
1717
) -> None:
1818
channel = models.PGChannel(f"test_eventqueue_and_pglistner_{N}_{operation}")
19-
eq = await listeners.PGEventQueue.create(channel, pgconn)
19+
listener = listeners.PGEventQueue()
20+
await listener.connect(pgconn, channel)
2021

2122
for _ in range(N):
2223
await utils.emit_event(
@@ -32,7 +33,7 @@ async def test_eventqueue_and_pglistner(
3233
evnets = list[models.Event]()
3334
while True:
3435
try:
35-
evnets.append(eq.get_nowait())
36+
evnets.append(listener.get_nowait())
3637
except asyncio.QueueEmpty:
3738
break
3839

tests/test_strategies.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99
@pytest.mark.parametrize("N", (4, 16, 64))
1010
async def test_gready_strategy(N: int, pgconn: asyncpg.Connection) -> None:
1111
channel = models.PGChannel("test_gready_strategy")
12-
listener = await listeners.PGEventQueue.create(channel, pgconn)
12+
13+
listener = listeners.PGEventQueue()
14+
await listener.connect(pgconn, channel)
15+
1316
strategy = strategies.Gready(
1417
listener=listener,
1518
predicate=lambda e: e.operation == "insert",
@@ -47,7 +50,8 @@ async def test_windowed_strategy(
4750
pgconn: asyncpg.Connection,
4851
) -> None:
4952
channel = models.PGChannel("test_windowed_strategy")
50-
listener = await listeners.PGEventQueue.create(channel, pgconn)
53+
listener = listeners.PGEventQueue()
54+
await listener.connect(pgconn, channel)
5155
strategy = strategies.Windowed(
5256
listener=listener, window=["insert", "update", "delete"]
5357
)
@@ -111,7 +115,8 @@ async def test_timed_strategy(
111115
pgconn: asyncpg.Connection,
112116
) -> None:
113117
channel = models.PGChannel("test_timed_strategy")
114-
listener = await listeners.PGEventQueue.create(channel, pgconn)
118+
listener = listeners.PGEventQueue()
119+
await listener.connect(pgconn, channel)
115120
strategy = strategies.Timed(listener=listener, timedelta=dt)
116121

117122
# Bursed spaced out accoring to min dt req. to trigger a refresh.

tests/test_utils.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,8 @@ async def test_emit_event(
1717
pgpool: asyncpg.Pool,
1818
) -> None:
1919
channel = "test_emit_event"
20-
listener = await listeners.PGEventQueue.create(
21-
models.PGChannel(channel), pgconn=pgconn
22-
)
20+
listener = listeners.PGEventQueue()
21+
await listener.connect(pgconn, models.PGChannel(channel))
2322
await asyncio.gather(
2423
*[
2524
utils.emit_event(
@@ -47,10 +46,8 @@ async def test_pick_until_deadline_max_iter(
4746
pgconn: asyncpg.Connection,
4847
) -> None:
4948
channel = "test_pick_until_deadline_max_iter"
50-
listener = await listeners.PGEventQueue.create(
51-
models.PGChannel(channel),
52-
pgconn=pgconn,
53-
)
49+
listener = listeners.PGEventQueue()
50+
await listener.connect(pgconn, models.PGChannel(channel))
5451

5552
items = list(range(max_iter * 2))
5653
for item in items:
@@ -87,10 +84,8 @@ async def test_pick_until_deadline_max_time(
8784
pgconn: asyncpg.Connection,
8885
) -> None:
8986
channel = "test_pick_until_deadline_max_time"
90-
listener = await listeners.PGEventQueue.create(
91-
models.PGChannel(channel),
92-
pgconn=pgconn,
93-
)
87+
listener = listeners.PGEventQueue()
88+
await listener.connect(pgconn, models.PGChannel(channel))
9489

9590
x = -1
9691

0 commit comments

Comments
 (0)