Skip to content

Commit 6a31e9c

Browse files
committed
Refactor ban/unban code to use BanService struct
1 parent c527db5 commit 6a31e9c

File tree

6 files changed

+253
-180
lines changed

6 files changed

+253
-180
lines changed

src/admin.rs

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use crate::pool::BanReason;
1+
use crate::ban_service::BanReason;
22
use crate::server::ServerParameters;
33
use crate::stats::pool::PoolStats;
44
use bytes::{Buf, BufMut, BytesMut};
@@ -357,7 +357,7 @@ where
357357
for server in 0..pool.servers(shard) {
358358
let address = pool.address(shard, server);
359359
let pool_state = pool.pool_state(shard, server);
360-
let banned = pool.is_banned(address);
360+
let banned = pool.ban_service.is_banned(address);
361361
let paused = pool.paused();
362362

363363
res.put(data_row(&vec![
@@ -440,8 +440,9 @@ where
440440

441441
for (id, pool) in get_all_pools().iter() {
442442
for address in pool.get_addresses_from_host(host) {
443-
if !pool.is_banned(&address) {
444-
pool.ban(&address, BanReason::AdminBan(duration_seconds), None);
443+
if !pool.ban_service.is_banned(&address) {
444+
pool.ban_service
445+
.ban(&address, BanReason::AdminBan(duration_seconds), None);
445446
res.put(data_row(&vec![
446447
id.db.clone(),
447448
id.user.clone(),
@@ -483,8 +484,8 @@ where
483484

484485
for (id, pool) in get_all_pools().iter() {
485486
for address in pool.get_addresses_from_host(host) {
486-
if pool.is_banned(&address) {
487-
pool.unban(&address);
487+
if pool.ban_service.is_banned(&address) {
488+
pool.ban_service.unban(&address);
488489
res.put(data_row(&vec![
489490
id.db.clone(),
490491
id.user.clone(),
@@ -530,10 +531,10 @@ where
530531
.as_secs() as i64;
531532

532533
for (id, pool) in get_all_pools().iter() {
533-
for (address, (ban_reason, ban_time)) in pool.get_bans().iter() {
534+
for (address, (ban_reason, ban_time)) in pool.ban_service.get_bans().iter() {
534535
let ban_duration = match ban_reason {
535536
BanReason::AdminBan(duration) => *duration,
536-
_ => pool.settings.ban_time,
537+
_ => pool.ban_service.ban_time,
537538
};
538539
let remaining = ban_duration - (now - ban_time.timestamp());
539540
if remaining <= 0 {

src/ban_service.rs

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
use crate::stats::ClientStats;
2+
use chrono::naive::NaiveDateTime;
3+
use log::{debug, error, warn};
4+
use parking_lot::RwLock;
5+
use std::sync::Arc;
6+
7+
use std::collections::HashMap;
8+
9+
use crate::config::{Address, Role};
10+
pub type BanList = Arc<RwLock<Vec<HashMap<Address, (BanReason, NaiveDateTime)>>>>;
11+
#[derive(Debug, Clone, Default)]
12+
pub struct BanService {
13+
/// List of banned addresses (see above)
14+
/// that should not be queried.
15+
banlist: BanList,
16+
17+
/// Whether or not we should use primary when replicas are unavailable
18+
pub replica_to_primary_failover_enabled: bool,
19+
20+
/// Ban time (in seconds)
21+
pub ban_time: i64,
22+
}
23+
24+
// Reasons for banning a server.
25+
#[derive(Debug, PartialEq, Clone)]
26+
pub enum BanReason {
27+
FailedHealthCheck,
28+
MessageSendFailed,
29+
MessageReceiveFailed,
30+
FailedCheckout,
31+
StatementTimeout,
32+
AdminBan(i64),
33+
}
34+
35+
pub enum UnbanReason {
36+
AllReplicasBanned,
37+
BanTimeExceeded,
38+
PrimaryBanned,
39+
NotBanned,
40+
}
41+
42+
impl BanService {
43+
pub fn new(replica_to_primary_failover_enabled: bool, ban_time: i64) -> Self {
44+
BanService {
45+
banlist: Arc::new(RwLock::new(vec![HashMap::new()])),
46+
replica_to_primary_failover_enabled,
47+
ban_time,
48+
}
49+
}
50+
51+
/// Ban an address (i.e. replica). It no longer will serve
52+
/// traffic for any new transactions. Existing transactions on that replica
53+
/// will finish successfully or error out to the clients.
54+
pub fn ban(&self, address: &Address, reason: BanReason, client_info: Option<&ClientStats>) {
55+
// Count the number of errors since the last successful checkout
56+
// This is used to determine if the shard is down
57+
match reason {
58+
BanReason::FailedHealthCheck
59+
| BanReason::FailedCheckout
60+
| BanReason::MessageSendFailed
61+
| BanReason::MessageReceiveFailed => {
62+
address.increment_error_count();
63+
}
64+
_ => (),
65+
};
66+
67+
// Primary can never be banned
68+
if address.role == Role::Primary {
69+
return;
70+
}
71+
72+
let now = chrono::offset::Utc::now().naive_utc();
73+
error!("Banning instance {:?}, reason: {:?}", address, reason);
74+
let mut guard = self.banlist.write();
75+
76+
if let Some(client_info) = client_info {
77+
client_info.ban_error();
78+
address.stats.error();
79+
}
80+
81+
guard[address.shard].insert(address.clone(), (reason, now));
82+
}
83+
84+
/// Clear the replica to receive traffic again. Takes effect immediately
85+
/// for all new transactions.
86+
pub fn unban(&self, address: &Address) {
87+
warn!("Unbanning {:?}", address);
88+
let mut guard = self.banlist.write();
89+
guard[address.shard].remove(address);
90+
}
91+
92+
/// Check if address is banned
93+
/// true if banned, false otherwise
94+
pub fn is_banned(&self, address: &Address) -> bool {
95+
let guard = self.banlist.read();
96+
97+
match guard[address.shard].get(address) {
98+
Some(_) => true,
99+
None => {
100+
debug!("{:?} is ok", address);
101+
false
102+
}
103+
}
104+
}
105+
106+
/// Returns a list of banned replicas
107+
pub fn get_bans(&self) -> Vec<(Address, (BanReason, NaiveDateTime))> {
108+
let mut bans: Vec<(Address, (BanReason, NaiveDateTime))> = Vec::new();
109+
let guard = self.banlist.read();
110+
for banlist in guard.iter() {
111+
for (address, (reason, timestamp)) in banlist.iter() {
112+
bans.push((address.clone(), (reason.clone(), *timestamp)));
113+
}
114+
}
115+
bans
116+
}
117+
118+
/// Unban all replicas in the shard
119+
/// This is typically used when all replicas are banned and
120+
/// we don't allow sending traffic to primary.
121+
pub fn unban_all_replicas(&self, address: &Address) {
122+
let mut write_guard = self.banlist.write();
123+
warn!("Unbanning all replicas.");
124+
write_guard[address.shard].clear();
125+
}
126+
127+
/// Determines whether a replica should be unban and returns the reason
128+
/// why it should be unbanned.
129+
///
130+
/// UnbanReason:
131+
/// - All replicas are banned (AllReplicasBanned)
132+
/// - Ban time is exceeded (BanTimeExceeded)
133+
/// - Primary is banned (PrimaryBanned, this should never happen)
134+
/// - Not banned (NotBanned, the replica was unbanned while checking the conditions)
135+
///
136+
/// Returns:
137+
/// - Some(UnbanReason), if the replica should be unbanned
138+
/// - None, if the replica should not be unbanned
139+
pub fn should_unban(
140+
&self,
141+
pool_addresses: &[Vec<Address>],
142+
address: &Address,
143+
) -> Option<UnbanReason> {
144+
// If somehow primary ends up being banned we should return true here
145+
if address.role == Role::Primary {
146+
return Some(UnbanReason::PrimaryBanned);
147+
}
148+
149+
// If we have replica to primary failover we should not unban replicas
150+
// as we still have the primary to server traffic.
151+
if !self.replica_to_primary_failover_enabled {
152+
// Check if all replicas are banned, in that case unban all of them
153+
let replicas_available = pool_addresses[address.shard]
154+
.iter()
155+
.filter(|addr| addr.role == Role::Replica)
156+
.count();
157+
158+
debug!("Available targets: {}", replicas_available);
159+
160+
let read_guard = self.banlist.read();
161+
let all_replicas_banned = read_guard[address.shard].len() == replicas_available;
162+
drop(read_guard);
163+
164+
if all_replicas_banned {
165+
return Some(UnbanReason::AllReplicasBanned);
166+
}
167+
}
168+
169+
// Check if ban time is expired
170+
let read_guard = self.banlist.read();
171+
let exceeded_ban_time = match read_guard[address.shard].get(address) {
172+
Some((ban_reason, timestamp)) => {
173+
let now = chrono::offset::Utc::now().naive_utc();
174+
match ban_reason {
175+
BanReason::AdminBan(duration) => {
176+
now.timestamp() - timestamp.timestamp() > *duration
177+
}
178+
_ => now.timestamp() - timestamp.timestamp() > self.ban_time,
179+
}
180+
}
181+
None => return Some(UnbanReason::NotBanned),
182+
};
183+
drop(read_guard);
184+
185+
if exceeded_ban_time {
186+
Some(UnbanReason::BanTimeExceeded)
187+
} else {
188+
debug!("{:?} is banned", address);
189+
None
190+
}
191+
}
192+
}

src/client.rs

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1+
use crate::ban_service::BanReason;
12
use crate::errors::{ClientIdentifier, Error};
2-
use crate::pool::BanReason;
33
/// Handle clients by pretending to be a PostgreSQL server.
44
use bytes::{Buf, BufMut, BytesMut};
55
use log::{debug, error, info, trace, warn};
@@ -1773,7 +1773,8 @@ where
17731773
// Don't ban for this.
17741774
Error::PreparedStatementError => (),
17751775
_ => {
1776-
pool.ban(address, BanReason::MessageSendFailed, Some(&self.stats));
1776+
pool.ban_service
1777+
.ban(address, BanReason::MessageSendFailed, Some(&self.stats));
17771778
}
17781779
};
17791780

@@ -2014,7 +2015,8 @@ where
20142015
match server.send(message).await {
20152016
Ok(_) => Ok(()),
20162017
Err(err) => {
2017-
pool.ban(address, BanReason::MessageSendFailed, Some(&self.stats));
2018+
pool.ban_service
2019+
.ban(address, BanReason::MessageSendFailed, Some(&self.stats));
20182020
Err(err)
20192021
}
20202022
}
@@ -2041,7 +2043,11 @@ where
20412043
Ok(result) => match result {
20422044
Ok(message) => Ok(message),
20432045
Err(err) => {
2044-
pool.ban(address, BanReason::MessageReceiveFailed, Some(client_stats));
2046+
pool.ban_service.ban(
2047+
address,
2048+
BanReason::MessageReceiveFailed,
2049+
Some(client_stats),
2050+
);
20452051
error_response_terminal(
20462052
&mut self.write,
20472053
&format!("error receiving data from server: {:?}", err),
@@ -2058,7 +2064,8 @@ where
20582064
)
20592065
.as_str(),
20602066
);
2061-
pool.ban(address, BanReason::StatementTimeout, Some(client_stats));
2067+
pool.ban_service
2068+
.ban(address, BanReason::StatementTimeout, Some(client_stats));
20622069
error_response_terminal(&mut self.write, "pool statement timeout").await?;
20632070
Err(Error::StatementTimeout)
20642071
}

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
pub mod admin;
22
pub mod auth_passthrough;
3+
pub mod ban_service;
34
pub mod client;
45
pub mod cmd_args;
56
pub mod config;

0 commit comments

Comments
 (0)