Skip to content

Commit 850df71

Browse files
committed
sqlx-postgres: Add worker
1 parent 57eb13a commit 850df71

File tree

3 files changed

+233
-2
lines changed

3 files changed

+233
-2
lines changed

sqlx-postgres/src/connection/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ mod request;
3131
mod sasl;
3232
mod stream;
3333
mod tls;
34+
mod worker;
3435

3536
/// A connection to a PostgreSQL database.
3637
///
Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
use std::{
2+
collections::VecDeque,
3+
future::Future,
4+
ops::ControlFlow,
5+
pin::Pin,
6+
task::{ready, Context, Poll},
7+
};
8+
9+
use crate::message::{
10+
BackendMessageFormat, FrontendMessage, Notification, ReadyForQuery, ReceivedMessage, Terminate,
11+
};
12+
use futures_channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender};
13+
use futures_util::{SinkExt, StreamExt};
14+
use sqlx_core::{
15+
bytes::Buf,
16+
net::{BufferedSocket, Socket},
17+
rt::spawn,
18+
Result,
19+
};
20+
21+
use super::request::IoRequest;
22+
23+
#[derive(PartialEq, Debug)]
24+
enum WorkerState {
25+
// The connection is open and ready for business.
26+
Open,
27+
// Sent/sending a [Terminate] message but did not close the socket. Responding to the last
28+
// messages but not receiving new ones.
29+
Closing,
30+
// The connection is terminated, this step closes the socket and stops the background task.
31+
Closed,
32+
}
33+
34+
pub struct Worker {
35+
state: WorkerState,
36+
should_flush: bool,
37+
chan: UnboundedReceiver<IoRequest>,
38+
back_log: VecDeque<UnboundedSender<ReceivedMessage>>,
39+
socket: BufferedSocket<Box<dyn Socket>>,
40+
notif_chan: UnboundedSender<Notification>,
41+
}
42+
43+
impl Worker {
44+
pub fn spawn(
45+
socket: BufferedSocket<Box<dyn Socket>>,
46+
notif_chan: UnboundedSender<Notification>,
47+
) -> UnboundedSender<IoRequest> {
48+
let (tx, rx) = unbounded();
49+
50+
let worker = Worker {
51+
state: WorkerState::Open,
52+
should_flush: false,
53+
chan: rx,
54+
back_log: VecDeque::new(),
55+
socket,
56+
notif_chan,
57+
};
58+
59+
spawn(worker);
60+
tx
61+
}
62+
63+
// Tries to receive the next message from the channel. Also handles termination if needed.
64+
#[inline(always)]
65+
fn poll_next_request(&mut self, cx: &mut Context<'_>) -> Poll<IoRequest> {
66+
if self.state != WorkerState::Open {
67+
return Poll::Pending;
68+
}
69+
70+
match self.chan.poll_next_unpin(cx) {
71+
Poll::Pending => Poll::Pending,
72+
Poll::Ready(Some(request)) => Poll::Ready(request),
73+
Poll::Ready(None) => {
74+
// Channel was closed, explicitly or because the sender was dropped. Either way
75+
// we should start a gracefull shutdown.
76+
self.socket
77+
.write_buffer_mut()
78+
.put_slice(&[Terminate::FORMAT as u8, 0, 0, 0, 4]);
79+
80+
self.state = WorkerState::Closing;
81+
self.should_flush = true;
82+
Poll::Pending
83+
}
84+
}
85+
}
86+
87+
#[inline(always)]
88+
fn poll_receiver(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
89+
if self.state != WorkerState::Open {
90+
return Poll::Ready(Ok(()));
91+
}
92+
93+
loop {
94+
ready!(self.socket.poll_ready_unpin(cx))?;
95+
96+
let request = ready!(self.poll_next_request(cx));
97+
98+
self.socket.start_send_unpin(&request.data)?;
99+
self.should_flush = true;
100+
101+
if let Some(chan) = request.chan {
102+
// We should send the responses back
103+
self.back_log.push_back(chan);
104+
} else {
105+
}
106+
}
107+
}
108+
109+
#[inline(always)]
110+
fn handle_poll_flush(&mut self, cx: &mut Context<'_>) -> Result<()> {
111+
if self.should_flush && self.socket.poll_flush_unpin(cx).is_ready() {
112+
self.should_flush = false;
113+
}
114+
Ok(())
115+
}
116+
117+
#[inline(always)]
118+
fn send_back(&mut self, response: ReceivedMessage) -> Result<()> {
119+
if let Some(chan) = self.back_log.front_mut() {
120+
let _ = chan.unbounded_send(response);
121+
Ok(())
122+
} else {
123+
Err(err_protocol!("Received response but did not expect one."))
124+
}
125+
}
126+
127+
#[inline(always)]
128+
fn poll_backlog(&mut self, cx: &mut Context<'_>) -> Result<()> {
129+
while let Poll::Ready(response) = self.poll_next_message(cx)? {
130+
match response.format {
131+
BackendMessageFormat::ReadyForQuery => {
132+
let rfq: ReadyForQuery = response.clone().decode()?;
133+
self.send_back(response)?;
134+
// Remove from the backlog so we dont send more responses back.
135+
let _ = self.back_log.pop_front();
136+
}
137+
BackendMessageFormat::CopyInResponse => {
138+
// End of response
139+
self.send_back(response)?;
140+
// Remove from the backlog so we dont send more responses back.
141+
let _ = self.back_log.pop_front();
142+
}
143+
BackendMessageFormat::NotificationResponse => {
144+
// Notification
145+
let notif: Notification = response.decode()?;
146+
let _ = self.notif_chan.unbounded_send(notif);
147+
}
148+
BackendMessageFormat::ParameterStatus => {
149+
// Asynchronous response - todo
150+
}
151+
BackendMessageFormat::NoticeResponse => {
152+
// Asynchronous response - todo
153+
}
154+
_ => self.send_back(response)?,
155+
}
156+
}
157+
158+
if self.state != WorkerState::Open && self.back_log.is_empty() {
159+
self.state = WorkerState::Closed;
160+
}
161+
Ok(())
162+
}
163+
164+
#[inline(always)]
165+
fn poll_next_message(&mut self, cx: &mut Context<'_>) -> Poll<Result<ReceivedMessage>> {
166+
self.socket.poll_try_read(cx, |buf| {
167+
// all packets in postgres start with a 5-byte header
168+
// this header contains the message type and the total length of the message
169+
let Some(mut header) = buf.get(..5) else {
170+
return Ok(ControlFlow::Continue(5));
171+
};
172+
173+
let format = BackendMessageFormat::try_from_u8(header.get_u8())?;
174+
175+
let message_len = header.get_u32() as usize;
176+
177+
let expected_len = message_len
178+
.checked_add(1)
179+
// this shouldn't really happen but is mostly a sanity check
180+
.ok_or_else(|| err_protocol!("message_len + 1 overflows usize: {message_len}"))?;
181+
182+
if buf.len() < expected_len {
183+
return Ok(ControlFlow::Continue(expected_len));
184+
}
185+
186+
// `buf` SHOULD NOT be modified ABOVE this line
187+
188+
// pop off the format code since it's not counted in `message_len`
189+
buf.advance(1);
190+
191+
// consume the message, including the length prefix
192+
let mut contents = buf.split_to(message_len).freeze();
193+
194+
// cut off the length prefix
195+
contents.advance(4);
196+
197+
Ok(ControlFlow::Break(ReceivedMessage { format, contents }))
198+
})
199+
}
200+
201+
#[inline(always)]
202+
fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
203+
if self.state == WorkerState::Closed {
204+
// The buffer is closed, a [Terminate] message has been sent, now try and close the socket.
205+
self.socket.poll_close_unpin(cx)
206+
} else {
207+
Poll::Pending
208+
}
209+
}
210+
}
211+
212+
impl Future for Worker {
213+
type Output = Result<()>;
214+
215+
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
216+
// Try to receive responses from the database and handle them.
217+
self.poll_backlog(cx)?;
218+
219+
// Push as many new requests in the write buffer as we can.
220+
if let Poll::Ready(Err(e)) = self.poll_receiver(cx) {
221+
return Poll::Ready(Err(e));
222+
};
223+
224+
// Flush the write buffer if needed.
225+
self.handle_poll_flush(cx)?;
226+
227+
// Close this socket if we're done.
228+
self.poll_shutdown(cx)
229+
}
230+
}

sqlx-postgres/src/message/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ pub enum FrontendMessageFormat {
8787
Terminate = b'X',
8888
}
8989

90-
#[derive(Debug, PartialOrd, PartialEq)]
90+
#[derive(Debug, PartialOrd, PartialEq, Clone)]
9191
#[repr(u8)]
9292
pub enum BackendMessageFormat {
9393
Authentication,
@@ -113,7 +113,7 @@ pub enum BackendMessageFormat {
113113
RowDescription,
114114
}
115115

116-
#[derive(Debug)]
116+
#[derive(Debug, Clone)]
117117
pub struct ReceivedMessage {
118118
pub format: BackendMessageFormat,
119119
pub contents: Bytes,

0 commit comments

Comments
 (0)