Skip to content

Commit a784883

Browse files
authored
Allow to set shard and set sharding key without quotes (#43)
* Allow to set shard and set sharding key without quotes * cover it * dont look for these in the middle of another query * friendly regex * its own response to set shard key
1 parent 5972b6f commit a784883

File tree

2 files changed

+78
-12
lines changed

2 files changed

+78
-12
lines changed

src/client.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,11 +229,17 @@ impl Client {
229229
}
230230
}
231231

232-
Some((Command::SetShard, _)) | Some((Command::SetShardingKey, _)) => {
232+
Some((Command::SetShard, _)) => {
233233
custom_protocol_response_ok(&mut self.write, &format!("SET SHARD")).await?;
234234
continue;
235235
}
236236

237+
Some((Command::SetShardingKey, _)) => {
238+
custom_protocol_response_ok(&mut self.write, &format!("SET SHARDING KEY"))
239+
.await?;
240+
continue;
241+
}
242+
237243
Some((Command::SetServerRole, _)) => {
238244
custom_protocol_response_ok(&mut self.write, "SET SERVER ROLE").await?;
239245
continue;

src/query_router.rs

Lines changed: 71 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,17 @@ use crate::sharding::{Sharder, ShardingFunction};
55
use bytes::{Buf, BytesMut};
66
use log::{debug, error};
77
use once_cell::sync::OnceCell;
8-
use regex::RegexSet;
8+
use regex::{Regex, RegexSet};
99
use sqlparser::ast::Statement::{Query, StartTransaction};
1010
use sqlparser::dialect::PostgreSqlDialect;
1111
use sqlparser::parser::Parser;
1212

1313
const CUSTOM_SQL_REGEXES: [&str; 5] = [
14-
r"(?i)SET SHARDING KEY TO '[0-9]+'",
15-
r"(?i)SET SHARD TO '[0-9]+'",
16-
r"(?i)SHOW SHARD",
17-
r"(?i)SET SERVER ROLE TO '(PRIMARY|REPLICA|ANY|AUTO|DEFAULT)'",
18-
r"(?i)SHOW SERVER ROLE",
14+
r"(?i)^ *SET SHARDING KEY TO '?([0-9]+)'? *;? *$",
15+
r"(?i)^ *SET SHARD TO '?([0-9]+)'? *;? *$",
16+
r"(?i)^ *SHOW SHARD *;? *$",
17+
r"(?i)^ *SET SERVER ROLE TO '(PRIMARY|REPLICA|ANY|AUTO|DEFAULT)' *;? *$",
18+
r"(?i)^ *SHOW SERVER ROLE *;? *$",
1919
];
2020

2121
#[derive(PartialEq, Debug)]
@@ -27,8 +27,12 @@ pub enum Command {
2727
ShowServerRole,
2828
}
2929

30+
// Quick test
3031
static CUSTOM_SQL_REGEX_SET: OnceCell<RegexSet> = OnceCell::new();
3132

33+
// Capture value
34+
static CUSTOM_SQL_REGEX_LIST: OnceCell<Vec<Regex>> = OnceCell::new();
35+
3236
pub struct QueryRouter {
3337
// By default, queries go here, unless we have better information
3438
// about what the client wants.
@@ -63,6 +67,21 @@ impl QueryRouter {
6367
}
6468
};
6569

70+
let list: Vec<_> = CUSTOM_SQL_REGEXES
71+
.iter()
72+
.map(|rgx| Regex::new(rgx).unwrap())
73+
.collect();
74+
75+
// Impossible
76+
if list.len() != set.len() {
77+
return false;
78+
}
79+
80+
match CUSTOM_SQL_REGEX_LIST.set(list) {
81+
Ok(_) => true,
82+
Err(_) => return false,
83+
};
84+
6685
match CUSTOM_SQL_REGEX_SET.set(set) {
6786
Ok(_) => true,
6887
Err(_) => false,
@@ -113,6 +132,11 @@ impl QueryRouter {
113132
None => return None,
114133
};
115134

135+
let regex_list = match CUSTOM_SQL_REGEX_LIST.get() {
136+
Some(regex_list) => regex_list,
137+
None => return None,
138+
};
139+
116140
let matches: Vec<_> = regex_set.matches(&query).into_iter().collect();
117141

118142
if matches.len() != 1 {
@@ -130,7 +154,19 @@ impl QueryRouter {
130154

131155
let mut value = match command {
132156
Command::SetShardingKey | Command::SetShard | Command::SetServerRole => {
133-
query.split("'").collect::<Vec<&str>>()[1].to_string()
157+
// Capture value. I know this re-runs the regex engine, but I haven't
158+
// figured out a better way just yet. I think I can write a single Regex
159+
// that matches all 5 custom SQL patterns, but maybe that's not very legible?
160+
//
161+
// I think this is faster than running the Regex engine 5 times, so
162+
// this is a strong maybe for me so far.
163+
match regex_list[matches[0]].captures(&query) {
164+
Some(captures) => match captures.get(1) {
165+
Some(value) => value.as_str().to_string(),
166+
None => return None,
167+
},
168+
None => return None,
169+
}
134170
}
135171

136172
Command::ShowShard => self.shard().to_string(),
@@ -411,14 +447,38 @@ mod test {
411447
"set server role to 'any'",
412448
"set server role to 'auto'",
413449
"show server role",
450+
// No quotes
451+
"SET SHARDING KEY TO 11235",
452+
"SET SHARD TO 15",
453+
// Spaces and semicolon
454+
" SET SHARDING KEY TO 11235 ; ",
455+
" SET SHARD TO 15; ",
456+
" SET SHARDING KEY TO 11235 ;",
457+
" SET SERVER ROLE TO 'primary'; ",
458+
" SET SERVER ROLE TO 'primary' ; ",
459+
" SET SERVER ROLE TO 'primary' ;",
460+
];
461+
462+
// Which regexes it'll match to in the list
463+
let matches = [
464+
0, 1, 2, 3, 3, 3, 3, 4, 0, 1, 2, 3, 3, 3, 3, 4, 0, 1, 0, 1, 0, 3, 3, 3,
414465
];
415466

467+
let list = CUSTOM_SQL_REGEX_LIST.get().unwrap();
416468
let set = CUSTOM_SQL_REGEX_SET.get().unwrap();
417469

418-
for test in &tests {
419-
let matches: Vec<_> = set.matches(test).into_iter().collect();
470+
for (i, test) in tests.iter().enumerate() {
471+
assert!(list[matches[i]].is_match(test));
472+
assert_eq!(set.matches(test).into_iter().collect::<Vec<_>>().len(), 1);
473+
}
474+
475+
let bad = [
476+
"SELECT * FROM table",
477+
"SELECT * FROM table WHERE value = 'set sharding key to 5'", // Don't capture things in the middle of the query
478+
];
420479

421-
assert_eq!(matches.len(), 1);
480+
for query in &bad {
481+
assert_eq!(set.matches(query).into_iter().collect::<Vec<_>>().len(), 0);
422482
}
423483
}
424484

@@ -428,7 +488,7 @@ mod test {
428488
let mut qr = QueryRouter::new();
429489

430490
// SetShardingKey
431-
let query = simple_query("SET SHARDING KEY TO '13'");
491+
let query = simple_query("SET SHARDING KEY TO 13");
432492
assert_eq!(
433493
qr.try_execute_command(query),
434494
Some((Command::SetShardingKey, String::from("1")))

0 commit comments

Comments
 (0)