Skip to content

Commit 8ca2ca9

Browse files
committed
lock tables before capturing lsn, quick and dirty don't judge
1 parent 315ef48 commit 8ca2ca9

File tree

3 files changed

+103
-35
lines changed

3 files changed

+103
-35
lines changed

src/sql-server-util/src/cdc.rs

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@ impl<'a> CdcStream<'a> {
123123
),
124124
SqlServerError,
125125
> {
126+
static SAVEPOINT_NAME: &str = "_mz_snap_";
127+
126128
// Determine what table we need to snapshot.
127129
let instances = self
128130
.capture_instances
@@ -147,10 +149,26 @@ impl<'a> CdcStream<'a> {
147149
let mut txn = self.client.transaction().await?;
148150
// Get the current LSN of the database. This operation assigns a savepoint to our
149151
// current transaction.
150-
let lsn = crate::inspect::lsn_at_savepoint(&mut txn, "_mz_snapshot").await?;
151152

153+
// create a savepoint that we will lock the tables under and collect an LSN
154+
// this allows us to rollback the savepoint while maintaining the outer transaction
155+
txn.create_savepoint(SAVEPOINT_NAME).await?;
156+
// lock all the tables we are planning to snapshot so that we can ensure that
157+
// writes that might be in progress are properly ordered before or after this snapshot
158+
// in addition to the LSN being properly ordered.
159+
// TODO (maz): we should considering a timeout here because we may lock some tables,
160+
// and the next table may be locked for some extended period, resulting in a traffic
161+
// jam.
162+
for (_capture_instance, schema, table) in &tables {
163+
tracing::trace!(%schema, %table, "locking table");
164+
crate::inspect::lock_table(&mut txn, &*schema, &*table).await?;
165+
}
166+
167+
let lsn = crate::inspect::get_lsn(&mut txn).await?;
152168
tracing::info!(?tables, ?lsn, "starting snapshot");
153169

170+
txn.rollback_savepoint(SAVEPOINT_NAME).await?;
171+
154172
// Get the size of each table we're about to snapshot.
155173
//
156174
// TODO(sql_server3): To expose a more "generic" interface it would be nice to
@@ -564,11 +582,11 @@ impl TryFrom<Numeric> for Lsn {
564582

565583
let vlf_id = u32::try_from(decimal_lsn / 10_i128.pow(15))
566584
.map_err(|e| format!("Failed to decode vlf_id for lsn {decimal_lsn}: {e:?}"))?;
567-
decimal_lsn -= vlf_id as i128 * 10_i128.pow(15);
585+
decimal_lsn -= i128::try_from(vlf_id).unwrap() * 10_i128.pow(15);
568586

569587
let block_id = u32::try_from(decimal_lsn / 10_i128.pow(5))
570588
.map_err(|e| format!("Failed to decode block_id for lsn {decimal_lsn}: {e:?}"))?;
571-
decimal_lsn -= block_id as i128 * 10_i128.pow(5);
589+
decimal_lsn -= i128::try_from(block_id).unwrap() * 10_i128.pow(5);
572590

573591
let record_id = u16::try_from(decimal_lsn)
574592
.map_err(|e| format!("Failed to decode record_id for lsn {decimal_lsn}: {e:?}"))?;

src/sql-server-util/src/inspect.rs

Lines changed: 47 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -68,26 +68,64 @@ pub async fn increment_lsn(client: &mut Client, lsn: Lsn) -> Result<Lsn, SqlServ
6868
/// - only contains letters, digits, and underscores
6969
/// - no resserved words
7070
/// - 32 char max
71-
pub async fn lsn_at_savepoint(
71+
pub async fn create_savepoint(
7272
txn: &mut Transaction<'_>,
7373
savepoint_name: &str,
74-
) -> Result<Lsn, SqlServerError> {
75-
let current_lsn_query: &str = "
74+
) -> Result<(), SqlServerError> {
75+
// TODO (maz): make sure savepoint name is safe
76+
let _result = txn.client
77+
.simple_query(format!("SAVE TRANSACTION [{savepoint_name}]"))
78+
.await?;
79+
Ok(())
80+
}
81+
82+
pub async fn get_lsn(txn: &mut Transaction<'_>) -> Result<Lsn, SqlServerError> {
83+
static CURRENT_LSN_QUERY: &str = "
7684
SELECT dt.database_transaction_begin_lsn
7785
FROM sys.dm_tran_database_transactions AS dt
7886
JOIN sys.dm_tran_session_transactions AS st
7987
ON dt.transaction_id = st.transaction_id
8088
WHERE st.session_id = @@SPID
8189
";
82-
// TODO (maz): make sure savepoint name is safe
83-
84-
txn.client
85-
.simple_query(format!("SAVE TRANSACTION [{savepoint_name}]"))
86-
.await?;
87-
let result = txn.client.simple_query(current_lsn_query).await?;
90+
let result = txn.client.simple_query(CURRENT_LSN_QUERY).await?;
8891
parse_numeric_lsn(&result)
8992
}
9093

94+
pub async fn lock_table(
95+
txn: &mut Transaction<'_>,
96+
schema: &str,
97+
table: &str,
98+
) -> Result<(), SqlServerError> {
99+
// This query probably seems odd, but there is no LOCK command in MS SQL. Locks are specified
100+
// in SELECT using the WITH keyword. This query does not need to return any rows to lock the table,
101+
// hence the 1=0, which is something short that always evaluates to false in this universe;
102+
let query = format!("SELECT * FROM {schema}.{table} WITH (TABLOCKX) WHERE 1=0;");
103+
let _result = txn.client.query(query, &[]).await?;
104+
Ok(())
105+
}
106+
107+
/// Parse an [`Lsn`] in Decimal(25,0) format of the provided [`tiberius::Row`].
108+
///
109+
/// Returns an error if the provided slice doesn't have exactly one row.
110+
fn parse_numeric_lsn(row: &[tiberius::Row]) -> Result<Lsn, SqlServerError> {
111+
match row {
112+
[r] => {
113+
let numeric_lsn = r
114+
.try_get::<Numeric, _>(0)?
115+
.ok_or_else(|| SqlServerError::NullLsn)?;
116+
let lsn = Lsn::try_from(numeric_lsn).map_err(|msg| SqlServerError::InvalidData {
117+
column_name: "lsn".to_string(),
118+
error: msg,
119+
})?;
120+
Ok(lsn)
121+
}
122+
other => Err(SqlServerError::InvalidData {
123+
column_name: "lsn".to_string(),
124+
error: format!("expected 1 column, got {other:?}"),
125+
}),
126+
}
127+
}
128+
91129
/// Parse an [`Lsn`] from the first column of the provided [`tiberius::Row`].
92130
///
93131
/// Returns an error if the provided slice doesn't have exactly one row.
@@ -114,28 +152,6 @@ fn parse_lsn(result: &[tiberius::Row]) -> Result<Lsn, SqlServerError> {
114152
}
115153
}
116154

117-
/// Parse an [`Lsn`] in Decimal(25,0) format of the provided [`tiberius::Row`].
118-
///
119-
/// Returns an error if the provided slice doesn't have exactly one row.
120-
fn parse_numeric_lsn(row: &[tiberius::Row]) -> Result<Lsn, SqlServerError> {
121-
match row {
122-
[r] => {
123-
let numeric_lsn = r
124-
.try_get::<Numeric, _>(0)?
125-
.ok_or_else(|| SqlServerError::NullLsn)?;
126-
let lsn = Lsn::try_from(numeric_lsn).map_err(|msg| SqlServerError::InvalidData {
127-
column_name: "lsn".to_string(),
128-
error: msg,
129-
})?;
130-
Ok(lsn)
131-
}
132-
other => Err(SqlServerError::InvalidData {
133-
column_name: "lsn".to_string(),
134-
error: format!("expected 1 column, got {other:?}"),
135-
}),
136-
}
137-
}
138-
139155
/// Queries the specified capture instance and returns all changes from
140156
/// `[start_lsn, end_lsn)`, ordered by `start_lsn` in an ascending fashion.
141157
///

src/sql-server-util/src/lib.rs

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use futures::{FutureExt, Stream, StreamExt, TryStreamExt};
2020
use mz_ore::result::ResultExt;
2121
use mz_repr::ScalarType;
2222
use smallvec::{SmallVec, smallvec};
23-
use tiberius::ToSql;
23+
use tiberius::{Query, ToSql};
2424
use tokio::net::TcpStream;
2525
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
2626
use tokio::sync::oneshot;
@@ -336,6 +336,7 @@ pub type RowStream<'a> =
336336
pub struct Transaction<'a> {
337337
client: &'a mut Client,
338338
closed: bool,
339+
nested_xact_names: Vec<String>,
339340
}
340341

341342
impl<'a> Transaction<'a> {
@@ -352,10 +353,43 @@ impl<'a> Transaction<'a> {
352353
Ok(Transaction {
353354
client,
354355
closed: false,
356+
nested_xact_names: Default::default(),
355357
})
356358
}
357359
}
358360

361+
/// Creates a savepoint with a transaction that can be committed or rolled back
362+
/// without affecting the out transaction.
363+
pub async fn create_savepoint<'q>(
364+
&mut self,
365+
savepoint_name: &str,
366+
) -> Result<(), SqlServerError> {
367+
let stmt = format!("SAVE TRANSACTION [{savepoint_name}]");
368+
let _result = self.client.simple_query(stmt).await?;
369+
self.nested_xact_names.push(savepoint_name.to_string());
370+
Ok(())
371+
}
372+
373+
pub async fn rollback_savepoint<'q>(
374+
&mut self,
375+
savepoint_name: &str,
376+
) -> Result<(), SqlServerError> {
377+
let last_xact_name = self.nested_xact_names.pop();
378+
if last_xact_name
379+
.as_ref()
380+
.is_none_or(|last_xact_name| *last_xact_name != savepoint_name)
381+
{
382+
panic!(
383+
"Attempt to rollback savepoint {savepoint_name} doesn't match last savepoint {:?}",
384+
last_xact_name
385+
);
386+
}
387+
let stmt = format!("ROLLBACK TRANSACTION [{savepoint_name}]");
388+
let _result = self.client.simple_query(stmt).await?;
389+
self.nested_xact_names.push(savepoint_name.to_string());
390+
Ok(())
391+
}
392+
359393
/// See [`Client::execute`].
360394
pub async fn execute<'q>(
361395
&mut self,

0 commit comments

Comments
 (0)