Skip to content

Commit 2cc6a09

Browse files
authored
Add Manual host banning to PgCat (#340)
Sometimes we want an admin to be able to ban a host for some time to route traffic away from that host for reasons like partial outages, replication lag, and scheduled maintenance. We can achieve this today using a configuration update but a quicker approach is to send a control command to PgCat that bans the replica for some specified duration. This command does not change the current banning rules like Primaries cannot be banned When all replicas are banned, all replicas are unbanned
1 parent 8a0da10 commit 2cc6a09

File tree

5 files changed

+300
-13
lines changed

5 files changed

+300
-13
lines changed

src/admin.rs

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
use crate::config::Role;
2+
use crate::pool::BanReason;
13
/// Admin database.
24
use bytes::{Buf, BufMut, BytesMut};
35
use log::{error, info, trace};
46
use nix::sys::signal::{self, Signal};
57
use nix::unistd::Pid;
68
use std::collections::HashMap;
9+
use std::time::{SystemTime, UNIX_EPOCH};
710
use tokio::time::Instant;
811

912
use crate::config::{get_config, reload_config, VERSION};
@@ -53,6 +56,14 @@ where
5356
let query_parts: Vec<&str> = query.trim_end_matches(';').split_whitespace().collect();
5457

5558
match query_parts[0].to_ascii_uppercase().as_str() {
59+
"BAN" => {
60+
trace!("BAN");
61+
ban(stream, query_parts).await
62+
}
63+
"UNBAN" => {
64+
trace!("UNBAN");
65+
unban(stream, query_parts).await
66+
}
5667
"RELOAD" => {
5768
trace!("RELOAD");
5869
reload(stream, client_server_map).await
@@ -74,6 +85,10 @@ where
7485
shutdown(stream).await
7586
}
7687
"SHOW" => match query_parts[1].to_ascii_uppercase().as_str() {
88+
"BANS" => {
89+
trace!("SHOW BANS");
90+
show_bans(stream).await
91+
}
7792
"CONFIG" => {
7893
trace!("SHOW CONFIG");
7994
show_config(stream).await
@@ -350,6 +365,163 @@ where
350365
custom_protocol_response_ok(stream, "SET").await
351366
}
352367

368+
/// Bans a host from being used
369+
async fn ban<T>(stream: &mut T, tokens: Vec<&str>) -> Result<(), Error>
370+
where
371+
T: tokio::io::AsyncWrite + std::marker::Unpin,
372+
{
373+
let host = match tokens.get(1) {
374+
Some(host) => host,
375+
None => return error_response(stream, "usage: BAN hostname duration_seconds").await,
376+
};
377+
378+
let duration_seconds = match tokens.get(2) {
379+
Some(duration_seconds) => match duration_seconds.parse::<i64>() {
380+
Ok(duration_seconds) => duration_seconds,
381+
Err(_) => {
382+
return error_response(stream, "duration_seconds must be an integer").await;
383+
}
384+
},
385+
None => return error_response(stream, "usage: BAN hostname duration_seconds").await,
386+
};
387+
388+
if duration_seconds <= 0 {
389+
return error_response(stream, "duration_seconds must be >= 0").await;
390+
}
391+
392+
let columns = vec![
393+
("db", DataType::Text),
394+
("user", DataType::Text),
395+
("role", DataType::Text),
396+
("host", DataType::Text),
397+
];
398+
let mut res = BytesMut::new();
399+
res.put(row_description(&columns));
400+
401+
for (id, pool) in get_all_pools().iter() {
402+
for address in pool.get_addresses_from_host(host) {
403+
if !pool.is_banned(&address) {
404+
pool.ban(&address, BanReason::AdminBan(duration_seconds), -1);
405+
res.put(data_row(&vec![
406+
id.db.clone(),
407+
id.user.clone(),
408+
address.role.to_string(),
409+
address.host,
410+
]));
411+
}
412+
}
413+
}
414+
415+
res.put(command_complete("BAN"));
416+
417+
// ReadyForQuery
418+
res.put_u8(b'Z');
419+
res.put_i32(5);
420+
res.put_u8(b'I');
421+
422+
write_all_half(stream, &res).await
423+
}
424+
425+
/// Clear a host for use
426+
async fn unban<T>(stream: &mut T, tokens: Vec<&str>) -> Result<(), Error>
427+
where
428+
T: tokio::io::AsyncWrite + std::marker::Unpin,
429+
{
430+
let host = match tokens.get(1) {
431+
Some(host) => host,
432+
None => return error_response(stream, "UNBAN command requires a hostname to unban").await,
433+
};
434+
435+
let columns = vec![
436+
("db", DataType::Text),
437+
("user", DataType::Text),
438+
("role", DataType::Text),
439+
("host", DataType::Text),
440+
];
441+
let mut res = BytesMut::new();
442+
res.put(row_description(&columns));
443+
444+
for (id, pool) in get_all_pools().iter() {
445+
for address in pool.get_addresses_from_host(host) {
446+
if pool.is_banned(&address) {
447+
pool.unban(&address);
448+
res.put(data_row(&vec![
449+
id.db.clone(),
450+
id.user.clone(),
451+
address.role.to_string(),
452+
address.host,
453+
]));
454+
}
455+
}
456+
}
457+
458+
res.put(command_complete("UNBAN"));
459+
460+
// ReadyForQuery
461+
res.put_u8(b'Z');
462+
res.put_i32(5);
463+
res.put_u8(b'I');
464+
465+
write_all_half(stream, &res).await
466+
}
467+
468+
/// Shows all the bans
469+
async fn show_bans<T>(stream: &mut T) -> Result<(), Error>
470+
where
471+
T: tokio::io::AsyncWrite + std::marker::Unpin,
472+
{
473+
let columns = vec![
474+
("db", DataType::Text),
475+
("user", DataType::Text),
476+
("role", DataType::Text),
477+
("host", DataType::Text),
478+
("reason", DataType::Text),
479+
("ban_time", DataType::Text),
480+
("ban_duration_seconds", DataType::Text),
481+
("ban_remaining_seconds", DataType::Text),
482+
];
483+
let mut res = BytesMut::new();
484+
res.put(row_description(&columns));
485+
486+
// The block should be pretty quick so we cache the time outside
487+
let now = SystemTime::now()
488+
.duration_since(UNIX_EPOCH)
489+
.expect("Time went backwards")
490+
.as_secs() as i64;
491+
492+
for (id, pool) in get_all_pools().iter() {
493+
for (address, (ban_reason, ban_time)) in pool.get_bans().iter() {
494+
let ban_duration = match ban_reason {
495+
BanReason::AdminBan(duration) => *duration,
496+
_ => pool.settings.ban_time,
497+
};
498+
let remaining = ban_duration - (now - ban_time.timestamp());
499+
if remaining <= 0 {
500+
continue;
501+
}
502+
res.put(data_row(&vec![
503+
id.db.clone(),
504+
id.user.clone(),
505+
address.role.to_string(),
506+
address.host.clone(),
507+
format!("{:?}", ban_reason),
508+
ban_time.to_string(),
509+
ban_duration.to_string(),
510+
remaining.to_string(),
511+
]));
512+
}
513+
}
514+
515+
res.put(command_complete("SHOW BANS"));
516+
517+
// ReadyForQuery
518+
res.put_u8(b'Z');
519+
res.put_i32(5);
520+
res.put_u8(b'I');
521+
522+
write_all_half(stream, &res).await
523+
}
524+
353525
/// Reload the configuration file without restarting the process.
354526
async fn reload<T>(stream: &mut T, client_server_map: ClientServerMap) -> Result<(), Error>
355527
where

src/client.rs

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
use crate::errors::Error;
2+
use crate::pool::BanReason;
13
/// Handle clients by pretending to be a PostgreSQL server.
24
use bytes::{Buf, BufMut, BytesMut};
35
use log::{debug, error, info, trace, warn};
6+
47
use std::collections::HashMap;
58
use std::time::Instant;
69
use tokio::io::{split, AsyncReadExt, BufReader, ReadHalf, WriteHalf};
@@ -11,7 +14,7 @@ use tokio::sync::mpsc::Sender;
1114
use crate::admin::{generate_server_info_for_admin, handle_admin};
1215
use crate::config::{get_config, Address, PoolMode};
1316
use crate::constants::*;
14-
use crate::errors::Error;
17+
1518
use crate::messages::*;
1619
use crate::pool::{get_pool, ClientServerMap, ConnectionPool};
1720
use crate::query_router::{Command, QueryRouter};
@@ -1111,7 +1114,7 @@ where
11111114
match server.send(message).await {
11121115
Ok(_) => Ok(()),
11131116
Err(err) => {
1114-
pool.ban(address, self.process_id);
1117+
pool.ban(address, BanReason::MessageSendFailed, self.process_id);
11151118
Err(err)
11161119
}
11171120
}
@@ -1133,7 +1136,7 @@ where
11331136
Ok(result) => match result {
11341137
Ok(message) => Ok(message),
11351138
Err(err) => {
1136-
pool.ban(address, self.process_id);
1139+
pool.ban(address, BanReason::MessageReceiveFailed, self.process_id);
11371140
error_response_terminal(
11381141
&mut self.write,
11391142
&format!("error receiving data from server: {:?}", err),
@@ -1148,7 +1151,7 @@ where
11481151
address, pool.settings.user.username
11491152
);
11501153
server.mark_bad();
1151-
pool.ban(address, self.process_id);
1154+
pool.ban(address, BanReason::StatementTimeout, self.process_id);
11521155
error_response_terminal(&mut self.write, "pool statement timeout").await?;
11531156
Err(Error::StatementTimeout)
11541157
}
@@ -1157,7 +1160,7 @@ where
11571160
match server.recv().await {
11581161
Ok(message) => Ok(message),
11591162
Err(err) => {
1160-
pool.ban(address, self.process_id);
1163+
pool.ban(address, BanReason::MessageReceiveFailed, self.process_id);
11611164
error_response_terminal(
11621165
&mut self.write,
11631166
&format!("error receiving data from server: {:?}", err),

src/errors.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ pub enum Error {
66
SocketError(String),
77
ClientBadStartup,
88
ProtocolSyncError(String),
9+
BadQuery(String),
910
ServerError,
1011
BadConfig,
1112
AllServersDown,

src/pool.rs

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ pub type SecretKey = i32;
2929
pub type ServerHost = String;
3030
pub type ServerPort = u16;
3131

32-
pub type BanList = Arc<RwLock<Vec<HashMap<Address, NaiveDateTime>>>>;
32+
pub type BanList = Arc<RwLock<Vec<HashMap<Address, (BanReason, NaiveDateTime)>>>>;
3333
pub type ClientServerMap =
3434
Arc<Mutex<HashMap<(ProcessId, SecretKey), (ProcessId, SecretKey, ServerHost, ServerPort)>>>;
3535
pub type PoolMap = HashMap<PoolIdentifier, ConnectionPool>;
@@ -38,6 +38,17 @@ pub type PoolMap = HashMap<PoolIdentifier, ConnectionPool>;
3838
/// The pool is recreated dynamically when the config is reloaded.
3939
pub static POOLS: Lazy<ArcSwap<PoolMap>> = Lazy::new(|| ArcSwap::from_pointee(HashMap::default()));
4040

41+
// Reasons for banning a server.
42+
#[derive(Debug, PartialEq, Clone)]
43+
pub enum BanReason {
44+
FailedHealthCheck,
45+
MessageSendFailed,
46+
MessageReceiveFailed,
47+
FailedCheckout,
48+
StatementTimeout,
49+
AdminBan(i64),
50+
}
51+
4152
/// An identifier for a PgCat pool,
4253
/// a database visible to clients.
4354
#[derive(Hash, Debug, Clone, PartialEq, Eq)]
@@ -489,7 +500,7 @@ impl ConnectionPool {
489500
Ok(conn) => conn,
490501
Err(err) => {
491502
error!("Banning instance {:?}, error: {:?}", address, err);
492-
self.ban(address, client_process_id);
503+
self.ban(address, BanReason::FailedCheckout, client_process_id);
493504
self.stats
494505
.client_checkout_error(client_process_id, address.id);
495506
continue;
@@ -582,14 +593,14 @@ impl ConnectionPool {
582593
// Don't leave a bad connection in the pool.
583594
server.mark_bad();
584595

585-
self.ban(&address, client_process_id);
596+
self.ban(&address, BanReason::FailedHealthCheck, client_process_id);
586597
return false;
587598
}
588599

589600
/// Ban an address (i.e. replica). It no longer will serve
590601
/// traffic for any new transactions. Existing transactions on that replica
591602
/// will finish successfully or error out to the clients.
592-
pub fn ban(&self, address: &Address, client_id: i32) {
603+
pub fn ban(&self, address: &Address, reason: BanReason, client_id: i32) {
593604
// Primary can never be banned
594605
if address.role == Role::Primary {
595606
return;
@@ -599,12 +610,12 @@ impl ConnectionPool {
599610
let mut guard = self.banlist.write();
600611
error!("Banning {:?}", address);
601612
self.stats.client_ban_error(client_id, address.id);
602-
guard[address.shard].insert(address.clone(), now);
613+
guard[address.shard].insert(address.clone(), (reason, now));
603614
}
604615

605616
/// Clear the replica to receive traffic again. Takes effect immediately
606617
/// for all new transactions.
607-
pub fn _unban(&self, address: &Address) {
618+
pub fn unban(&self, address: &Address) {
608619
let mut guard = self.banlist.write();
609620
guard[address.shard].remove(address);
610621
}
@@ -653,9 +664,14 @@ impl ConnectionPool {
653664
// Check if ban time is expired
654665
let read_guard = self.banlist.read();
655666
let exceeded_ban_time = match read_guard[address.shard].get(address) {
656-
Some(timestamp) => {
667+
Some((ban_reason, timestamp)) => {
657668
let now = chrono::offset::Utc::now().naive_utc();
658-
now.timestamp() - timestamp.timestamp() > self.settings.ban_time
669+
match ban_reason {
670+
BanReason::AdminBan(duration) => {
671+
now.timestamp() - timestamp.timestamp() > *duration
672+
}
673+
_ => now.timestamp() - timestamp.timestamp() > self.settings.ban_time,
674+
}
659675
}
660676
None => return true,
661677
};
@@ -679,6 +695,31 @@ impl ConnectionPool {
679695
self.databases.len()
680696
}
681697

698+
pub fn get_bans(&self) -> Vec<(Address, (BanReason, NaiveDateTime))> {
699+
let mut bans: Vec<(Address, (BanReason, NaiveDateTime))> = Vec::new();
700+
let guard = self.banlist.read();
701+
for banlist in guard.iter() {
702+
for (address, (reason, timestamp)) in banlist.iter() {
703+
bans.push((address.clone(), (reason.clone(), timestamp.clone())));
704+
}
705+
}
706+
return bans;
707+
}
708+
709+
/// Get the address from the host url
710+
pub fn get_addresses_from_host(&self, host: &str) -> Vec<Address> {
711+
let mut addresses = Vec::new();
712+
for shard in 0..self.shards() {
713+
for server in 0..self.servers(shard) {
714+
let address = self.address(shard, server);
715+
if address.host == host {
716+
addresses.push(address.clone());
717+
}
718+
}
719+
}
720+
addresses
721+
}
722+
682723
/// Get the number of servers (primary and replicas)
683724
/// configured for a shard.
684725
pub fn servers(&self, shard: usize) -> usize {

0 commit comments

Comments
 (0)