Skip to content

Add Client Connection Limits #802

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/auth_passthrough.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ impl AuthPassthrough {
password: Some(self.password.clone()),
server_username: None,
server_password: None,
max_clients: None,
pool_size: 1,
statement_timeout: 0,
pool_mode: None,
Expand Down
63 changes: 61 additions & 2 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ use crate::config::{
use crate::constants::*;
use crate::messages::*;
use crate::plugins::PluginOutput;
use crate::pool::{get_pool, ClientServerMap, ConnectionPool};
use crate::pool::{get_pool, ClientServerMap, ConnectionPool, PoolIdentifier};
use crate::query_router::{Command, QueryRouter};
use crate::server::{Server, ServerParameters};
use crate::stats::{ClientStats, ServerStats};
use crate::stats::pool::PoolStats;
use crate::stats::{get_client_stats, ClientStats, ServerStats};
use crate::tls::Tls;

use tokio_rustls::server::TlsStream;
Expand Down Expand Up @@ -459,6 +460,24 @@ where

let client_identifier = ClientIdentifier::new(application_name, username, pool_name);

let config = get_config();
if config.general.max_clients.is_some() {
let max_clients = config.general.max_clients.unwrap();
let mut pgcat_client_connections = 0;
let client_map = get_client_stats();
for (_, _) in client_map.into_iter() {
pgcat_client_connections += 1;
}
if max_clients <= pgcat_client_connections {
error_response_terminal(&mut write, "pgcat client connection limit exceeded")
.await?;
return Err(Error::ClientGeneralError(
"Max Clients exceeded".into(),
client_identifier,
));
}
};

let admin = ["pgcat", "pgbouncer"]
.iter()
.filter(|db| *db == pool_name)
Expand Down Expand Up @@ -488,6 +507,28 @@ where
// Authenticate admin user.
let (transaction_mode, mut server_parameters) = if admin {
let config = get_config();

if config.general.admin_max_clients.is_some() {
let max_clients = config.general.admin_max_clients.unwrap();
let mut pgcat_client_connections = 0;
let client_map = get_client_stats();
for (_, client) in client_map.into_iter() {
if client.username() == config.general.admin_username
&& ["pgcat", "pgbouncer"].contains(&client.pool_name().as_str())
{
pgcat_client_connections += 1;
}
}
if max_clients <= pgcat_client_connections {
error_response_terminal(&mut write, "pgcat client connection limit exceeded")
.await?;
return Err(Error::ClientGeneralError(
"Max Clients exceeded".into(),
client_identifier,
));
}
};

// TODO: Add SASL support.
// Perform MD5 authentication.
match config.general.admin_auth_type {
Expand Down Expand Up @@ -576,6 +617,24 @@ where
}
};

if pool.settings.user.max_clients.is_some() {
let max_clients = pool.settings.user.max_clients.unwrap();
// get pool stats
let pool_stats = PoolStats::construct_pool_lookup()
.get(&PoolIdentifier::new(pool_name, username))
.cloned()
.unwrap();
let current_pool_connections = pool_stats.total_client_connection();
if max_clients <= current_pool_connections {
error_response_terminal(&mut write, "pool client connection limit exceeded")
.await?;
return Err(Error::ClientGeneralError(
"Max Clients exceeded".into(),
client_identifier,
));
}
};

// Obtain the hash to compare, we give preference to that written in cleartext in config
// if there is nothing set in cleartext and auth passthrough (auth_query) is configured, we use the hash obtained
// when the pool was created. If there is no hash there, we try to fetch it one more time.
Expand Down
7 changes: 7 additions & 0 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ pub struct User {
pub server_password: Option<String>,
pub pool_size: u32,
pub min_pool_size: Option<u32>,
pub max_clients: Option<u64>,
pub pool_mode: Option<PoolMode>,
pub server_lifetime: Option<u64>,
#[serde(default)] // 0
Expand All @@ -233,6 +234,7 @@ impl Default for User {
server_password: None,
pool_size: 15,
min_pool_size: None,
max_clients: None,
statement_timeout: 0,
pool_mode: None,
server_lifetime: None,
Expand Down Expand Up @@ -297,6 +299,8 @@ pub struct General {
#[serde(default)] // False
pub log_client_disconnections: bool,

pub max_clients: Option<u64>,

#[serde(default)] // False
pub dns_cache_enabled: bool,

Expand Down Expand Up @@ -341,6 +345,7 @@ pub struct General {

pub admin_username: String,
pub admin_password: String,
pub admin_max_clients: Option<u64>,

#[serde(default = "General::default_admin_auth_type")]
pub admin_auth_type: AuthType,
Expand Down Expand Up @@ -452,6 +457,7 @@ impl Default for General {
tcp_keepalives_count: Self::default_tcp_keepalives_count(),
tcp_keepalives_interval: Self::default_tcp_keepalives_interval(),
tcp_user_timeout: Self::default_tcp_user_timeout(),
max_clients: None,
log_client_connections: false,
log_client_disconnections: false,
dns_cache_enabled: false,
Expand All @@ -471,6 +477,7 @@ impl Default for General {
verify_server_certificate: false,
admin_username: String::from("admin"),
admin_password: String::from("admin"),
admin_max_clients: None,
admin_auth_type: AuthType::MD5,
validate_config: true,
auth_query: None,
Expand Down
4 changes: 4 additions & 0 deletions src/stats/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ impl PoolStats {
]
}

pub fn total_client_connection(&self) -> u64 {
self.cl_idle + self.cl_active + self.cl_waiting + self.cl_cancel_req
}

pub fn generate_row(&self) -> Vec<String> {
vec![
self.identifier.db.clone(),
Expand Down
22 changes: 14 additions & 8 deletions tests/python/test_auth.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import time
import utils
import signal


class TestTrustAuth:
@classmethod
def setup_method(cls):
config= """
config = """
[general]
host = "0.0.0.0"
port = 6432
Expand All @@ -26,11 +27,13 @@ def setup_method(cls):
]
database = "shard0"
"""
utils.pgcat_generic_start(config)
cls.process = utils.pgcat_generic_start_from_string(config)

@classmethod
def teardown_method(self):
utils.pg_cat_send_signal(signal.SIGTERM)
def teardown_method(cls):
cls.process.kill()
cls.process.wait()
time.sleep(2)

def test_admin_trust_auth(self):
conn, cur = utils.connect_db_trust(admin=True)
Expand All @@ -46,14 +49,17 @@ def test_normal_trust_auth(self):
print(res)
utils.cleanup_conn(conn, cur)


class TestMD5Auth:
@classmethod
def setup_method(cls):
utils.pgcat_start()
cls.process = utils.pgcat_start_from_file_path("./.circleci/pgcat.toml")

@classmethod
def teardown_method(self):
utils.pg_cat_send_signal(signal.SIGTERM)
def teardown_method(cls):
cls.process.kill()
cls.process.wait()
time.sleep(2)

def test_normal_db_access(self):
conn, cur = utils.connect_db(autocommit=False)
Expand Down
117 changes: 117 additions & 0 deletions tests/python/test_client_limit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import psycopg2
import pytest
import time
import utils


def setup_module(module) -> None:
pgcat_conf = """
[general]

host = "0.0.0.0"
port = 6433
max_clients = 4

admin_username = "pgcat"
admin_password = "pgcat"
admin_max_clients = 2

[pools.pgml.users.0]
username = "limited_user"
password = "limited_user"
server_username = "sharding_user"
pool_size = 10
max_clients = 2
min_pool_size = 1
pool_mode = "transaction"

[pools.pgml.users.1]
username = "unlimited_user"
password = "unlimited_user"
server_username = "sharding_user"
pool_size = 10
min_pool_size = 1
pool_mode = "transaction"

[pools.pgml.shards.0]
servers = [
["127.0.0.1", 5432, "primary"]
]
database = "some_db"
"""
module.pgcat_process = utils.pgcat_generic_start_from_string(pgcat_conf)
time.sleep(2)


def teardown_module(module) -> None:
module.pgcat_process.terminate()
module.pgcat_process.wait()
time.sleep(2)


def test_pgcat_limit() -> None:
# Open 4 connections (2 for each user)
conns = [
utils.connect_db_generic(
username=user, password=user, host='127.0.0.1', database='pgml', port=6433)
for user in ['unlimited_user', 'unlimited_user'] * 2]

# Verify 5th connection does not work for both users
with pytest.raises(psycopg2.OperationalError):
utils.connect_db_generic(
username='limited_user', password='limited_user', host='127.0.0.1', database='pgml', port=6433)

with pytest.raises(psycopg2.OperationalError):
utils.connect_db_generic(
username='unlimited_user', password='unlimited_user', host='127.0.0.1', database='pgml', port=6433)

# Close 4th connection
(conn, curr) = conns.pop(-1)
utils.cleanup_conn(conn, curr)

utils.connect_db_generic(
username='unlimited_user', password='unlimited_user', host='127.0.0.1', database='pgml', port=6433)


def test_admin_user_limit():
# Open 2 connections for limited User
limited_conns = [
utils.connect_db_generic(
username='pgcat', password='pgcat', host='127.0.0.1', database='pgcat', port=6433)
for _ in range(2)]

# Validate 3rd connection does not work
with pytest.raises(psycopg2.OperationalError):
utils.connect_db_generic(
username='pgcat', password='pgcat', host='127.0.0.1', database='pgcat', port=6433)

# Close 2nd Connection
(conn, curr) = limited_conns.pop(-1)
utils.cleanup_conn(conn, curr)

utils.connect_db_generic(
username='pgcat', password='pgcat', host='127.0.0.1', database='pgcat', port=6433)


def test_user_limit() -> None:
# Open 2 connections for limited User
limited_conns = [
utils.connect_db_generic(
username='limited_user', password='limited_user', host='127.0.0.1', database='pgml', port=6433)
for _ in range(2)]

# Validate 3rd connection does not work
with pytest.raises(psycopg2.OperationalError):
utils.connect_db_generic(
username='limited_user', password='limited_user', host='127.0.0.1', database='pgml', port=6433)

# Validate unlimited user can still open connection
utils.connect_db_generic(
username='unlimited_user', password='unlimited_user', host='127.0.0.1', database='pgml', port=6433)

# Close 2nd Connection
(conn, curr) = limited_conns.pop(-1)
utils.cleanup_conn(conn, curr)

utils.connect_db_generic(
username='limited_user', password='limited_user', host='127.0.0.1', database='pgml', port=6433)
Loading