Skip to content

feat: Implement EventCacheStoreLock::lock() with poison error, and ::lock_unchecked #4285

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
24 changes: 12 additions & 12 deletions crates/matrix-sdk-base/src/event_cache/store/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,57 +260,57 @@ macro_rules! event_cache_store_integration_tests_time {
let store = get_event_cache_store().await.unwrap().into_event_cache_store();

let acquired0 = store.try_take_leased_lock(0, "key", "alice").await.unwrap();
assert!(acquired0);
assert_eq!(acquired0, Some(0)); // first lock generation

// Should extend the lease automatically (same holder).
let acquired2 = store.try_take_leased_lock(300, "key", "alice").await.unwrap();
assert!(acquired2);
assert_eq!(acquired2, Some(0)); // same lock generation

// Should extend the lease automatically (same holder + time is ok).
let acquired3 = store.try_take_leased_lock(300, "key", "alice").await.unwrap();
assert!(acquired3);
assert_eq!(acquired3, Some(0)); // same lock generation

// Another attempt at taking the lock should fail, because it's taken.
let acquired4 = store.try_take_leased_lock(300, "key", "bob").await.unwrap();
assert!(!acquired4);
assert!(acquired4.is_none()); // not acquired

// Even if we insist.
let acquired5 = store.try_take_leased_lock(300, "key", "bob").await.unwrap();
assert!(!acquired5);
assert!(acquired5.is_none()); // not acquired

// That's a nice test we got here, go take a little nap.
tokio::time::sleep(Duration::from_millis(50)).await;

// Still too early.
let acquired55 = store.try_take_leased_lock(300, "key", "bob").await.unwrap();
assert!(!acquired55);
assert!(acquired55.is_none()); // not acquired

// Ok you can take another nap then.
tokio::time::sleep(Duration::from_millis(250)).await;

// At some point, we do get the lock.
let acquired6 = store.try_take_leased_lock(0, "key", "bob").await.unwrap();
assert!(acquired6);
assert_eq!(acquired6, Some(1)); // new lock generation!

tokio::time::sleep(Duration::from_millis(1)).await;

// The other gets it almost immediately too.
let acquired7 = store.try_take_leased_lock(0, "key", "alice").await.unwrap();
assert!(acquired7);
assert_eq!(acquired7, Some(2)); // new lock generation!

tokio::time::sleep(Duration::from_millis(1)).await;

// But when we take a longer lease...
// But when we take a longer lease
let acquired8 = store.try_take_leased_lock(300, "key", "bob").await.unwrap();
assert!(acquired8);
assert_eq!(acquired8, Some(3)); // new lock generation!

// It blocks the other user.
let acquired9 = store.try_take_leased_lock(300, "key", "alice").await.unwrap();
assert!(!acquired9);
assert!(acquired9.is_none()); // not acquired

// We can hold onto our lease.
let acquired10 = store.try_take_leased_lock(300, "key", "bob").await.unwrap();
assert!(acquired10);
assert_eq!(acquired10, Some(3)); // same lock generation
}
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ use std::{collections::HashMap, num::NonZeroUsize, sync::RwLock as StdRwLock, ti

use async_trait::async_trait;
use matrix_sdk_common::{
ring_buffer::RingBuffer, store_locks::memory_store_helper::try_take_leased_lock,
ring_buffer::RingBuffer,
store_locks::{memory_store_helper::try_take_leased_lock, LockGeneration},
};
use ruma::{MxcUri, OwnedMxcUri};

Expand All @@ -30,7 +31,7 @@ use crate::media::{MediaRequestParameters, UniqueKey as _};
#[derive(Debug)]
pub struct MemoryStore {
media: StdRwLock<RingBuffer<(OwnedMxcUri, String /* unique key */, Vec<u8>)>>,
leases: StdRwLock<HashMap<String, (String, Instant)>>,
leases: StdRwLock<HashMap<String, (String, Instant, LockGeneration)>>,
}

// SAFETY: `new_unchecked` is safe because 20 is not zero.
Expand Down Expand Up @@ -62,7 +63,7 @@ impl EventCacheStore for MemoryStore {
lease_duration_ms: u32,
key: &str,
holder: &str,
) -> Result<bool, Self::Error> {
) -> Result<Option<LockGeneration>, Self::Error> {
Ok(try_take_leased_lock(&self.leases, lease_duration_ms, key, holder))
}

Expand Down
8 changes: 4 additions & 4 deletions crates/matrix-sdk-base/src/event_cache/store/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
//! into the event cache for the actual storage. By default this brings an
//! in-memory store.

use std::{fmt, ops::Deref, str::Utf8Error, sync::Arc};
use std::{fmt, ops::Deref, result::Result as StdResult, str::Utf8Error, sync::Arc};

#[cfg(any(test, feature = "testing"))]
#[macro_use]
Expand Down Expand Up @@ -160,7 +160,7 @@ impl EventCacheStoreError {
}

/// An `EventCacheStore` specific result type.
pub type Result<T, E = EventCacheStoreError> = std::result::Result<T, E>;
pub type Result<T, E = EventCacheStoreError> = StdResult<T, E>;

/// A type that wraps the [`EventCacheStore`] but implements [`BackingStore`] to
/// make it usable inside the cross process lock.
Expand All @@ -177,7 +177,7 @@ impl BackingStore for LockableEventCacheStore {
lease_duration_ms: u32,
key: &str,
holder: &str,
) -> std::result::Result<bool, Self::LockError> {
self.0.try_take_leased_lock(lease_duration_ms, key, holder).await
) -> StdResult<bool, Self::LockError> {
Ok(self.0.try_take_leased_lock(lease_duration_ms, key, holder).await?.is_some())
}
}
6 changes: 3 additions & 3 deletions crates/matrix-sdk-base/src/event_cache/store/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
use std::{fmt, sync::Arc};

