Skip to content

(sync_db_pools) Postgres TLS #2701

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions contrib/sync_db_pools/lib/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ diesel_postgres_pool = ["diesel/postgres", "diesel/r2d2"]
diesel_mysql_pool = ["diesel/mysql", "diesel/r2d2"]
sqlite_pool = ["rusqlite", "r2d2_sqlite"]
postgres_pool = ["postgres", "r2d2_postgres"]
postgres_pool_tls = ["postgres_pool", "dep:postgres-native-tls", "dep:native-tls"]
memcache_pool = ["memcache", "r2d2-memcache"]

[dependencies]
Expand All @@ -27,6 +28,8 @@ diesel = { version = "2.0.0", default-features = false, optional = true }

postgres = { version = "0.19", optional = true }
r2d2_postgres = { version = "0.18", optional = true }
postgres-native-tls = { version = "0.5", optional = true }
native-tls = { version = "0.2", optional = true }

rusqlite = { version = "0.29.0", optional = true }
r2d2_sqlite = { version = "0.22.0", optional = true }
Expand Down
42 changes: 40 additions & 2 deletions contrib/sync_db_pools/lib/src/config.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::path::PathBuf;

use rocket::{Rocket, Build};
use rocket::figment::{self, Figment, providers::Serialized};

Expand All @@ -21,7 +23,8 @@ use serde::{Serialize, Deserialize};
/// Config {
/// url: "postgres://root:root@localhost/my_database".into(),
/// pool_size: 10,
/// timeout: 5
/// timeout: 5,
/// tls: None,
/// };
/// ```
///
Expand All @@ -39,6 +42,33 @@ pub struct Config {
/// Defaults to `5`.
// FIXME: Use `time`.
pub timeout: u8,
/// TLS configuration.
pub tls: Option<TlsConfig>,
}

#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct TlsConfig {
/// Allow TLS connections with invalid certificates.
///
/// _Default:_ `false`.
pub accept_invalid_certs: bool,
/// Allow TLS connections with invalid hostnames.
///
/// _Default:_ `false`.
pub accept_invalid_hostnames: bool,
/// Sets the name of a file containing SSL certificate authority (CA) certificate(s).
/// If the file exists, the server’s certificate will be verified to be signed by one of these authorities.
///
/// _Default:_ `None`.
pub ssl_root_cert: Option<PathBuf>,
/// Sets the name of a file containing SSL client certificate.
///
/// _Default:_ `None`.
pub ssl_client_cert: Option<PathBuf>,
/// Sets the name of a file containing SSL client key.
///
/// _Default:_ `None`.
pub ssl_client_key: Option<PathBuf>,
}

impl Config {
Expand Down Expand Up @@ -107,10 +137,18 @@ impl Config {
.map(|workers| workers * 4)
.ok();

let figment = Figment::from(rocket.figment())
let mut figment = Figment::from(rocket.figment())
.focus(&db_key)
.join(Serialized::default("timeout", 5));

if figment.find_value("tls").is_ok() {
figment = figment.join(Serialized::default("tls.accept_invalid_certs", false))
.join(Serialized::default("tls.accept_invalid_hostnames", false))
.join(Serialized::default("tls.ssl_root_cert", None::<PathBuf>))
.join(Serialized::default("tls.ssl_client_cert", None::<PathBuf>))
.join(Serialized::default("tls.ssl_client_key", None::<PathBuf>));
}

match default_pool_size {
Some(pool_size) => figment.join(Serialized::default("pool_size", pool_size)),
None => figment
Expand Down
2 changes: 2 additions & 0 deletions contrib/sync_db_pools/lib/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ impl<K: 'static, C: Poolable> ConnectionPool<K, C> {
Err(Error::Config(e)) => dberr!("config", db, "{}", e, rocket),
Err(Error::Pool(e)) => dberr!("pool init", db, "{}", e, rocket),
Err(Error::Custom(e)) => dberr!("pool manager", db, "{:?}", e, rocket),
Err(Error::Io(e)) => dberr!("io", db, "{:?}", e, rocket),
Err(Error::Tls(e)) => dberr!("tls", db, "{:?}", e, rocket),
}
}).await
})
Expand Down
10 changes: 10 additions & 0 deletions contrib/sync_db_pools/lib/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ pub enum Error<T> {
Pool(r2d2::Error),
/// An error occurred while extracting a `figment` configuration.
Config(figment::Error),
/// An IO error occurred.
Io(std::io::Error),
/// A TLS error occurred.
Tls(Box<dyn std::error::Error>),
}

impl<T> From<figment::Error> for Error<T> {
Expand All @@ -27,3 +31,9 @@ impl<T> From<r2d2::Error> for Error<T> {
Error::Pool(error)
}
}

impl<T> From<std::io::Error> for Error<T> {
fn from(error: std::io::Error) -> Self {
Error::Io(error)
}
}
267 changes: 264 additions & 3 deletions contrib/sync_db_pools/lib/src/poolable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,16 +184,277 @@ impl Poolable for diesel::MysqlConnection {
}
}

