Skip to content

Commit 32370b6

Browse files
committed
SQL Alchemy based persistent cache implementation. Using it with sqlite backend is the easiest option for local workflows.
Why SQLite: performance wise it is a terrible choice and any sane deployment should use something like memcached or redis. And SQL (e.g. postgres) is just an overkill. However sqlite is a zero-conf option that does not require any daemon running. So it is very easy to use for local workflows when we don't care that much about performance. PiperOrigin-RevId: 829383146
1 parent 873909e commit 32370b6

File tree

2 files changed

+361
-0
lines changed

2 files changed

+361
-0
lines changed

genai_processors/sql_cache.py

Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
# Copyright 2025 DeepMind Technologies Limited. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""SQLite-based cache for processors using SQLAlchemy.
16+
17+
This is a persistent cache suitable for small Parts, for example metadata
18+
extracted using constrained decoding. It might not scale for Parts containing
19+
large amounts of data e.g. video frames.
20+
21+
Cache is very handy for writing long-running agents or during development as it
22+
provides a lightweight way to resume execution from the point of failure. By
23+
wrapping all LLM calls and other heavy logic in a cache, while keeping the rest
24+
of the code idempotent, restarting the agent from the beginning will promptly
25+
catch up to the place where it has previously failed. During development one can
26+
force the changed code to be rerun by altering key_prefix e.g. by appending code
27+
version to it.
28+
"""
29+
30+
import asyncio
31+
from collections.abc import Callable
32+
import contextlib
33+
import datetime
34+
import json
35+
from typing import AsyncIterator
36+
37+
from genai_processors import cache
38+
from genai_processors import cache_base
39+
from genai_processors import content_api
40+
import sqlalchemy
41+
from sqlalchemy.ext.asyncio import AsyncSession
42+
from sqlalchemy.ext.asyncio import create_async_engine
43+
import sqlalchemy.orm
44+
from typing_extensions import override
45+
46+
47+
ProcessorContent = content_api.ProcessorContent
48+
49+
_Base = sqlalchemy.orm.declarative_base()
50+
51+
52+
class _ContentCacheEntry(_Base):
53+
"""SQLAlchemy model for the cache table."""
54+
55+
__tablename__ = 'content_cache'
56+
57+
key = sqlalchemy.Column(sqlalchemy.String, primary_key=True)
58+
value = sqlalchemy.Column(sqlalchemy.LargeBinary)
59+
expires_at = sqlalchemy.Column(sqlalchemy.DateTime(timezone=True), index=True)
60+
61+
62+
@contextlib.asynccontextmanager
63+
async def sql_cache(
64+
db_url: str,
65+
ttl_hours: float | None = 12,
66+
hash_fn: (
67+
Callable[[content_api.ProcessorContentTypes], str | None] | None
68+
) = None,
69+
) -> AsyncIterator['SqlCache']:
70+
"""Context manager that creates an SqlCache instance.
71+
72+
Args:
73+
db_url: SQLAlchemy database URL.
74+
ttl_hours: Time-to-live for cache items in hours. If None, the cache items
75+
never expire.
76+
hash_fn: Function to convert a content_api.ProcessorContentTypes query into
77+
a string key. If None, `cache.default_processor_content_hash` is used.
78+
79+
Yields:
80+
A SqlCache instance.
81+
"""
82+
engine = create_async_engine(db_url)
83+
84+
async with engine.begin() as conn:
85+
await conn.run_sync(_Base.metadata.create_all)
86+
87+
async with AsyncSession(engine) as session:
88+
yield SqlCache(
89+
session=session,
90+
ttl_hours=ttl_hours,
91+
hash_fn=hash_fn or cache.default_processor_content_hash,
92+
lock=asyncio.Lock(),
93+
)
94+
95+
96+
class SqlCache(cache_base.CacheBase):
97+
"""An SQLAlchemy based persistent content cache."""
98+
99+
def __init__(
100+
self,
101+
*,
102+
session: AsyncSession,
103+
ttl_hours: float | None,
104+
hash_fn: Callable[[content_api.ProcessorContentTypes], str | None],
105+
lock: asyncio.Lock,
106+
):
107+
"""Prefer using sql_cache() factory to construct the cache."""
108+
self._hash_fn = hash_fn
109+
self._ttl = (
110+
datetime.timedelta(hours=ttl_hours) if ttl_hours is not None else None
111+
)
112+
self._session = session
113+
self._lock = lock
114+
115+
@property
116+
@override
117+
def hash_fn(
118+
self,
119+
) -> Callable[[content_api.ProcessorContentTypes], str | None]:
120+
return self._hash_fn
121+
122+
@override
123+
def with_key_prefix(self, prefix: str) -> 'SqlCache':
124+
"""Creates a new SqlCache instance with a key prefix.
125+
126+
Args:
127+
prefix: String to prepend to generated string keys.
128+
129+
Returns:
130+
A new SqlCache instance with the given prefix.
131+
"""
132+
# This creates a new instance but shares the same session.
133+
return SqlCache(
134+
session=self._session,
135+
ttl_hours=self._ttl.total_seconds() / 3600
136+
if self._ttl is not None
137+
else None,
138+
hash_fn=cache.prefixed_hash_fn(self._hash_fn, prefix),
139+
lock=self._lock,
140+
)
141+
142+
@override
143+
async def lookup(
144+
self,
145+
query: content_api.ProcessorContentTypes | None = None,
146+
*,
147+
key: str | None = None,
148+
) -> content_api.ProcessorContent | cache_base.CacheMissT:
149+
query_key = key if key is not None else self._hash_fn(query)
150+
151+
async with self._lock:
152+
item = await self._session.get(_ContentCacheEntry, query_key)
153+
if item is None:
154+
return cache_base.CacheMiss
155+
156+
if self._ttl is not None:
157+
# Ensure item.expires_at is offset-aware for comparison
158+
# Assuming stored times are UTC, add UTC timezone info if none.
159+
expires_at = item.expires_at
160+
if expires_at.tzinfo is None:
161+
expires_at = expires_at.replace(tzinfo=datetime.timezone.utc)
162+
163+
if expires_at < datetime.datetime.now(datetime.timezone.utc):
164+
await self._remove_by_string_key(query_key)
165+
return cache_base.CacheMiss
166+
167+
try:
168+
return _deserialize_content(item.value)
169+
except Exception: # pylint: disable=broad-exception-caught
170+
await self._remove_by_string_key(query_key)
171+
return cache_base.CacheMiss
172+
173+
@override
174+
async def put(
175+
self,
176+
*,
177+
query: content_api.ProcessorContentTypes | None = None,
178+
key: str | None = None,
179+
value: content_api.ProcessorContentTypes,
180+
) -> None:
181+
query_key = key if key is not None else self._hash_fn(query)
182+
183+
data_to_cache_bytes = _serialize_content(
184+
content_api.ProcessorContent(value)
185+
)
186+
187+
expires_at = None
188+
if self._ttl is not None:
189+
expires_at = datetime.datetime.now(datetime.timezone.utc) + self._ttl
190+
191+
item = _ContentCacheEntry()
192+
item.key = query_key
193+
item.value = data_to_cache_bytes
194+
item.expires_at = expires_at
195+
async with self._lock:
196+
self._session.add(item)
197+
await self._cleanup_expired()
198+
await self._session.commit()
199+
200+
async def _remove_by_string_key(self, string_key: str) -> None:
201+
"""Internal helper to remove by the actual string key."""
202+
item = await self._session.get(_ContentCacheEntry, string_key)
203+
if item:
204+
await self._session.delete(item)
205+
await self._session.commit()
206+
207+
@override
208+
async def remove(self, query: content_api.ProcessorContentTypes) -> None:
209+
query_key = self._hash_fn(query)
210+
if query_key is None:
211+
return
212+
async with self._lock:
213+
await self._remove_by_string_key(query_key)
214+
215+
async def _cleanup_expired(self) -> None:
216+
"""Removes expired items from the cache."""
217+
if self._ttl is None:
218+
return
219+
now = datetime.datetime.now(datetime.timezone.utc)
220+
expired_items = await self._session.execute(
221+
sqlalchemy.select(_ContentCacheEntry).where(
222+
_ContentCacheEntry.expires_at < now
223+
)
224+
)
225+
for item in expired_items.scalars():
226+
await self._session.delete(item)
227+
228+
229+
def _serialize_content(value: ProcessorContent) -> bytes:
230+
"""Serializes ProcessorContent to bytes (via JSON)."""
231+
list_of_part_dicts_val = [part.to_dict() for part in value.all_parts]
232+
json_string_val = json.dumps(list_of_part_dicts_val)
233+
return json_string_val.encode('utf-8')
234+
235+
236+
def _deserialize_content(data_bytes: bytes) -> ProcessorContent:
237+
"""Deserializer for ProcessorContent from bytes (via JSON)."""
238+
json_string_val = data_bytes.decode('utf-8')
239+
list_of_part_dicts_val = json.loads(json_string_val)
240+
return ProcessorContent([
241+
content_api.ProcessorPart.from_dict(data=pd)
242+
for pd in list_of_part_dicts_val
243+
])
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import asyncio
2+
import datetime
3+
import os
4+
import tempfile
5+
import unittest
6+
from unittest import mock
7+
8+
from absl.testing import absltest
9+
from absl.testing import parameterized
10+
from genai_processors import cache_base
11+
from genai_processors import content_api
12+
from genai_processors import sql_cache
13+
import sqlalchemy
14+
15+
# Test Content
16+
TEST_QUERY = content_api.ProcessorContent('test query')
17+
TEST_VALUE = content_api.ProcessorContent('test value')
18+
TEST_QUERY_2 = content_api.ProcessorContent('test query 2')
19+
TEST_VALUE_2 = content_api.ProcessorContent('test value 2')
20+
21+
22+
class SqlCacheTest(parameterized.TestCase, unittest.IsolatedAsyncioTestCase):
23+
24+
def setUp(self):
25+
super().setUp()
26+
with tempfile.NamedTemporaryFile(suffix='.sqlite', delete=False) as tmp:
27+
self.db_url = f'sqlite+aiosqlite:///{tmp.name}'
28+
self.addCleanup(os.remove, tmp.name)
29+
30+
async def test_cache_put_and_lookup(self):
31+
"""Tests basic put and lookup functionality."""
32+
async with sql_cache.sql_cache(self.db_url) as cache:
33+
await cache.put(TEST_QUERY, TEST_VALUE)
34+
result = await cache.lookup(TEST_QUERY)
35+
self.assertEqual(result, TEST_VALUE)
36+
37+
async def test_cache_miss(self):
38+
"""Tests cache miss for a non-existent key."""
39+
async with sql_cache.sql_cache(self.db_url) as cache:
40+
result = await cache.lookup(TEST_QUERY)
41+
self.assertIs(result, cache_base.CacheMiss)
42+
43+
async def test_cache_ttl(self):
44+
"""Tests that cache entries expire after TTL."""
45+
async with sql_cache.sql_cache(
46+
self.db_url, ttl_hours=0.0001 / 60 / 60
47+
) as cache: # Very short TTL
48+
await cache.put(TEST_QUERY, TEST_VALUE)
49+
await asyncio.sleep(0.0002) # Wait for TTL to expire
50+
result = await cache.lookup(TEST_QUERY)
51+
self.assertIs(result, cache_base.CacheMiss)
52+
53+
async def test_cache_remove(self):
54+
"""Tests removing an item from the cache."""
55+
async with sql_cache.sql_cache(self.db_url) as cache:
56+
await cache.put(TEST_QUERY, TEST_VALUE)
57+
await cache.remove(TEST_QUERY)
58+
result = await cache.lookup(TEST_QUERY)
59+
self.assertIs(result, cache_base.CacheMiss)
60+
61+
async def test_with_key_prefix(self):
62+
"""Tests that with_key_prefix creates a namespace."""
63+
async with sql_cache.sql_cache(self.db_url) as cache1:
64+
cache2 = cache1.with_key_prefix('prefix:')
65+
66+
await cache1.put(TEST_QUERY, TEST_VALUE)
67+
await cache2.put(TEST_QUERY, TEST_VALUE_2)
68+
69+
result1 = await cache1.lookup(TEST_QUERY)
70+
result2 = await cache2.lookup(TEST_QUERY)
71+
72+
self.assertEqual(result1, TEST_VALUE)
73+
self.assertEqual(result2, TEST_VALUE_2)
74+
self.assertNotEqual(result1, result2)
75+
76+
async def test_different_content_types(self):
77+
"""Tests caching with different content types."""
78+
async with sql_cache.sql_cache(self.db_url) as cache:
79+
query1 = content_api.ProcessorContent('text query')
80+
value1 = content_api.ProcessorContent(['list ', 'of ', 'strings'])
81+
query2 = content_api.ProcessorContent(
82+
[content_api.ProcessorPart(b'imagedata', mimetype='image/png')]
83+
)
84+
85+
value2 = content_api.ProcessorContent({'a': 1, 'b': True})
86+
87+
await cache.put(query1, value1)
88+
await cache.put(query2, value2)
89+
90+
self.assertEqual(await cache.lookup(query1), value1)
91+
self.assertEqual(await cache.lookup(query2), value2)
92+
93+
async def test_cleanup_expired(self):
94+
"""Tests that the _cleanup_expired method removes old entries."""
95+
async with sql_cache.sql_cache(self.db_url, ttl_hours=0.0001) as cache:
96+
await cache.put(TEST_QUERY, TEST_VALUE)
97+
# Mock datetime to control time
98+
with mock.patch('genai_processors.sql_cache.datetime') as mock_datetime:
99+
mock_datetime.datetime.now.return_value = datetime.datetime.now(
100+
datetime.timezone.utc
101+
) + datetime.timedelta(hours=1)
102+
mock_datetime.timedelta.side_effect = datetime.timedelta
103+
mock_datetime.timezone.utc = datetime.timezone.utc
104+
await cache._cleanup_expired()
105+
106+
# Check that the entry is gone
107+
async def _lookup():
108+
stmt = sqlalchemy.select(sql_cache._ContentCacheEntry).where(
109+
sql_cache._ContentCacheEntry.key == cache._hash_fn(TEST_QUERY)
110+
)
111+
result = await cache._session.execute(stmt)
112+
return result.fetchone()
113+
114+
result = await _lookup()
115+
self.assertIsNone(result)
116+
117+
if __name__ == '__main__':
118+
absltest.main()

0 commit comments

Comments
 (0)