Skip to content

Commit 1064baa

Browse files
committed
sqlx-core: Implement poll methods and Sink for BufferedSocket
1 parent 5fee4e5 commit 1064baa

File tree

2 files changed

+93
-3
lines changed

2 files changed

+93
-3
lines changed

sqlx-core/src/net/socket/buffered.rs

Lines changed: 92 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
use crate::error::Error;
2-
use crate::net::Socket;
2+
use crate::net::{Socket, SocketExt};
33
use bytes::BytesMut;
4+
use futures_util::Sink;
45
use std::ops::ControlFlow;
6+
use std::pin::Pin;
7+
use std::task::{ready, Context, Poll};
58
use std::{cmp, io};
69

710
use crate::io::{AsyncRead, AsyncReadExt, ProtocolDecode, ProtocolEncode};
@@ -13,6 +16,7 @@ pub struct BufferedSocket<S> {
1316
socket: S,
1417
write_buf: WriteBuffer,
1518
read_buf: ReadBuffer,
19+
wants_bytes: usize,
1620
}
1721

1822
pub struct WriteBuffer {
@@ -42,6 +46,7 @@ impl<S: Socket> BufferedSocket<S> {
4246
read: BytesMut::new(),
4347
available: BytesMut::with_capacity(DEFAULT_BUF_SIZE),
4448
},
49+
wants_bytes: 0,
4550
}
4651
}
4752

@@ -56,6 +61,25 @@ impl<S: Socket> BufferedSocket<S> {
5661
.await
5762
}
5863

64+
pub fn poll_try_read<F, R>(
65+
&mut self,
66+
cx: &mut Context<'_>,
67+
mut try_read: F,
68+
) -> Poll<Result<R, Error>>
69+
where
70+
F: FnMut(&mut BytesMut) -> Result<ControlFlow<R, usize>, Error>,
71+
{
72+
loop {
73+
// Read if we want bytes
74+
ready!(self.poll_handle_read(cx)?);
75+
76+
match try_read(&mut self.read_buf.read)? {
77+
ControlFlow::Continue(read_len) => self.wants_bytes = read_len,
78+
ControlFlow::Break(ret) => return Poll::Ready(Ok(ret)),
79+
};
80+
}
81+
}
82+
5983
/// Retryable read operation.
6084
///
6185
/// The callback should check the contents of the buffer passed to it and either:
@@ -125,8 +149,8 @@ impl<S: Socket> BufferedSocket<S> {
125149
pub async fn flush(&mut self) -> io::Result<()> {
126150
while !self.write_buf.is_empty() {
127151
let written = self.socket.write(self.write_buf.get()).await?;
152+
// Consume does the sanity check
128153
self.write_buf.consume(written);
129-
self.write_buf.sanity_check();
130154
}
131155

132156
self.socket.flush().await?;
@@ -154,8 +178,37 @@ impl<S: Socket> BufferedSocket<S> {
154178
socket: Box::new(self.socket),
155179
write_buf: self.write_buf,
156180
read_buf: self.read_buf,
181+
wants_bytes: self.wants_bytes,
157182
}
158183
}
184+
185+
fn poll_handle_read(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
186+
// Because of how `BytesMut` works, we should only be shifting capacity back and forth
187+
// between `read` and `available` unless we have to read an oversize message.
188+
while self.read_buf.len() < self.wants_bytes {
189+
self.read_buf
190+
.reserve(self.wants_bytes - self.read_buf.len());
191+
192+
let read = ready!(self.socket.poll_read(cx, &mut self.read_buf.available)?);
193+
194+
if read == 0 {
195+
return Poll::Ready(Err(io::Error::new(
196+
io::ErrorKind::UnexpectedEof,
197+
format!(
198+
"expected to read {} bytes, got {} bytes at EOF",
199+
self.wants_bytes,
200+
self.read_buf.len()
201+
),
202+
)));
203+
}
204+
205+
self.read_buf.advance(read);
206+
}
207+
208+
// we've read at least enough for `wants_bytes`, so we don't want more.
209+
self.wants_bytes = 0;
210+
Poll::Ready(Ok(()))
211+
}
159212
}
160213

161214
impl WriteBuffer {
@@ -326,4 +379,41 @@ impl ReadBuffer {
326379
self.available = BytesMut::with_capacity(DEFAULT_BUF_SIZE);
327380
}
328381
}
382+
383+
fn len(&self) -> usize {
384+
self.read.len()
385+
}
386+
}
387+
388+
impl<S: Socket> Sink<&[u8]> for BufferedSocket<S> {
389+
type Error = crate::Error;
390+
391+
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<crate::Result<()>> {
392+
if self.write_buf.bytes_written >= DEFAULT_BUF_SIZE {
393+
self.poll_flush(cx)
394+
} else {
395+
Poll::Ready(Ok(()))
396+
}
397+
}
398+
399+
fn start_send(mut self: Pin<&mut Self>, item: &[u8]) -> crate::Result<()> {
400+
self.write_buffer_mut().put_slice(item);
401+
Ok(())
402+
}
403+
404+
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<crate::Result<()>> {
405+
let this = &mut *self;
406+
407+
while !this.write_buf.is_empty() {
408+
let written = ready!(this.socket.poll_write(cx, this.write_buf.get())?);
409+
// Consume does the sanity check
410+
this.write_buf.consume(written);
411+
}
412+
this.socket.poll_flush(cx).map_err(Into::into)
413+
}
414+
415+
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<crate::Result<()>> {
416+
ready!(self.as_mut().poll_flush(cx))?;
417+
self.socket.poll_shutdown(cx).map_err(Into::into)
418+
}
329419
}

sqlx-postgres/src/connection/tls.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use crate::error::Error;
22
use crate::net::tls::{self, TlsConfig};
3-
use crate::net::{Socket, SocketIntoBox, WithSocket};
3+
use crate::net::{Socket, SocketExt, SocketIntoBox, WithSocket};
44

55
use crate::message::SslRequest;
66
use crate::{PgConnectOptions, PgSslMode};

0 commit comments

Comments
 (0)