// TODO: Add a feature to enable TLS in `postgres`; parse a suitable `config`.
#[cfg(feature = "postgres_pool")]
pub mod pg {
use std::pin::Pin;
use std::task::{Context, Poll};
use std::io;

#[derive(Clone)]
pub enum MaybeTlsConnector {
NoTls(postgres::tls::NoTls),
#[cfg(feature = "postgres_pool_tls")]
Tls(postgres_native_tls::MakeTlsConnector)
}

impl postgres::tls::MakeTlsConnect<postgres::Socket> for MaybeTlsConnector {
type Stream = MaybeTlsConnector_Stream;
type TlsConnect = MaybeTlsConnector_TlsConnect;
type Error = MaybeTlsConnector_Error;

fn make_tls_connect(&mut self, domain: &str) -> Result<Self::TlsConnect, Self::Error> {
match self {
MaybeTlsConnector::NoTls(connector) => {
<postgres::tls::NoTls as postgres::tls::MakeTlsConnect<postgres::Socket>>
::make_tls_connect(connector, domain)
.map(Self::TlsConnect::NoTls)
.map_err(Self::Error::NoTls)
},
#[cfg(feature = "postgres_pool_tls")]
MaybeTlsConnector::Tls(connector) => {
<
postgres_native_tls::MakeTlsConnector as
postgres::tls::MakeTlsConnect<postgres::Socket>
>::make_tls_connect(connector, domain)
.map(Self::TlsConnect::Tls)
.map_err(Self::Error::Tls)
},
}
}
}

// --- Stream ---

#[allow(non_camel_case_types)]
pub enum MaybeTlsConnector_Stream {
NoTls(postgres::tls::NoTlsStream),
#[cfg(feature = "postgres_pool_tls")]
Tls(postgres_native_tls::TlsStream<postgres::Socket>)
}

impl postgres::tls::TlsStream for MaybeTlsConnector_Stream {
fn channel_binding(&self) -> postgres::tls::ChannelBinding {
match self {
MaybeTlsConnector_Stream::NoTls(stream) => stream.channel_binding(),
#[cfg(feature = "postgres_pool_tls")]
MaybeTlsConnector_Stream::Tls(stream) => stream.channel_binding(),
}
}
}

impl tokio::io::AsyncRead for MaybeTlsConnector_Stream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>
) -> Poll<Result<(), io::Error>> {
match *self {
MaybeTlsConnector_Stream::NoTls(ref mut stream) =>
Pin::new(stream).poll_read(cx, buf),
#[cfg(feature = "postgres_pool_tls")]
MaybeTlsConnector_Stream::Tls(ref mut stream) =>
Pin::new(stream).poll_read(cx, buf),
}
}
}

impl tokio::io::AsyncWrite for MaybeTlsConnector_Stream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8]
) -> Poll<io::Result<usize>> {
match *self {
MaybeTlsConnector_Stream::NoTls(ref mut stream) =>
Pin::new(stream).poll_write(cx, buf),
#[cfg(feature = "postgres_pool_tls")]
MaybeTlsConnector_Stream::Tls(ref mut stream) =>
Pin::new(stream).poll_write(cx, buf),
}
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match *self {
MaybeTlsConnector_Stream::NoTls(ref mut stream) => Pin::new(stream).poll_flush(cx),
#[cfg(feature = "postgres_pool_tls")]
MaybeTlsConnector_Stream::Tls(ref mut stream) => Pin::new(stream).poll_flush(cx),
}
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match *self {
MaybeTlsConnector_Stream::NoTls(ref mut stream) =>
Pin::new(stream).poll_shutdown(cx),
#[cfg(feature = "postgres_pool_tls")]
MaybeTlsConnector_Stream::Tls(ref mut stream) =>
Pin::new(stream).poll_shutdown(cx),
}
}
}

// --- TlsConnect ---

#[allow(non_camel_case_types)]
pub enum MaybeTlsConnector_TlsConnect {
NoTls(postgres::tls::NoTls),
#[cfg(feature = "postgres_pool_tls")]
Tls(postgres_native_tls::TlsConnector)
}

impl postgres::tls::TlsConnect<postgres::Socket> for MaybeTlsConnector_TlsConnect {
type Stream = MaybeTlsConnector_Stream;
type Error = MaybeTlsConnector_Error;
type Future = MaybeTlsConnector_Future;

fn connect(self, socket: postgres::Socket) -> Self::Future {
match self {
MaybeTlsConnector_TlsConnect::NoTls(connector) =>
Self::Future::NoTls(connector.connect(socket)),
#[cfg(feature = "postgres_pool_tls")]
MaybeTlsConnector_TlsConnect::Tls(connector) =>
Self::Future::Tls(connector.connect(socket)),
}
}
}

// --- Error ---

