Skip to content

Commit 87224d4

Browse files
committed
add update! macro
1 parent 46da264 commit 87224d4

File tree

4 files changed

+41
-16
lines changed

4 files changed

+41
-16
lines changed

turbosql-impl/src/lib.rs

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ const MIGRATIONS_FILENAME: &str = "migrations.toml";
2626
#[cfg(feature = "test")]
2727
const MIGRATIONS_FILENAME: &str = "test.migrations.toml";
2828

29+
mod delete;
2930
mod insert;
3031
mod update;
31-
mod delete;
3232

3333
#[derive(Debug, Clone)]
3434
struct Table {
@@ -80,6 +80,11 @@ struct ExecuteTokens {
8080
tokens: proc_macro2::TokenStream,
8181
}
8282

83+
#[derive(Debug)]
84+
struct UpdateTokens {
85+
tokens: proc_macro2::TokenStream,
86+
}
87+
8388
#[derive(Clone, Debug)]
8489
struct SingleColumn {
8590
table: Ident,
@@ -230,8 +235,9 @@ fn _extract_stmt_members(stmt: &Statement, span: &Span) -> MembersAndCasters {
230235
enum ParseStatementType {
231236
Execute,
232237
Select,
238+
Update,
233239
}
234-
use ParseStatementType::{Execute, Select};
240+
use ParseStatementType::{Execute, Select, Update};
235241

236242
#[derive(Debug)]
237243
struct StatementInfo {
@@ -414,22 +420,25 @@ fn do_parse_tokens(
414420
None
415421
};
416422

417-
let (sql, params, sql_and_parameters_tokens) = parse_interpolated_sql(input)?;
423+
let (mut sql, params, sql_and_parameters_tokens) = parse_interpolated_sql(input)?;
418424

419425
// Try validating SQL as-is
420426

421-
let stmt_info = sql.as_ref().and_then(|s| validate_sql(s).ok());
427+
let mut stmt_info = sql.as_ref().and_then(|s| validate_sql(s).ok());
422428

423-
// Try adding SELECT if it didn't validate
429+
// Try adding SELECT or UPDATE if it didn't validate
424430

425-
let (sql, stmt_info) = match (sql, stmt_info) {
426-
(Some(sql), None) => {
427-
let sql_with_select = format!("SELECT {}", sql);
428-
let stmt_info = validate_sql(&sql_with_select).ok();
429-
(Some(if stmt_info.is_some() { sql_with_select } else { sql }), stmt_info)
431+
if let (ty @ (Select | Update), Some(orig_sql), None) = (&statement_type, &sql, &stmt_info) {
432+
let sql_modified = match ty {
433+
Select => format!("SELECT {}", orig_sql),
434+
Update => format!("UPDATE {}", orig_sql),
435+
_ => unreachable!(),
436+
};
437+
if let Ok(stmt_info_modified) = validate_sql(&sql_modified) {
438+
sql = Some(sql_modified);
439+
stmt_info = Some(stmt_info_modified);
430440
}
431-
t => t,
432-
};
441+
}
433442

434443
if is_rust_analyzer() {
435444
return Ok(if let Some(ty) = result_type {
@@ -547,7 +556,7 @@ fn do_parse_tokens(
547556
// if we return no columns, this should be an execute
548557

549558
if stmt_info.column_names.is_empty() {
550-
if !matches!(statement_type, Execute) {
559+
if matches!(statement_type, Select) {
551560
abort_call_site!("No rows returned from SQL, use execute! instead.");
552561
}
553562

@@ -660,6 +669,12 @@ impl Parse for ExecuteTokens {
660669
}
661670
}
662671

672+
impl Parse for UpdateTokens {
673+
fn parse(input: ParseStream) -> syn::Result<Self> {
674+
Ok(UpdateTokens { tokens: do_parse_tokens(input, Update)? })
675+
}
676+
}
677+
663678
/// Executes a SQL statement. On success, returns the number of rows that were changed or inserted or deleted.
664679
#[proc_macro]
665680
#[proc_macro_error]
@@ -676,6 +691,14 @@ pub fn select(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
676691
proc_macro::TokenStream::from(tokens)
677692
}
678693

694+
/// Executes a SQL statement with optionally automatic `UPDATE` clause. On success, returns the number of rows that were changed.
695+
#[proc_macro]
696+
#[proc_macro_error]
697+
pub fn update(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
698+
let UpdateTokens { tokens } = parse_macro_input!(input);
699+
proc_macro::TokenStream::from(tokens)
700+
}
701+
679702
/// Derive this on a `struct` to create a corresponding SQLite table and `Turbosql` trait methods.
680703
#[proc_macro_derive(Turbosql, attributes(turbosql))]
681704
#[proc_macro_error]

turbosql/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ compile_error!("turbosql must be tested with '--features test -- --test-threads=
88
include!("lib_inner.rs");
99

1010
#[cfg(target_arch = "wasm32")]
11-
pub use turbosql_impl::{execute, select, Turbosql};
11+
pub use turbosql_impl::{execute, select, update, Turbosql};
1212

1313
#[cfg(target_arch = "wasm32")]
1414
pub fn now_ms() -> i64 {

turbosql/src/lib_inner.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ pub use rusqlite::{
2121
pub use serde::Serialize;
2222
#[doc(hidden)]
2323
pub use serde_json;
24-
pub use turbosql_impl::{execute, select, Turbosql};
24+
pub use turbosql_impl::{execute, select, update, Turbosql};
2525

2626
/// Wrapper for `Vec<u8>` that may one day impl `Read`, `Write` and `Seek` traits.
2727
pub type Blob = Vec<u8>;

turbosql/tests/integration_test.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ compile_error!("turbosql must be tested with '--features test -- --test-threads=
77
#[cfg(not(test))]
88
compile_error!("integration_tests.rs must be run in test mode");
99

10-
use turbosql::{execute, select, Blob, Turbosql};
10+
use turbosql::{execute, select, update, Blob, Turbosql};
1111

1212
#[derive(Turbosql, Default, Debug, PartialEq, Clone)]
1313
struct PersonIntegrationTest {
@@ -301,6 +301,8 @@ fn integration_test() {
301301

302302
execute!("INSERT INTO personintegrationtest(field_u8, field_i8) VALUES (" 1, 2 ")").unwrap();
303303

304+
update!("personintegrationtest SET field_u8 = " 0).unwrap();
305+
304306
// DELETE
305307

306308
assert!(execute!("DELETE FROM personintegrationtest").is_ok());

0 commit comments

Comments
 (0)