From ab7ac16974a39bc634acd99c6386369a945639f0 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Sat, 22 Apr 2023 07:40:21 -0700 Subject: [PATCH 1/3] reqs --- .circleci/run_tests.sh | 1 + src/client.rs | 52 ++++++++++++++++++++++++++------ src/pool.rs | 6 +++- src/server.rs | 28 ++++++++++++++--- tests/python/async_test.py | 57 +++++++++++++++++++++++++++++++++++ tests/python/requirements.txt | 11 ++++++- 6 files changed, 139 insertions(+), 16 deletions(-) create mode 100644 tests/python/async_test.py diff --git a/.circleci/run_tests.sh b/.circleci/run_tests.sh index 4ba497c3..e44f80f8 100644 --- a/.circleci/run_tests.sh +++ b/.circleci/run_tests.sh @@ -107,6 +107,7 @@ cd ../.. # pip3 install -r tests/python/requirements.txt python3 tests/python/tests.py || exit 1 +python3 tests/python/async_test.py start_pgcat "info" diff --git a/src/client.rs b/src/client.rs index 5098ec6f..60dc41d7 100644 --- a/src/client.rs +++ b/src/client.rs @@ -932,7 +932,7 @@ where } // Grab a server from the pool. - let connection = match pool + let mut connection = match pool .get(query_router.shard(), query_router.role(), &self.stats) .await { @@ -975,9 +975,8 @@ where } }; - let mut reference = connection.0; + let server = &mut *connection.0; let address = connection.1; - let server = &mut *reference; // Server is assigned to the client in case the client wants to // cancel a query later. @@ -1000,6 +999,7 @@ where // Set application_name. server.set_name(&self.application_name).await?; + server.switch_async(false); let mut initial_message = Some(message); @@ -1019,12 +1019,37 @@ where None => { trace!("Waiting for message inside transaction or in session mode"); - match tokio::time::timeout( - idle_client_timeout_duration, - read_message(&mut self.read), - ) - .await - { + let message = tokio::select! { + message = tokio::time::timeout( + idle_client_timeout_duration, + read_message(&mut self.read), + ) => message, + + server_message = server.recv() => { + debug!("Got async message"); + + let server_message = match server_message { + Ok(message) => message, + Err(err) => { + pool.ban(&address, BanReason::MessageReceiveFailed, Some(&self.stats)); + server.mark_bad(); + return Err(err); + } + }; + + match write_all_half(&mut self.write, &server_message).await { + Ok(_) => (), + Err(err) => { + server.mark_bad(); + return Err(err); + } + }; + + continue; + } + }; + + match message { Ok(Ok(message)) => message, Ok(Err(err)) => { // Client disconnected inside a transaction. @@ -1141,9 +1166,14 @@ where // Sync // Frontend (client) is asking for the query result now. - 'S' => { + 'S' | 'H' => { debug!("Sending query to server"); + if code == 'H' { + server.switch_async(true); + debug!("Client requested flush, going async"); + } + self.buffer.put(&message[..]); let first_message_code = (*self.buffer.get(0).unwrap_or(&0)) as char; @@ -1318,6 +1348,8 @@ where } }; + debug!("Wrote to client"); + if !server.is_data_available() { break; } diff --git a/src/pool.rs b/src/pool.rs index 7f8e41c0..abb123ed 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -777,6 +777,7 @@ impl ConnectionPool { self.databases.len() } + /// Retrieve all bans for all servers. pub fn get_bans(&self) -> Vec<(Address, (BanReason, NaiveDateTime))> { let mut bans: Vec<(Address, (BanReason, NaiveDateTime))> = Vec::new(); let guard = self.banlist.read(); @@ -788,7 +789,7 @@ impl ConnectionPool { return bans; } - /// Get the address from the host url + /// Get the address from the host url. pub fn get_addresses_from_host(&self, host: &str) -> Vec
{ let mut addresses = Vec::new(); for shard in 0..self.shards() { @@ -827,10 +828,13 @@ impl ConnectionPool { &self.addresses[shard][server] } + /// Get server settings retrieved at connection setup. pub fn server_info(&self) -> BytesMut { self.server_info.read().clone() } + /// Calculate how many used connections in the pool + /// for the given server. fn busy_connection_count(&self, address: &Address) -> u32 { let state = self.pool_state(address.shard, address.address_index); let idle = state.idle_connections; diff --git a/src/server.rs b/src/server.rs index 84bed6cc..8c1ab70e 100644 --- a/src/server.rs +++ b/src/server.rs @@ -38,6 +38,7 @@ pub struct Server { /// Our server response buffer. We buffer data before we give it to the client. buffer: BytesMut, + is_async: bool, /// Server information the server sent us over on startup. server_info: BytesMut, @@ -450,6 +451,7 @@ impl Server { read: BufReader::new(read), write, buffer: BytesMut::with_capacity(8196), + is_async: false, server_info, process_id, secret_key, @@ -537,6 +539,16 @@ impl Server { } } + /// Switch to async mode, flushing messages as soon + /// as we receive them without buffering or waiting for "ReadyForQuery". + pub fn switch_async(&mut self, on: bool) { + if on { + self.is_async = true; + } else { + self.is_async = false; + } + } + /// Receive data from the server in response to a client request. /// This method must be called multiple times while `self.is_data_available()` is true /// in order to receive all data the server has to offer. @@ -557,8 +569,6 @@ impl Server { let code = message.get_u8() as char; let _len = message.get_i32(); - trace!("Message: {}", code); - match code { // ReadyForQuery 'Z' => { @@ -632,7 +642,10 @@ impl Server { // DataRow 'D' => { // More data is available after this message, this is not the end of the reply. - self.data_available = true; + // If we're async, flush to client now. + if !self.is_async { + self.data_available = true; + } // Don't flush yet, the more we buffer, the faster this goes...up to a limit. if self.buffer.len() >= 8196 { @@ -645,7 +658,10 @@ impl Server { // CopyOutResponse: copy is starting from the server to the client. 'H' => { - self.data_available = true; + // If we're in async mode, flush now. + if !self.is_async { + self.data_available = true; + } break; } @@ -665,6 +681,10 @@ impl Server { // Keep buffering until ReadyForQuery shows up. _ => (), }; + + if self.is_async { + break; + } } let bytes = self.buffer.clone(); diff --git a/tests/python/async_test.py b/tests/python/async_test.py new file mode 100644 index 00000000..27038ffb --- /dev/null +++ b/tests/python/async_test.py @@ -0,0 +1,57 @@ +import psycopg2 +import asyncio +import asyncpg + + +def regular_main(): + # Connect to the PostgreSQL database + conn = psycopg2.connect( + host="localhost", + database="sharded_db", + user="sharding_user", + password="sharding_user", + port=6432, + ) + + # Open a cursor to perform database operations + cur = conn.cursor() + + # Execute a SQL query + cur.execute("SELECT 1") + + # Fetch the results + rows = cur.fetchall() + + # Print the results + for row in rows: + print(row[0]) + + # Close the cursor and the database connection + cur.close() + conn.close() + + +async def main(): + # Connect to the PostgreSQL database + conn = await asyncpg.connect( + host="localhost", + database="sharded_db", + user="sharding_user", + password="sharding_user", + port=6432, + ) + + # Execute a SQL query + for _ in range(25): + rows = await conn.fetch("SELECT 1") + + # Print the results + for row in rows: + print(row[0]) + + # Close the database connection + await conn.close() + + +regular_main() +asyncio.run(main()) diff --git a/tests/python/requirements.txt b/tests/python/requirements.txt index eebd9c90..71c94103 100644 --- a/tests/python/requirements.txt +++ b/tests/python/requirements.txt @@ -1,2 +1,11 @@ +asyncio==3.4.3 +asyncpg==0.27.0 +black==23.3.0 +click==8.1.3 +mypy-extensions==1.0.0 +packaging==23.1 +pathspec==0.11.1 +platformdirs==3.2.0 +psutil==5.9.1 psycopg2==2.9.3 -psutil==5.9.1 \ No newline at end of file +tomli==2.0.1 From 088f1a7dae1fd7472d11544087817e1228901850 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Sat, 22 Apr 2023 07:47:19 -0700 Subject: [PATCH 2/3] remove debug msg --- src/client.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/client.rs b/src/client.rs index 60dc41d7..45455015 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1348,8 +1348,6 @@ where } }; - debug!("Wrote to client"); - if !server.is_data_available() { break; } From fd3623ff139c9669fc9344fe5888e59ff0eebabe Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Sat, 22 Apr 2023 08:02:20 -0700 Subject: [PATCH 3/3] mm --- .circleci/run_tests.sh | 3 +++ src/server.rs | 2 ++ tests/python/async_test.py | 11 +++++++---- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/.circleci/run_tests.sh b/.circleci/run_tests.sh index e44f80f8..ccbb71ac 100644 --- a/.circleci/run_tests.sh +++ b/.circleci/run_tests.sh @@ -107,6 +107,9 @@ cd ../.. # pip3 install -r tests/python/requirements.txt python3 tests/python/tests.py || exit 1 + +start_pgcat "info" + python3 tests/python/async_test.py start_pgcat "info" diff --git a/src/server.rs b/src/server.rs index 8c1ab70e..be731fb5 100644 --- a/src/server.rs +++ b/src/server.rs @@ -569,6 +569,8 @@ impl Server { let code = message.get_u8() as char; let _len = message.get_i32(); + trace!("Message: {}", code); + match code { // ReadyForQuery 'Z' => { diff --git a/tests/python/async_test.py b/tests/python/async_test.py index 27038ffb..34a65e03 100644 --- a/tests/python/async_test.py +++ b/tests/python/async_test.py @@ -2,15 +2,18 @@ import asyncio import asyncpg +PGCAT_HOST = "127.0.0.1" +PGCAT_PORT = "6432" + def regular_main(): # Connect to the PostgreSQL database conn = psycopg2.connect( - host="localhost", + host=PGCAT_HOST, database="sharded_db", user="sharding_user", password="sharding_user", - port=6432, + port=PGCAT_PORT, ) # Open a cursor to perform database operations @@ -34,11 +37,11 @@ def regular_main(): async def main(): # Connect to the PostgreSQL database conn = await asyncpg.connect( - host="localhost", + host=PGCAT_HOST, database="sharded_db", user="sharding_user", password="sharding_user", - port=6432, + port=PGCAT_PORT, ) # Execute a SQL query