#[allow(non_camel_case_types)]
#[derive(Debug)]
pub enum MaybeTlsConnector_Error {
NoTls(postgres::tls::NoTlsError),
#[cfg(feature = "postgres_pool_tls")]
Tls(native_tls::Error)
}

impl std::fmt::Display for MaybeTlsConnector_Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
MaybeTlsConnector_Error::NoTls(e) => e.fmt(f),
#[cfg(feature = "postgres_pool_tls")]
MaybeTlsConnector_Error::Tls(e) => e.fmt(f),
}
}
}

impl std::error::Error for MaybeTlsConnector_Error {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
MaybeTlsConnector_Error::NoTls(e) => e.source(),
#[cfg(feature = "postgres_pool_tls")]
MaybeTlsConnector_Error::Tls(e) => e.source(),
}
}
}

// --- Future ---

#[allow(non_camel_case_types)]
pub enum MaybeTlsConnector_Future {
NoTls(postgres::tls::NoTlsFuture),
#[cfg(feature = "postgres_pool_tls")]
Tls(<postgres_native_tls::TlsConnector as
postgres::tls::TlsConnect<postgres::Socket>>::Future)
}

impl std::future::Future for MaybeTlsConnector_Future {
type Output = Result<MaybeTlsConnector_Stream, MaybeTlsConnector_Error>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match *self {
MaybeTlsConnector_Future::NoTls(ref mut future) => {
Pin::new(future)
.poll(cx)
.map(|v| v.map(MaybeTlsConnector_Stream::NoTls))
.map_err(MaybeTlsConnector_Error::NoTls)
},
#[cfg(feature = "postgres_pool_tls")]
MaybeTlsConnector_Future::Tls(ref mut future) => {
Pin::new(future)
.poll(cx)
.map(|v| v.map(MaybeTlsConnector_Stream::Tls))
.map_err(MaybeTlsConnector_Error::Tls)
}
}
}
}
}

#[cfg(feature = "postgres_pool")]
impl Poolable for postgres::Client {
type Manager = r2d2_postgres::PostgresConnectionManager<postgres::tls::NoTls>;
type Manager = r2d2_postgres::PostgresConnectionManager<pg::MaybeTlsConnector>;
type Error = postgres::Error;

fn pool(db_name: &str, rocket: &Rocket<Build>) -> PoolResult<Self> {
let config = Config::from(db_name, rocket)?;
let url = config.url.parse().map_err(Error::Custom)?;
let manager = r2d2_postgres::PostgresConnectionManager::new(url, postgres::tls::NoTls);

let tls_connector = match config.tls {
// `tls_config` is unused when `postgres_pool_tls` is disabled.
#[allow(unused_variables)]
Some(ref tls_config) => {

#[cfg(feature = "postgres_pool_tls")]
{
let mut connector_builder = native_tls::TlsConnector::builder();
if let Some(ref cert) = tls_config.ssl_root_cert {
let cert_file_bytes = std::fs::read(cert)?;
let cert = native_tls::Certificate::from_pem(&cert_file_bytes)
.map_err(|e| Error::Tls(e.into()))?;
connector_builder.add_root_certificate(cert);

// Client certs
match (
tls_config.ssl_client_cert.as_ref(),
tls_config.ssl_client_key.as_ref(),
) {
(Some(cert), Some(key)) => {
let cert_file_bytes = std::fs::read(cert)?;
let key_file_bytes = std::fs::read(key)?;
let cert = native_tls::Identity::from_pkcs8(
&cert_file_bytes,
&key_file_bytes
).map_err(|e| Error::Tls(e.into()))?;
connector_builder.identity(cert);
},
(Some(_), None) => {
return Err(Error::Tls(
"Client certificate provided without client key".into()
))
},
(None, Some(_)) => {
return Err(Error::Tls(
"Client key provided without client certificate".into()
))
},
(None, None) => {},
}
}

connector_builder
.danger_accept_invalid_certs(tls_config.accept_invalid_certs);
connector_builder
.danger_accept_invalid_hostnames(tls_config.accept_invalid_hostnames);

pg::MaybeTlsConnector::Tls(postgres_native_tls::MakeTlsConnector::new(
connector_builder.build().map_err(|e| Error::Tls(e.into()))?
))
}

#[cfg(not(feature = "postgres_pool_tls"))]
{
// TODO: Should this be an error?
rocket::warn!("The `postgres_pool_tls` feature is disabled. \
Postgres TLS configuration will be ignored.");
pg::MaybeTlsConnector::NoTls(postgres::tls::NoTls)
}
},
None => {
pg::MaybeTlsConnector::NoTls(postgres::tls::NoTls)
}
};

let manager = r2d2_postgres::PostgresConnectionManager::new(url, tls_connector);
let pool = r2d2::Pool::builder()
.max_size(config.pool_size)
.connection_timeout(Duration::from_secs(config.timeout as u64))
Expand Down
Loading