Skip to content

Commit c39501a

Browse files
committed
sqlx-postgres: Startup with background worker
1 parent a30c706 commit c39501a

File tree

5 files changed

+136
-58
lines changed

5 files changed

+136
-58
lines changed

sqlx-postgres/src/connection/establish.rs

Lines changed: 32 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,28 @@
1-
use crate::HashMap;
2-
3-
use crate::common::StatementCache;
41
use crate::connection::{sasl, stream::PgStream};
52
use crate::error::Error;
6-
use crate::io::StatementId;
73
use crate::message::{
84
Authentication, BackendKeyData, BackendMessageFormat, Password, ReadyForQuery, Startup,
95
};
106
use crate::{PgConnectOptions, PgConnection};
7+
use futures_channel::mpsc::unbounded;
118

12-
use super::PgConnectionInner;
9+
use super::worker::Worker;
1310

1411
// https://www.postgresql.org/docs/current/protocol-flow.html#id-1.10.5.7.3
1512
// https://www.postgresql.org/docs/current/protocol-flow.html#id-1.10.5.7.11
1613

1714
impl PgConnection {
1815
pub(crate) async fn establish(options: &PgConnectOptions) -> Result<Self, Error> {
1916
// Upgrade to TLS if we were asked to and the server supports it
20-
let mut stream = PgStream::connect(options).await?;
17+
let pg_stream = PgStream::connect(options).await?;
18+
19+
let stream = PgStream::connect(options).await?;
20+
21+
let (notif_tx, notif_rx) = unbounded();
22+
23+
let x = Worker::spawn(stream.into_inner(), notif_tx);
24+
25+
let mut conn = PgConnection::new(pg_stream, options, x, notif_rx);
2126

2227
// To begin a session, a frontend opens a connection to the server
2328
// and sends a startup message.
@@ -45,14 +50,14 @@ impl PgConnection {
4550
params.push(("options", options));
4651
}
4752

48-
stream.write(Startup {
49-
username: Some(&options.username),
50-
database: options.database.as_deref(),
51-
params: &params,
53+
let mut pipe = conn.pipe(|buf| {
54+
buf.write(Startup {
55+
username: Some(&options.username),
56+
database: options.database.as_deref(),
57+
params: &params,
58+
})
5259
})?;
5360

54-
stream.flush().await?;
55-
5661
// The server then uses this information and the contents of
5762
// its configuration files (such as pg_hba.conf) to determine whether the connection is
5863
// provisionally acceptable, and what additional
@@ -63,7 +68,7 @@ impl PgConnection {
6368
let transaction_status;
6469

6570
loop {
66-
let message = stream.recv().await?;
71+
let message = pipe.recv().await?;
6772
match message.format {
6873
BackendMessageFormat::Authentication => match message.decode()? {
6974
Authentication::Ok => {
@@ -75,11 +80,9 @@ impl PgConnection {
7580
// The frontend must now send a [PasswordMessage] containing the
7681
// password in clear-text form.
7782

78-
stream
79-
.send(Password::Cleartext(
80-
options.password.as_deref().unwrap_or_default(),
81-
))
82-
.await?;
83+
conn.pipe_and_forget(Password::Cleartext(
84+
options.password.as_deref().unwrap_or_default(),
85+
))?;
8386
}
8487

8588
Authentication::Md5Password(body) => {
@@ -88,17 +91,15 @@ impl PgConnection {
8891
// using the 4-byte random salt specified in the
8992
// [AuthenticationMD5Password] message.
9093

91-
stream
92-
.send(Password::Md5 {
93-
username: &options.username,
94-
password: options.password.as_deref().unwrap_or_default(),
95-
salt: body.salt,
96-
})
97-
.await?;
94+
conn.pipe_and_forget(Password::Md5 {
95+
username: &options.username,
96+
password: options.password.as_deref().unwrap_or_default(),
97+
salt: body.salt,
98+
})?;
9899
}
99100

100101
Authentication::Sasl(body) => {
101-
sasl::authenticate(&mut stream, options, body).await?;
102+
sasl::authenticate(&conn, &mut pipe, options, body).await?;
102103
}
103104

104105
method => {
@@ -135,22 +136,10 @@ impl PgConnection {
135136
}
136137
}
137138

138-
Ok(PgConnection {
139-
inner: Box::new(PgConnectionInner {
140-
stream,
141-
process_id,
142-
secret_key,
143-
transaction_status,
144-
transaction_depth: 0,
145-
pending_ready_for_query_count: 0,
146-
next_statement_id: StatementId::NAMED_START,
147-
cache_statement: StatementCache::new(options.statement_cache_capacity),
148-
cache_type_oid: HashMap::new(),
149-
cache_type_info: HashMap::new(),
150-
cache_elem_type_to_array: HashMap::new(),
151-
cache_table_to_column_names: HashMap::new(),
152-
log_settings: options.log_settings.clone(),
153-
}),
154-
})
139+
conn.inner.transaction_status = transaction_status;
140+
conn.inner.secret_key = secret_key;
141+
conn.inner.process_id = process_id;
142+
143+
Ok(conn)
155144
}
156145
}

sqlx-postgres/src/connection/mod.rs

Lines changed: 87 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,19 @@ use std::fmt::{self, Debug, Formatter};
44
use std::sync::Arc;
55

66
use crate::HashMap;
7+
use futures_channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender};
78
use futures_core::future::BoxFuture;
89
use futures_util::FutureExt;
10+
use pipe::Pipe;
11+
use request::{IoRequest, MessageBuf};
912

1013
use crate::common::StatementCache;
1114
use crate::error::Error;
1215
use crate::ext::ustr::UStr;
1316
use crate::io::StatementId;
1417
use crate::message::{
15-
BackendMessageFormat, Close, Query, ReadyForQuery, ReceivedMessage, Terminate,
16-
TransactionStatus,
18+
BackendMessageFormat, Close, FrontendMessage, Notification, Query, ReadyForQuery,
19+
ReceivedMessage, Terminate, TransactionStatus,
1720
};
1821
use crate::statement::PgStatementMetadata;
1922
use crate::transaction::Transaction;
@@ -47,6 +50,10 @@ pub struct PgConnectionInner {
4750
// wrapped in a buffered stream
4851
pub(crate) stream: PgStream,
4952

53+
chan: UnboundedSender<IoRequest>,
54+
55+
notifications: UnboundedReceiver<Notification>,
56+
5057
// process id of this backend
5158
// used to send cancel requests
5259
#[allow(dead_code)]
@@ -148,6 +155,84 @@ impl PgConnection {
148155
TransactionStatus::Error | TransactionStatus::Idle => false,
149156
}
150157
}
158+
159+
fn new(
160+
stream: PgStream,
161+
options: &PgConnectOptions,
162+
chan: UnboundedSender<IoRequest>,
163+
notifications: UnboundedReceiver<Notification>,
164+
) -> Self {
165+
Self {
166+
inner: Box::new(PgConnectionInner {
167+
chan,
168+
notifications,
169+
log_settings: options.log_settings.clone(),
170+
process_id: 0,
171+
secret_key: 0,
172+
next_statement_id: StatementId::NAMED_START,
173+
cache_statement: StatementCache::new(options.statement_cache_capacity),
174+
cache_type_info: HashMap::new(),
175+
cache_type_oid: HashMap::new(),
176+
cache_elem_type_to_array: HashMap::new(),
177+
cache_table_to_column_names: HashMap::new(),
178+
transaction_depth: 0,
179+
stream,
180+
pending_ready_for_query_count: 0,
181+
transaction_status: TransactionStatus::Idle,
182+
}),
183+
}
184+
}
185+
186+
fn create_request<F>(&self, callback: F) -> sqlx_core::Result<IoRequest>
187+
where
188+
F: FnOnce(&mut MessageBuf) -> sqlx_core::Result<()>,
189+
{
190+
let mut buffer = MessageBuf::new();
191+
(callback)(&mut buffer)?;
192+
Ok(buffer.finish())
193+
}
194+
195+
fn send_request(&self, request: IoRequest) -> sqlx_core::Result<()> {
196+
self.inner
197+
.chan
198+
.unbounded_send(request)
199+
.map_err(|_| sqlx_core::Error::WorkerCrashed)
200+
}
201+
202+
pub(crate) fn pipe<F>(&self, callback: F) -> sqlx_core::Result<Pipe>
203+
where
204+
F: FnOnce(&mut MessageBuf) -> sqlx_core::Result<()>,
205+
{
206+
let mut req = self.create_request(callback)?;
207+
let (tx, rx) = unbounded();
208+
req.chan = Some(tx);
209+
210+
self.send_request(req)?;
211+
Ok(Pipe::new(rx))
212+
}
213+
214+
pub(crate) fn pipe_and_forget<T>(&self, value: T) -> sqlx_core::Result<()>
215+
where
216+
T: FrontendMessage,
217+
{
218+
let req = self.create_request(|buf| buf.write_msg(value))?;
219+
self.send_request(req)
220+
}
221+
222+
pub(crate) async fn start_pipe_async<F, R>(&self, callback: F) -> sqlx_core::Result<(R, Pipe)>
223+
where
224+
F: AsyncFnOnce(&mut MessageBuf) -> sqlx_core::Result<R>,
225+
{
226+
let mut buffer = MessageBuf::new();
227+
let result = (callback)(&mut buffer).await?;
228+
let mut req = buffer.finish();
229+
let (tx, rx) = unbounded();
230+
req.chan = Some(tx);
231+
232+
self.send_request(req)?;
233+
234+
Ok((result, Pipe::new(rx)))
235+
}
151236
}
152237

153238
impl Debug for PgConnection {

sqlx-postgres/src/connection/sasl.rs

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
use crate::connection::stream::PgStream;
21
use crate::error::Error;
32
use crate::message::{Authentication, AuthenticationSasl, SaslInitialResponse, SaslResponse};
43
use crate::PgConnectOptions;
@@ -9,14 +8,18 @@ use stringprep::saslprep;
98

109
use base64::prelude::{Engine as _, BASE64_STANDARD};
1110

11+
use super::pipe::Pipe;
12+
use super::PgConnection;
13+
1214
const GS2_HEADER: &str = "n,,";
1315
const CHANNEL_ATTR: &str = "c";
1416
const USERNAME_ATTR: &str = "n";
1517
const CLIENT_PROOF_ATTR: &str = "p";
1618
const NONCE_ATTR: &str = "r";
1719

1820
pub(crate) async fn authenticate(
19-
stream: &mut PgStream,
21+
conn: &PgConnection,
22+
pipe: &mut Pipe,
2023
options: &PgConnectOptions,
2124
data: AuthenticationSasl,
2225
) -> Result<(), Error> {
@@ -67,14 +70,12 @@ pub(crate) async fn authenticate(
6770

6871
let client_first_message = format!("{GS2_HEADER}{client_first_message_bare}");
6972

70-
stream
71-
.send(SaslInitialResponse {
72-
response: &client_first_message,
73-
plus: false,
74-
})
75-
.await?;
73+
conn.pipe_and_forget(SaslInitialResponse {
74+
response: &client_first_message,
75+
plus: false,
76+
})?;
7677

77-
let cont = match stream.recv_expect().await? {
78+
let cont = match pipe.recv_expect().await? {
7879
Authentication::SaslContinue(data) => data,
7980

8081
auth => {
@@ -143,9 +144,9 @@ pub(crate) async fn authenticate(
143144
let mut client_final_message = format!("{client_final_message_wo_proof},{CLIENT_PROOF_ATTR}=");
144145
BASE64_STANDARD.encode_string(client_proof, &mut client_final_message);
145146

146-
stream.send(SaslResponse(&client_final_message)).await?;
147+
conn.pipe_and_forget(SaslResponse(&client_final_message))?;
147148

148-
let data = match stream.recv_expect().await? {
149+
let data = match pipe.recv_expect().await? {
149150
Authentication::SaslFinal(data) => data,
150151

151152
auth => {

sqlx-postgres/src/connection/stream.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ pub struct PgStream {
4141
}
4242

4343
impl PgStream {
44+
pub fn into_inner(self) -> BufferedSocket<Box<dyn Socket>> {
45+
self.inner
46+
}
47+
4448
pub(super) async fn connect(options: &PgConnectOptions) -> Result<Self, Error> {
4549
let socket_result = match options.fetch_socket() {
4650
Some(ref path) => net::connect_uds(path, MaybeUpgradeTls(options)).await?,

sqlx-postgres/src/connection/worker.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use std::{
77
};
88

99
use crate::message::{
10-
BackendMessageFormat, FrontendMessage, Notification, ReadyForQuery, ReceivedMessage, Terminate,
10+
BackendMessageFormat, FrontendMessage, Notification, ReceivedMessage, Terminate,
1111
};
1212
use futures_channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender};
1313
use futures_util::{SinkExt, StreamExt};
@@ -129,7 +129,6 @@ impl Worker {
129129
while let Poll::Ready(response) = self.poll_next_message(cx)? {
130130
match response.format {
131131
BackendMessageFormat::ReadyForQuery => {
132-
let rfq: ReadyForQuery = response.clone().decode()?;
133132
self.send_back(response)?;
134133
// Remove from the backlog so we dont send more responses back.
135134
let _ = self.back_log.pop_front();

0 commit comments

Comments
 (0)