Skip to content

Commit 684068a

Browse files
authored
Merge pull request #41 from launchbadge/ab/tls
implement TLS for Postgres and MySQL
2 parents 6c8fd94 + 114aaa5 commit 684068a

File tree

19 files changed

+800
-25
lines changed

19 files changed

+800
-25
lines changed

.github/workflows/mysql.yml

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@ jobs:
2525
# will assign a random free host port
2626
- 3306/tcp
2727
# needed because the container does not provide a healthcheck
28-
options: --health-cmd "mysqladmin ping --silent" --health-interval 30s --health-timeout 30s --health-retries 10
28+
options: >-
29+
--health-cmd "mysqladmin ping --silent" --health-interval 30s --health-timeout 30s
30+
--health-retries 10 -v /data/mysql:/var/lib/mysql
31+
2932
3033
steps:
3134
- uses: actions/checkout@v1
@@ -48,9 +51,11 @@ jobs:
4851

4952
# -----------------------------------------------------
5053

51-
- run: cargo test -p sqlx --no-default-features --features 'mysql macros chrono'
54+
- run: cargo test -p sqlx --no-default-features --features 'mysql macros chrono tls'
5255
env:
53-
DATABASE_URL: mysql://root:password@localhost:${{ job.services.mysql.ports[3306] }}/sqlx
56+
# pass the path to the CA that the MySQL service generated
57+
# Github Actions' YML parser doesn't handle multiline strings correctly
58+
DATABASE_URL: mysql://root:password@localhost:${{ job.services.mysql.ports[3306] }}/sqlx?ssl-mode=VERIFY_CA&ssl-ca=%2Fdata%2Fmysql%2Fca.pem
5459

5560
# Rust ------------------------------------------------
5661

.github/workflows/postgres.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ jobs:
4949

5050
# -----------------------------------------------------
5151

52+
# Check that we build with TLS support (TODO: we need a postgres image with SSL certs to test)
53+
- run: cargo check -p sqlx-core --no-default-features --features 'postgres macros uuid chrono tls'
54+
5255
- run: cargo test -p sqlx --no-default-features --features 'postgres macros uuid chrono'
5356
env:
5457
DATABASE_URL: postgres://postgres:postgres@localhost:${{ job.services.postgres.ports[5432] }}/postgres

Cargo.lock

Lines changed: 181 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ all-features = true
3030
[features]
3131
default = [ "macros" ]
3232
macros = [ "sqlx-macros", "proc-macro-hack" ]
33+
tls = ["sqlx-core/tls"]
3334

3435
# database
3536
postgres = [ "sqlx-core/postgres", "sqlx-macros/postgres" ]
@@ -48,6 +49,7 @@ hex = "0.4.0"
4849
[dev-dependencies]
4950
anyhow = "1.0.26"
5051
futures = "0.3.1"
52+
env_logger = "0.7"
5153
async-std = { version = "1.4.0", features = [ "attributes" ] }
5254
dotenv = "0.15.0"
5355

sqlx-core/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@ default = []
2020
unstable = []
2121
postgres = [ "md-5", "sha2", "base64", "sha-1", "rand", "hmac" ]
2222
mysql = [ "sha-1", "sha2", "generic-array", "num-bigint", "base64", "digest", "rand" ]
23+
tls = ["async-native-tls"]
2324

2425
[dependencies]
26+
async-native-tls = { version = "0.3", optional = true }
2527
async-std = "1.4.0"
2628
async-stream = { version = "0.2.0", default-features = false }
2729
base64 = { version = "0.11.0", default-features = false, optional = true, features = [ "std" ] }

