Skip to content

Commit ba8b692

Browse files
committed
lock tables before capturing lsn, quick and dirty don't judge
1 parent 315ef48 commit ba8b692

File tree

3 files changed

+97
-34
lines changed

3 files changed

+97
-34
lines changed

src/sql-server-util/src/cdc.rs

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@ impl<'a> CdcStream<'a> {
123123
),
124124
SqlServerError,
125125
> {
126+
static SAVEPOINT_NAME: &str = "_mz_snap_";
127+
126128
// Determine what table we need to snapshot.
127129
let instances = self
128130
.capture_instances
@@ -147,10 +149,26 @@ impl<'a> CdcStream<'a> {
147149
let mut txn = self.client.transaction().await?;
148150
// Get the current LSN of the database. This operation assigns a savepoint to our
149151
// current transaction.
150-
let lsn = crate::inspect::lsn_at_savepoint(&mut txn, "_mz_snapshot").await?;
151152

153+
// create a savepoint that we will lock the tables under and collect an LSN
154+
// this allows us to rollback the savepoint while maintaining the outer transaction
155+
txn.create_savepoint(SAVEPOINT_NAME).await?;
156+
// lock all the tables we are planning to snapshot so that we can ensure that
157+
// writes that might be in progress are properly ordered before or after this snapshot
158+
// in addition to the LSN being properly ordered.
159+
// TODO (maz): we should considering a timeout here because we may lock some tables,
160+
// and the next table may be locked for some extended period, resulting in a traffic
161+
// jam.
162+
for (_capture_instance, schema, table) in &tables {
163+
tracing::trace!(%schema, %table, "locking table");
164+
crate::inspect::lock_table(&mut txn, &*schema, &*table).await?;
165+
}
166+
167+
let lsn = crate::inspect::get_lsn(&mut txn).await?;
152168
tracing::info!(?tables, ?lsn, "starting snapshot");
153169

170+
txn.rollback_savepoint(SAVEPOINT_NAME).await?;
171+
154172
// Get the size of each table we're about to snapshot.
155173
//
156174
// TODO(sql_server3): To expose a more "generic" interface it would be nice to
@@ -564,11 +582,11 @@ impl TryFrom<Numeric> for Lsn {
564582

565583
let vlf_id = u32::try_from(decimal_lsn / 10_i128.pow(15))
566584
.map_err(|e| format!("Failed to decode vlf_id for lsn {decimal_lsn}: {e:?}"))?;
567-
decimal_lsn -= vlf_id as i128 * 10_i128.pow(15);
585+
decimal_lsn -= i128::try_from(vlf_id).unwrap() * 10_i128.pow(15);
568586

569587
let block_id = u32::try_from(decimal_lsn / 10_i128.pow(5))
570588
.map_err(|e| format!("Failed to decode block_id for lsn {decimal_lsn}: {e:?}"))?;
571-
decimal_lsn -= block_id as i128 * 10_i128.pow(5);
589+
decimal_lsn -= i128::try_from(block_id).unwrap() * 10_i128.pow(5);
572590

573591
let record_id = u16::try_from(decimal_lsn)
574592
.map_err(|e| format!("Failed to decode record_id for lsn {decimal_lsn}: {e:?}"))?;

src/sql-server-util/src/inspect.rs

Lines changed: 48 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -68,26 +68,65 @@ pub async fn increment_lsn(client: &mut Client, lsn: Lsn) -> Result<Lsn, SqlServ
6868
/// - only contains letters, digits, and underscores
6969
/// - no resserved words
7070
/// - 32 char max
71-
pub async fn lsn_at_savepoint(
71+
pub async fn create_savepoint(
7272
txn: &mut Transaction<'_>,
7373
savepoint_name: &str,
74-
) -> Result<Lsn, SqlServerError> {
75-
let current_lsn_query: &str = "
74+
) -> Result<(), SqlServerError> {
75+
// TODO (maz): make sure savepoint name is safe
76+
let _result = txn
77+
.client
78+
.simple_query(format!("SAVE TRANSACTION [{savepoint_name}]"))
79+
.await?;
80+
Ok(())
81+
}
82+
83+
pub async fn get_lsn(txn: &mut Transaction<'_>) -> Result<Lsn, SqlServerError> {
84+
static CURRENT_LSN_QUERY: &str = "
7685
SELECT dt.database_transaction_begin_lsn
7786
FROM sys.dm_tran_database_transactions AS dt
7887
JOIN sys.dm_tran_session_transactions AS st
7988
ON dt.transaction_id = st.transaction_id
8089
WHERE st.session_id = @@SPID
8190
";
82-
// TODO (maz): make sure savepoint name is safe
83-
84-
txn.client
85-
.simple_query(format!("SAVE TRANSACTION [{savepoint_name}]"))
86-
.await?;
87-
let result = txn.client.simple_query(current_lsn_query).await?;
91+
let result = txn.client.simple_query(CURRENT_LSN_QUERY).await?;
8892
parse_numeric_lsn(&result)
8993
}
9094

95+
pub async fn lock_table(
96+
txn: &mut Transaction<'_>,
97+
schema: &str,
98+
table: &str,
99+
) -> Result<(), SqlServerError> {
100+
// This query probably seems odd, but there is no LOCK command in MS SQL. Locks are specified
101+
// in SELECT using the WITH keyword. This query does not need to return any rows to lock the table,
102+
// hence the 1=0, which is something short that always evaluates to false in this universe;
103+
let query = format!("SELECT * FROM {schema}.{table} WITH (TABLOCKX) WHERE 1=0;");
104+
let _result = txn.client.query(query, &[]).await?;
105+
Ok(())
106+
}
107+
108+
/// Parse an [`Lsn`] in Decimal(25,0) format of the provided [`tiberius::Row`].
109+
///
110+
/// Returns an error if the provided slice doesn't have exactly one row.
111+
fn parse_numeric_lsn(row: &[tiberius::Row]) -> Result<Lsn, SqlServerError> {
112+
match row {
113+
[r] => {
114+
let numeric_lsn = r
115+
.try_get::<Numeric, _>(0)?
116+
.ok_or_else(|| SqlServerError::NullLsn)?;
117+
let lsn = Lsn::try_from(numeric_lsn).map_err(|msg| SqlServerError::InvalidData {
118+
column_name: "lsn".to_string(),
119+
error: msg,
120+
})?;
121+
Ok(lsn)
122+
}
123+
other => Err(SqlServerError::InvalidData {
124+
column_name: "lsn".to_string(),
125+
error: format!("expected 1 column, got {other:?}"),
126+
}),
127+
}
128+
}
129+
91130
/// Parse an [`Lsn`] from the first column of the provided [`tiberius::Row`].
92131
///
93132
/// Returns an error if the provided slice doesn't have exactly one row.
@@ -114,28 +153,6 @@ fn parse_lsn(result: &[tiberius::Row]) -> Result<Lsn, SqlServerError> {
114153
}
115154
}
116155

117-
/// Parse an [`Lsn`] in Decimal(25,0) format of the provided [`tiberius::Row`].
118-
///
119-
/// Returns an error if the provided slice doesn't have exactly one row.
120-
fn parse_numeric_lsn(row: &[tiberius::Row]) -> Result<Lsn, SqlServerError> {
121-
match row {
122-
[r] => {
123-
let numeric_lsn = r
124-
.try_get::<Numeric, _>(0)?
125-
.ok_or_else(|| SqlServerError::NullLsn)?;
126-
let lsn = Lsn::try_from(numeric_lsn).map_err(|msg| SqlServerError::InvalidData {
127-
column_name: "lsn".to_string(),
128-
error: msg,
129-
})?;
130-
Ok(lsn)
131-
}
132-
other => Err(SqlServerError::InvalidData {
133-
column_name: "lsn".to_string(),
134-
error: format!("expected 1 column, got {other:?}"),
135-
}),
136-
}
137-
}
138-
139156
/// Queries the specified capture instance and returns all changes from
140157
/// `[start_lsn, end_lsn)`, ordered by `start_lsn` in an ascending fashion.
141158
///

src/sql-server-util/src/lib.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,7 @@ pub type RowStream<'a> =
336336
pub struct Transaction<'a> {
337337
client: &'a mut Client,
338338
closed: bool,
339+
nested_xact_names: Vec<String>,
339340
}
340341

341342
impl<'a> Transaction<'a> {
@@ -352,10 +353,37 @@ impl<'a> Transaction<'a> {
352353
Ok(Transaction {
353354
client,
354355
closed: false,
356+
nested_xact_names: Default::default(),
355357
})
356358
}
357359
}
358360

361+
/// Creates a savepoint with a transaction that can be committed or rolled back
362+
/// without affecting the out transaction.
363+
pub async fn create_savepoint(&mut self, savepoint_name: &str) -> Result<(), SqlServerError> {
364+
let stmt = format!("SAVE TRANSACTION [{savepoint_name}]");
365+
let _result = self.client.simple_query(stmt).await?;
366+
self.nested_xact_names.push(savepoint_name.to_string());
367+
Ok(())
368+
}
369+
370+
pub async fn rollback_savepoint(&mut self, savepoint_name: &str) -> Result<(), SqlServerError> {
371+
let last_xact_name = self.nested_xact_names.pop();
372+
if last_xact_name
373+
.as_ref()
374+
.is_none_or(|last_xact_name| *last_xact_name != savepoint_name)
375+
{
376+
panic!(
377+
"Attempt to rollback savepoint {savepoint_name} doesn't match last savepoint {:?}",
378+
last_xact_name
379+
);
380+
}
381+
let stmt = format!("ROLLBACK TRANSACTION [{savepoint_name}]");
382+
let _result = self.client.simple_query(stmt).await?;
383+
self.nested_xact_names.push(savepoint_name.to_string());
384+
Ok(())
385+
}
386+
359387
/// See [`Client::execute`].
360388
pub async fn execute<'q>(
361389
&mut self,

0 commit comments

Comments
 (0)