@@ -26,9 +26,9 @@ const MIGRATIONS_FILENAME: &str = "migrations.toml";
26
26
#[ cfg( feature = "test" ) ]
27
27
const MIGRATIONS_FILENAME : & str = "test.migrations.toml" ;
28
28
29
+ mod delete;
29
30
mod insert;
30
31
mod update;
31
- mod delete;
32
32
33
33
#[ derive( Debug , Clone ) ]
34
34
struct Table {
@@ -80,6 +80,11 @@ struct ExecuteTokens {
80
80
tokens : proc_macro2:: TokenStream ,
81
81
}
82
82
83
+ #[ derive( Debug ) ]
84
+ struct UpdateTokens {
85
+ tokens : proc_macro2:: TokenStream ,
86
+ }
87
+
83
88
#[ derive( Clone , Debug ) ]
84
89
struct SingleColumn {
85
90
table : Ident ,
@@ -230,8 +235,9 @@ fn _extract_stmt_members(stmt: &Statement, span: &Span) -> MembersAndCasters {
230
235
enum ParseStatementType {
231
236
Execute ,
232
237
Select ,
238
+ Update ,
233
239
}
234
- use ParseStatementType :: { Execute , Select } ;
240
+ use ParseStatementType :: { Execute , Select , Update } ;
235
241
236
242
#[ derive( Debug ) ]
237
243
struct StatementInfo {
@@ -414,22 +420,25 @@ fn do_parse_tokens(
414
420
None
415
421
} ;
416
422
417
- let ( sql, params, sql_and_parameters_tokens) = parse_interpolated_sql ( input) ?;
423
+ let ( mut sql, params, sql_and_parameters_tokens) = parse_interpolated_sql ( input) ?;
418
424
419
425
// Try validating SQL as-is
420
426
421
- let stmt_info = sql. as_ref ( ) . and_then ( |s| validate_sql ( s) . ok ( ) ) ;
427
+ let mut stmt_info = sql. as_ref ( ) . and_then ( |s| validate_sql ( s) . ok ( ) ) ;
422
428
423
- // Try adding SELECT if it didn't validate
429
+ // Try adding SELECT or UPDATE if it didn't validate
424
430
425
- let ( sql, stmt_info) = match ( sql, stmt_info) {
426
- ( Some ( sql) , None ) => {
427
- let sql_with_select = format ! ( "SELECT {}" , sql) ;
428
- let stmt_info = validate_sql ( & sql_with_select) . ok ( ) ;
429
- ( Some ( if stmt_info. is_some ( ) { sql_with_select } else { sql } ) , stmt_info)
431
+ if let ( ty @ ( Select | Update ) , Some ( orig_sql) , None ) = ( & statement_type, & sql, & stmt_info) {
432
+ let sql_modified = match ty {
433
+ Select => format ! ( "SELECT {}" , orig_sql) ,
434
+ Update => format ! ( "UPDATE {}" , orig_sql) ,
435
+ _ => unreachable ! ( ) ,
436
+ } ;
437
+ if let Ok ( stmt_info_modified) = validate_sql ( & sql_modified) {
438
+ sql = Some ( sql_modified) ;
439
+ stmt_info = Some ( stmt_info_modified) ;
430
440
}
431
- t => t,
432
- } ;
441
+ }
433
442
434
443
if is_rust_analyzer ( ) {
435
444
return Ok ( if let Some ( ty) = result_type {
@@ -547,7 +556,7 @@ fn do_parse_tokens(
547
556
// if we return no columns, this should be an execute
548
557
549
558
if stmt_info. column_names . is_empty ( ) {
550
- if ! matches ! ( statement_type, Execute ) {
559
+ if matches ! ( statement_type, Select ) {
551
560
abort_call_site ! ( "No rows returned from SQL, use execute! instead." ) ;
552
561
}
553
562
@@ -660,6 +669,12 @@ impl Parse for ExecuteTokens {
660
669
}
661
670
}
662
671
672
+ impl Parse for UpdateTokens {
673
+ fn parse ( input : ParseStream ) -> syn:: Result < Self > {
674
+ Ok ( UpdateTokens { tokens : do_parse_tokens ( input, Update ) ? } )
675
+ }
676
+ }
677
+
663
678
/// Executes a SQL statement. On success, returns the number of rows that were changed or inserted or deleted.
664
679
#[ proc_macro]
665
680
#[ proc_macro_error]
@@ -676,6 +691,14 @@ pub fn select(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
676
691
proc_macro:: TokenStream :: from ( tokens)
677
692
}
678
693
694
+ /// Executes a SQL statement with optionally automatic `UPDATE` clause. On success, returns the number of rows that were changed.
695
+ #[ proc_macro]
696
+ #[ proc_macro_error]
697
+ pub fn update ( input : proc_macro:: TokenStream ) -> proc_macro:: TokenStream {
698
+ let UpdateTokens { tokens } = parse_macro_input ! ( input) ;
699
+ proc_macro:: TokenStream :: from ( tokens)
700
+ }
701
+
679
702
/// Derive this on a `struct` to create a corresponding SQLite table and `Turbosql` trait methods.
680
703
#[ proc_macro_derive( Turbosql , attributes( turbosql) ) ]
681
704
#[ proc_macro_error]
0 commit comments