sqlx-core/src/error.rs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ pub enum Error {
4444
/// [Pool::close] was called while we were waiting in [Pool::acquire].
4545
PoolClosed,
4646

47+
/// An error occurred during a TLS upgrade.
48+
TlsUpgrade(Box<dyn StdError + Send + Sync>),
49+
4750
Decode(DecodeError),
4851

4952
// TODO: Remove and replace with `#[non_exhaustive]` when possible
@@ -62,6 +65,8 @@ impl StdError for Error {
6265

6366
Error::Decode(DecodeError::Other(error)) => Some(&**error),
6467

68+
Error::TlsUpgrade(error) => Some(&**error),
69+
6570
_ => None,
6671
}
6772
}
@@ -100,6 +105,8 @@ impl Display for Error {
100105

101106
Error::PoolClosed => f.write_str("attempted to acquire a connection on a closed pool"),
102107

108+
Error::TlsUpgrade(ref err) => write!(f, "error during TLS upgrade: {}", err),
109+
103110
Error::__Nonexhaustive => unreachable!(),
104111
}
105112
}
@@ -140,6 +147,21 @@ impl From<ProtocolError<'_>> for Error {
140147
}
141148
}
142149

150+
#[cfg(feature = "tls")]
151+
impl From<async_native_tls::Error> for Error {
152+
#[inline]
153+
fn from(err: async_native_tls::Error) -> Self {
154+
Error::TlsUpgrade(err.into())
155+
}
156+
}
157+
158+
impl From<TlsError<'_>> for Error {
159+
#[inline]
160+
fn from(err: TlsError<'_>) -> Self {
161+
Error::TlsUpgrade(err.args.to_string().into())
162+
}
163+
}
164+
143165
impl<T> From<T> for Error
144166
where
145167
T: 'static + DatabaseError,
@@ -189,6 +211,15 @@ macro_rules! protocol_err (
189211
}
190212
);
191213

214+
pub(crate) struct TlsError<'a> {
215+
pub args: fmt::Arguments<'a>,
216+
}
217+
218+
#[allow(unused_macros)]
219+
macro_rules! tls_err {
220+
($($args:tt)*) => { crate::error::TlsError { args: format_args!($($args)*)} };
221+
}
222+
192223
#[allow(unused_macros)]
193224
macro_rules! impl_fmt_error {
194225
($err:ty) => {

sqlx-core/src/io/buf_stream.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use async_std::io::{
33
Read, Write,
44
};
55
use std::io;
6+
use std::ops::{Deref, DerefMut};
67

78
const RBUF_SIZE: usize = 8 * 1024;
89

@@ -51,6 +52,12 @@ where
5152
Ok(())
5253
}
5354

55+
pub fn clear_bufs(&mut self) {
56+
self.rbuf_rindex = 0;
57+
self.rbuf_windex = 0;
58+
self.wbuf.clear();
59+
}
60+
5461
#[inline]
5562
pub fn consume(&mut self, cnt: usize) {
5663
self.rbuf_rindex += cnt;
@@ -118,6 +125,20 @@ where
118125
}
119126
}
120127

128+
impl<S> Deref for BufStream<S> {
129+
type Target = S;
130+
131+
fn deref(&self) -> &Self::Target {
132+
&self.stream
133+
}
134+
}
135+
136+
impl<S> DerefMut for BufStream<S> {
137+
fn deref_mut(&mut self) -> &mut Self::Target {
138+
&mut self.stream
139+
}
140+
}
141+
121142
// TODO: Find a nicer way to do this
122143
// Return `Ok(None)` immediately from a function if the wrapped value is `None`
123144
#[allow(unused)]

sqlx-core/src/io/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,14 @@ mod buf;
55
mod buf_mut;
66
mod byte_str;
77

8+
mod tls;
9+
810
pub use self::{
911
buf::{Buf, ToBuf},
1012
buf_mut::BufMut,
1113
buf_stream::BufStream,
1214
byte_str::ByteStr,
15+
tls::MaybeTlsStream,
1316
};
1417

1518
#[cfg(test)]

