@@ -9,20 +9,20 @@ use crate::server::Server;
9
9
10
10
use std:: collections:: HashMap ;
11
11
use std:: sync:: {
12
- atomic:: { AtomicUsize , Ordering } ,
12
+ // atomic::{AtomicUsize, Ordering},
13
13
Arc , Mutex ,
14
14
} ;
15
15
16
16
// Banlist: bad servers go in here.
17
17
pub type BanList = Arc < Mutex < Vec < HashMap < Address , NaiveDateTime > > > > ;
18
- pub type Counter = Arc < AtomicUsize > ;
18
+ // pub type Counter = Arc<AtomicUsize>;
19
19
pub type ClientServerMap = Arc < Mutex < HashMap < ( i32 , i32 ) , ( i32 , i32 , String , String ) > > > ;
20
20
21
21
#[ derive( Clone , Debug ) ]
22
22
pub struct ConnectionPool {
23
23
databases : Vec < Vec < Pool < ServerPool > > > ,
24
24
addresses : Vec < Vec < Address > > ,
25
- round_robin : Counter ,
25
+ round_robin : usize ,
26
26
banlist : BanList ,
27
27
healthcheck_timeout : u64 ,
28
28
ban_time : i64 ,
@@ -90,10 +90,13 @@ impl ConnectionPool {
90
90
banlist. push ( HashMap :: new ( ) ) ;
91
91
}
92
92
93
+ assert_eq ! ( shards. len( ) , addresses. len( ) ) ;
94
+ let address_len = addresses. len ( ) ;
95
+
93
96
ConnectionPool {
94
97
databases : shards,
95
98
addresses : addresses,
96
- round_robin : Arc :: new ( AtomicUsize :: new ( 0 ) ) ,
99
+ round_robin : rand :: random :: < usize > ( ) % address_len , // Start at a random replica
97
100
banlist : Arc :: new ( Mutex :: new ( banlist) ) ,
98
101
healthcheck_timeout : config. general . healthcheck_timeout ,
99
102
ban_time : config. general . ban_time ,
@@ -103,7 +106,7 @@ impl ConnectionPool {
103
106
104
107
/// Get a connection from the pool.
105
108
pub async fn get (
106
- & self ,
109
+ & mut self ,
107
110
shard : Option < usize > ,
108
111
role : Option < Role > ,
109
112
) -> Result < ( PooledConnection < ' _ , ServerPool > , Address ) , Error > {
@@ -115,40 +118,48 @@ impl ConnectionPool {
115
118
None => 0 , // TODO: pick a shard at random
116
119
} ;
117
120
118
- let mut allowed_attempts = match role {
119
- // Primary-specific queries get one attempt, if the primary is down,
120
- // nothing we should do about it I think. It's dangerous to retry
121
- // write queries.
122
- Some ( Role :: Primary ) => {
123
- // Make sure we have a primary in the pool configured.
124
- let primary_present = self . addresses [ shard]
121
+ let addresses = & self . addresses [ shard] ;
122
+
123
+ // Make sure if a specific role is requested, it's available in the pool.
124
+ match role {
125
+ Some ( role) => {
126
+ let role_count = addresses
125
127
. iter ( )
126
128
. filter ( |& db| db. role == Role :: Primary )
127
129
. count ( ) ;
128
130
129
- // TODO: return this error to the client, so people don't have to look in
130
- // the logs to figure out what happened.
131
- if primary_present == 0 {
132
- println ! ( ">> Error: Primary requested but none are configured." ) ;
131
+ if role_count == 0 {
132
+ println ! (
133
+ ">> Error: Role '{:?}' requested, but none are configured." ,
134
+ role
135
+ ) ;
136
+
133
137
return Err ( Error :: AllServersDown ) ;
134
138
}
135
-
136
- // Primary gets one attempt.
137
- 1
138
139
}
139
140
141
+ // Any role should be present.
142
+ _ => ( ) ,
143
+ } ;
144
+
145
+ let mut allowed_attempts = match role {
146
+ // Primary-specific queries get one attempt, if the primary is down,
147
+ // nothing we should do about it I think. It's dangerous to retry
148
+ // write queries.
149
+ Some ( Role :: Primary ) => 1 ,
150
+
140
151
// Replicas get to try as many times as there are replicas
141
152
// and connections in the pool.
142
153
_ => self . databases [ shard] . len ( ) * self . pool_size as usize ,
143
154
} ;
144
155
145
156
while allowed_attempts > 0 {
146
- // TODO: think about making this local, so multiple clients
147
- // don't compete for the same round-robin integer.
148
- // Especially since we're going to be skipping (see role selection below) .
149
- let index =
150
- self . round_robin . fetch_add ( 1 , Ordering :: SeqCst ) % self . databases [ shard ] . len ( ) ;
151
- let address = self . addresses [ shard ] [ index] . clone ( ) ;
157
+ // Round-robin each client's queries.
158
+ // If a client only sends one query and then disconnects, it doesn't matter
159
+ // which replica it'll go to.
160
+ self . round_robin += 1 ;
161
+ let index = self . round_robin % addresses . len ( ) ;
162
+ let address = & addresses[ index] ;
152
163
153
164
// Make sure you're getting a primary or a replica
154
165
// as per request.
@@ -158,14 +169,14 @@ impl ConnectionPool {
158
169
// we'll do our best to pick it, but if we only
159
170
// have one server in the cluster, it's probably only a primary
160
171
// (or only a replica), so the client will just get what we have.
161
- if address. role != role && self . addresses [ shard ] . len ( ) > 1 {
172
+ if address. role != role && addresses. len ( ) > 1 {
162
173
continue ;
163
174
}
164
175
}
165
176
None => ( ) ,
166
177
} ;
167
178
168
- if self . is_banned ( & address, shard, role) {
179
+ if self . is_banned ( address, shard, role) {
169
180
continue ;
170
181
}
171
182
@@ -177,13 +188,13 @@ impl ConnectionPool {
177
188
Ok ( conn) => conn,
178
189
Err ( err) => {
179
190
println ! ( ">> Banning replica {}, error: {:?}" , index, err) ;
180
- self . ban ( & address, shard) ;
191
+ self . ban ( address, shard) ;
181
192
continue ;
182
193
}
183
194
} ;
184
195
185
196
if !with_health_check {
186
- return Ok ( ( conn, address) ) ;
197
+ return Ok ( ( conn, address. clone ( ) ) ) ;
187
198
}
188
199
189
200
// // Check if this server is alive with a health check
@@ -197,7 +208,7 @@ impl ConnectionPool {
197
208
{
198
209
// Check if health check succeeded
199
210
Ok ( res) => match res {
200
- Ok ( _) => return Ok ( ( conn, address) ) ,
211
+ Ok ( _) => return Ok ( ( conn, address. clone ( ) ) ) ,
201
212
Err ( _) => {
202
213
println ! (
203
214
">> Banning replica {} because of failed health check" ,
@@ -206,7 +217,7 @@ impl ConnectionPool {
206
217
// Don't leave a bad connection in the pool.
207
218
server. mark_bad ( ) ;
208
219
209
- self . ban ( & address, shard) ;
220
+ self . ban ( address, shard) ;
210
221
continue ;
211
222
}
212
223
} ,
@@ -219,7 +230,7 @@ impl ConnectionPool {
219
230
// Don't leave a bad connection in the pool.
220
231
server. mark_bad ( ) ;
221
232
222
- self . ban ( & address, shard) ;
233
+ self . ban ( address, shard) ;
223
234
continue ;
224
235
}
225
236
}
0 commit comments