@@ -5,17 +5,17 @@ use crate::sharding::{Sharder, ShardingFunction};
5
5
use bytes:: { Buf , BytesMut } ;
6
6
use log:: { debug, error} ;
7
7
use once_cell:: sync:: OnceCell ;
8
- use regex:: RegexSet ;
8
+ use regex:: { Regex , RegexSet } ;
9
9
use sqlparser:: ast:: Statement :: { Query , StartTransaction } ;
10
10
use sqlparser:: dialect:: PostgreSqlDialect ;
11
11
use sqlparser:: parser:: Parser ;
12
12
13
13
const CUSTOM_SQL_REGEXES : [ & str ; 5 ] = [
14
- r"(?i)SET SHARDING KEY TO '[0-9]+' " ,
15
- r"(?i)SET SHARD TO '[0-9]+' " ,
16
- r"(?i)SHOW SHARD" ,
17
- r"(?i)SET SERVER ROLE TO '(PRIMARY|REPLICA|ANY|AUTO|DEFAULT)'" ,
18
- r"(?i)SHOW SERVER ROLE" ,
14
+ r"(?i)^ * SET SHARDING KEY TO '?( [0-9]+)'? *;? *$ " ,
15
+ r"(?i)^ * SET SHARD TO '?( [0-9]+)'? *;? *$ " ,
16
+ r"(?i)^ * SHOW SHARD *;? *$ " ,
17
+ r"(?i)^ * SET SERVER ROLE TO '(PRIMARY|REPLICA|ANY|AUTO|DEFAULT)' *;? *$ " ,
18
+ r"(?i)^ * SHOW SERVER ROLE *;? *$ " ,
19
19
] ;
20
20
21
21
#[ derive( PartialEq , Debug ) ]
@@ -27,8 +27,12 @@ pub enum Command {
27
27
ShowServerRole ,
28
28
}
29
29
30
+ // Quick test
30
31
static CUSTOM_SQL_REGEX_SET : OnceCell < RegexSet > = OnceCell :: new ( ) ;
31
32
33
+ // Capture value
34
+ static CUSTOM_SQL_REGEX_LIST : OnceCell < Vec < Regex > > = OnceCell :: new ( ) ;
35
+
32
36
pub struct QueryRouter {
33
37
// By default, queries go here, unless we have better information
34
38
// about what the client wants.
@@ -63,6 +67,21 @@ impl QueryRouter {
63
67
}
64
68
} ;
65
69
70
+ let list: Vec < _ > = CUSTOM_SQL_REGEXES
71
+ . iter ( )
72
+ . map ( |rgx| Regex :: new ( rgx) . unwrap ( ) )
73
+ . collect ( ) ;
74
+
75
+ // Impossible
76
+ if list. len ( ) != set. len ( ) {
77
+ return false ;
78
+ }
79
+
80
+ match CUSTOM_SQL_REGEX_LIST . set ( list) {
81
+ Ok ( _) => true ,
82
+ Err ( _) => return false ,
83
+ } ;
84
+
66
85
match CUSTOM_SQL_REGEX_SET . set ( set) {
67
86
Ok ( _) => true ,
68
87
Err ( _) => false ,
@@ -113,6 +132,11 @@ impl QueryRouter {
113
132
None => return None ,
114
133
} ;
115
134
135
+ let regex_list = match CUSTOM_SQL_REGEX_LIST . get ( ) {
136
+ Some ( regex_list) => regex_list,
137
+ None => return None ,
138
+ } ;
139
+
116
140
let matches: Vec < _ > = regex_set. matches ( & query) . into_iter ( ) . collect ( ) ;
117
141
118
142
if matches. len ( ) != 1 {
@@ -130,7 +154,19 @@ impl QueryRouter {
130
154
131
155
let mut value = match command {
132
156
Command :: SetShardingKey | Command :: SetShard | Command :: SetServerRole => {
133
- query. split ( "'" ) . collect :: < Vec < & str > > ( ) [ 1 ] . to_string ( )
157
+ // Capture value. I know this re-runs the regex engine, but I haven't
158
+ // figured out a better way just yet. I think I can write a single Regex
159
+ // that matches all 5 custom SQL patterns, but maybe that's not very legible?
160
+ //
161
+ // I think this is faster than running the Regex engine 5 times, so
162
+ // this is a strong maybe for me so far.
163
+ match regex_list[ matches[ 0 ] ] . captures ( & query) {
164
+ Some ( captures) => match captures. get ( 1 ) {
165
+ Some ( value) => value. as_str ( ) . to_string ( ) ,
166
+ None => return None ,
167
+ } ,
168
+ None => return None ,
169
+ }
134
170
}
135
171
136
172
Command :: ShowShard => self . shard ( ) . to_string ( ) ,
@@ -411,14 +447,38 @@ mod test {
411
447
"set server role to 'any'" ,
412
448
"set server role to 'auto'" ,
413
449
"show server role" ,
450
+ // No quotes
451
+ "SET SHARDING KEY TO 11235" ,
452
+ "SET SHARD TO 15" ,
453
+ // Spaces and semicolon
454
+ " SET SHARDING KEY TO 11235 ; " ,
455
+ " SET SHARD TO 15; " ,
456
+ " SET SHARDING KEY TO 11235 ;" ,
457
+ " SET SERVER ROLE TO 'primary'; " ,
458
+ " SET SERVER ROLE TO 'primary' ; " ,
459
+ " SET SERVER ROLE TO 'primary' ;" ,
460
+ ] ;
461
+
462
+ // Which regexes it'll match to in the list
463
+ let matches = [
464
+ 0 , 1 , 2 , 3 , 3 , 3 , 3 , 4 , 0 , 1 , 2 , 3 , 3 , 3 , 3 , 4 , 0 , 1 , 0 , 1 , 0 , 3 , 3 , 3 ,
414
465
] ;
415
466
467
+ let list = CUSTOM_SQL_REGEX_LIST . get ( ) . unwrap ( ) ;
416
468
let set = CUSTOM_SQL_REGEX_SET . get ( ) . unwrap ( ) ;
417
469
418
- for test in & tests {
419
- let matches: Vec < _ > = set. matches ( test) . into_iter ( ) . collect ( ) ;
470
+ for ( i, test) in tests. iter ( ) . enumerate ( ) {
471
+ assert ! ( list[ matches[ i] ] . is_match( test) ) ;
472
+ assert_eq ! ( set. matches( test) . into_iter( ) . collect:: <Vec <_>>( ) . len( ) , 1 ) ;
473
+ }
474
+
475
+ let bad = [
476
+ "SELECT * FROM table" ,
477
+ "SELECT * FROM table WHERE value = 'set sharding key to 5'" , // Don't capture things in the middle of the query
478
+ ] ;
420
479
421
- assert_eq ! ( matches. len( ) , 1 ) ;
480
+ for query in & bad {
481
+ assert_eq ! ( set. matches( query) . into_iter( ) . collect:: <Vec <_>>( ) . len( ) , 0 ) ;
422
482
}
423
483
}
424
484
@@ -428,7 +488,7 @@ mod test {
428
488
let mut qr = QueryRouter :: new ( ) ;
429
489
430
490
// SetShardingKey
431
- let query = simple_query ( "SET SHARDING KEY TO '13' " ) ;
491
+ let query = simple_query ( "SET SHARDING KEY TO 13 " ) ;
432
492
assert_eq ! (
433
493
qr. try_execute_command( query) ,
434
494
Some ( ( Command :: SetShardingKey , String :: from( "1" ) ) )
0 commit comments