Skip to content

Commit 22c6f13

Browse files
committed
removed atomic round-robin
1 parent c1476d2 commit 22c6f13

File tree

2 files changed

+44
-33
lines changed

2 files changed

+44
-33
lines changed

src/client.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ impl Client {
153153
}
154154

155155
/// Client loop. We handle all messages between the client and the database here.
156-
pub async fn handle(&mut self, pool: ConnectionPool) -> Result<(), Error> {
156+
pub async fn handle(&mut self, mut pool: ConnectionPool) -> Result<(), Error> {
157157
// Special: cancelling existing running query
158158
if self.cancel_mode {
159159
let (process_id, secret_key, address, port) = {

src/pool.rs

Lines changed: 43 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,20 @@ use crate::server::Server;
99

1010
use std::collections::HashMap;
1111
use std::sync::{
12-
atomic::{AtomicUsize, Ordering},
12+
// atomic::{AtomicUsize, Ordering},
1313
Arc, Mutex,
1414
};
1515

1616
// Banlist: bad servers go in here.
1717
pub type BanList = Arc<Mutex<Vec<HashMap<Address, NaiveDateTime>>>>;
18-
pub type Counter = Arc<AtomicUsize>;
18+
// pub type Counter = Arc<AtomicUsize>;
1919
pub type ClientServerMap = Arc<Mutex<HashMap<(i32, i32), (i32, i32, String, String)>>>;
2020

2121
#[derive(Clone, Debug)]
2222
pub struct ConnectionPool {
2323
databases: Vec<Vec<Pool<ServerPool>>>,
2424
addresses: Vec<Vec<Address>>,
25-
round_robin: Counter,
25+
round_robin: usize,
2626
banlist: BanList,
2727
healthcheck_timeout: u64,
2828
ban_time: i64,
@@ -90,10 +90,13 @@ impl ConnectionPool {
9090
banlist.push(HashMap::new());
9191
}
9292

93+
assert_eq!(shards.len(), addresses.len());
94+
let address_len = addresses.len();
95+
9396
ConnectionPool {
9497
databases: shards,
9598
addresses: addresses,
96-
round_robin: Arc::new(AtomicUsize::new(0)),
99+
round_robin: rand::random::<usize>() % address_len, // Start at a random replica
97100
banlist: Arc::new(Mutex::new(banlist)),
98101
healthcheck_timeout: config.general.healthcheck_timeout,
99102
ban_time: config.general.ban_time,
@@ -103,7 +106,7 @@ impl ConnectionPool {
103106

104107
/// Get a connection from the pool.
105108
pub async fn get(
106-
&self,
109+
&mut self,
107110
shard: Option<usize>,
108111
role: Option<Role>,
109112
) -> Result<(PooledConnection<'_, ServerPool>, Address), Error> {
@@ -115,40 +118,48 @@ impl ConnectionPool {
115118
None => 0, // TODO: pick a shard at random
116119
};
117120

118-
let mut allowed_attempts = match role {
119-
// Primary-specific queries get one attempt, if the primary is down,
120-
// nothing we should do about it I think. It's dangerous to retry
121-
// write queries.
122-
Some(Role::Primary) => {
123-
// Make sure we have a primary in the pool configured.
124-
let primary_present = self.addresses[shard]
121+
let addresses = &self.addresses[shard];
122+
123+
// Make sure if a specific role is requested, it's available in the pool.
124+
match role {
125+
Some(role) => {
126+
let role_count = addresses
125127
.iter()
126128
.filter(|&db| db.role == Role::Primary)
127129
.count();
128130

129-
// TODO: return this error to the client, so people don't have to look in
130-
// the logs to figure out what happened.
131-
if primary_present == 0 {
132-
println!(">> Error: Primary requested but none are configured.");
131+
if role_count == 0 {
132+
println!(
133+
">> Error: Role '{:?}' requested, but none are configured.",
134+
role
135+
);
136+
133137
return Err(Error::AllServersDown);
134138
}
135-
136-
// Primary gets one attempt.
137-
1
138139
}
139140

141+
// Any role should be present.
142+
_ => (),
143+
};
144+
145+
let mut allowed_attempts = match role {
146+
// Primary-specific queries get one attempt, if the primary is down,
147+
// nothing we should do about it I think. It's dangerous to retry
148+
// write queries.
149+
Some(Role::Primary) => 1,
150+
140151
// Replicas get to try as many times as there are replicas
141152
// and connections in the pool.
142153
_ => self.databases[shard].len() * self.pool_size as usize,
143154
};
144155

145156
while allowed_attempts > 0 {
146-
// TODO: think about making this local, so multiple clients
147-
// don't compete for the same round-robin integer.
148-
// Especially since we're going to be skipping (see role selection below).
149-
let index =
150-
self.round_robin.fetch_add(1, Ordering::SeqCst) % self.databases[shard].len();
151-
let address = self.addresses[shard][index].clone();
157+
// Round-robin each client's queries.
158+
// If a client only sends one query and then disconnects, it doesn't matter
159+
// which replica it'll go to.
160+
self.round_robin += 1;
161+
let index = self.round_robin % addresses.len();
162+
let address = &addresses[index];
152163

153164
// Make sure you're getting a primary or a replica
154165
// as per request.
@@ -158,14 +169,14 @@ impl ConnectionPool {
158169
// we'll do our best to pick it, but if we only
159170
// have one server in the cluster, it's probably only a primary
160171
// (or only a replica), so the client will just get what we have.
161-
if address.role != role && self.addresses[shard].len() > 1 {
172+
if address.role != role && addresses.len() > 1 {
162173
continue;
163174
}
164175
}
165176
None => (),
166177
};
167178

168-
if self.is_banned(&address, shard, role) {
179+
if self.is_banned(address, shard, role) {
169180
continue;
170181
}
171182

@@ -177,13 +188,13 @@ impl ConnectionPool {
177188
Ok(conn) => conn,
178189
Err(err) => {
179190
println!(">> Banning replica {}, error: {:?}", index, err);
180-
self.ban(&address, shard);
191+
self.ban(address, shard);
181192
continue;
182193
}
183194
};
184195

185196
if !with_health_check {
186-
return Ok((conn, address));
197+
return Ok((conn, address.clone()));
187198
}
188199

189200
// // Check if this server is alive with a health check
@@ -197,7 +208,7 @@ impl ConnectionPool {
197208
{
198209
// Check if health check succeeded
199210
Ok(res) => match res {
200-
Ok(_) => return Ok((conn, address)),
211+
Ok(_) => return Ok((conn, address.clone())),
201212
Err(_) => {
202213
println!(
203214
">> Banning replica {} because of failed health check",
@@ -206,7 +217,7 @@ impl ConnectionPool {
206217
// Don't leave a bad connection in the pool.
207218
server.mark_bad();
208219

209-
self.ban(&address, shard);
220+
self.ban(address, shard);
210221
continue;
211222
}
212223
},
@@ -219,7 +230,7 @@ impl ConnectionPool {
219230
// Don't leave a bad connection in the pool.
220231
server.mark_bad();
221232

222-
self.ban(&address, shard);
233+
self.ban(address, shard);
223234
continue;
224235
}
225236
}

0 commit comments

Comments
 (0)