sqlx-core/src/io/tls.rs

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
use std::io::{IoSlice, IoSliceMut};
2+
use std::pin::Pin;
3+
use std::task::{Context, Poll};
4+
5+
use async_std::io::{self, Read, Write};
6+
use async_std::net::{Shutdown, TcpStream};
7+
8+
use crate::url::Url;
9+
10+
use self::Inner::*;
11+
12+
pub struct MaybeTlsStream {
13+
inner: Inner,
14+
}
15+
16+
enum Inner {
17+
NotTls(TcpStream),
18+
#[cfg(feature = "tls")]
19+
Tls(async_native_tls::TlsStream<TcpStream>),
20+
#[cfg(feature = "tls")]
21+
Upgrading,
22+
}
23+
24+
impl MaybeTlsStream {
25+
pub async fn connect(url: &Url, default_port: u16) -> crate::Result<Self> {
26+
let conn = TcpStream::connect((url.host(), url.port(default_port))).await?;
27+
Ok(Self {
28+
inner: Inner::NotTls(conn),
29+
})
30+
}
31+
32+
#[allow(dead_code)]
33+
pub fn is_tls(&self) -> bool {
34+
match self.inner {
35+
Inner::NotTls(_) => false,
36+
#[cfg(feature = "tls")]
37+
Inner::Tls(_) => true,
38+
#[cfg(feature = "tls")]
39+
Inner::Upgrading => false,
40+
}
41+
}
42+
43+
#[cfg(feature = "tls")]
44+
pub async fn upgrade(
45+
&mut self,
46+
url: &Url,
47+
connector: async_native_tls::TlsConnector,
48+
) -> crate::Result<()> {
49+
let conn = match std::mem::replace(&mut self.inner, Upgrading) {
50+
NotTls(conn) => conn,
51+
Tls(_) => return Err(tls_err!("connection already upgraded").into()),
52+
Upgrading => return Err(tls_err!("connection already failed to upgrade").into()),
53+
};
54+
55+
self.inner = Tls(connector.connect(url.host(), conn).await?);
56+
57+
Ok(())
58+
}
59+
60+
pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
61+
match self.inner {
62+
NotTls(ref conn) => conn.shutdown(how),
63+
#[cfg(feature = "tls")]
64+
Tls(ref conn) => conn.get_ref().shutdown(how),
65+
#[cfg(feature = "tls")]
66+
// connection already closed
67+
Upgrading => Ok(()),
68+
}
69+
}
70+
}
71+
72+
macro_rules! forward_pin (
73+
($self:ident.$method:ident($($arg:ident),*)) => (
74+
match &mut $self.inner {
75+
NotTls(ref mut conn) => Pin::new(conn).$method($($arg),*),
76+
#[cfg(feature = "tls")]
77+
Tls(ref mut conn) => Pin::new(conn).$method($($arg),*),
78+
#[cfg(feature = "tls")]
79+
Upgrading => Err(io::Error::new(io::ErrorKind::Other, "connection broken; TLS upgrade failed")).into(),
80+
}
81+
)
82+
);
83+
84+
impl Read for MaybeTlsStream {
85+
fn poll_read(
86+
mut self: Pin<&mut Self>,
87+
cx: &mut Context,
88+
buf: &mut [u8],
89+
) -> Poll<io::Result<usize>> {
90+
forward_pin!(self.poll_read(cx, buf))
91+
}
92+
93+
fn poll_read_vectored(
94+
mut self: Pin<&mut Self>,
95+
cx: &mut Context,
96+
bufs: &mut [IoSliceMut],
97+
) -> Poll<io::Result<usize>> {
98+
forward_pin!(self.poll_read_vectored(cx, bufs))
99+
}
100+
}
101+
102+
impl Write for MaybeTlsStream {
103+
fn poll_write(
104+
mut self: Pin<&mut Self>,
105+
cx: &mut Context,
106+
buf: &[u8],
107+
) -> Poll<io::Result<usize>> {
108+
forward_pin!(self.poll_write(cx, buf))
109+
}
110+
111+
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
112+
forward_pin!(self.poll_flush(cx))
113+
}
114+
115+
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
116+
forward_pin!(self.poll_close(cx))
117+
}
118+
119+
fn poll_write_vectored(
120+
mut self: Pin<&mut Self>,
121+
cx: &mut Context,
122+
bufs: &[IoSlice],
123+
) -> Poll<io::Result<usize>> {
124+
forward_pin!(self.poll_write_vectored(cx, bufs))
125+
}
126+
}

0 commit comments

Comments
 (0)