Skip to content

Commit 526b9eb

Browse files
authored
Pass real server info to the client (#10)
1 parent ab8573c commit 526b9eb

File tree

5 files changed

+56
-9
lines changed

5 files changed

+56
-9
lines changed

src/client.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ impl Client {
6363
client_server_map: ClientServerMap,
6464
transaction_mode: bool,
6565
default_server_role: Option<Role>,
66+
server_info: BytesMut,
6667
) -> Result<Client, Error> {
6768
loop {
6869
// Could be StartupMessage or SSLRequest
@@ -102,7 +103,7 @@ impl Client {
102103
let secret_key: i32 = rand::random();
103104

104105
auth_ok(&mut stream).await?;
105-
server_parameters(&mut stream).await?;
106+
write_all(&mut stream, server_info).await?;
106107
backend_key_data(&mut stream, process_id, secret_key).await?;
107108
ready_for_query(&mut stream).await?;
108109

src/config.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,11 @@ pub async fn parse(path: &str) -> Result<Config, Error> {
9090
let mut dup_check = HashSet::new();
9191
let mut primary_count = 0;
9292

93+
if shard.1.servers.len() == 0 {
94+
println!("> Shard {} has no servers configured", shard.0);
95+
return Err(Error::BadConfig);
96+
}
97+
9398
for server in &shard.1.servers {
9499
dup_check.insert(server);
95100

src/main.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ async fn main() {
8686
);
8787
println!("> Connection timeout: {}ms", config.general.connect_timeout);
8888

89-
let pool = ConnectionPool::from_config(config.clone(), client_server_map.clone()).await;
89+
let mut pool = ConnectionPool::from_config(config.clone(), client_server_map.clone()).await;
9090
let transaction_mode = config.general.pool_mode == "transaction";
9191
let default_server_role = match config.query_router.default_role.as_ref() {
9292
"any" => None,
@@ -98,11 +98,20 @@ async fn main() {
9898
}
9999
};
100100

101+
let server_info = match pool.validate().await {
102+
Ok(info) => info,
103+
Err(err) => {
104+
println!("> Could not validate connection pool: {:?}", err);
105+
return;
106+
}
107+
};
108+
101109
println!("> Waiting for clients...");
102110

103111
loop {
104112
let pool = pool.clone();
105113
let client_server_map = client_server_map.clone();
114+
let server_info = server_info.clone();
106115

107116
let (socket, addr) = match listener.accept().await {
108117
Ok((socket, addr)) => (socket, addr),
@@ -124,6 +133,7 @@ async fn main() {
124133
client_server_map,
125134
transaction_mode,
126135
default_server_role,
136+
server_info,
127137
)
128138
.await
129139
{

src/messages.rs

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,8 @@ use crate::errors::Error;
88

99
// This is a funny one. `psql` parses this to figure out which
1010
// queries to send when using shortcuts, e.g. \d+.
11-
//
12-
// TODO: Actually get the version from the server itself.
13-
//
14-
const SERVER_VESION: &str = "12.9 (Ubuntu 12.9-0ubuntu0.20.04.1)";
11+
// No longer used. Keeping it here until I'm sure we don't need it again.
12+
const _SERVER_VESION: &str = "12.9 (Ubuntu 12.9-0ubuntu0.20.04.1)";
1513

1614
/// Tell the client that authentication handshake completed successfully.
1715
pub async fn auth_ok(stream: &mut TcpStream) -> Result<(), Error> {
@@ -27,12 +25,12 @@ pub async fn auth_ok(stream: &mut TcpStream) -> Result<(), Error> {
2725
/// Send server parameters to the client. This will tell the client
2826
/// what server version and what's the encoding we're using.
2927
//
30-
// TODO: Forward these from the server instead of hardcoding.
28+
// No longer used. Keeping it here until I'm sure we don't need it again.
3129
//
32-
pub async fn server_parameters(stream: &mut TcpStream) -> Result<(), Error> {
30+
pub async fn _server_parameters(stream: &mut TcpStream) -> Result<(), Error> {
3331
let client_encoding = BytesMut::from(&b"client_encoding\0UTF8\0"[..]);
3432
let server_version =
35-
BytesMut::from(&format!("server_version\0{}\0", SERVER_VESION).as_bytes()[..]);
33+
BytesMut::from(&format!("server_version\0{}\0", _SERVER_VESION).as_bytes()[..]);
3634

3735
// Client encoding
3836
let len = client_encoding.len() as i32 + 4; // TODO: add more parameters here

src/pool.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
/// Pooling and failover and banlist.
22
use async_trait::async_trait;
33
use bb8::{ManageConnection, Pool, PooledConnection};
4+
use bytes::BytesMut;
45
use chrono::naive::NaiveDateTime;
56

67
use crate::config::{Address, Config, Role, User};
@@ -105,6 +106,38 @@ impl ConnectionPool {
105106
}
106107
}
107108

109+
/// Connect to all shards and grab server information.
110+
/// Return server information we will pass to the clients
111+
/// when they connect.
112+
pub async fn validate(&mut self) -> Result<BytesMut, Error> {
113+
let mut server_infos = Vec::new();
114+
115+
for shard in 0..self.shards() {
116+
// TODO: query all primary and replicas in the shard configuration.
117+
let connection = match self.get(Some(shard), None).await {
118+
Ok(conn) => conn,
119+
Err(err) => {
120+
println!("> Shard {} down or misconfigured.", shard);
121+
return Err(err);
122+
}
123+
};
124+
125+
let mut proxy = connection.0;
126+
let _address = connection.1;
127+
let server = &mut *proxy;
128+
129+
server_infos.push(server.server_info());
130+
}
131+
132+
// TODO: compare server information to make sure
133+
// all shards are running identical configurations.
134+
if server_infos.len() == 0 {
135+
return Err(Error::AllServersDown);
136+
}
137+
138+
Ok(server_infos[0].clone())
139+
}
140+
108141
/// Get a connection from the pool.
109142
pub async fn get(
110143
&mut self,

0 commit comments

Comments
 (0)