Skip to content

Commit 7da5f63

Browse files
committed
sqlx-postgres: Startup with background worker
1 parent ccdadea commit 7da5f63

File tree

5 files changed

+135
-57
lines changed

5 files changed

+135
-57
lines changed

sqlx-postgres/src/connection/establish.rs

Lines changed: 32 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,28 @@
1-
use crate::HashMap;
2-
3-
use crate::common::StatementCache;
41
use crate::connection::{sasl, stream::PgStream};
52
use crate::error::Error;
6-
use crate::io::StatementId;
73
use crate::message::{
84
Authentication, BackendKeyData, BackendMessageFormat, Password, ReadyForQuery, Startup,
95
};
106
use crate::{PgConnectOptions, PgConnection};
7+
use futures_channel::mpsc::unbounded;
118

12-
use super::PgConnectionInner;
9+
use super::worker::Worker;
1310

1411
// https://www.postgresql.org/docs/current/protocol-flow.html#id-1.10.5.7.3
1512
// https://www.postgresql.org/docs/current/protocol-flow.html#id-1.10.5.7.11
1613

1714
impl PgConnection {
1815
pub(crate) async fn establish(options: &PgConnectOptions) -> Result<Self, Error> {
1916
// Upgrade to TLS if we were asked to and the server supports it
20-
let mut stream = PgStream::connect(options).await?;
17+
let pg_stream = PgStream::connect(options).await?;
18+
19+
let stream = PgStream::connect(options).await?;
20+
21+
let (notif_tx, notif_rx) = unbounded();
22+
23+
let x = Worker::spawn(stream.into_inner(), notif_tx);
24+
25+
let mut conn = PgConnection::new(pg_stream, options, x, notif_rx);
2126

2227
// To begin a session, a frontend opens a connection to the server
2328
// and sends a startup message.
@@ -45,14 +50,14 @@ impl PgConnection {
4550
params.push(("options", options));
4651
}
4752

48-
stream.write(Startup {
49-
username: Some(&options.username),
50-
database: options.database.as_deref(),
51-
params: &params,
53+
let mut pipe = conn.pipe(|buf| {
54+
buf.write(Startup {
55+
username: Some(&options.username),
56+
database: options.database.as_deref(),
57+
params: &params,
58+
})
5259
})?;
5360

54-
stream.flush().await?;
55-
5661
// The server then uses this information and the contents of
5762
// its configuration files (such as pg_hba.conf) to determine whether the connection is
5863
// provisionally acceptable, and what additional
@@ -63,7 +68,7 @@ impl PgConnection {
6368
let transaction_status;
6469

6570
loop {
66-
let message = stream.recv().await?;
71+
let message = pipe.recv().await?;
6772
match message.format {
6873
BackendMessageFormat::Authentication => match message.decode()? {
6974
Authentication::Ok => {
@@ -75,11 +80,9 @@ impl PgConnection {
7580
// The frontend must now send a [PasswordMessage] containing the
7681
// password in clear-text form.
7782

78-
stream
79-
.send(Password::Cleartext(
80-
options.password.as_deref().unwrap_or_default(),
81-
))
82-
.await?;
83+
conn.pipe_and_forget(Password::Cleartext(
84+
options.password.as_deref().unwrap_or_default(),
85+
))?;
8386
}
8487

8588
Authentication::Md5Password(body) => {
@@ -88,17 +91,15 @@ impl PgConnection {
8891
// using the 4-byte random salt specified in the
8992
// [AuthenticationMD5Password] message.
9093

91-
stream
92-
.send(Password::Md5 {
93-
username: &options.username,
94-
password: options.password.as_deref().unwrap_or_default(),
95-
salt: body.salt,
96-
})
97-
.await?;
94+
conn.pipe_and_forget(Password::Md5 {
95+
username: &options.username,
96+
password: options.password.as_deref().unwrap_or_default(),
97+
salt: body.salt,
98+
})?;
9899
}
99100

100101
Authentication::Sasl(body) => {
101-
sasl::authenticate(&mut stream, options, body).await?;
102+
sasl::authenticate(&conn, &mut pipe, options, body).await?;
102103
}
103104

104105
method => {
@@ -135,21 +136,10 @@ impl PgConnection {
135136
}
136137
}
137138

138-
Ok(PgConnection {
139-
inner: Box::new(PgConnectionInner {
140-
stream,
141-
process_id,
142-
secret_key,
143-
transaction_status,
144-
transaction_depth: 0,
145-
pending_ready_for_query_count: 0,
146-
next_statement_id: StatementId::NAMED_START,
147-
cache_statement: StatementCache::new(options.statement_cache_capacity),
148-
cache_type_oid: HashMap::new(),
149-
cache_type_info: HashMap::new(),
150-
cache_elem_type_to_array: HashMap::new(),
151-
log_settings: options.log_settings.clone(),
152-
}),
153-
})
139+
conn.inner.transaction_status = transaction_status;
140+
conn.inner.secret_key = secret_key;
141+
conn.inner.process_id = process_id;
142+
143+
Ok(conn)
154144
}
155145
}

sqlx-postgres/src/connection/mod.rs

Lines changed: 86 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,19 @@ use std::fmt::{self, Debug, Formatter};
33
use std::sync::Arc;
44

55
use crate::HashMap;
6+
use futures_channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender};
67
use futures_core::future::BoxFuture;
78
use futures_util::FutureExt;
9+
use pipe::Pipe;
10+
use request::{IoRequest, MessageBuf};
811

912
use crate::common::StatementCache;
1013
use crate::error::Error;
1114
use crate::ext::ustr::UStr;
1215
use crate::io::StatementId;
1316
use crate::message::{
14-
BackendMessageFormat, Close, Query, ReadyForQuery, ReceivedMessage, Terminate,
15-
TransactionStatus,
17+
BackendMessageFormat, Close, FrontendMessage, Notification, Query, ReadyForQuery,
18+
ReceivedMessage, Terminate, TransactionStatus,
1619
};
1720
use crate::statement::PgStatementMetadata;
1821
use crate::transaction::Transaction;
@@ -46,6 +49,10 @@ pub struct PgConnectionInner {
4649
// wrapped in a buffered stream
4750
pub(crate) stream: PgStream,
4851

52+
chan: UnboundedSender<IoRequest>,
53+
54+
notifications: UnboundedReceiver<Notification>,
55+
4956
// process id of this backend
5057
// used to send cancel requests
5158
#[allow(dead_code)]
@@ -140,6 +147,83 @@ impl PgConnection {
140147
TransactionStatus::Error | TransactionStatus::Idle => false,
141148
}
142149
}
150+
151+
fn new(
152+
stream: PgStream,
153+
options: &PgConnectOptions,
154+
chan: UnboundedSender<IoRequest>,
155+
notifications: UnboundedReceiver<Notification>,
156+
) -> Self {
157+
Self {
158+
inner: Box::new(PgConnectionInner {
159+
chan,
160+
notifications,
161+
log_settings: options.log_settings.clone(),
162+
process_id: 0,
163+
secret_key: 0,
164+
next_statement_id: StatementId::NAMED_START,
165+
cache_statement: StatementCache::new(options.statement_cache_capacity),
166+
cache_type_info: HashMap::new(),
167+
cache_type_oid: HashMap::new(),
168+
cache_elem_type_to_array: HashMap::new(),
169+
transaction_depth: 0,
170+
stream,
171+
pending_ready_for_query_count: 0,
172+
transaction_status: TransactionStatus::Idle,
173+
}),
174+
}
175+
}
176+
177+
fn create_request<F>(&self, callback: F) -> sqlx_core::Result<IoRequest>
178+
where
179+
F: FnOnce(&mut MessageBuf) -> sqlx_core::Result<()>,
180+
{
181+
let mut buffer = MessageBuf::new();
182+
(callback)(&mut buffer)?;
183+
Ok(buffer.finish())
184+
}
185+
186+
fn send_request(&self, request: IoRequest) -> sqlx_core::Result<()> {
187+
self.inner
188+
.chan
189+
.unbounded_send(request)
190+
.map_err(|_| sqlx_core::Error::WorkerCrashed)
191+
}
192+
193+
pub(crate) fn pipe<F>(&self, callback: F) -> sqlx_core::Result<Pipe>
194+
where
195+
F: FnOnce(&mut MessageBuf) -> sqlx_core::Result<()>,
196+
{
197+
let mut req = self.create_request(callback)?;
198+
let (tx, rx) = unbounded();
199+
req.chan = Some(tx);
200+
201+
self.send_request(req)?;
202+
Ok(Pipe::new(rx))
203+
}
204+
205+
pub(crate) fn pipe_and_forget<T>(&self, value: T) -> sqlx_core::Result<()>
206+
where
207+
T: FrontendMessage,
208+
{
209+
let req = self.create_request(|buf| buf.write_msg(value))?;
210+
self.send_request(req)
211+
}
212+
213+
pub(crate) async fn start_pipe_async<F, R>(&self, callback: F) -> sqlx_core::Result<(R, Pipe)>
214+
where
215+
F: AsyncFnOnce(&mut MessageBuf) -> sqlx_core::Result<R>,
216+
{
217+
let mut buffer = MessageBuf::new();
218+
let result = (callback)(&mut buffer).await?;
219+
let mut req = buffer.finish();
220+
let (tx, rx) = unbounded();
221+
req.chan = Some(tx);
222+
223+
self.send_request(req)?;
224+
225+
Ok((result, Pipe::new(rx)))
226+
}
143227
}
144228

145229
impl Debug for PgConnection {

sqlx-postgres/src/connection/sasl.rs

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
use crate::connection::stream::PgStream;
21
use crate::error::Error;
32
use crate::message::{Authentication, AuthenticationSasl, SaslInitialResponse, SaslResponse};
43
use crate::PgConnectOptions;
@@ -9,14 +8,18 @@ use stringprep::saslprep;
98

109
use base64::prelude::{Engine as _, BASE64_STANDARD};
1110

11+
use super::pipe::Pipe;
12+
use super::PgConnection;
13+
1214
const GS2_HEADER: &str = "n,,";
1315
const CHANNEL_ATTR: &str = "c";
1416
const USERNAME_ATTR: &str = "n";
1517
const CLIENT_PROOF_ATTR: &str = "p";
1618
const NONCE_ATTR: &str = "r";
1719

1820
pub(crate) async fn authenticate(
19-
stream: &mut PgStream,
21+
conn: &PgConnection,
22+
pipe: &mut Pipe,
2023
options: &PgConnectOptions,
2124
data: AuthenticationSasl,
2225
) -> Result<(), Error> {
@@ -67,14 +70,12 @@ pub(crate) async fn authenticate(
6770

6871
let client_first_message = format!("{GS2_HEADER}{client_first_message_bare}");
6972

70-
stream
71-
.send(SaslInitialResponse {
72-
response: &client_first_message,
73-
plus: false,
74-
})
75-
.await?;
73+
conn.pipe_and_forget(SaslInitialResponse {
74+
response: &client_first_message,
75+
plus: false,
76+
})?;
7677

77-
let cont = match stream.recv_expect().await? {
78+
let cont = match pipe.recv_expect().await? {
7879
Authentication::SaslContinue(data) => data,
7980

8081
auth => {
@@ -143,9 +144,9 @@ pub(crate) async fn authenticate(
143144
let mut client_final_message = format!("{client_final_message_wo_proof},{CLIENT_PROOF_ATTR}=");
144145
BASE64_STANDARD.encode_string(client_proof, &mut client_final_message);
145146

146-
stream.send(SaslResponse(&client_final_message)).await?;
147+
conn.pipe_and_forget(SaslResponse(&client_final_message))?;
147148

148-
let data = match stream.recv_expect().await? {
149+
let data = match pipe.recv_expect().await? {
149150
Authentication::SaslFinal(data) => data,
150151

151152
auth => {

sqlx-postgres/src/connection/stream.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ pub struct PgStream {
4141
}
4242

4343
impl PgStream {
44+
pub fn into_inner(self) -> BufferedSocket<Box<dyn Socket>> {
45+
self.inner
46+
}
47+
4448
pub(super) async fn connect(options: &PgConnectOptions) -> Result<Self, Error> {
4549
let socket_result = match options.fetch_socket() {
4650
Some(ref path) => net::connect_uds(path, MaybeUpgradeTls(options)).await?,

sqlx-postgres/src/connection/worker.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use std::{
77
};
88

99
use crate::message::{
10-
BackendMessageFormat, FrontendMessage, Notification, ReadyForQuery, ReceivedMessage, Terminate,
10+
BackendMessageFormat, FrontendMessage, Notification, ReceivedMessage, Terminate,
1111
};
1212
use futures_channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender};
1313
use futures_util::{SinkExt, StreamExt};
@@ -129,7 +129,6 @@ impl Worker {
129129
while let Poll::Ready(response) = self.poll_next_message(cx)? {
130130
match response.format {
131131
BackendMessageFormat::ReadyForQuery => {
132-
let rfq: ReadyForQuery = response.clone().decode()?;
133132
self.send_back(response)?;
134133
// Remove from the backlog so we dont send more responses back.
135134
let _ = self.back_log.pop_front();

0 commit comments

Comments
 (0)