Skip to content

Commit c3eaf02

Browse files
authored
Automatic sharding for SELECT v2 (#337)
* More comprehensive read sharding support * A few fixes * fq * comment * wildcard
1 parent 02839e4 commit c3eaf02

File tree

3 files changed

+192
-28
lines changed

3 files changed

+192
-28
lines changed

pgcat.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ primary_reads_enabled = true
102102
sharding_function = "pg_bigint_hash"
103103

104104
# Automatically parse this from queries and route queries to the right shard!
105-
automatic_sharding_key = "id"
105+
automatic_sharding_key = "data.id"
106106

107107
# Idle timeout can be overwritten in the pool
108108
idle_timeout = 40000

src/config.rs

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ impl Pool {
374374
None
375375
}
376376

377-
pub fn validate(&self) -> Result<(), Error> {
377+
pub fn validate(&mut self) -> Result<(), Error> {
378378
match self.default_role.as_ref() {
379379
"any" => (),
380380
"primary" => (),
@@ -414,6 +414,25 @@ impl Pool {
414414
}
415415
}
416416

417+
self.automatic_sharding_key = match &self.automatic_sharding_key {
418+
Some(key) => {
419+
// No quotes in the key so we don't have to compare quoted
420+
// to unquoted idents.
421+
let key = key.replace("\"", "");
422+
423+
if key.split(".").count() != 2 {
424+
error!(
425+
"automatic_sharding_key '{}' must be fully qualified, e.g. t.{}`",
426+
key, key
427+
);
428+
return Err(Error::BadConfig);
429+
}
430+
431+
Some(key)
432+
}
433+
None => None,
434+
};
435+
417436
Ok(())
418437
}
419438
}

src/query_router.rs

Lines changed: 171 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ use log::{debug, error};
55
use once_cell::sync::OnceCell;
66
use regex::{Regex, RegexSet};
77
use sqlparser::ast::Statement::{Query, StartTransaction};
8-
use sqlparser::ast::{BinaryOperator, Expr, SetExpr, Value};
8+
use sqlparser::ast::{
9+
BinaryOperator, Expr, Ident, JoinConstraint, JoinOperator, SetExpr, TableFactor, Value,
10+
};
911
use sqlparser::dialect::PostgreSqlDialect;
1012
use sqlparser::parser::Parser;
1113

@@ -403,20 +405,67 @@ impl QueryRouter {
403405

404406
/// A `selection` is the `WHERE` clause. This parses
405407
/// the clause and extracts the sharding key, if present.
406-
fn selection_parser(&self, expr: &Expr) -> Vec<i64> {
408+
fn selection_parser(&self, expr: &Expr, table_names: &Vec<Vec<Ident>>) -> Vec<i64> {
407409
let mut result = Vec::new();
408410
let mut found = false;
409411

412+
let sharding_key = self
413+
.pool_settings
414+
.automatic_sharding_key
415+
.as_ref()
416+
.unwrap()
417+
.split(".")
418+
.map(|ident| Ident::new(ident))
419+
.collect::<Vec<Ident>>();
420+
421+
// Sharding key must be always fully qualified
422+
assert_eq!(sharding_key.len(), 2);
423+
410424
// This parses `sharding_key = 5`. But it's technically
411425
// legal to write `5 = sharding_key`. I don't judge the people
412426
// who do that, but I think ORMs will still use the first variant,
413427
// so we can leave the second as a TODO.
414428
if let Expr::BinaryOp { left, op, right } = expr {
415429
match &**left {
416-
Expr::BinaryOp { .. } => result.extend(self.selection_parser(left)),
430+
Expr::BinaryOp { .. } => result.extend(self.selection_parser(left, table_names)),
417431
Expr::Identifier(ident) => {
418-
found =
419-
ident.value == *self.pool_settings.automatic_sharding_key.as_ref().unwrap();
432+
// Only if we're dealing with only one table
433+
// and there is no ambiguity
434+
if &ident.value == &sharding_key[1].value {
435+
// Sharding key is unique enough, don't worry about
436+
// table names.
437+
if &sharding_key[0].value == "*" {
438+
found = true;
439+
} else if table_names.len() == 1 {
440+
let table = &table_names[0];
441+
442+
if table.len() == 1 {
443+
// Table is not fully qualified, e.g.
444+
// SELECT * FROM t WHERE sharding_key = 5
445+
// Make sure the table name from the sharding key matches
446+
// the table name from the query.
447+
found = &sharding_key[0].value == &table[0].value;
448+
} else if table.len() == 2 {
449+
// Table name is fully qualified with the schema: e.g.
450+
// SELECT * FROM public.t WHERE sharding_key = 5
451+
// Ignore the schema (TODO: at some point, we want schema support)
452+
// and use the table name only.
453+
found = &sharding_key[0].value == &table[1].value;
454+
} else {
455+
debug!("Got table name with more than two idents, which is not possible");
456+
}
457+
}
458+
}
459+
}
460+
461+
Expr::CompoundIdentifier(idents) => {
462+
// The key is fully qualified in the query,
463+
// it will exist or Postgres will throw an error.
464+
if idents.len() == 2 {
465+
found = &sharding_key[0].value == &idents[0].value
466+
&& &sharding_key[1].value == &idents[1].value;
467+
}
468+
// TODO: key can have schema as well, e.g. public.data.id (len == 3)
420469
}
421470
_ => (),
422471
};
@@ -433,7 +482,7 @@ impl QueryRouter {
433482
};
434483

435484
match &**right {
436-
Expr::BinaryOp { .. } => result.extend(self.selection_parser(right)),
485+
Expr::BinaryOp { .. } => result.extend(self.selection_parser(right, table_names)),
437486
Expr::Value(Value::Number(value, ..)) => {
438487
if found {
439488
match value.parse::<i64>() {
@@ -456,6 +505,7 @@ impl QueryRouter {
456505
/// Try to figure out which shard the query should go to.
457506
fn infer_shard(&self, query: &sqlparser::ast::Query) -> Option<usize> {
458507
let mut shards = BTreeSet::new();
508+
let mut exprs = Vec::new();
459509

460510
match &*query.body {
461511
SetExpr::Query(query) => {
@@ -467,27 +517,75 @@ impl QueryRouter {
467517
};
468518
}
469519

520+
// SELECT * FROM ...
521+
// We understand that pretty well.
470522
SetExpr::Select(select) => {
471-
match &select.selection {
472-
Some(selection) => {
473-
let sharding_keys = self.selection_parser(selection);
523+
// Collect all table names from the query.
524+
let mut table_names = Vec::new();
474525

475-
// TODO: Add support for prepared statements here.
476-
// This should just give us the position of the value in the `B` message.
526+
for table in select.from.iter() {
527+
match &table.relation {
528+
TableFactor::Table { name, .. } => {
529+
table_names.push(name.0.clone());
530+
}
477531

478-
let sharder = Sharder::new(
479-
self.pool_settings.shards,
480-
self.pool_settings.sharding_function,
481-
);
532+
_ => (),
533+
};
482534

483-
for value in sharding_keys {
484-
let shard = sharder.shard(value);
485-
shards.insert(shard);
486-
}
535+
// Get table names from all the joins.
536+
for join in table.joins.iter() {
537+
match &join.relation {
538+
TableFactor::Table { name, .. } => {
539+
table_names.push(name.0.clone());
540+
}
541+
542+
_ => (),
543+
};
544+
545+
// We can filter results based on join conditions, e.g.
546+
// SELECT * FROM t INNER JOIN B ON B.sharding_key = 5;
547+
match &join.join_operator {
548+
JoinOperator::Inner(inner_join) => match &inner_join {
549+
JoinConstraint::On(expr) => {
550+
// Parse the selection criteria later.
551+
exprs.push(expr.clone());
552+
}
553+
554+
_ => (),
555+
},
556+
557+
_ => (),
558+
};
559+
}
560+
}
561+
562+
// Parse the actual "FROM ..."
563+
match &select.selection {
564+
Some(selection) => {
565+
exprs.push(selection.clone());
487566
}
488567

489568
None => (),
490569
};
570+
571+
// Look for sharding keys in either the join condition
572+
// or the selection.
573+
for expr in exprs.iter() {
574+
let sharding_keys = self.selection_parser(expr, &table_names);
575+
576+
// TODO: Add support for prepared statements here.
577+
// This should just give us the position of the value in the `B` message.
578+
579+
let sharder = Sharder::new(
580+
self.pool_settings.shards,
581+
self.pool_settings.sharding_function,
582+
);
583+
584+
for value in sharding_keys {
585+
let shard = sharder.shard(value);
586+
shards.insert(shard);
587+
}
588+
}
491589
}
492590
_ => (),
493591
};
@@ -825,7 +923,7 @@ mod test {
825923
query_parser_enabled: true,
826924
primary_reads_enabled: false,
827925
sharding_function: ShardingFunction::PgBigintHash,
828-
automatic_sharding_key: Some(String::from("id")),
926+
automatic_sharding_key: Some(String::from("test.id")),
829927
healthcheck_delay: PoolSettings::default().healthcheck_delay,
830928
healthcheck_timeout: PoolSettings::default().healthcheck_timeout,
831929
ban_time: PoolSettings::default().ban_time,
@@ -854,11 +952,6 @@ mod test {
854952
let q2 = simple_query("SET SERVER ROLE TO 'default'");
855953
assert!(qr.try_execute_command(&q2) != None);
856954
assert_eq!(qr.active_role.unwrap(), pool_settings.default_role);
857-
858-
// Here we go :)
859-
let q3 = simple_query("SELECT * FROM test WHERE id = 5 AND values IN (1, 2, 3)");
860-
assert!(qr.infer(&q3));
861-
assert_eq!(qr.shard(), 1);
862955
}
863956

864957
#[test]
@@ -891,7 +984,7 @@ mod test {
891984
query_parser_enabled: true,
892985
primary_reads_enabled: false,
893986
sharding_function: ShardingFunction::PgBigintHash,
894-
automatic_sharding_key: Some(String::from("id")),
987+
automatic_sharding_key: None,
895988
healthcheck_delay: PoolSettings::default().healthcheck_delay,
896989
healthcheck_timeout: PoolSettings::default().healthcheck_timeout,
897990
ban_time: PoolSettings::default().ban_time,
@@ -920,4 +1013,56 @@ mod test {
9201013
assert!(qr.try_execute_command(&q2) == None);
9211014
assert_eq!(qr.active_shard, Some(2));
9221015
}
1016+
1017+
#[test]
1018+
fn test_automatic_sharding_key() {
1019+
QueryRouter::setup();
1020+
1021+
let mut qr = QueryRouter::new();
1022+
qr.pool_settings.automatic_sharding_key = Some("data.id".to_string());
1023+
qr.pool_settings.shards = 3;
1024+
1025+
assert!(qr.infer(&simple_query("SELECT * FROM data WHERE id = 5")));
1026+
assert_eq!(qr.shard(), 2);
1027+
1028+
assert!(qr.infer(&simple_query(
1029+
"SELECT one, two, three FROM public.data WHERE id = 6"
1030+
)));
1031+
assert_eq!(qr.shard(), 0);
1032+
1033+
assert!(qr.infer(&simple_query(
1034+
"SELECT * FROM data
1035+
INNER JOIN t2 ON data.id = 5
1036+
AND t2.data_id = data.id
1037+
WHERE data.id = 5"
1038+
)));
1039+
assert_eq!(qr.shard(), 2);
1040+
1041+
// Shard did not move because we couldn't determine the sharding key since it could be ambiguous
1042+
// in the query.
1043+
assert!(qr.infer(&simple_query(
1044+
"SELECT * FROM t2 INNER JOIN data ON id = 6 AND data.id = t2.data_id"
1045+
)));
1046+
assert_eq!(qr.shard(), 2);
1047+
1048+
assert!(qr.infer(&simple_query(
1049+
r#"SELECT * FROM "public"."data" WHERE "id" = 6"#
1050+
)));
1051+
assert_eq!(qr.shard(), 0);
1052+
1053+
assert!(qr.infer(&simple_query(
1054+
r#"SELECT * FROM "public"."data" WHERE "data"."id" = 5"#
1055+
)));
1056+
assert_eq!(qr.shard(), 2);
1057+
1058+
// Super unique sharding key
1059+
qr.pool_settings.automatic_sharding_key = Some("*.unique_enough_column_name".to_string());
1060+
assert!(qr.infer(&simple_query(
1061+
"SELECT * FROM table_x WHERE unique_enough_column_name = 6"
1062+
)));
1063+
assert_eq!(qr.shard(), 0);
1064+
1065+
assert!(qr.infer(&simple_query("SELECT * FROM table_y WHERE another_key = 5")));
1066+
assert_eq!(qr.shard(), 0);
1067+
}
9231068
}

0 commit comments

Comments
 (0)