From c878c0ce8cb778a7e7b379cb021e08c8d7e8fef7 Mon Sep 17 00:00:00 2001 From: Andrew Jackson Date: Fri, 6 Sep 2024 14:00:13 -0500 Subject: [PATCH 1/4] add client connection limits --- src/auth_passthrough.rs | 1 + src/client.rs | 39 ++++++++++++- src/config.rs | 5 ++ src/stats/pool.rs | 4 ++ tests/python/conftest.py | 0 tests/python/test_client_limit.py | 95 +++++++++++++++++++++++++++++++ tests/python/test_pgcat.py | 1 - tests/python/utils.py | 28 ++++++++- 8 files changed, 170 insertions(+), 3 deletions(-) delete mode 100644 tests/python/conftest.py create mode 100644 tests/python/test_client_limit.py diff --git a/src/auth_passthrough.rs b/src/auth_passthrough.rs index 159847ed..5ef59d26 100644 --- a/src/auth_passthrough.rs +++ b/src/auth_passthrough.rs @@ -74,6 +74,7 @@ impl AuthPassthrough { password: Some(self.password.clone()), server_username: None, server_password: None, + max_clients: None, pool_size: 1, statement_timeout: 0, pool_mode: None, diff --git a/src/client.rs b/src/client.rs index 23392b73..19c0255a 100644 --- a/src/client.rs +++ b/src/client.rs @@ -18,9 +18,10 @@ use crate::config::{get_config, get_idle_client_in_transaction_timeout, Address, use crate::constants::*; use crate::messages::*; use crate::plugins::PluginOutput; -use crate::pool::{get_pool, ClientServerMap, ConnectionPool}; +use crate::pool::{get_pool, ClientServerMap, ConnectionPool, PoolIdentifier}; use crate::query_router::{Command, QueryRouter}; use crate::server::{Server, ServerParameters}; +use crate::stats::pool::PoolStats; use crate::stats::{ClientStats, ServerStats}; use crate::tls::Tls; @@ -570,6 +571,42 @@ where } }; + let config = get_config(); + if config.general.max_clients.is_some() { + let max_clients = config.general.max_clients.unwrap(); + let mut pgcat_client_connections = 0; + let all_pool_stats = PoolStats::construct_pool_lookup(); + for (_, pool_stats) in all_pool_stats.into_iter() { + pgcat_client_connections += pool_stats.total_client_connection(); + } + if max_clients <= pgcat_client_connections { + error_response_terminal(&mut write, "pgcat client connection limit exceeded") + .await?; + return Err(Error::ClientGeneralError( + "Max Clients exceeded".into(), + client_identifier, + )); + } + }; + + if pool.settings.user.max_clients.is_some() { + let max_clients = pool.settings.user.max_clients.unwrap(); + // get pool stats + let pool_stats = PoolStats::construct_pool_lookup() + .get(&PoolIdentifier::new(pool_name, username)) + .cloned() + .unwrap(); + let current_pool_connections = pool_stats.total_client_connection(); + if max_clients <= current_pool_connections { + error_response_terminal(&mut write, "pool client connection limit exceeded") + .await?; + return Err(Error::ClientGeneralError( + "Max Clients exceeded".into(), + client_identifier, + )); + } + }; + // Obtain the hash to compare, we give preference to that written in cleartext in config // if there is nothing set in cleartext and auth passthrough (auth_query) is configured, we use the hash obtained // when the pool was created. If there is no hash there, we try to fetch it one more time. diff --git a/src/config.rs b/src/config.rs index c7aaf4c3..c7043b39 100644 --- a/src/config.rs +++ b/src/config.rs @@ -212,6 +212,7 @@ pub struct User { pub server_password: Option, pub pool_size: u32, pub min_pool_size: Option, + pub max_clients: Option, pub pool_mode: Option, pub server_lifetime: Option, #[serde(default)] // 0 @@ -229,6 +230,7 @@ impl Default for User { server_password: None, pool_size: 15, min_pool_size: None, + max_clients: None, statement_timeout: 0, pool_mode: None, server_lifetime: None, @@ -289,6 +291,8 @@ pub struct General { #[serde(default)] // False pub log_client_disconnections: bool, + pub max_clients: Option, + #[serde(default)] // False pub dns_cache_enabled: bool, @@ -437,6 +441,7 @@ impl Default for General { tcp_keepalives_count: Self::default_tcp_keepalives_count(), tcp_keepalives_interval: Self::default_tcp_keepalives_interval(), tcp_user_timeout: Self::default_tcp_user_timeout(), + max_clients: None, log_client_connections: false, log_client_disconnections: false, dns_cache_enabled: false, diff --git a/src/stats/pool.rs b/src/stats/pool.rs index b5c6ff5b..66a79579 100644 --- a/src/stats/pool.rs +++ b/src/stats/pool.rs @@ -111,6 +111,10 @@ impl PoolStats { ] } + pub fn total_client_connection(&self) -> u64 { + self.cl_idle + self.cl_active + self.cl_waiting + self.cl_cancel_req + } + pub fn generate_row(&self) -> Vec { vec![ self.identifier.db.clone(), diff --git a/tests/python/conftest.py b/tests/python/conftest.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/python/test_client_limit.py b/tests/python/test_client_limit.py new file mode 100644 index 00000000..2cb33a97 --- /dev/null +++ b/tests/python/test_client_limit.py @@ -0,0 +1,95 @@ +import psycopg2 +import pytest +import time +import utils + + +def setup_module(module) -> None: + pgcat_conf = """ + [general] + + host = "0.0.0.0" + port = 6433 + admin_username = "pgcat" + admin_password = "pgcat" + max_clients = 4 + + [pools.pgml.users.0] + username = "limited_user" + password = "limited_user" + server_username = "sharding_user" + pool_size = 10 + max_clients = 2 + min_pool_size = 1 + pool_mode = "transaction" + + [pools.pgml.users.1] + username = "unlimited_user" + password = "unlimited_user" + server_username = "sharding_user" + pool_size = 10 + min_pool_size = 1 + pool_mode = "transaction" + + [pools.pgml.shards.0] + servers = [ + ["127.0.0.1", 5432, "primary"] + ] + database = "some_db" + """ + module.pgcat_process = utils.pgcat_start_with_config(pgcat_conf) + time.sleep(2) + + +def teardown_module(module) -> None: + module.pgcat_process.terminate() + module.pgcat_process.wait() + time.sleep(2) + + +def test_pgcat_limit() -> None: + # Open 4 connections (2 for each user) + conns = [ + utils.connect_db_generic( + username=user, password=user, host='127.0.0.1', database='pgml', port=6433) + for user in ['unlimited_user', 'unlimited_user'] * 2] + + # Verify 5th connection does not work for both users + with pytest.raises(psycopg2.OperationalError): + utils.connect_db_generic( + username='limited_user', password='limited_user', host='127.0.0.1', database='pgml', port=6433) + + with pytest.raises(psycopg2.OperationalError): + utils.connect_db_generic( + username='unlimited_user', password='unlimited_user', host='127.0.0.1', database='pgml', port=6433) + + # Close 4th connection + (conn, curr) = conns.pop(-1) + utils.cleanup_conn(conn, curr) + + utils.connect_db_generic( + username='unlimited_user', password='unlimited_user', host='127.0.0.1', database='pgml', port=6433) + + +def test_user_limit() -> None: + # Open 2 connections for limited User + limited_conns = [ + utils.connect_db_generic( + username='limited_user', password='limited_user', host='127.0.0.1', database='pgml', port=6433) + for _ in range(2)] + + # Validate 3rd connection does not work + with pytest.raises(psycopg2.OperationalError): + utils.connect_db_generic( + username='limited_user', password='limited_user', host='127.0.0.1', database='pgml', port=6433) + + # Validate unlimited user can still open connection + utils.connect_db_generic( + username='unlimited_user', password='unlimited_user', host='127.0.0.1', database='pgml', port=6433) + + # Close 2nd Connection + (conn, curr) = limited_conns.pop(-1) + utils.cleanup_conn(conn, curr) + + utils.connect_db_generic( + username='limited_user', password='limited_user', host='127.0.0.1', database='pgml', port=6433) diff --git a/tests/python/test_pgcat.py b/tests/python/test_pgcat.py index dc2f11e5..d01d7374 100644 --- a/tests/python/test_pgcat.py +++ b/tests/python/test_pgcat.py @@ -1,4 +1,3 @@ -import os import signal import time diff --git a/tests/python/utils.py b/tests/python/utils.py index 5c49bce9..ad04e157 100644 --- a/tests/python/utils.py +++ b/tests/python/utils.py @@ -2,6 +2,8 @@ import os import psutil import signal +import subprocess +import tempfile import time import psycopg2 @@ -9,16 +11,39 @@ PGCAT_HOST = "127.0.0.1" PGCAT_PORT = "6432" + def pgcat_start(): pg_cat_send_signal(signal.SIGTERM) os.system("./target/debug/pgcat .circleci/pgcat.toml &") time.sleep(2) +def pgcat_start_with_config(config: str): + config_file = tempfile.NamedTemporaryFile(delete=False) + config_file.write(str.encode(config)) + config_file.close() + process = subprocess.Popen(["./target/debug/pgcat", config_file.name], shell=False) + time.sleep(2) + return process + + +def connect_db_generic( + username: str, password: str, host: str, port: int, + database: str, autocommit: bool = True) -> Tuple[psycopg2.extensions.connection, psycopg2.extensions.cursor]: + + conn = psycopg2.connect( + f"postgres://{username}:{password}@{host}:{port}/{database}?application_name=testing_pgcat", + connect_timeout=2, + ) + conn.autocommit = autocommit + cur = conn.cursor() + return (conn, cur) + + def pg_cat_send_signal(signal: signal.Signals): try: for proc in psutil.process_iter(["pid", "name"]): - if "pgcat" == proc.name(): + if proc.name() == "pgcat": os.kill(proc.pid, signal) except Exception as e: # The process can be gone when we send this signal @@ -28,6 +53,7 @@ def pg_cat_send_signal(signal: signal.Signals): # Returns 0 if pgcat process exists time.sleep(2) if not os.system('pgrep pgcat'): + breakpoint() raise Exception("pgcat not closed after SIGTERM") From d103556f97cbdf046653252fb7e9e7359ab6a1bc Mon Sep 17 00:00:00 2001 From: Andrew Jackson <46945903+AndrewJackson2020@users.noreply.github.com> Date: Tue, 10 Sep 2024 09:38:20 -0500 Subject: [PATCH 2/4] Update utils.py Removed breakpoint --- tests/python/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/python/utils.py b/tests/python/utils.py index 84fddc9f..83ad20ed 100644 --- a/tests/python/utils.py +++ b/tests/python/utils.py @@ -81,7 +81,6 @@ def pg_cat_send_signal(signal: signal.Signals): # Returns 0 if pgcat process exists time.sleep(2) if not os.system('pgrep pgcat'): - breakpoint() raise Exception("pgcat not closed after SIGTERM") From 2507deac20356355fd6450c8553fe1c0204c1e08 Mon Sep 17 00:00:00 2001 From: Andrew Jackson Date: Tue, 10 Sep 2024 10:25:04 -0500 Subject: [PATCH 3/4] rebase --- tests/python/test_auth.py | 22 +++++++++++++-------- tests/python/test_client_limit.py | 2 +- tests/python/test_pgcat.py | 16 +++++++-------- tests/python/utils.py | 33 ++++++++++++------------------- 4 files changed, 36 insertions(+), 37 deletions(-) diff --git a/tests/python/test_auth.py b/tests/python/test_auth.py index bd943429..649e077b 100644 --- a/tests/python/test_auth.py +++ b/tests/python/test_auth.py @@ -1,10 +1,11 @@ +import time import utils -import signal + class TestTrustAuth: @classmethod def setup_method(cls): - config= """ + config = """ [general] host = "0.0.0.0" port = 6432 @@ -26,11 +27,13 @@ def setup_method(cls): ] database = "shard0" """ - utils.pgcat_generic_start(config) + cls.process = utils.pgcat_generic_start_from_string(config) @classmethod - def teardown_method(self): - utils.pg_cat_send_signal(signal.SIGTERM) + def teardown_method(cls): + cls.process.kill() + cls.process.wait() + time.sleep(2) def test_admin_trust_auth(self): conn, cur = utils.connect_db_trust(admin=True) @@ -46,14 +49,17 @@ def test_normal_trust_auth(self): print(res) utils.cleanup_conn(conn, cur) + class TestMD5Auth: @classmethod def setup_method(cls): - utils.pgcat_start() + cls.process = utils.pgcat_start_from_file_path("./.circleci/pgcat.toml") @classmethod - def teardown_method(self): - utils.pg_cat_send_signal(signal.SIGTERM) + def teardown_method(cls): + cls.process.kill() + cls.process.wait() + time.sleep(2) def test_normal_db_access(self): conn, cur = utils.connect_db(autocommit=False) diff --git a/tests/python/test_client_limit.py b/tests/python/test_client_limit.py index 2cb33a97..2f26df0b 100644 --- a/tests/python/test_client_limit.py +++ b/tests/python/test_client_limit.py @@ -37,7 +37,7 @@ def setup_module(module) -> None: ] database = "some_db" """ - module.pgcat_process = utils.pgcat_start_with_config(pgcat_conf) + module.pgcat_process = utils.pgcat_generic_start_from_string(pgcat_conf) time.sleep(2) diff --git a/tests/python/test_pgcat.py b/tests/python/test_pgcat.py index 773715d4..14c27801 100644 --- a/tests/python/test_pgcat.py +++ b/tests/python/test_pgcat.py @@ -14,7 +14,7 @@ def test_shutdown_logic(): # NO ACTIVE QUERIES SIGINT HANDLING # Start pgcat - utils.pgcat_start() + utils.pgcat_start_background() # Create client connection and send query (not in transaction) conn, cur = utils.connect_db() @@ -43,7 +43,7 @@ def test_shutdown_logic(): # NO ACTIVE QUERIES ADMIN SHUTDOWN COMMAND # Start pgcat - utils.pgcat_start() + utils.pgcat_start_background() # Create client connection and begin transaction conn, cur = utils.connect_db() @@ -74,7 +74,7 @@ def test_shutdown_logic(): # HANDLE TRANSACTION WITH SIGINT # Start pgcat - utils.pgcat_start() + utils.pgcat_start_background() # Create client connection and begin transaction conn, cur = utils.connect_db() @@ -100,7 +100,7 @@ def test_shutdown_logic(): # HANDLE TRANSACTION WITH ADMIN SHUTDOWN COMMAND # Start pgcat - utils.pgcat_start() + utils.pgcat_start_background() # Create client connection and begin transaction conn, cur = utils.connect_db() @@ -129,7 +129,7 @@ def test_shutdown_logic(): # - - - - - - - - - - - - - - - - - - # NO NEW NON-ADMIN CONNECTIONS DURING SHUTDOWN # Start pgcat - utils.pgcat_start() + utils.pgcat_start_background() # Create client connection and begin transaction transaction_conn, transaction_cur = utils.connect_db() @@ -161,7 +161,7 @@ def test_shutdown_logic(): # - - - - - - - - - - - - - - - - - - # ALLOW NEW ADMIN CONNECTIONS DURING SHUTDOWN # Start pgcat - utils.pgcat_start() + utils.pgcat_start_background() # Create client connection and begin transaction transaction_conn, transaction_cur = utils.connect_db() @@ -186,7 +186,7 @@ def test_shutdown_logic(): # - - - - - - - - - - - - - - - - - - # ADMIN CONNECTIONS CONTINUING TO WORK AFTER SHUTDOWN # Start pgcat - utils.pgcat_start() + utils.pgcat_start_background() # Create client connection and begin transaction transaction_conn, transaction_cur = utils.connect_db() @@ -213,7 +213,7 @@ def test_shutdown_logic(): # HANDLE SHUTDOWN TIMEOUT WITH SIGINT # Start pgcat - utils.pgcat_start() + utils.pgcat_start_background() # Create client connection and begin transaction, which should prevent server shutdown unless shutdown timeout is reached conn, cur = utils.connect_db() diff --git a/tests/python/utils.py b/tests/python/utils.py index 83ad20ed..c27b912d 100644 --- a/tests/python/utils.py +++ b/tests/python/utils.py @@ -1,10 +1,10 @@ +from typing import Tuple import os +import pathlib import signal import subprocess import tempfile import time -from typing import Tuple -import tempfile import psutil import psycopg2 @@ -13,17 +13,9 @@ PGCAT_PORT = "6432" -def _pgcat_start(config_path: str): +def pgcat_start_from_file_path(config_path: str): pg_cat_send_signal(signal.SIGTERM) - os.system(f"./target/debug/pgcat {config_path} &") - time.sleep(2) - - -def pgcat_start_with_config(config: str): - config_file = tempfile.NamedTemporaryFile(delete=False) - config_file.write(str.encode(config)) - config_file.close() - process = subprocess.Popen(["./target/debug/pgcat", config_file.name], shell=False) + process = subprocess.Popen(["./target/debug/pgcat", config_path], shell=False) time.sleep(2) return process @@ -39,17 +31,18 @@ def connect_db_generic( conn.autocommit = autocommit cur = conn.cursor() return (conn, cur) - - -def pgcat_start(): - _pgcat_start(config_path='.circleci/pgcat.toml') -def pgcat_generic_start(config: str): +def pgcat_start_background(): + os.system("./target/debug/pgcat .circleci/pgcat.toml &") + time.sleep(2) + + +def pgcat_generic_start_from_string(config: str): tmp = tempfile.NamedTemporaryFile() - with open(tmp.name, 'w') as f: + with pathlib.Path(tmp.name).open("w") as f: f.write(config) - _pgcat_start(config_path=tmp.name) + return pgcat_start_from_file_path(config_path=tmp.name) def glauth_send_signal(signal: signal.Signals): @@ -71,7 +64,7 @@ def glauth_send_signal(signal: signal.Signals): def pg_cat_send_signal(signal: signal.Signals): try: for proc in psutil.process_iter(["pid", "name"]): - if proc.name() == "pgcat": + if "pgcat" == proc.name(): os.kill(proc.pid, signal) except Exception as e: # The process can be gone when we send this signal From 48d252b2248dec701a73430edfa95eec90151d14 Mon Sep 17 00:00:00 2001 From: Andrew Jackson Date: Tue, 10 Sep 2024 13:21:32 -0500 Subject: [PATCH 4/4] added admin user to limits --- src/client.rs | 60 +++++++++++++++++++++---------- src/config.rs | 2 ++ tests/python/test_client_limit.py | 24 ++++++++++++- 3 files changed, 66 insertions(+), 20 deletions(-) diff --git a/src/client.rs b/src/client.rs index 02712bca..3956c430 100644 --- a/src/client.rs +++ b/src/client.rs @@ -24,7 +24,7 @@ use crate::pool::{get_pool, ClientServerMap, ConnectionPool, PoolIdentifier}; use crate::query_router::{Command, QueryRouter}; use crate::server::{Server, ServerParameters}; use crate::stats::pool::PoolStats; -use crate::stats::{ClientStats, ServerStats}; +use crate::stats::{get_client_stats, ClientStats, ServerStats}; use crate::tls::Tls; use tokio_rustls::server::TlsStream; @@ -460,6 +460,24 @@ where let client_identifier = ClientIdentifier::new(application_name, username, pool_name); + let config = get_config(); + if config.general.max_clients.is_some() { + let max_clients = config.general.max_clients.unwrap(); + let mut pgcat_client_connections = 0; + let client_map = get_client_stats(); + for (_, _) in client_map.into_iter() { + pgcat_client_connections += 1; + } + if max_clients <= pgcat_client_connections { + error_response_terminal(&mut write, "pgcat client connection limit exceeded") + .await?; + return Err(Error::ClientGeneralError( + "Max Clients exceeded".into(), + client_identifier, + )); + } + }; + let admin = ["pgcat", "pgbouncer"] .iter() .filter(|db| *db == pool_name) @@ -489,6 +507,28 @@ where // Authenticate admin user. let (transaction_mode, mut server_parameters) = if admin { let config = get_config(); + + if config.general.admin_max_clients.is_some() { + let max_clients = config.general.admin_max_clients.unwrap(); + let mut pgcat_client_connections = 0; + let client_map = get_client_stats(); + for (_, client) in client_map.into_iter() { + if client.username() == config.general.admin_username + && ["pgcat", "pgbouncer"].contains(&client.pool_name().as_str()) + { + pgcat_client_connections += 1; + } + } + if max_clients <= pgcat_client_connections { + error_response_terminal(&mut write, "pgcat client connection limit exceeded") + .await?; + return Err(Error::ClientGeneralError( + "Max Clients exceeded".into(), + client_identifier, + )); + } + }; + // TODO: Add SASL support. // Perform MD5 authentication. match config.general.admin_auth_type { @@ -577,24 +617,6 @@ where } }; - let config = get_config(); - if config.general.max_clients.is_some() { - let max_clients = config.general.max_clients.unwrap(); - let mut pgcat_client_connections = 0; - let all_pool_stats = PoolStats::construct_pool_lookup(); - for (_, pool_stats) in all_pool_stats.into_iter() { - pgcat_client_connections += pool_stats.total_client_connection(); - } - if max_clients <= pgcat_client_connections { - error_response_terminal(&mut write, "pgcat client connection limit exceeded") - .await?; - return Err(Error::ClientGeneralError( - "Max Clients exceeded".into(), - client_identifier, - )); - } - }; - if pool.settings.user.max_clients.is_some() { let max_clients = pool.settings.user.max_clients.unwrap(); // get pool stats diff --git a/src/config.rs b/src/config.rs index 306e3660..b696c9cf 100644 --- a/src/config.rs +++ b/src/config.rs @@ -345,6 +345,7 @@ pub struct General { pub admin_username: String, pub admin_password: String, + pub admin_max_clients: Option, #[serde(default = "General::default_admin_auth_type")] pub admin_auth_type: AuthType, @@ -476,6 +477,7 @@ impl Default for General { verify_server_certificate: false, admin_username: String::from("admin"), admin_password: String::from("admin"), + admin_max_clients: None, admin_auth_type: AuthType::MD5, validate_config: true, auth_query: None, diff --git a/tests/python/test_client_limit.py b/tests/python/test_client_limit.py index 2f26df0b..ca4ce8b3 100644 --- a/tests/python/test_client_limit.py +++ b/tests/python/test_client_limit.py @@ -10,9 +10,11 @@ def setup_module(module) -> None: host = "0.0.0.0" port = 6433 + max_clients = 4 + admin_username = "pgcat" admin_password = "pgcat" - max_clients = 4 + admin_max_clients = 2 [pools.pgml.users.0] username = "limited_user" @@ -71,6 +73,26 @@ def test_pgcat_limit() -> None: username='unlimited_user', password='unlimited_user', host='127.0.0.1', database='pgml', port=6433) +def test_admin_user_limit(): + # Open 2 connections for limited User + limited_conns = [ + utils.connect_db_generic( + username='pgcat', password='pgcat', host='127.0.0.1', database='pgcat', port=6433) + for _ in range(2)] + + # Validate 3rd connection does not work + with pytest.raises(psycopg2.OperationalError): + utils.connect_db_generic( + username='pgcat', password='pgcat', host='127.0.0.1', database='pgcat', port=6433) + + # Close 2nd Connection + (conn, curr) = limited_conns.pop(-1) + utils.cleanup_conn(conn, curr) + + utils.connect_db_generic( + username='pgcat', password='pgcat', host='127.0.0.1', database='pgcat', port=6433) + + def test_user_limit() -> None: # Open 2 connections for limited User limited_conns = [