Skip to content

Commit c73652c

Browse files
committed
Unban address in the background when ban_time exceeded
1 parent 622e210 commit c73652c

File tree

4 files changed

+124
-8
lines changed

4 files changed

+124
-8
lines changed

src/ban_service.rs

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,20 @@ use log::{debug, error, warn};
44
use parking_lot::RwLock;
55
use std::sync::Arc;
66

7-
use std::collections::HashMap;
7+
use std::collections::{HashMap, HashSet};
88

99
use crate::config::{Address, Role};
10+
use crate::pool::ServerPool;
11+
use bb8::Pool;
12+
1013
pub type BanList = Arc<RwLock<Vec<HashMap<Address, (BanReason, NaiveDateTime)>>>>;
1114
#[derive(Debug, Clone, Default)]
1215
pub struct BanService {
16+
/// A set of addresses that are being unbanned.
17+
/// This is used to prevent multiple unbanning tasks
18+
/// and to prevent interfiere with normal traffic, when unbanning due to ban_time expired.
19+
addresses_being_unbanned: Arc<RwLock<HashSet<Address>>>,
20+
1321
/// List of banned addresses (see above)
1422
/// that should not be queried.
1523
banlist: BanList,
@@ -40,9 +48,14 @@ pub enum UnbanReason {
4048
}
4149

4250
impl BanService {
43-
pub fn new(replica_to_primary_failover_enabled: bool, ban_time: i64) -> Self {
51+
pub fn new(
52+
number_of_shards: usize,
53+
replica_to_primary_failover_enabled: bool,
54+
ban_time: i64,
55+
) -> Self {
4456
BanService {
45-
banlist: Arc::new(RwLock::new(vec![HashMap::new()])),
57+
addresses_being_unbanned: Arc::new(RwLock::new(HashSet::new())),
58+
banlist: Arc::new(RwLock::new(vec![HashMap::new(); number_of_shards])),
4659
replica_to_primary_failover_enabled,
4760
ban_time,
4861
}
@@ -89,6 +102,52 @@ impl BanService {
89102
guard[address.shard].remove(address);
90103
}
91104

105+
/// Starts a tokio::task that will try to connect to the server and
106+
/// if successful, it will unban the server.
107+
pub fn schedule_unban(
108+
&self,
109+
address: &Address,
110+
server_pool: &Pool<ServerPool>,
111+
client_stats: ClientStats,
112+
) {
113+
// If the address is in the backlog, then it means that
114+
// a healthcheck is being performed and we should do nothing.
115+
if self.addresses_being_unbanned.read().contains(address) {
116+
return;
117+
} else {
118+
self.addresses_being_unbanned
119+
.write()
120+
.insert(address.clone());
121+
}
122+
123+
tokio::spawn({
124+
let server_pool = server_pool.clone();
125+
let address = address.clone();
126+
let banlist = self.banlist.clone();
127+
let addresses_being_unbanned = self.addresses_being_unbanned.clone();
128+
129+
async move {
130+
match server_pool.get().await {
131+
Ok(_) => {
132+
address.reset_error_count();
133+
warn!("Unbanning {:?}", address);
134+
let mut guard = banlist.write();
135+
guard[address.shard].remove(&address);
136+
}
137+
Err(err) => {
138+
error!(
139+
"Connection checkout error while trying to unban instance {:?}, error: {:?}",
140+
address, err
141+
);
142+
address.stats.error();
143+
client_stats.checkout_error();
144+
}
145+
};
146+
addresses_being_unbanned.write().remove(&address);
147+
}
148+
});
149+
}
150+
92151
/// Check if address is banned
93152
/// true if banned, false otherwise
94153
pub fn is_banned(&self, address: &Address) -> bool {

src/pool.rs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -489,11 +489,12 @@ impl ConnectionPool {
489489
pool_name, user.username
490490
);
491491
}
492-
492+
let number_of_shards = shards.len();
493493
let pool = ConnectionPool {
494494
databases: Arc::new(shards),
495495
addresses: Arc::new(addresses),
496496
ban_service: Arc::new(BanService::new(
497+
number_of_shards,
497498
pool_config.replica_to_primary_failover_enabled,
498499
config.general.ban_time,
499500
)),
@@ -748,8 +749,12 @@ impl ConnectionPool {
748749
force_healthcheck = true;
749750
}
750751
Some(UnbanReason::BanTimeExceeded) => {
751-
self.ban_service.unban(address);
752-
force_healthcheck = true;
752+
self.ban_service.schedule_unban(
753+
address,
754+
&self.databases[address.shard][address.address_index],
755+
client_stats.clone(),
756+
);
757+
continue;
753758
}
754759
Some(UnbanReason::PrimaryBanned) | Some(UnbanReason::NotBanned) => {
755760
force_healthcheck = true;

src/query_router.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1453,7 +1453,6 @@ mod test {
14531453
automatic_sharding_key: Some(String::from("test.id")),
14541454
healthcheck_delay: PoolSettings::default().healthcheck_delay,
14551455
healthcheck_timeout: PoolSettings::default().healthcheck_timeout,
1456-
ban_time: PoolSettings::default().ban_time,
14571456
sharding_key_regex: None,
14581457
shard_id_regex: None,
14591458
default_shard: crate::config::DefaultShard::Shard(0),
@@ -1532,7 +1531,6 @@ mod test {
15321531
automatic_sharding_key: None,
15331532
healthcheck_delay: PoolSettings::default().healthcheck_delay,
15341533
healthcheck_timeout: PoolSettings::default().healthcheck_timeout,
1535-
ban_time: PoolSettings::default().ban_time,
15361534
sharding_key_regex: Some(Regex::new(r"/\* sharding_key: (\d+) \*/").unwrap()),
15371535
shard_id_regex: Some(Regex::new(r"/\* shard_id: (\d+) \*/").unwrap()),
15381536
default_shard: crate::config::DefaultShard::Shard(0),

tests/ruby/load_balancing_spec.rb

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@
253253

254254
expect(failed_count).to(eq(number_of_replicas))
255255
end
256+
256257
context("when banned replicas are tested for availability because they expired the ban time") do
257258
let(:ban_time) { 2 }
258259
it "should be done in the background without interfering with traffic" do
@@ -266,6 +267,16 @@
266267
failed_count = 0
267268
number_of_replicas = processes[:replicas].length
268269

270+
# We need to allow pgcat to open connections to replicas
271+
(number_of_replicas + 10).times do |n|
272+
response = conn.async_exec(select_server_port)
273+
expect(response[0]["port"].to_i).not_to(eq(primary_port))
274+
rescue
275+
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
276+
failed_count += 1
277+
end
278+
expect(failed_count).to(eq(0))
279+
269280
# Take down all replicas
270281
processes[:replicas].each(&:take_down)
271282

@@ -295,6 +306,49 @@
295306
expect(response[0]["port"].to_i).to(eq(primary_port))
296307
expect(failed_count).to(eq(0))
297308
end
309+
310+
it "should unban replicas if they become available after the ban time" do
311+
select_server_port = "SELECT setting AS port FROM pg_settings WHERE name = 'port';"
312+
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
313+
failed_count = 0
314+
primary_port = processes.primary.original_port;
315+
316+
number_of_replicas = processes[:replicas].length
317+
318+
# We need to allow pgcat to open connections to replicas
319+
(number_of_replicas + 10).times do |n|
320+
response = conn.async_exec(select_server_port)
321+
expect(response[0]["port"].to_i).not_to(eq(primary_port))
322+
rescue
323+
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
324+
failed_count += 1
325+
end
326+
expect(failed_count).to(eq(0))
327+
328+
# Take down all replicas
329+
processes[:replicas].each(&:take_down)
330+
331+
(number_of_replicas).times do |n|
332+
conn.async_exec("SELECT 1 + 2")
333+
rescue
334+
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
335+
failed_count += 1
336+
end
337+
expect(failed_count).to(eq(number_of_replicas))
338+
failed_count = 0
339+
340+
processes[:replicas].each(&:reset)
341+
sleep(ban_time + 1)
342+
response = nil
343+
number_of_replicas.times do
344+
response = conn.async_exec("SELECT 1 + 2")
345+
rescue
346+
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
347+
failed_count += 1
348+
end
349+
expect(response[0]["port"].to_i).not_to(eq(primary_port))
350+
expect(failed_count).to(eq(0))
351+
end
298352
end
299353
end
300354
end

0 commit comments

Comments
 (0)