use async_trait::async_trait;
use matrix_sdk_common::AsyncTraitDeps;
use matrix_sdk_common::{store_locks::LockGeneration, AsyncTraitDeps};
use ruma::MxcUri;

use super::EventCacheStoreError;
Expand All @@ -35,7 +35,7 @@ pub trait EventCacheStore: AsyncTraitDeps {
lease_duration_ms: u32,
key: &str,
holder: &str,
) -> Result<bool, Self::Error>;
) -> Result<Option<LockGeneration>, Self::Error>;

/// Add a media file's content in the media store.
///
Expand Down Expand Up @@ -127,7 +127,7 @@ impl<T: EventCacheStore> EventCacheStore for EraseEventCacheStoreError<T> {
lease_duration_ms: u32,
key: &str,
holder: &str,
) -> Result<bool, Self::Error> {
) -> Result<Option<LockGeneration>, Self::Error> {
self.0.try_take_leased_lock(lease_duration_ms, key, holder).await.map_err(Into::into)
}

Expand Down
46 changes: 35 additions & 11 deletions crates/matrix-sdk-common/src/store_locks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ use crate::{
SendOutsideWasm,
};

/// A lock generation is an integer incremented each time it is taken by another
/// holder. This is not used by all cross-process locks.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the meaning of the last sentence here? Consider removing it now, since we're likely to forget about this comment and let it rot even in a future where the crypto store lock would use it?

pub type LockGeneration = u64;

/// Describe the first lock generation value (see [`LockGeneration`]).
pub const FIRST_LOCK_GENERATION: LockGeneration = 0;

