Skip to content

feat(sqlite): SqliteEventCacheStore has 1 write connection #5382

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
144 changes: 89 additions & 55 deletions crates/matrix-sdk-sqlite/src/event_cache_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ use ruma::{
OwnedEventId, RoomId,
};
use rusqlite::{params_from_iter, OptionalExtension, ToSql, Transaction, TransactionBehavior};
use tokio::fs;
use tokio::{
fs,
sync::{Mutex, OwnedMutexGuard},
};
use tracing::{debug, error, trace};

use crate::{
Expand Down Expand Up @@ -86,7 +89,16 @@ const CHUNK_TYPE_GAP_TYPE_STRING: &str = "G";
#[derive(Clone)]
pub struct SqliteEventCacheStore {
store_cipher: Option<Arc<StoreCipher>>,

/// The pool of connections.
pool: SqlitePool,

/// We make the difference between connections for read operations, and for
/// write operations. We keep a single connection apart from write
/// operations. All other connections are used for read operations. The
/// lock is used to ensure there is one owner at a time.
write_connection: Arc<Mutex<SqliteAsyncConn>>,

media_service: MediaService,
}

Expand Down Expand Up @@ -125,7 +137,7 @@ impl SqliteEventCacheStore {
let pool = config.create_pool(Runtime::Tokio1)?;

let this = Self::open_with_pool(pool, passphrase.as_deref()).await?;
this.pool.get().await?.apply_runtime_config(runtime_config).await?;
this.write().await?.apply_runtime_config(runtime_config).await?;

Ok(this)
}
Expand All @@ -151,10 +163,17 @@ impl SqliteEventCacheStore {
let last_media_cleanup_time = conn.get_serialized_kv(keys::LAST_MEDIA_CLEANUP_TIME).await?;
media_service.restore(media_retention_policy, last_media_cleanup_time);

Ok(Self { store_cipher, pool, media_service })
Ok(Self {
store_cipher,
pool,
// Use `conn` as our selected write connections.
write_connection: Arc::new(Mutex::new(conn)),
media_service,
})
}

async fn acquire(&self) -> Result<SqliteAsyncConn> {
// Acquire a connection for executing read operations.
async fn read(&self) -> Result<SqliteAsyncConn> {
let connection = self.pool.get().await?;

// Per https://www.sqlite.org/foreignkeys.html#fk_enable, foreign key
Expand All @@ -166,6 +185,19 @@ impl SqliteEventCacheStore {
Ok(connection)
}

// Acquire a connection for executing write operations.
async fn write(&self) -> Result<OwnedMutexGuard<SqliteAsyncConn>> {
let connection = self.write_connection.clone().lock_owned().await;

// Per https://www.sqlite.org/foreignkeys.html#fk_enable, foreign key
// support must be enabled on a per-connection basis. Execute it every
// time we try to get a connection, since we can't guarantee a previous
// connection did enable it before.
connection.execute_batch("PRAGMA foreign_keys = ON;").await?;

Ok(connection)
}

fn map_row_to_chunk(
row: &rusqlite::Row<'_>,
) -> Result<(u64, Option<u64>, Option<u64>, String), rusqlite::Error> {
Expand Down Expand Up @@ -425,7 +457,7 @@ impl EventCacheStore for SqliteEventCacheStore {
let expiration = now + lease_duration_ms as u64;

let num_touched = self
.acquire()
.write()
.await?
.with_transaction(move |txn| {
txn.execute(
Expand Down Expand Up @@ -457,7 +489,7 @@ impl EventCacheStore for SqliteEventCacheStore {
let linked_chunk_id = linked_chunk_id.to_owned();
let this = self.clone();

with_immediate_transaction(self.acquire().await?, move |txn| {
with_immediate_transaction(self, move |txn| {
for up in updates {
match up {
Update::NewItemsChunk { previous, new, next } => {
Expand Down Expand Up @@ -783,7 +815,7 @@ impl EventCacheStore for SqliteEventCacheStore {
let this = self.clone();

let result = self
.acquire()
.read()
.await?
.with_transaction(move |txn| -> Result<_> {
let mut items = Vec::new();
Expand Down Expand Up @@ -821,7 +853,7 @@ impl EventCacheStore for SqliteEventCacheStore {
let hashed_linked_chunk_id =
self.encode_key(keys::LINKED_CHUNKS, linked_chunk_id.storage_key());

self.acquire()
self.read()
.await?
.with_transaction(move |txn| -> Result<_> {
// I'm not a DB analyst, so for my own future sanity: this query joins the
Expand Down Expand Up @@ -884,7 +916,7 @@ impl EventCacheStore for SqliteEventCacheStore {
let this = self.clone();

self
.acquire()
.read()
.await?
.with_transaction(move |txn| -> Result<_> {
// Find the latest chunk identifier to generate a `ChunkIdentifierGenerator`, and count the number of chunks.
Expand Down Expand Up @@ -977,7 +1009,7 @@ impl EventCacheStore for SqliteEventCacheStore {
let this = self.clone();

self
.acquire()
.read()
.await?
.with_transaction(move |txn| -> Result<_> {
// Find the chunk before the chunk identified by `before_chunk_identifier`.
Expand Down Expand Up @@ -1018,7 +1050,7 @@ impl EventCacheStore for SqliteEventCacheStore {
}

async fn clear_all_linked_chunks(&self) -> Result<(), Self::Error> {
self.acquire()
self.write()
.await?
.with_transaction(move |txn| {
// Remove all the chunks, and let cascading do its job.
Expand Down Expand Up @@ -1047,7 +1079,7 @@ impl EventCacheStore for SqliteEventCacheStore {
self.encode_key(keys::LINKED_CHUNKS, linked_chunk_id.storage_key());
let linked_chunk_id = linked_chunk_id.to_owned();

self.acquire()
self.read()
.await?
.with_transaction(move |txn| -> Result<_> {
txn.chunk_large_query_over(events, None, move |txn, events| {
Expand Down Expand Up @@ -1119,7 +1151,7 @@ impl EventCacheStore for SqliteEventCacheStore {

let hashed_room_id = self.encode_key(keys::LINKED_CHUNKS, room_id);

self.acquire()
self.read()
.await?
.with_transaction(move |txn| -> Result<_> {
let Some(event) = txn
Expand Down Expand Up @@ -1153,7 +1185,7 @@ impl EventCacheStore for SqliteEventCacheStore {
let filters = filters.map(ToOwned::to_owned);
let this = self.clone();

self.acquire()
self.read()
.await?
.with_transaction(move |txn| -> Result<_> {
let filter_query = if let Some(filters) = compute_filters_string(filters.as_deref())
Expand Down Expand Up @@ -1216,7 +1248,7 @@ impl EventCacheStore for SqliteEventCacheStore {
let event_id = event_id.to_string();
let encoded_event = self.encode_event(&event)?;

self.acquire()
self.write()
.await?
.with_transaction(move |txn| -> Result<_> {
txn.execute(
Expand Down Expand Up @@ -1248,7 +1280,7 @@ impl EventCacheStore for SqliteEventCacheStore {
let new_uri = self.encode_key(keys::MEDIA, to.source.unique_key());
let new_format = self.encode_key(keys::MEDIA, to.format.unique_key());

let conn = self.acquire().await?;
let conn = self.write().await?;
conn.execute(
r#"UPDATE media SET uri = ?, format = ? WHERE uri = ? AND format = ?"#,
(new_uri, new_format, prev_uri, prev_format),
Expand All @@ -1266,7 +1298,7 @@ impl EventCacheStore for SqliteEventCacheStore {
let uri = self.encode_key(keys::MEDIA, request.source.unique_key());
let format = self.encode_key(keys::MEDIA, request.format.unique_key());

let conn = self.acquire().await?;
let conn = self.write().await?;
conn.execute("DELETE FROM media WHERE uri = ? AND format = ?", (uri, format)).await?;

Ok(())
Expand All @@ -1282,7 +1314,7 @@ impl EventCacheStore for SqliteEventCacheStore {
async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<()> {
let uri = self.encode_key(keys::MEDIA, uri);

let conn = self.acquire().await?;
let conn = self.write().await?;
conn.execute("DELETE FROM media WHERE uri = ?", (uri,)).await?;

Ok(())
Expand Down Expand Up @@ -1320,15 +1352,15 @@ impl EventCacheStoreMedia for SqliteEventCacheStore {
async fn media_retention_policy_inner(
&self,
) -> Result<Option<MediaRetentionPolicy>, Self::Error> {
let conn = self.acquire().await?;
let conn = self.read().await?;
conn.get_serialized_kv(keys::MEDIA_RETENTION_POLICY).await
}

async fn set_media_retention_policy_inner(
&self,
policy: MediaRetentionPolicy,
) -> Result<(), Self::Error> {
let conn = self.acquire().await?;
let conn = self.write().await?;
conn.set_serialized_kv(keys::MEDIA_RETENTION_POLICY, policy).await?;
Ok(())
}
Expand All @@ -1352,7 +1384,7 @@ impl EventCacheStoreMedia for SqliteEventCacheStore {
let format = self.encode_key(keys::MEDIA, request.format.unique_key());
let timestamp = time_to_timestamp(last_access);

let conn = self.acquire().await?;
let conn = self.write().await?;
conn.execute(
"INSERT OR REPLACE INTO media (uri, format, data, last_access, ignore_policy) VALUES (?, ?, ?, ?, ?)",
(uri, format, data, timestamp, ignore_policy),
Expand All @@ -1371,7 +1403,7 @@ impl EventCacheStoreMedia for SqliteEventCacheStore {
let format = self.encode_key(keys::MEDIA, request.format.unique_key());
let ignore_policy = ignore_policy.is_yes();

let conn = self.acquire().await?;
let conn = self.write().await?;
conn.execute(
r#"UPDATE media SET ignore_policy = ? WHERE uri = ? AND format = ?"#,
(ignore_policy, uri, format),
Expand All @@ -1390,7 +1422,7 @@ impl EventCacheStoreMedia for SqliteEventCacheStore {
let format = self.encode_key(keys::MEDIA, request.format.unique_key());
let timestamp = time_to_timestamp(current_time);

let conn = self.acquire().await?;
let conn = self.write().await?;
let data = conn
.with_transaction::<_, rusqlite::Error, _>(move |txn| {
// Update the last access.
Expand Down Expand Up @@ -1421,7 +1453,7 @@ impl EventCacheStoreMedia for SqliteEventCacheStore {
let uri = self.encode_key(keys::MEDIA, uri);
let timestamp = time_to_timestamp(current_time);

let conn = self.acquire().await?;
let conn = self.write().await?;
let data = conn
.with_transaction::<_, rusqlite::Error, _>(move |txn| {
// Update the last access.
Expand Down Expand Up @@ -1451,7 +1483,7 @@ impl EventCacheStoreMedia for SqliteEventCacheStore {
return Ok(());
}

let conn = self.acquire().await?;
let conn = self.write().await?;
let removed = conn
.with_transaction::<_, Error, _>(move |txn| {
let mut removed = false;
Expand Down Expand Up @@ -1570,7 +1602,7 @@ impl EventCacheStoreMedia for SqliteEventCacheStore {
}

async fn last_media_cleanup_time_inner(&self) -> Result<Option<SystemTime>, Self::Error> {
let conn = self.acquire().await?;
let conn = self.read().await?;
conn.get_serialized_kv(keys::LAST_MEDIA_CLEANUP_TIME).await
}
}
Expand All @@ -1583,33 +1615,35 @@ async fn with_immediate_transaction<
T: Send + 'static,
F: FnOnce(&Transaction<'_>) -> Result<T, Error> + Send + 'static,
>(
conn: SqliteAsyncConn,
this: &SqliteEventCacheStore,
f: F,
) -> Result<T, Error> {
conn.interact(move |conn| -> Result<T, Error> {
// Start the transaction in IMMEDIATE mode since all updates may cause writes,
// to avoid read transactions upgrading to write mode and causing
// SQLITE_BUSY errors. See also: https://www.sqlite.org/lang_transaction.html#deferred_immediate_and_exclusive_transactions
conn.set_transaction_behavior(TransactionBehavior::Immediate);

let code = || -> Result<T, Error> {
let txn = conn.transaction()?;
let res = f(&txn)?;
txn.commit()?;
Ok(res)
};

let res = code();

// Reset the transaction behavior to use Deferred, after this transaction has
// been run, whether it was successful or not.
conn.set_transaction_behavior(TransactionBehavior::Deferred);

res
})
.await
// SAFETY: same logic as in [`deadpool::managed::Object::with_transaction`].`
.unwrap()
this.write()
.await?
Comment on lines +1699 to +1700
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change ensures with_immediate_transaction takes a write connection.

.interact(move |conn| -> Result<T, Error> {
// Start the transaction in IMMEDIATE mode since all updates may cause writes,
// to avoid read transactions upgrading to write mode and causing
// SQLITE_BUSY errors. See also: https://www.sqlite.org/lang_transaction.html#deferred_immediate_and_exclusive_transactions
conn.set_transaction_behavior(TransactionBehavior::Immediate);

let code = || -> Result<T, Error> {
let txn = conn.transaction()?;
let res = f(&txn)?;
txn.commit()?;
Ok(res)
};

let res = code();

// Reset the transaction behavior to use Deferred, after this transaction has
// been run, whether it was successful or not.
conn.set_transaction_behavior(TransactionBehavior::Deferred);

res
})
.await
// SAFETY: same logic as in [`deadpool::managed::Object::with_transaction`].`
.unwrap()
}

fn insert_chunk(
Expand Down Expand Up @@ -1716,7 +1750,7 @@ mod tests {
async fn get_event_cache_store_content_sorted_by_last_access(
event_cache_store: &SqliteEventCacheStore,
) -> Vec<Vec<u8>> {
let sqlite_db = event_cache_store.acquire().await.expect("accessing sqlite db failed");
let sqlite_db = event_cache_store.read().await.expect("accessing sqlite db failed");
sqlite_db
.prepare("SELECT data FROM media ORDER BY last_access DESC", |mut stmt| {
stmt.query(())?.mapped(|row| row.get(0)).collect()
Expand Down Expand Up @@ -2006,7 +2040,7 @@ mod tests {

// Check that cascading worked. Yes, SQLite, I doubt you.
let gaps = store
.acquire()
.read()
.await
.unwrap()
.with_transaction(|txn| -> rusqlite::Result<_> {
Expand Down Expand Up @@ -2128,7 +2162,7 @@ mod tests {

// Make sure the position have been updated for the remaining events.
let num_rows: u64 = store
.acquire()
.read()
.await
.unwrap()
.with_transaction(move |txn| {
Expand Down Expand Up @@ -2277,7 +2311,7 @@ mod tests {

// Check that cascading worked. Yes, SQLite, I doubt you.
store
.acquire()
.read()
.await
.unwrap()
.with_transaction(|txn| -> rusqlite::Result<_> {
Expand Down
Loading
Loading