@@ -14,6 +14,7 @@ use crate::messages::BytesMutReader;
14
14
use crate :: pool:: PoolSettings ;
15
15
use crate :: sharding:: Sharder ;
16
16
17
+ use std:: cmp;
17
18
use std:: collections:: BTreeSet ;
18
19
use std:: io:: Cursor ;
19
20
@@ -114,7 +115,52 @@ impl QueryRouter {
114
115
115
116
let code = message_cursor. get_u8 ( ) as char ;
116
117
117
- // Only simple protocol supported for commands.
118
+ // Check for any sharding regex matches in any queries
119
+ match code as char {
120
+ // For Parse and Query messages peek to see if they specify a shard_id as a comment early in the statement
121
+ 'P' | 'Q' => {
122
+ if self . pool_settings . shard_id_regex . is_some ( )
123
+ || self . pool_settings . sharding_key_regex . is_some ( )
124
+ {
125
+ // Check only the first block of bytes configured by the pool settings
126
+ let len = message_cursor. get_i32 ( ) as usize ;
127
+ let seg = cmp:: min ( len - 5 , self . pool_settings . regex_search_limit ) ;
128
+ let initial_segment = String :: from_utf8_lossy ( & message_buffer[ 0 ..seg] ) ;
129
+
130
+ // Check for a shard_id included in the query
131
+ if let Some ( shard_id_regex) = & self . pool_settings . shard_id_regex {
132
+ let shard_id = shard_id_regex. captures ( & initial_segment) . and_then ( |cap| {
133
+ cap. get ( 1 ) . and_then ( |id| id. as_str ( ) . parse :: < usize > ( ) . ok ( ) )
134
+ } ) ;
135
+ if let Some ( shard_id) = shard_id {
136
+ debug ! ( "Setting shard to {:?}" , shard_id) ;
137
+ self . set_shard ( shard_id) ;
138
+ // Skip other command processing since a sharding command was found
139
+ return None ;
140
+ }
141
+ }
142
+
143
+ // Check for a sharding_key included in the query
144
+ if let Some ( sharding_key_regex) = & self . pool_settings . sharding_key_regex {
145
+ let sharding_key =
146
+ sharding_key_regex
147
+ . captures ( & initial_segment)
148
+ . and_then ( |cap| {
149
+ cap. get ( 1 ) . and_then ( |id| id. as_str ( ) . parse :: < i64 > ( ) . ok ( ) )
150
+ } ) ;
151
+ if let Some ( sharding_key) = sharding_key {
152
+ debug ! ( "Setting sharding_key to {:?}" , sharding_key) ;
153
+ self . set_sharding_key ( sharding_key) ;
154
+ // Skip other command processing since a sharding command was found
155
+ return None ;
156
+ }
157
+ }
158
+ }
159
+ }
160
+ _ => { }
161
+ }
162
+
163
+ // Only simple protocol supported for commands processed below
118
164
if code != 'Q' {
119
165
return None ;
120
166
}
@@ -192,13 +238,11 @@ impl QueryRouter {
192
238
193
239
match command {
194
240
Command :: SetShardingKey => {
195
- let sharder = Sharder :: new (
196
- self . pool_settings . shards ,
197
- self . pool_settings . sharding_function ,
198
- ) ;
199
- let shard = sharder. shard ( value. parse :: < i64 > ( ) . unwrap ( ) ) ;
200
- self . active_shard = Some ( shard) ;
201
- value = shard. to_string ( ) ;
241
+ // TODO: some error handling here
242
+ value = self
243
+ . set_sharding_key ( value. parse :: < i64 > ( ) . unwrap ( ) )
244
+ . unwrap ( )
245
+ . to_string ( ) ;
202
246
}
203
247
204
248
Command :: SetShard => {
@@ -465,6 +509,16 @@ impl QueryRouter {
465
509
}
466
510
}
467
511
512
+ fn set_sharding_key ( & mut self , sharding_key : i64 ) -> Option < usize > {
513
+ let sharder = Sharder :: new (
514
+ self . pool_settings . shards ,
515
+ self . pool_settings . sharding_function ,
516
+ ) ;
517
+ let shard = sharder. shard ( sharding_key) ;
518
+ self . set_shard ( shard) ;
519
+ self . active_shard
520
+ }
521
+
468
522
/// Get the current desired server role we should be talking to.
469
523
pub fn role ( & self ) -> Option < Role > {
470
524
self . active_role
@@ -775,6 +829,9 @@ mod test {
775
829
healthcheck_delay : PoolSettings :: default ( ) . healthcheck_delay ,
776
830
healthcheck_timeout : PoolSettings :: default ( ) . healthcheck_timeout ,
777
831
ban_time : PoolSettings :: default ( ) . ban_time ,
832
+ sharding_key_regex : None ,
833
+ shard_id_regex : None ,
834
+ regex_search_limit : 1000 ,
778
835
} ;
779
836
let mut qr = QueryRouter :: new ( ) ;
780
837
assert_eq ! ( qr. active_role, None ) ;
@@ -820,4 +877,47 @@ mod test {
820
877
) ) ) ;
821
878
assert_eq ! ( qr. role( ) , Role :: Primary ) ;
822
879
}
880
+
881
+ #[ test]
882
+ fn test_regex_shard_parsing ( ) {
883
+ QueryRouter :: setup ( ) ;
884
+
885
+ let pool_settings = PoolSettings {
886
+ pool_mode : PoolMode :: Transaction ,
887
+ load_balancing_mode : crate :: config:: LoadBalancingMode :: Random ,
888
+ shards : 5 ,
889
+ user : crate :: config:: User :: default ( ) ,
890
+ default_role : Some ( Role :: Replica ) ,
891
+ query_parser_enabled : true ,
892
+ primary_reads_enabled : false ,
893
+ sharding_function : ShardingFunction :: PgBigintHash ,
894
+ automatic_sharding_key : Some ( String :: from ( "id" ) ) ,
895
+ healthcheck_delay : PoolSettings :: default ( ) . healthcheck_delay ,
896
+ healthcheck_timeout : PoolSettings :: default ( ) . healthcheck_timeout ,
897
+ ban_time : PoolSettings :: default ( ) . ban_time ,
898
+ sharding_key_regex : Some ( Regex :: new ( r"/\* sharding_key: (\d+) \*/" ) . unwrap ( ) ) ,
899
+ shard_id_regex : Some ( Regex :: new ( r"/\* shard_id: (\d+) \*/" ) . unwrap ( ) ) ,
900
+ regex_search_limit : 1000 ,
901
+ } ;
902
+ let mut qr = QueryRouter :: new ( ) ;
903
+ qr. update_pool_settings ( pool_settings. clone ( ) ) ;
904
+
905
+ // Shard should start out unset
906
+ assert_eq ! ( qr. active_shard, None ) ;
907
+
908
+ // Make sure setting it works
909
+ let q1 = simple_query ( "/* shard_id: 1 */ select 1 from foo;" ) ;
910
+ assert ! ( qr. try_execute_command( & q1) == None ) ;
911
+ assert_eq ! ( qr. active_shard, Some ( 1 ) ) ;
912
+
913
+ // And make sure changing it works
914
+ let q2 = simple_query ( "/* shard_id: 0 */ select 1 from foo;" ) ;
915
+ assert ! ( qr. try_execute_command( & q2) == None ) ;
916
+ assert_eq ! ( qr. active_shard, Some ( 0 ) ) ;
917
+
918
+ // Validate setting by shard with expected shard copied from sharding.rs tests
919
+ let q2 = simple_query ( "/* sharding_key: 6 */ select 1 from foo;" ) ;
920
+ assert ! ( qr. try_execute_command( & q2) == None ) ;
921
+ assert_eq ! ( qr. active_shard, Some ( 2 ) ) ;
922
+ }
823
923
}
0 commit comments