@@ -5,7 +5,9 @@ use log::{debug, error};
5
5
use once_cell:: sync:: OnceCell ;
6
6
use regex:: { Regex , RegexSet } ;
7
7
use sqlparser:: ast:: Statement :: { Query , StartTransaction } ;
8
- use sqlparser:: ast:: { BinaryOperator , Expr , SetExpr , Value } ;
8
+ use sqlparser:: ast:: {
9
+ BinaryOperator , Expr , Ident , JoinConstraint , JoinOperator , SetExpr , TableFactor , Value ,
10
+ } ;
9
11
use sqlparser:: dialect:: PostgreSqlDialect ;
10
12
use sqlparser:: parser:: Parser ;
11
13
@@ -403,20 +405,67 @@ impl QueryRouter {
403
405
404
406
/// A `selection` is the `WHERE` clause. This parses
405
407
/// the clause and extracts the sharding key, if present.
406
- fn selection_parser ( & self , expr : & Expr ) -> Vec < i64 > {
408
+ fn selection_parser ( & self , expr : & Expr , table_names : & Vec < Vec < Ident > > ) -> Vec < i64 > {
407
409
let mut result = Vec :: new ( ) ;
408
410
let mut found = false ;
409
411
412
+ let sharding_key = self
413
+ . pool_settings
414
+ . automatic_sharding_key
415
+ . as_ref ( )
416
+ . unwrap ( )
417
+ . split ( "." )
418
+ . map ( |ident| Ident :: new ( ident) )
419
+ . collect :: < Vec < Ident > > ( ) ;
420
+
421
+ // Sharding key must be always fully qualified
422
+ assert_eq ! ( sharding_key. len( ) , 2 ) ;
423
+
410
424
// This parses `sharding_key = 5`. But it's technically
411
425
// legal to write `5 = sharding_key`. I don't judge the people
412
426
// who do that, but I think ORMs will still use the first variant,
413
427
// so we can leave the second as a TODO.
414
428
if let Expr :: BinaryOp { left, op, right } = expr {
415
429
match & * * left {
416
- Expr :: BinaryOp { .. } => result. extend ( self . selection_parser ( left) ) ,
430
+ Expr :: BinaryOp { .. } => result. extend ( self . selection_parser ( left, table_names ) ) ,
417
431
Expr :: Identifier ( ident) => {
418
- found =
419
- ident. value == * self . pool_settings . automatic_sharding_key . as_ref ( ) . unwrap ( ) ;
432
+ // Only if we're dealing with only one table
433
+ // and there is no ambiguity
434
+ if & ident. value == & sharding_key[ 1 ] . value {
435
+ // Sharding key is unique enough, don't worry about
436
+ // table names.
437
+ if & sharding_key[ 0 ] . value == "*" {
438
+ found = true ;
439
+ } else if table_names. len ( ) == 1 {
440
+ let table = & table_names[ 0 ] ;
441
+
442
+ if table. len ( ) == 1 {
443
+ // Table is not fully qualified, e.g.
444
+ // SELECT * FROM t WHERE sharding_key = 5
445
+ // Make sure the table name from the sharding key matches
446
+ // the table name from the query.
447
+ found = & sharding_key[ 0 ] . value == & table[ 0 ] . value ;
448
+ } else if table. len ( ) == 2 {
449
+ // Table name is fully qualified with the schema: e.g.
450
+ // SELECT * FROM public.t WHERE sharding_key = 5
451
+ // Ignore the schema (TODO: at some point, we want schema support)
452
+ // and use the table name only.
453
+ found = & sharding_key[ 0 ] . value == & table[ 1 ] . value ;
454
+ } else {
455
+ debug ! ( "Got table name with more than two idents, which is not possible" ) ;
456
+ }
457
+ }
458
+ }
459
+ }
460
+
461
+ Expr :: CompoundIdentifier ( idents) => {
462
+ // The key is fully qualified in the query,
463
+ // it will exist or Postgres will throw an error.
464
+ if idents. len ( ) == 2 {
465
+ found = & sharding_key[ 0 ] . value == & idents[ 0 ] . value
466
+ && & sharding_key[ 1 ] . value == & idents[ 1 ] . value ;
467
+ }
468
+ // TODO: key can have schema as well, e.g. public.data.id (len == 3)
420
469
}
421
470
_ => ( ) ,
422
471
} ;
@@ -433,7 +482,7 @@ impl QueryRouter {
433
482
} ;
434
483
435
484
match & * * right {
436
- Expr :: BinaryOp { .. } => result. extend ( self . selection_parser ( right) ) ,
485
+ Expr :: BinaryOp { .. } => result. extend ( self . selection_parser ( right, table_names ) ) ,
437
486
Expr :: Value ( Value :: Number ( value, ..) ) => {
438
487
if found {
439
488
match value. parse :: < i64 > ( ) {
@@ -456,6 +505,7 @@ impl QueryRouter {
456
505
/// Try to figure out which shard the query should go to.
457
506
fn infer_shard ( & self , query : & sqlparser:: ast:: Query ) -> Option < usize > {
458
507
let mut shards = BTreeSet :: new ( ) ;
508
+ let mut exprs = Vec :: new ( ) ;
459
509
460
510
match & * query. body {
461
511
SetExpr :: Query ( query) => {
@@ -467,27 +517,75 @@ impl QueryRouter {
467
517
} ;
468
518
}
469
519
520
+ // SELECT * FROM ...
521
+ // We understand that pretty well.
470
522
SetExpr :: Select ( select) => {
471
- match & select. selection {
472
- Some ( selection) => {
473
- let sharding_keys = self . selection_parser ( selection) ;
523
+ // Collect all table names from the query.
524
+ let mut table_names = Vec :: new ( ) ;
474
525
475
- // TODO: Add support for prepared statements here.
476
- // This should just give us the position of the value in the `B` message.
526
+ for table in select. from . iter ( ) {
527
+ match & table. relation {
528
+ TableFactor :: Table { name, .. } => {
529
+ table_names. push ( name. 0 . clone ( ) ) ;
530
+ }
477
531
478
- let sharder = Sharder :: new (
479
- self . pool_settings . shards ,
480
- self . pool_settings . sharding_function ,
481
- ) ;
532
+ _ => ( ) ,
533
+ } ;
482
534
483
- for value in sharding_keys {
484
- let shard = sharder. shard ( value) ;
485
- shards. insert ( shard) ;
486
- }
535
+ // Get table names from all the joins.
536
+ for join in table. joins . iter ( ) {
537
+ match & join. relation {
538
+ TableFactor :: Table { name, .. } => {
539
+ table_names. push ( name. 0 . clone ( ) ) ;
540
+ }
541
+
542
+ _ => ( ) ,
543
+ } ;
544
+
545
+ // We can filter results based on join conditions, e.g.
546
+ // SELECT * FROM t INNER JOIN B ON B.sharding_key = 5;
547
+ match & join. join_operator {
548
+ JoinOperator :: Inner ( inner_join) => match & inner_join {
549
+ JoinConstraint :: On ( expr) => {
550
+ // Parse the selection criteria later.
551
+ exprs. push ( expr. clone ( ) ) ;
552
+ }
553
+
554
+ _ => ( ) ,
555
+ } ,
556
+
557
+ _ => ( ) ,
558
+ } ;
559
+ }
560
+ }
561
+
562
+ // Parse the actual "FROM ..."
563
+ match & select. selection {
564
+ Some ( selection) => {
565
+ exprs. push ( selection. clone ( ) ) ;
487
566
}
488
567
489
568
None => ( ) ,
490
569
} ;
570
+
571
+ // Look for sharding keys in either the join condition
572
+ // or the selection.
573
+ for expr in exprs. iter ( ) {
574
+ let sharding_keys = self . selection_parser ( expr, & table_names) ;
575
+
576
+ // TODO: Add support for prepared statements here.
577
+ // This should just give us the position of the value in the `B` message.
578
+
579
+ let sharder = Sharder :: new (
580
+ self . pool_settings . shards ,
581
+ self . pool_settings . sharding_function ,
582
+ ) ;
583
+
584
+ for value in sharding_keys {
585
+ let shard = sharder. shard ( value) ;
586
+ shards. insert ( shard) ;
587
+ }
588
+ }
491
589
}
492
590
_ => ( ) ,
493
591
} ;
@@ -825,7 +923,7 @@ mod test {
825
923
query_parser_enabled : true ,
826
924
primary_reads_enabled : false ,
827
925
sharding_function : ShardingFunction :: PgBigintHash ,
828
- automatic_sharding_key : Some ( String :: from ( "id" ) ) ,
926
+ automatic_sharding_key : Some ( String :: from ( "test. id" ) ) ,
829
927
healthcheck_delay : PoolSettings :: default ( ) . healthcheck_delay ,
830
928
healthcheck_timeout : PoolSettings :: default ( ) . healthcheck_timeout ,
831
929
ban_time : PoolSettings :: default ( ) . ban_time ,
@@ -854,11 +952,6 @@ mod test {
854
952
let q2 = simple_query ( "SET SERVER ROLE TO 'default'" ) ;
855
953
assert ! ( qr. try_execute_command( & q2) != None ) ;
856
954
assert_eq ! ( qr. active_role. unwrap( ) , pool_settings. default_role) ;
857
-
858
- // Here we go :)
859
- let q3 = simple_query ( "SELECT * FROM test WHERE id = 5 AND values IN (1, 2, 3)" ) ;
860
- assert ! ( qr. infer( & q3) ) ;
861
- assert_eq ! ( qr. shard( ) , 1 ) ;
862
955
}
863
956
864
957
#[ test]
@@ -891,7 +984,7 @@ mod test {
891
984
query_parser_enabled : true ,
892
985
primary_reads_enabled : false ,
893
986
sharding_function : ShardingFunction :: PgBigintHash ,
894
- automatic_sharding_key : Some ( String :: from ( "id" ) ) ,
987
+ automatic_sharding_key : None ,
895
988
healthcheck_delay : PoolSettings :: default ( ) . healthcheck_delay ,
896
989
healthcheck_timeout : PoolSettings :: default ( ) . healthcheck_timeout ,
897
990
ban_time : PoolSettings :: default ( ) . ban_time ,
@@ -920,4 +1013,56 @@ mod test {
920
1013
assert ! ( qr. try_execute_command( & q2) == None ) ;
921
1014
assert_eq ! ( qr. active_shard, Some ( 2 ) ) ;
922
1015
}
1016
+
1017
+ #[ test]
1018
+ fn test_automatic_sharding_key ( ) {
1019
+ QueryRouter :: setup ( ) ;
1020
+
1021
+ let mut qr = QueryRouter :: new ( ) ;
1022
+ qr. pool_settings . automatic_sharding_key = Some ( "data.id" . to_string ( ) ) ;
1023
+ qr. pool_settings . shards = 3 ;
1024
+
1025
+ assert ! ( qr. infer( & simple_query( "SELECT * FROM data WHERE id = 5" ) ) ) ;
1026
+ assert_eq ! ( qr. shard( ) , 2 ) ;
1027
+
1028
+ assert ! ( qr. infer( & simple_query(
1029
+ "SELECT one, two, three FROM public.data WHERE id = 6"
1030
+ ) ) ) ;
1031
+ assert_eq ! ( qr. shard( ) , 0 ) ;
1032
+
1033
+ assert ! ( qr. infer( & simple_query(
1034
+ "SELECT * FROM data
1035
+ INNER JOIN t2 ON data.id = 5
1036
+ AND t2.data_id = data.id
1037
+ WHERE data.id = 5"
1038
+ ) ) ) ;
1039
+ assert_eq ! ( qr. shard( ) , 2 ) ;
1040
+
1041
+ // Shard did not move because we couldn't determine the sharding key since it could be ambiguous
1042
+ // in the query.
1043
+ assert ! ( qr. infer( & simple_query(
1044
+ "SELECT * FROM t2 INNER JOIN data ON id = 6 AND data.id = t2.data_id"
1045
+ ) ) ) ;
1046
+ assert_eq ! ( qr. shard( ) , 2 ) ;
1047
+
1048
+ assert ! ( qr. infer( & simple_query(
1049
+ r#"SELECT * FROM "public"."data" WHERE "id" = 6"#
1050
+ ) ) ) ;
1051
+ assert_eq ! ( qr. shard( ) , 0 ) ;
1052
+
1053
+ assert ! ( qr. infer( & simple_query(
1054
+ r#"SELECT * FROM "public"."data" WHERE "data"."id" = 5"#
1055
+ ) ) ) ;
1056
+ assert_eq ! ( qr. shard( ) , 2 ) ;
1057
+
1058
+ // Super unique sharding key
1059
+ qr. pool_settings . automatic_sharding_key = Some ( "*.unique_enough_column_name" . to_string ( ) ) ;
1060
+ assert ! ( qr. infer( & simple_query(
1061
+ "SELECT * FROM table_x WHERE unique_enough_column_name = 6"
1062
+ ) ) ) ;
1063
+ assert_eq ! ( qr. shard( ) , 0 ) ;
1064
+
1065
+ assert ! ( qr. infer( & simple_query( "SELECT * FROM table_y WHERE another_key = 5" ) ) ) ;
1066
+ assert_eq ! ( qr. shard( ) , 0 ) ;
1067
+ }
923
1068
}
0 commit comments