From 29300bd601c092db57b313e4d44a0d8c40273e1a Mon Sep 17 00:00:00 2001 From: Andy Breuhan Date: Wed, 21 May 2025 12:47:29 +0200 Subject: [PATCH] feat(migrate): add run_through_version method for versioned migrations --- sqlx-core/src/migrate/migrator.rs | 101 ++++++++++++++++++++++++++++++ tests/mysql/migrate.rs | 11 ++++ tests/postgres/migrate.rs | 11 ++++ tests/sqlite/migrate.rs | 11 ++++ 4 files changed, 134 insertions(+) diff --git a/sqlx-core/src/migrate/migrator.rs b/sqlx-core/src/migrate/migrator.rs index 3209ba6e45..d5874c76a6 100644 --- a/sqlx-core/src/migrate/migrator.rs +++ b/sqlx-core/src/migrate/migrator.rs @@ -110,6 +110,36 @@ impl Migrator { self.iter().any(|m| m.version == version) } + /// Returns the highest version number of all migrations + /// in the Migrator. This corresponds to the latest version + pub fn latest_version(&self) -> i64 { + self.iter() + .max_by(|x, y| x.version.cmp(&y.version)) + .map(|migration| migration.version) + .unwrap_or(0) + } + + /// Get the latest version of the applied migrations. + pub async fn latest_applied_version<'a, A>( + &self, + migrator: A, + ) -> Result, MigrateError> + where + A: Acquire<'a>, + ::Target: Migrate, + { + let mut conn = migrator.acquire().await?; + + conn.ensure_migrations_table().await?; + let applied_migrations = conn.list_applied_migrations().await?; + let latest_version = applied_migrations + .iter() + .max_by(|x, y| x.version.cmp(&y.version)) + .map(|migration| migration.version); + + Ok(latest_version) + } + /// Run any pending migrations against the database; and, validate previously applied migrations /// against the current migration source to detect accidental changes in previously-applied migrations. /// @@ -255,4 +285,75 @@ impl Migrator { Ok(()) } + + /// Run up migrations against the database until a specific version. + /// + /// # Examples + /// + /// ```rust,no_run + /// # use sqlx::migrate::MigrateError; + /// # fn main() -> Result<(), MigrateError> { + /// # sqlx::__rt::test_block_on(async move { + /// use sqlx::migrate::Migrator; + /// use sqlx::sqlite::SqlitePoolOptions; + /// + /// let m = Migrator::new(std::path::Path::new("./migrations")).await?; + /// let pool = SqlitePoolOptions::new().connect("sqlite::memory:").await?; + /// m.run_through_version(&pool, 4).await + /// # }) + /// # } + /// ``` + pub async fn run_through_version<'a, A>( + &self, + migrator: A, + target: i64, + ) -> Result<(), MigrateError> + where + A: Acquire<'a>, + ::Target: Migrate, + { + let mut conn = migrator.acquire().await?; + + // lock the database for exclusive access by the migrator + if self.locking { + conn.lock().await?; + } + + // creates [_migrations] table only if needed + // eventually this will likely migrate previous versions of the table + conn.ensure_migrations_table().await?; + + let version = conn.dirty_version().await?; + if let Some(version) = version { + return Err(MigrateError::Dirty(version)); + } + + let applied_migrations = conn.list_applied_migrations().await?; + let applied_migrations: HashMap<_, _> = applied_migrations + .into_iter() + .map(|m| (m.version, m)) + .collect(); + + for migration in self + .iter() + .filter(|m| m.migration_type.is_up_migration()) + .filter(|m| m.version <= target) + { + if let Some(applied_migration) = applied_migrations.get(&migration.version) { + if migration.checksum != applied_migration.checksum { + return Err(MigrateError::VersionMismatch(migration.version)); + } + } else { + conn.apply(migration).await?; + } + } + + // unlock the migrator to allow other migrators to run + // but do nothing as we already migrated + if self.locking { + conn.unlock().await?; + } + + Ok(()) + } } diff --git a/tests/mysql/migrate.rs b/tests/mysql/migrate.rs index 97caa38005..dbf00c6eea 100644 --- a/tests/mysql/migrate.rs +++ b/tests/mysql/migrate.rs @@ -33,9 +33,20 @@ async fn reversible(mut conn: PoolConnection) -> anyhow::Result<()> { let migrator = Migrator::new(Path::new("tests/mysql/migrations_reversible")).await?; + // run only until first reversible migration + migrator + .run_through_version(&mut conn, 20220721124650) + .await?; + + let latest_version = migrator.latest_version(); + assert_eq!(latest_version, 20220721125033); + // run migration migrator.run(&mut conn).await?; + let latest_applied_version = migrator.latest_applied_version(&mut conn).await?.unwrap(); + assert_eq!(latest_applied_version, latest_version); + // check outcome let res: i64 = conn .fetch_one("SELECT some_payload FROM migrations_reversible_test") diff --git a/tests/postgres/migrate.rs b/tests/postgres/migrate.rs index 636dffe860..9d695341c7 100644 --- a/tests/postgres/migrate.rs +++ b/tests/postgres/migrate.rs @@ -33,9 +33,20 @@ async fn reversible(mut conn: PoolConnection) -> anyhow::Result<()> { let migrator = Migrator::new(Path::new("tests/postgres/migrations_reversible")).await?; + // run only until first reversible migration + migrator + .run_through_version(&mut conn, 20220721124650) + .await?; + + let latest_version = migrator.latest_version(); + assert_eq!(latest_version, 20220721125033); + // run migration migrator.run(&mut conn).await?; + let latest_applied_version = migrator.latest_applied_version(&mut conn).await?.unwrap(); + assert_eq!(latest_applied_version, latest_version); + // check outcome let res: i64 = conn .fetch_one("SELECT some_payload FROM migrations_reversible_test") diff --git a/tests/sqlite/migrate.rs b/tests/sqlite/migrate.rs index 19e8690f9a..6f38c98a3c 100644 --- a/tests/sqlite/migrate.rs +++ b/tests/sqlite/migrate.rs @@ -33,9 +33,20 @@ async fn reversible(mut conn: PoolConnection) -> anyhow::Result<()> { let migrator = Migrator::new(Path::new("tests/sqlite/migrations_reversible")).await?; + // run only until first reversible migration + migrator + .run_through_version(&mut conn, 20220721124650) + .await?; + + let latest_version = migrator.latest_version(); + assert_eq!(latest_version, 20220721125033); + // run migration migrator.run(&mut conn).await?; + let latest_applied_version = migrator.latest_applied_version(&mut conn).await?.unwrap(); + assert_eq!(latest_applied_version, latest_version); + // check outcome let res: i64 = conn .fetch_one("SELECT some_payload FROM migrations_reversible_test")