diff --git a/sqlx-postgres/src/advisory_lock.rs b/sqlx-postgres/src/advisory_lock.rs index 82191726f2..bd6b80ef79 100644 --- a/sqlx-postgres/src/advisory_lock.rs +++ b/sqlx-postgres/src/advisory_lock.rs @@ -5,6 +5,7 @@ use hkdf::Hkdf; use once_cell::sync::OnceCell; use sha2::Sha256; use std::ops::{Deref, DerefMut}; +use std::sync::Arc; /// A mutex-like type utilizing [Postgres advisory locks]. /// @@ -82,6 +83,11 @@ pub struct PgAdvisoryLockGuard<'lock, C: AsMut> { conn: Option, } +pub struct PgAdvisoryLockGuardOwned> { + lock: Arc, + conn: Option, +} + impl PgAdvisoryLock { /// Construct a `PgAdvisoryLock` using the given string as a key. /// @@ -203,22 +209,7 @@ impl PgAdvisoryLock { &self, mut conn: C, ) -> Result> { - match &self.key { - PgAdvisoryLockKey::BigInt(key) => { - crate::query::query("SELECT pg_advisory_lock($1)") - .bind(key) - .execute(conn.as_mut()) - .await?; - } - PgAdvisoryLockKey::IntPair(key1, key2) => { - crate::query::query("SELECT pg_advisory_lock($1, $2)") - .bind(key1) - .bind(key2) - .execute(conn.as_mut()) - .await?; - } - } - + self.execute_acquire(conn.as_mut()).await?; Ok(PgAdvisoryLockGuard::new(self, conn)) } @@ -246,26 +237,68 @@ impl PgAdvisoryLock { &self, mut conn: C, ) -> Result, C>> { - let locked: bool = match &self.key { + let locked = self.execute_try_acquire(conn.as_mut()).await?; + if locked { + Ok(Either::Left(PgAdvisoryLockGuard::new(self, conn))) + } else { + Ok(Either::Right(conn)) + } + } + + pub async fn acquire_owned>( + self: Arc, + mut conn: C, + ) -> Result> { + self.execute_acquire(conn.as_mut()).await?; + Ok(PgAdvisoryLockGuardOwned::new(self, conn)) + } + + pub async fn try_acquire_owned>( + self: Arc, + mut conn: C, + ) -> Result, C>> { + let locked = self.execute_try_acquire(conn.as_mut()).await?; + if locked { + Ok(Either::Left(PgAdvisoryLockGuardOwned::new(self, conn))) + } else { + Ok(Either::Right(conn)) + } + } + + async fn execute_acquire(&self, conn: &mut PgConnection) -> Result<(), sqlx_core::Error> { + match &self.key { + PgAdvisoryLockKey::BigInt(key) => { + crate::query::query("SELECT pg_advisory_lock($1)") + .bind(key) + .execute(conn.as_mut()) + .await?; + } + PgAdvisoryLockKey::IntPair(key1, key2) => { + crate::query::query("SELECT pg_advisory_lock($1, $2)") + .bind(key1) + .bind(key2) + .execute(conn.as_mut()) + .await?; + } + } + Ok(()) + } + + async fn execute_try_acquire(&self, conn: &mut PgConnection) -> Result { + match &self.key { PgAdvisoryLockKey::BigInt(key) => { crate::query_scalar::query_scalar("SELECT pg_try_advisory_lock($1)") .bind(key) .fetch_one(conn.as_mut()) - .await? + .await } PgAdvisoryLockKey::IntPair(key1, key2) => { crate::query_scalar::query_scalar("SELECT pg_try_advisory_lock($1, $2)") .bind(key1) .bind(key2) .fetch_one(conn.as_mut()) - .await? + .await } - }; - - if locked { - Ok(Either::Left(PgAdvisoryLockGuard::new(self, conn))) - } else { - Ok(Either::Right(conn)) } } @@ -419,3 +452,68 @@ impl<'lock, C: AsMut> Drop for PgAdvisoryLockGuard<'lock, C> { } } } + +impl> PgAdvisoryLockGuardOwned { + fn new(lock: Arc, conn: C) -> Self { + Self { + lock, + conn: Some(conn), + } + } + + pub fn leak(mut self) -> C { + self.conn.take().expect(NONE_ERR) + } + + pub async fn release_now(mut self) -> Result { + let (conn, released) = self + .lock + .force_release(self.conn.take().expect(NONE_ERR)) + .await?; + + if !released { + tracing::warn!( + lock = ?self.lock.key, + "PgAdvisoryLockGuard: advisory lock was not held by the contained connection", + ); + } + + Ok(conn) + } +} + +impl> Drop for PgAdvisoryLockGuardOwned { + fn drop(&mut self) { + if let Some(mut conn) = self.conn.take() { + conn.as_mut() + .queue_simple_query(self.lock.get_release_query()); + } + } +} + +impl + AsMut> Deref for PgAdvisoryLockGuardOwned { + type Target = PgConnection; + + fn deref(&self) -> &Self::Target { + self.as_ref() + } +} +impl + AsRef> DerefMut for PgAdvisoryLockGuardOwned { + fn deref_mut(&mut self) -> &mut Self::Target { + self.as_mut() + } +} + +impl + AsRef> AsRef + for PgAdvisoryLockGuardOwned +{ + fn as_ref(&self) -> &PgConnection { + self.conn.as_ref().expect(NONE_ERR).as_ref() + } +} + +impl> AsMut for PgAdvisoryLockGuardOwned { + fn as_mut(&mut self) -> &mut PgConnection { + self.conn.as_mut().expect(NONE_ERR).as_mut() + } +} diff --git a/tests/postgres/postgres.rs b/tests/postgres/postgres.rs index 7edb5a7a8c..907f3477a1 100644 --- a/tests/postgres/postgres.rs +++ b/tests/postgres/postgres.rs @@ -1833,6 +1833,72 @@ async fn test_advisory_locks() -> anyhow::Result<()> { Ok(()) } +#[sqlx_macros::test] +async fn test_advisory_locks_with_owned_guards() -> anyhow::Result<()> { + let pool = PgPoolOptions::new() + .max_connections(2) + .connect(&dotenvy::var("DATABASE_URL")?) + .await?; + + let lock1 = Arc::new(PgAdvisoryLock::new("sqlx-postgres-tests-1")); + let lock2 = Arc::new(PgAdvisoryLock::new("sqlx-postgres-tests-2")); + + let conn1 = pool.acquire().await?; + let mut conn1_lock1 = lock1.clone().acquire_owned(conn1).await?; + + // try acquiring a recursive lock through a mutable reference then dropping + drop(lock1.clone().acquire_owned(&mut conn1_lock1).await?); + + let conn2 = pool.acquire().await?; + let conn2_lock2 = lock2.clone().acquire_owned(conn2).await?; + + sqlx_core::rt::spawn({ + let lock1 = lock1.clone(); + let lock2 = lock2.clone(); + + async move { + let conn2_lock2 = lock1 + .clone() + .try_acquire_owned(conn2_lock2) + .await? + .right_or_else(|_| { + panic!( + "acquired lock but wasn't supposed to! Key: {:?}", + lock1.key() + ) + }); + + let (conn2, released) = lock2.force_release(conn2_lock2).await?; + assert!(released); + + // acquire both locks but let the pool release them + let conn2_lock1 = lock1.acquire_owned(conn2).await?; + let _conn2_lock1and2 = lock2.acquire_owned(conn2_lock1).await?; + + anyhow::Ok(()) + } + }); + + // acquire lock2 on conn1, we leak the lock1 guard so we can manually release it before lock2 + let conn1_lock1and2 = lock2.clone().acquire_owned(conn1_lock1.leak()).await?; + + // release lock1 while holding lock2 + let (conn1_lock2, released) = lock1.force_release(conn1_lock1and2).await?; + assert!(released); + + let conn1 = conn1_lock2.release_now().await?; + + // acquire both locks to be sure they were released + { + let conn1_lock1 = lock1.acquire_owned(conn1).await?; + let _conn1_lock1and2 = lock2.acquire_owned(conn1_lock1).await?; + } + + pool.close().await; + + Ok(()) +} + #[sqlx_macros::test] async fn test_postgres_bytea_hex_deserialization_errors() -> anyhow::Result<()> { let mut conn = new::().await?;