diff --git a/src/auth_passthrough.rs b/src/auth_passthrough.rs index 53ef93d4..d0d86f06 100644 --- a/src/auth_passthrough.rs +++ b/src/auth_passthrough.rs @@ -76,6 +76,7 @@ impl AuthPassthrough { password: Some(self.password.clone()), server_username: None, server_password: None, + max_clients: None, pool_size: 1, statement_timeout: 0, pool_mode: None, diff --git a/src/client.rs b/src/client.rs index 405d72be..3956c430 100644 --- a/src/client.rs +++ b/src/client.rs @@ -20,10 +20,11 @@ use crate::config::{ use crate::constants::*; use crate::messages::*; use crate::plugins::PluginOutput; -use crate::pool::{get_pool, ClientServerMap, ConnectionPool}; +use crate::pool::{get_pool, ClientServerMap, ConnectionPool, PoolIdentifier}; use crate::query_router::{Command, QueryRouter}; use crate::server::{Server, ServerParameters}; -use crate::stats::{ClientStats, ServerStats}; +use crate::stats::pool::PoolStats; +use crate::stats::{get_client_stats, ClientStats, ServerStats}; use crate::tls::Tls; use tokio_rustls::server::TlsStream; @@ -459,6 +460,24 @@ where let client_identifier = ClientIdentifier::new(application_name, username, pool_name); + let config = get_config(); + if config.general.max_clients.is_some() { + let max_clients = config.general.max_clients.unwrap(); + let mut pgcat_client_connections = 0; + let client_map = get_client_stats(); + for (_, _) in client_map.into_iter() { + pgcat_client_connections += 1; + } + if max_clients <= pgcat_client_connections { + error_response_terminal(&mut write, "pgcat client connection limit exceeded") + .await?; + return Err(Error::ClientGeneralError( + "Max Clients exceeded".into(), + client_identifier, + )); + } + }; + let admin = ["pgcat", "pgbouncer"] .iter() .filter(|db| *db == pool_name) @@ -488,6 +507,28 @@ where // Authenticate admin user. let (transaction_mode, mut server_parameters) = if admin { let config = get_config(); + + if config.general.admin_max_clients.is_some() { + let max_clients = config.general.admin_max_clients.unwrap(); + let mut pgcat_client_connections = 0; + let client_map = get_client_stats(); + for (_, client) in client_map.into_iter() { + if client.username() == config.general.admin_username + && ["pgcat", "pgbouncer"].contains(&client.pool_name().as_str()) + { + pgcat_client_connections += 1; + } + } + if max_clients <= pgcat_client_connections { + error_response_terminal(&mut write, "pgcat client connection limit exceeded") + .await?; + return Err(Error::ClientGeneralError( + "Max Clients exceeded".into(), + client_identifier, + )); + } + }; + // TODO: Add SASL support. // Perform MD5 authentication. match config.general.admin_auth_type { @@ -576,6 +617,24 @@ where } }; + if pool.settings.user.max_clients.is_some() { + let max_clients = pool.settings.user.max_clients.unwrap(); + // get pool stats + let pool_stats = PoolStats::construct_pool_lookup() + .get(&PoolIdentifier::new(pool_name, username)) + .cloned() + .unwrap(); + let current_pool_connections = pool_stats.total_client_connection(); + if max_clients <= current_pool_connections { + error_response_terminal(&mut write, "pool client connection limit exceeded") + .await?; + return Err(Error::ClientGeneralError( + "Max Clients exceeded".into(), + client_identifier, + )); + } + }; + // Obtain the hash to compare, we give preference to that written in cleartext in config // if there is nothing set in cleartext and auth passthrough (auth_query) is configured, we use the hash obtained // when the pool was created. If there is no hash there, we try to fetch it one more time. diff --git a/src/config.rs b/src/config.rs index b0d98fb5..b696c9cf 100644 --- a/src/config.rs +++ b/src/config.rs @@ -215,6 +215,7 @@ pub struct User { pub server_password: Option, pub pool_size: u32, pub min_pool_size: Option, + pub max_clients: Option, pub pool_mode: Option, pub server_lifetime: Option, #[serde(default)] // 0 @@ -233,6 +234,7 @@ impl Default for User { server_password: None, pool_size: 15, min_pool_size: None, + max_clients: None, statement_timeout: 0, pool_mode: None, server_lifetime: None, @@ -297,6 +299,8 @@ pub struct General { #[serde(default)] // False pub log_client_disconnections: bool, + pub max_clients: Option, + #[serde(default)] // False pub dns_cache_enabled: bool, @@ -341,6 +345,7 @@ pub struct General { pub admin_username: String, pub admin_password: String, + pub admin_max_clients: Option, #[serde(default = "General::default_admin_auth_type")] pub admin_auth_type: AuthType, @@ -452,6 +457,7 @@ impl Default for General { tcp_keepalives_count: Self::default_tcp_keepalives_count(), tcp_keepalives_interval: Self::default_tcp_keepalives_interval(), tcp_user_timeout: Self::default_tcp_user_timeout(), + max_clients: None, log_client_connections: false, log_client_disconnections: false, dns_cache_enabled: false, @@ -471,6 +477,7 @@ impl Default for General { verify_server_certificate: false, admin_username: String::from("admin"), admin_password: String::from("admin"), + admin_max_clients: None, admin_auth_type: AuthType::MD5, validate_config: true, auth_query: None, diff --git a/src/stats/pool.rs b/src/stats/pool.rs index b5c6ff5b..66a79579 100644 --- a/src/stats/pool.rs +++ b/src/stats/pool.rs @@ -111,6 +111,10 @@ impl PoolStats { ] } + pub fn total_client_connection(&self) -> u64 { + self.cl_idle + self.cl_active + self.cl_waiting + self.cl_cancel_req + } + pub fn generate_row(&self) -> Vec { vec![ self.identifier.db.clone(), diff --git a/tests/python/test_auth.py b/tests/python/test_auth.py index bd943429..649e077b 100644 --- a/tests/python/test_auth.py +++ b/tests/python/test_auth.py @@ -1,10 +1,11 @@ +import time import utils -import signal + class TestTrustAuth: @classmethod def setup_method(cls): - config= """ + config = """ [general] host = "0.0.0.0" port = 6432 @@ -26,11 +27,13 @@ def setup_method(cls): ] database = "shard0" """ - utils.pgcat_generic_start(config) + cls.process = utils.pgcat_generic_start_from_string(config) @classmethod - def teardown_method(self): - utils.pg_cat_send_signal(signal.SIGTERM) + def teardown_method(cls): + cls.process.kill() + cls.process.wait() + time.sleep(2) def test_admin_trust_auth(self): conn, cur = utils.connect_db_trust(admin=True) @@ -46,14 +49,17 @@ def test_normal_trust_auth(self): print(res) utils.cleanup_conn(conn, cur) + class TestMD5Auth: @classmethod def setup_method(cls): - utils.pgcat_start() + cls.process = utils.pgcat_start_from_file_path("./.circleci/pgcat.toml") @classmethod - def teardown_method(self): - utils.pg_cat_send_signal(signal.SIGTERM) + def teardown_method(cls): + cls.process.kill() + cls.process.wait() + time.sleep(2) def test_normal_db_access(self): conn, cur = utils.connect_db(autocommit=False) diff --git a/tests/python/test_client_limit.py b/tests/python/test_client_limit.py new file mode 100644 index 00000000..ca4ce8b3 --- /dev/null +++ b/tests/python/test_client_limit.py @@ -0,0 +1,117 @@ +import psycopg2 +import pytest +import time +import utils + + +def setup_module(module) -> None: + pgcat_conf = """ + [general] + + host = "0.0.0.0" + port = 6433 + max_clients = 4 + + admin_username = "pgcat" + admin_password = "pgcat" + admin_max_clients = 2 + + [pools.pgml.users.0] + username = "limited_user" + password = "limited_user" + server_username = "sharding_user" + pool_size = 10 + max_clients = 2 + min_pool_size = 1 + pool_mode = "transaction" + + [pools.pgml.users.1] + username = "unlimited_user" + password = "unlimited_user" + server_username = "sharding_user" + pool_size = 10 + min_pool_size = 1 + pool_mode = "transaction" + + [pools.pgml.shards.0] + servers = [ + ["127.0.0.1", 5432, "primary"] + ] + database = "some_db" + """ + module.pgcat_process = utils.pgcat_generic_start_from_string(pgcat_conf) + time.sleep(2) + + +def teardown_module(module) -> None: + module.pgcat_process.terminate() + module.pgcat_process.wait() + time.sleep(2) + + +def test_pgcat_limit() -> None: + # Open 4 connections (2 for each user) + conns = [ + utils.connect_db_generic( + username=user, password=user, host='127.0.0.1', database='pgml', port=6433) + for user in ['unlimited_user', 'unlimited_user'] * 2] + + # Verify 5th connection does not work for both users + with pytest.raises(psycopg2.OperationalError): + utils.connect_db_generic( + username='limited_user', password='limited_user', host='127.0.0.1', database='pgml', port=6433) + + with pytest.raises(psycopg2.OperationalError): + utils.connect_db_generic( + username='unlimited_user', password='unlimited_user', host='127.0.0.1', database='pgml', port=6433) + + # Close 4th connection + (conn, curr) = conns.pop(-1) + utils.cleanup_conn(conn, curr) + + utils.connect_db_generic( + username='unlimited_user', password='unlimited_user', host='127.0.0.1', database='pgml', port=6433) + + +def test_admin_user_limit(): + # Open 2 connections for limited User + limited_conns = [ + utils.connect_db_generic( + username='pgcat', password='pgcat', host='127.0.0.1', database='pgcat', port=6433) + for _ in range(2)] + + # Validate 3rd connection does not work + with pytest.raises(psycopg2.OperationalError): + utils.connect_db_generic( + username='pgcat', password='pgcat', host='127.0.0.1', database='pgcat', port=6433) + + # Close 2nd Connection + (conn, curr) = limited_conns.pop(-1) + utils.cleanup_conn(conn, curr) + + utils.connect_db_generic( + username='pgcat', password='pgcat', host='127.0.0.1', database='pgcat', port=6433) + + +def test_user_limit() -> None: + # Open 2 connections for limited User + limited_conns = [ + utils.connect_db_generic( + username='limited_user', password='limited_user', host='127.0.0.1', database='pgml', port=6433) + for _ in range(2)] + + # Validate 3rd connection does not work + with pytest.raises(psycopg2.OperationalError): + utils.connect_db_generic( + username='limited_user', password='limited_user', host='127.0.0.1', database='pgml', port=6433) + + # Validate unlimited user can still open connection + utils.connect_db_generic( + username='unlimited_user', password='unlimited_user', host='127.0.0.1', database='pgml', port=6433) + + # Close 2nd Connection + (conn, curr) = limited_conns.pop(-1) + utils.cleanup_conn(conn, curr) + + utils.connect_db_generic( + username='limited_user', password='limited_user', host='127.0.0.1', database='pgml', port=6433) diff --git a/tests/python/test_pgcat.py b/tests/python/test_pgcat.py index 773715d4..14c27801 100644 --- a/tests/python/test_pgcat.py +++ b/tests/python/test_pgcat.py @@ -14,7 +14,7 @@ def test_shutdown_logic(): # NO ACTIVE QUERIES SIGINT HANDLING # Start pgcat - utils.pgcat_start() + utils.pgcat_start_background() # Create client connection and send query (not in transaction) conn, cur = utils.connect_db() @@ -43,7 +43,7 @@ def test_shutdown_logic(): # NO ACTIVE QUERIES ADMIN SHUTDOWN COMMAND # Start pgcat - utils.pgcat_start() + utils.pgcat_start_background() # Create client connection and begin transaction conn, cur = utils.connect_db() @@ -74,7 +74,7 @@ def test_shutdown_logic(): # HANDLE TRANSACTION WITH SIGINT # Start pgcat - utils.pgcat_start() + utils.pgcat_start_background() # Create client connection and begin transaction conn, cur = utils.connect_db() @@ -100,7 +100,7 @@ def test_shutdown_logic(): # HANDLE TRANSACTION WITH ADMIN SHUTDOWN COMMAND # Start pgcat - utils.pgcat_start() + utils.pgcat_start_background() # Create client connection and begin transaction conn, cur = utils.connect_db() @@ -129,7 +129,7 @@ def test_shutdown_logic(): # - - - - - - - - - - - - - - - - - - # NO NEW NON-ADMIN CONNECTIONS DURING SHUTDOWN # Start pgcat - utils.pgcat_start() + utils.pgcat_start_background() # Create client connection and begin transaction transaction_conn, transaction_cur = utils.connect_db() @@ -161,7 +161,7 @@ def test_shutdown_logic(): # - - - - - - - - - - - - - - - - - - # ALLOW NEW ADMIN CONNECTIONS DURING SHUTDOWN # Start pgcat - utils.pgcat_start() + utils.pgcat_start_background() # Create client connection and begin transaction transaction_conn, transaction_cur = utils.connect_db() @@ -186,7 +186,7 @@ def test_shutdown_logic(): # - - - - - - - - - - - - - - - - - - # ADMIN CONNECTIONS CONTINUING TO WORK AFTER SHUTDOWN # Start pgcat - utils.pgcat_start() + utils.pgcat_start_background() # Create client connection and begin transaction transaction_conn, transaction_cur = utils.connect_db() @@ -213,7 +213,7 @@ def test_shutdown_logic(): # HANDLE SHUTDOWN TIMEOUT WITH SIGINT # Start pgcat - utils.pgcat_start() + utils.pgcat_start_background() # Create client connection and begin transaction, which should prevent server shutdown unless shutdown timeout is reached conn, cur = utils.connect_db() diff --git a/tests/python/utils.py b/tests/python/utils.py index 9a1c6de9..c27b912d 100644 --- a/tests/python/utils.py +++ b/tests/python/utils.py @@ -1,8 +1,10 @@ +from typing import Tuple import os +import pathlib import signal -import time -from typing import Tuple +import subprocess import tempfile +import time import psutil import psycopg2 @@ -11,21 +13,36 @@ PGCAT_PORT = "6432" -def _pgcat_start(config_path: str): +def pgcat_start_from_file_path(config_path: str): pg_cat_send_signal(signal.SIGTERM) - os.system(f"./target/debug/pgcat {config_path} &") + process = subprocess.Popen(["./target/debug/pgcat", config_path], shell=False) time.sleep(2) + return process -def pgcat_start(): - _pgcat_start(config_path='.circleci/pgcat.toml') +def connect_db_generic( + username: str, password: str, host: str, port: int, + database: str, autocommit: bool = True) -> Tuple[psycopg2.extensions.connection, psycopg2.extensions.cursor]: + conn = psycopg2.connect( + f"postgres://{username}:{password}@{host}:{port}/{database}?application_name=testing_pgcat", + connect_timeout=2, + ) + conn.autocommit = autocommit + cur = conn.cursor() + return (conn, cur) -def pgcat_generic_start(config: str): + +def pgcat_start_background(): + os.system("./target/debug/pgcat .circleci/pgcat.toml &") + time.sleep(2) + + +def pgcat_generic_start_from_string(config: str): tmp = tempfile.NamedTemporaryFile() - with open(tmp.name, 'w') as f: + with pathlib.Path(tmp.name).open("w") as f: f.write(config) - _pgcat_start(config_path=tmp.name) + return pgcat_start_from_file_path(config_path=tmp.name) def glauth_send_signal(signal: signal.Signals): @@ -42,7 +59,7 @@ def glauth_send_signal(signal: signal.Signals): time.sleep(2) if not os.system('pgrep glauth'): raise Exception("glauth not closed after SIGTERM") - + def pg_cat_send_signal(signal: signal.Signals): try: