Skip to content

Commit 4c8a398

Browse files
authored
Refactor query routing into its own module (#22)
* Refactor query routing into its own module * commments; tests; dead code * error message * safer startup * hm * dont have to be public * wow * fix ci * ok * nl * no more silent errors
1 parent 7b0ceef commit 4c8a398

11 files changed

+364
-144
lines changed

src/client.rs

Lines changed: 28 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
/// We are pretending to the server in this scenario,
33
/// and this module implements that.
44
use bytes::{Buf, BufMut, BytesMut};
5-
use once_cell::sync::OnceCell;
6-
use regex::Regex;
75
use tokio::io::{AsyncReadExt, BufReader};
86
use tokio::net::{
97
tcp::{OwnedReadHalf, OwnedWriteHalf},
@@ -17,16 +15,10 @@ use crate::constants::*;
1715
use crate::errors::Error;
1816
use crate::messages::*;
1917
use crate::pool::{ClientServerMap, ConnectionPool};
18+
use crate::query_router::QueryRouter;
2019
use crate::server::Server;
21-
use crate::sharding::Sharder;
2220
use crate::stats::Reporter;
2321

24-
pub const SHARDING_REGEX: &str = r"SET SHARDING KEY TO '[0-9]+';";
25-
pub const ROLE_REGEX: &str = r"SET SERVER ROLE TO '(PRIMARY|REPLICA)';";
26-
27-
pub static SHARDING_REGEX_RE: OnceCell<Regex> = OnceCell::new();
28-
pub static ROLE_REGEX_RE: OnceCell<Regex> = OnceCell::new();
29-
3022
/// The client state. One of these is created per client.
3123
pub struct Client {
3224
// The reads are buffered (8K by default).
@@ -199,15 +191,11 @@ impl Client {
199191
return Ok(Server::cancel(&address, &port, process_id, secret_key).await?);
200192
}
201193

202-
// Active shard we're talking to.
203-
// The lifetime of this depends on the pool mode:
204-
// - if in session mode, this lives until the client disconnects,
205-
// - if in transaction mode, this lives for the duration of one transaction.
206-
let mut shard: Option<usize> = None;
207-
208-
// Active database role we want to talk to, e.g. primary or replica.
209-
let mut role: Option<Role> = self.default_server_role;
194+
let mut query_router = QueryRouter::new(self.default_server_role, pool.shards());
210195

196+
// Our custom protocol loop.
197+
// We expect the client to either start a transaction with regular queries
198+
// or issue commands for our sharding and server selection protocols.
211199
loop {
212200
// Read a complete message from the client, which normally would be
213201
// either a `Q` (query) or `P` (prepare, extended protocol).
@@ -218,32 +206,31 @@ impl Client {
218206

219207
// Parse for special select shard command.
220208
// SET SHARDING KEY TO 'bigint';
221-
match self.select_shard(message.clone(), pool.shards()) {
222-
Some(s) => {
223-
custom_protocol_response_ok(&mut self.write, "SET SHARDING KEY").await?;
224-
shard = Some(s);
225-
continue;
226-
}
227-
None => (),
228-
};
209+
if query_router.select_shard(message.clone()) {
210+
custom_protocol_response_ok(
211+
&mut self.write,
212+
&format!("SET SHARD TO {}", query_router.shard()),
213+
)
214+
.await?;
215+
continue;
216+
}
229217

230218
// Parse for special server role selection command.
231219
// SET SERVER ROLE TO '(primary|replica)';
232-
match self.select_role(message.clone()) {
233-
Some(r) => {
234-
custom_protocol_response_ok(&mut self.write, "SET SERVER ROLE").await?;
235-
role = Some(r);
236-
continue;
237-
}
238-
None => (),
239-
};
220+
if query_router.select_role(message.clone()) {
221+
custom_protocol_response_ok(&mut self.write, "SET SERVER ROLE").await?;
222+
continue;
223+
}
240224

241-
// Grab a server from the pool.
242-
let connection = match pool.get(shard, role).await {
225+
// Grab a server from the pool: the client issued a regular query.
226+
let connection = match pool.get(query_router.shard(), query_router.role()).await {
243227
Ok(conn) => conn,
244228
Err(err) => {
245229
println!(">> Could not get connection from pool: {:?}", err);
246-
return Err(err);
230+
error_response(&mut self.write, "could not get connection from the pool")
231+
.await?;
232+
query_router.reset();
233+
continue;
247234
}
248235
};
249236

@@ -264,11 +251,8 @@ impl Client {
264251
Err(err) => {
265252
// Client disconnected without warning.
266253
if server.in_transaction() {
267-
// TODO: this is what PgBouncer does
268-
// which leads to connection thrashing.
269-
//
270-
// I think we could issue a ROLLBACK here instead.
271-
// server.mark_bad();
254+
// Client left dirty server. Clean up and proceed
255+
// without thrashing this connection.
272256
server.query("ROLLBACK; DISCARD ALL;").await?;
273257
}
274258

@@ -328,8 +312,7 @@ impl Client {
328312
// Report this client as idle.
329313
self.stats.client_idle();
330314

331-
shard = None;
332-
role = self.default_server_role;
315+
query_router.reset();
333316

334317
break;
335318
}
@@ -414,8 +397,7 @@ impl Client {
414397
if self.transaction_mode {
415398
self.stats.client_idle();
416399

417-
shard = None;
418-
role = self.default_server_role;
400+
query_router.reset();
419401

420402
break;
421403
}
@@ -450,8 +432,7 @@ impl Client {
450432
self.stats.transaction();
451433

452434
if self.transaction_mode {
453-
shard = None;
454-
role = self.default_server_role;
435+
query_router.reset();
455436

456437
break;
457438
}
@@ -476,77 +457,4 @@ impl Client {
476457
let mut guard = self.client_server_map.lock().unwrap();
477458
guard.remove(&(self.process_id, self.secret_key));
478459
}
479-
480-
/// Determine if the query is part of our special syntax, extract
481-
/// the shard key, and return the shard to query based on Postgres'
482-
/// PARTITION BY HASH function.
483-
fn select_shard(&self, mut buf: BytesMut, shards: usize) -> Option<usize> {
484-
let code = buf.get_u8() as char;
485-
486-
// Only supporting simpe protocol here, so
487-
// one would have to execute something like this:
488-
// psql -c "SET SHARDING KEY TO '1234'"
489-
// after sanitizing the value manually, which can be just done with an
490-
// int parser, e.g. `let key = "1234".parse::<i64>().unwrap()`.
491-
match code {
492-
'Q' => (),
493-
_ => return None,
494-
};
495-
496-
let len = buf.get_i32();
497-
let query = String::from_utf8_lossy(&buf[..len as usize - 4 - 1]).to_ascii_uppercase(); // Don't read the ternminating null
498-
499-
let rgx = match SHARDING_REGEX_RE.get() {
500-
Some(r) => r,
501-
None => return None,
502-
};
503-
504-
if rgx.is_match(&query) {
505-
let shard = query.split("'").collect::<Vec<&str>>()[1];
506-
507-
match shard.parse::<i64>() {
508-
Ok(shard) => {
509-
let sharder = Sharder::new(shards);
510-
Some(sharder.pg_bigint_hash(shard))
511-
}
512-
513-
Err(_) => None,
514-
}
515-
} else {
516-
None
517-
}
518-
}
519-
520-
// Pick a primary or a replica from the pool.
521-
fn select_role(&self, mut buf: BytesMut) -> Option<Role> {
522-
let code = buf.get_u8() as char;
523-
524-
// Same story as select_shard() above.
525-
match code {
526-
'Q' => (),
527-
_ => return None,
528-
};
529-
530-
let len = buf.get_i32();
531-
let query = String::from_utf8_lossy(&buf[..len as usize - 4 - 1]).to_ascii_uppercase();
532-
533-
let rgx = match ROLE_REGEX_RE.get() {
534-
Some(r) => r,
535-
None => return None,
536-
};
537-
538-
// Copy / paste from above. If we get one more of these use cases,
539-
// it'll be time to abstract :).
540-
if rgx.is_match(&query) {
541-
let role = query.split("'").collect::<Vec<&str>>()[1];
542-
543-
match role {
544-
"PRIMARY" => Some(Role::Primary),
545-
"REPLICA" => Some(Role::Replica),
546-
_ => return None,
547-
}
548-
} else {
549-
None
550-
}
551-
}
552460
}

src/main.rs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ extern crate statsd;
2525
extern crate tokio;
2626
extern crate toml;
2727

28-
use regex::Regex;
2928
use tokio::net::TcpListener;
3029
use tokio::signal;
3130

@@ -39,6 +38,7 @@ mod constants;
3938
mod errors;
4039
mod messages;
4140
mod pool;
41+
mod query_router;
4242
mod server;
4343
mod sharding;
4444
mod stats;
@@ -54,12 +54,11 @@ use stats::{Collector, Reporter};
5454
async fn main() {
5555
println!("> Welcome to PgCat! Meow.");
5656

57-
client::SHARDING_REGEX_RE
58-
.set(Regex::new(client::SHARDING_REGEX).unwrap())
59-
.unwrap();
60-
client::ROLE_REGEX_RE
61-
.set(Regex::new(client::ROLE_REGEX).unwrap())
62-
.unwrap();
57+
// Prepare regexes
58+
if !query_router::QueryRouter::setup() {
59+
println!("> Could not setup query router.");
60+
return;
61+
}
6362

6463
let config = match config::parse("pgcat.toml").await {
6564
Ok(config) => config,

src/messages.rs

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,50 @@ pub async fn custom_protocol_response_ok(
185185
write_all_half(stream, res).await
186186
}
187187

188+
/// Send a custom error message to the client.
189+
/// Tell the client we are ready for the next query and no rollback is necessary.
190+
/// Docs on error codes: https://www.postgresql.org/docs/12/errcodes-appendix.html
191+
pub async fn error_response(stream: &mut OwnedWriteHalf, message: &str) -> Result<(), Error> {
192+
let mut error = BytesMut::new();
193+
194+
// Error level
195+
error.put_u8(b'S');
196+
error.put_slice(&b"FATAL\0"[..]);
197+
198+
// Error level (non-translatable)
199+
error.put_u8(b'V');
200+
error.put_slice(&b"FATAL\0"[..]);
201+
202+
// Error code: not sure how much this matters.
203+
error.put_u8(b'C');
204+
error.put_slice(&b"58000\0"[..]); // system_error, see Appendix A.
205+
206+
// The short error message.
207+
error.put_u8(b'M');
208+
error.put_slice(&format!("{}\0", message).as_bytes());
209+
210+
// No more fields follow.
211+
error.put_u8(0);
212+
213+
// Ready for query, no rollback needed (I = idle).
214+
let mut ready_for_query = BytesMut::new();
215+
216+
ready_for_query.put_u8(b'Z');
217+
ready_for_query.put_i32(5);
218+
ready_for_query.put_u8(b'I');
219+
220+
// Compose the two message reply.
221+
let mut res = BytesMut::with_capacity(error.len() + ready_for_query.len() + 5);
222+
223+
res.put_u8(b'E');
224+
res.put_i32(error.len() as i32 + 4);
225+
226+
res.put(error);
227+
res.put(ready_for_query);
228+
229+
Ok(write_all_half(stream, res).await?)
230+
}
231+
188232
/// Write all data in the buffer to the TcpStream.
189233
pub async fn write_all(stream: &mut TcpStream, buf: BytesMut) -> Result<(), Error> {
190234
match stream.write_all(&buf).await {

src/pool.rs

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ impl ConnectionPool {
121121

122122
for shard in 0..self.shards() {
123123
for _ in 0..self.replicas(shard) {
124-
let connection = match self.get(Some(shard), None).await {
124+
let connection = match self.get(shard, None).await {
125125
Ok(conn) => conn,
126126
Err(err) => {
127127
println!("> Shard {} down or misconfigured.", shard);
@@ -149,18 +149,13 @@ impl ConnectionPool {
149149
/// Get a connection from the pool.
150150
pub async fn get(
151151
&mut self,
152-
shard: Option<usize>,
152+
shard: usize,
153153
role: Option<Role>,
154154
) -> Result<(PooledConnection<'_, ServerPool>, Address), Error> {
155155
// Set this to false to gain ~3-4% speed.
156156
let with_health_check = true;
157157
let now = Instant::now();
158158

159-
let shard = match shard {
160-
Some(shard) => shard,
161-
None => 0, // TODO: pick a shard at random
162-
};
163-
164159
// We are waiting for a server now.
165160
self.stats.client_waiting();
166161

@@ -208,11 +203,8 @@ impl ConnectionPool {
208203
// as per request.
209204
match role {
210205
Some(role) => {
211-
// If the client wants a specific role,
212-
// we'll do our best to pick it, but if we only
213-
// have one server in the cluster, it's probably only a primary
214-
// (or only a replica), so the client will just get what we have.
215-
if address.role != role && addresses.len() > 1 {
206+
// Find the specific role the client wants in the pool.
207+
if address.role != role {
216208
continue;
217209
}
218210
}

0 commit comments

Comments
 (0)