/// Backing store for a cross-process lock.
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
Expand Down Expand Up @@ -351,16 +358,23 @@ mod tests {

use super::{
memory_store_helper::try_take_leased_lock, BackingStore, CrossProcessStoreLock,
CrossProcessStoreLockGuard, LockStoreError, EXTEND_LEASE_EVERY_MS,
CrossProcessStoreLockGuard, LockGeneration, LockStoreError, EXTEND_LEASE_EVERY_MS,
};

type HolderExpirationGeneration = (String, Instant, LockGeneration);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Care to either comment, or use a small struct with named fields, please?


#[derive(Clone, Default)]
struct TestStore {
leases: Arc<RwLock<HashMap<String, (String, Instant)>>>,
leases: Arc<RwLock<HashMap<String, HolderExpirationGeneration>>>,
}

impl TestStore {
fn try_take_leased_lock(&self, lease_duration_ms: u32, key: &str, holder: &str) -> bool {
fn try_take_leased_lock(
&self,
lease_duration_ms: u32,
key: &str,
holder: &str,
) -> Option<LockGeneration> {
try_take_leased_lock(&self.leases, lease_duration_ms, key, holder)
}
}
Expand All @@ -380,7 +394,7 @@ mod tests {
key: &str,
holder: &str,
) -> Result<bool, Self::LockError> {
Ok(self.try_take_leased_lock(lease_duration_ms, key, holder))
Ok(self.try_take_leased_lock(lease_duration_ms, key, holder).is_some())
}
}

Expand Down Expand Up @@ -506,36 +520,45 @@ pub mod memory_store_helper {
time::{Duration, Instant},
};

use super::{LockGeneration, FIRST_LOCK_GENERATION};

type HolderExpirationGeneration = (String, Instant, LockGeneration);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you deduplicate this?


/// Try to acquire or to extend the lock.
///
/// Return `Some` if the lock has been acquired (or extended). It contains
/// the generation number.
pub fn try_take_leased_lock(
leases: &RwLock<HashMap<String, (String, Instant)>>,
leases: &RwLock<HashMap<String, HolderExpirationGeneration>>,
lease_duration_ms: u32,
key: &str,
holder: &str,
) -> bool {
) -> Option<LockGeneration> {
let now = Instant::now();
let expiration = now + Duration::from_millis(lease_duration_ms.into());

match leases.write().unwrap().entry(key.to_owned()) {
// There is an existing holder.
Entry::Occupied(mut entry) => {
let (current_holder, current_expiration) = entry.get_mut();
let (current_holder, current_expiration, current_generation) = entry.get_mut();

if current_holder == holder {
// We had the lease before, extend it.
*current_expiration = expiration;

true
Some(*current_generation)
} else {
// We didn't have it.
if *current_expiration < now {
// Steal it!
*current_holder = holder.to_owned();
*current_expiration = expiration;
*current_generation += 1;

true
Some(*current_generation)
} else {
// We tried our best.
false
None
}
}
}
Expand All @@ -545,9 +568,10 @@ pub mod memory_store_helper {
entry.insert((
holder.to_owned(),
Instant::now() + Duration::from_millis(lease_duration_ms.into()),
0,
));

true
Some(FIRST_LOCK_GENERATION)
}
}
}
Expand Down
6 changes: 3 additions & 3 deletions crates/matrix-sdk-crypto/src/store/memorystore.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use std::{
};

use async_trait::async_trait;
use matrix_sdk_common::store_locks::memory_store_helper::try_take_leased_lock;
use matrix_sdk_common::store_locks::{memory_store_helper::try_take_leased_lock, LockGeneration};
use ruma::{
events::secret::request::SecretName, DeviceId, OwnedDeviceId, OwnedRoomId, OwnedTransactionId,
OwnedUserId, RoomId, TransactionId, UserId,
Expand Down Expand Up @@ -90,7 +90,7 @@ pub struct MemoryStore {
key_requests_by_info: StdRwLock<HashMap<String, OwnedTransactionId>>,
direct_withheld_info: StdRwLock<HashMap<OwnedRoomId, HashMap<String, RoomKeyWithheldEvent>>>,
custom_values: StdRwLock<HashMap<String, Vec<u8>>>,
leases: StdRwLock<HashMap<String, (String, Instant)>>,
leases: StdRwLock<HashMap<String, (String, Instant, LockGeneration)>>,
secret_inbox: StdRwLock<HashMap<String, Vec<GossippedSecret>>>,
backup_keys: RwLock<BackupKeys>,
next_batch_token: RwLock<Option<String>>,
Expand Down Expand Up @@ -632,7 +632,7 @@ impl CryptoStore for MemoryStore {
key: &str,
holder: &str,
) -> Result<bool> {
Ok(try_take_leased_lock(&self.leases, lease_duration_ms, key, holder))
Ok(try_take_leased_lock(&self.leases, lease_duration_ms, key, holder).is_some())
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ALTER TABLE "lease_locks" ADD COLUMN "generation" INTEGER NOT NULL DEFAULT 0;
40 changes: 30 additions & 10 deletions crates/matrix-sdk-sqlite/src/event_cache_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use deadpool_sqlite::{Object as SqliteAsyncConn, Pool as SqlitePool, Runtime};
use matrix_sdk_base::{
event_cache::store::EventCacheStore,
media::{MediaRequestParameters, UniqueKey},
store_locks::LockGeneration,
};
use matrix_sdk_store_encryption::StoreCipher;
use ruma::MilliSecondsSinceUnixEpoch;
Expand All @@ -28,7 +29,7 @@ mod keys {
/// This is used to figure whether the SQLite database requires a migration.
/// Every new SQL migration should imply a bump of this number, and changes in
/// the [`run_migrations`] function.
const DATABASE_VERSION: u8 = 2;
const DATABASE_VERSION: u8 = 3;

/// A SQLite-based event cache store.
#[derive(Clone)]
Expand Down Expand Up @@ -142,6 +143,16 @@ async fn run_migrations(conn: &SqliteAsyncConn, version: u8) -> Result<()> {
.await?;
}

if version < 3 {
conn.with_transaction(|txn| {
txn.execute_batch(include_str!(
"../migrations/event_cache_store/003_lease_locks_with_generation.sql"
))?;
txn.set_db_version(3)
})
.await?;
}

Ok(())
}

Expand All @@ -154,32 +165,41 @@ impl EventCacheStore for SqliteEventCacheStore {
lease_duration_ms: u32,
key: &str,
holder: &str,
) -> Result<bool> {
) -> Result<Option<LockGeneration>> {
let key = key.to_owned();
let holder = holder.to_owned();

let now: u64 = MilliSecondsSinceUnixEpoch::now().get().into();
let expiration = now + lease_duration_ms as u64;

let num_touched = self
let generation = self
.acquire()
.await?
.with_transaction(move |txn| {
txn.execute(
txn.query_row(
"INSERT INTO lease_locks (key, holder, expiration)
VALUES (?1, ?2, ?3)
ON CONFLICT (key)
DO
UPDATE SET holder = ?2, expiration = ?3
WHERE holder = ?2
ON CONFLICT (key) DO
UPDATE SET
holder = excluded.holder,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add in a comment where does this excluded entity comes from? More generally, can you add in a comment what this query does, because it's not quite the regular SQL one sees often? :)

expiration = excluded.expiration,
generation =
CASE holder
WHEN excluded.holder THEN generation
ELSE generation + 1
END
WHERE
holder = excluded.holder
OR expiration < ?4
",
RETURNING generation",
(key, holder, expiration, now),
|row| row.get(0),
)
.optional()
})
.await?;

Ok(num_touched == 1)
Ok(generation)
}

async fn add_media_content(
Expand Down
1 change: 1 addition & 0 deletions crates/matrix-sdk/src/event_cache/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ impl EventCache {
mut room_updates_feed: Receiver<RoomUpdates>,
) {
trace!("Spawning the listen task");

loop {
match room_updates_feed.recv().await {
Ok(updates) => {
Expand Down