Skip to content

Commit 6f27598

Browse files
committed
sqlx-postgres: Update copy
1 parent 2923e6f commit 6f27598

File tree

5 files changed

+70
-62
lines changed

5 files changed

+70
-62
lines changed

sqlx-core/src/io/mod.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,19 @@ pub use futures_util::io::AsyncReadExt;
2424

2525
#[cfg(feature = "_rt-tokio")]
2626
pub use tokio::io::AsyncReadExt;
27+
28+
pub async fn read_from(
29+
mut source: impl AsyncRead + Unpin,
30+
data: &mut Vec<u8>,
31+
) -> std::io::Result<usize> {
32+
match () {
33+
// Tokio lets us read into the buffer without zeroing first
34+
#[cfg(feature = "_rt-tokio")]
35+
_ => source.read_buf(data).await,
36+
#[cfg(not(feature = "_rt-tokio"))]
37+
_ => {
38+
data.resize(data.capacity(), 0);
39+
source.read(data).await
40+
}
41+
}
42+
}

sqlx-core/src/net/socket/buffered.rs

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use std::pin::Pin;
77
use std::task::{ready, Context, Poll};
88
use std::{cmp, io};
99

10-
use crate::io::{AsyncRead, AsyncReadExt, ProtocolDecode, ProtocolEncode};
10+
use crate::io::{read_from, AsyncRead, ProtocolDecode, ProtocolEncode};
1111

1212
// Tokio, async-std, and std all use this as the default capacity for their buffered I/O.
1313
const DEFAULT_BUF_SIZE: usize = 8192;
@@ -261,14 +261,8 @@ impl WriteBuffer {
261261
/// Read into the buffer from `source`, returning the number of bytes read.
262262
///
263263
/// The buffer is automatically advanced by the number of bytes read.
264-
pub async fn read_from(&mut self, mut source: impl AsyncRead + Unpin) -> io::Result<usize> {
265-
let read = match () {
266-
// Tokio lets us read into the buffer without zeroing first
267-
#[cfg(feature = "_rt-tokio")]
268-
_ => source.read_buf(self.buf_mut()).await?,
269-
#[cfg(not(feature = "_rt-tokio"))]
270-
_ => source.read(self.init_remaining_mut()).await?,
271-
};
264+
pub async fn read_from(&mut self, source: impl AsyncRead + Unpin) -> io::Result<usize> {
265+
let read = read_from(source, self.buf_mut()).await?;
272266

273267
if read > 0 {
274268
self.advance(read);

sqlx-postgres/src/connection/mod.rs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -175,19 +175,16 @@ impl PgConnection {
175175
self.send_request(req)
176176
}
177177

178-
pub(crate) async fn start_pipe_async<F, R>(&self, callback: F) -> sqlx_core::Result<(R, Pipe)>
178+
pub(crate) async fn pipe_and_forget_async<F, R>(&self, callback: F) -> sqlx_core::Result<R>
179179
where
180180
F: AsyncFnOnce(&mut MessageBuf) -> sqlx_core::Result<R>,
181181
{
182182
let mut buffer = MessageBuf::new();
183183
let result = (callback)(&mut buffer).await?;
184-
let mut req = buffer.finish();
185-
let (tx, rx) = unbounded();
186-
req.chan = Some(tx);
187-
184+
let req = buffer.finish();
188185
self.send_request(req)?;
189186

190-
Ok((result, Pipe::new(rx)))
187+
Ok(result)
191188
}
192189
}
193190

sqlx-postgres/src/connection/worker.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ impl Worker {
152152
let _ = self.notif_chan.unbounded_send(notif);
153153
}
154154
BackendMessageFormat::ParameterStatus => {
155-
// Asynchronous response - todo
155+
// Asynchronous response
156156
//
157157
let ParameterStatus { name, value } = response.decode()?;
158158
self.shared.insert_parameter_status(name, value);

sqlx-postgres/src/copy.rs

Lines changed: 47 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use crate::ext::async_stream::TryAsyncStream;
1212
use crate::io::AsyncRead;
1313
use crate::message::{
1414
BackendMessageFormat, CommandComplete, CopyData, CopyDone, CopyFail, CopyInResponse,
15-
CopyOutResponse, CopyResponseData, Query, ReadyForQuery,
15+
CopyOutResponse, CopyResponseData, ReadyForQuery,
1616
};
1717
use crate::pool::{Pool, PoolConnection};
1818
use crate::Postgres;
@@ -146,13 +146,13 @@ pub struct PgCopyIn<C: DerefMut<Target = PgConnection>> {
146146
}
147147

148148
impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
149-
async fn begin(mut conn: C, statement: &str) -> Result<Self> {
150-
conn.inner.stream.send(Query(statement)).await?;
149+
async fn begin(conn: C, statement: &str) -> Result<Self> {
150+
let mut pipe = conn.queue_simple_query(statement)?;
151151

152-
let response = match conn.inner.stream.recv_expect::<CopyInResponse>().await {
152+
let response = match pipe.recv_expect::<CopyInResponse>().await {
153153
Ok(res) => res.0,
154154
Err(e) => {
155-
conn.inner.stream.recv().await?;
155+
pipe.recv_ready_for_query().await?;
156156
return Err(e);
157157
}
158158
};
@@ -194,13 +194,11 @@ impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
194194
/// If you're copying data from an `AsyncRead`, maybe consider [Self::read_from] instead.
195195
pub async fn send(&mut self, data: impl Deref<Target = [u8]>) -> Result<&mut Self> {
196196
for chunk in data.deref().chunks(PG_COPY_MAX_DATA_LEN) {
197+
// TODO: We should probably have some kind of back-pressure here
197198
self.conn
198199
.as_deref_mut()
199200
.expect("send_data: conn taken")
200-
.inner
201-
.stream
202-
.send(CopyData(chunk))
203-
.await?;
201+
.pipe_and_forget(CopyData(chunk))?;
204202
}
205203

206204
Ok(self)
@@ -223,26 +221,31 @@ impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
223221
pub async fn read_from(&mut self, mut source: impl AsyncRead + Unpin) -> Result<&mut Self> {
224222
let conn: &mut PgConnection = self.conn.as_deref_mut().expect("copy_from: conn taken");
225223
loop {
226-
let buf = conn.inner.stream.write_buffer_mut();
224+
let read = conn
225+
.pipe_and_forget_async(async |buf| {
226+
let write_buf = buf.buf_mut();
227227

228-
// Write the CopyData format code and reserve space for the length.
229-
// This may end up sending an empty `CopyData` packet if, after this point,
230-
// we get canceled or read 0 bytes, but that should be fine.
231-
buf.put_slice(b"d\0\0\0\x04");
228+
// Write the CopyData format code and reserve space for the length.
229+
// This may end up sending an empty `CopyData` packet if, after this point,
230+
// we get canceled or read 0 bytes, but that should be fine.
231+
write_buf.put_slice(b"d\0\0\0\x04");
232232

233-
let read = buf.read_from(&mut source).await?;
233+
let read = sqlx_core::io::read_from(&mut source, write_buf).await?;
234234

235-
if read == 0 {
236-
break;
237-
}
235+
// Write the length
236+
let read32 = i32::try_from(read).map_err(|_| {
237+
err_protocol!("number of bytes read exceeds 2^31 - 1: {}", read)
238+
})?;
238239

239-
// Write the length
240-
let read32 = i32::try_from(read)
241-
.map_err(|_| err_protocol!("number of bytes read exceeds 2^31 - 1: {}", read))?;
240+
(&mut write_buf[1..]).put_i32(read32 + 4);
242241

243-
(&mut buf.get_mut()[1..]).put_i32(read32 + 4);
242+
Ok(read32)
243+
})
244+
.await?;
244245

245-
conn.inner.stream.flush().await?;
246+
if read == 0 {
247+
break;
248+
}
246249
}
247250

248251
Ok(self)
@@ -254,14 +257,14 @@ impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
254257
///
255258
/// The server is expected to respond with an error, so only _unexpected_ errors are returned.
256259
pub async fn abort(mut self, msg: impl Into<String>) -> Result<()> {
257-
let mut conn = self
260+
let conn = self
258261
.conn
259262
.take()
260263
.expect("PgCopyIn::fail_with: conn taken illegally");
261264

262-
conn.inner.stream.send(CopyFail::new(msg)).await?;
265+
let mut pipe = conn.pipe(|buf| buf.write_msg(CopyFail::new(msg)))?;
263266

264-
match conn.inner.stream.recv().await {
267+
match pipe.recv().await {
265268
Ok(msg) => Err(err_protocol!(
266269
"fail_with: expected ErrorResponse, got: {:?}",
267270
msg.format
@@ -270,7 +273,7 @@ impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
270273
match e.code() {
271274
Some(Cow::Borrowed("57014")) => {
272275
// postgres abort received error code
273-
conn.inner.stream.recv_expect::<ReadyForQuery>().await?;
276+
pipe.recv_expect::<ReadyForQuery>().await?;
274277
Ok(())
275278
}
276279
_ => Err(Error::Database(e)),
@@ -284,60 +287,58 @@ impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
284287
///
285288
/// The number of rows affected is returned.
286289
pub async fn finish(mut self) -> Result<u64> {
287-
let mut conn = self
290+
let conn = self
288291
.conn
289292
.take()
290293
.expect("CopyWriter::finish: conn taken illegally");
291294

292-
conn.inner.stream.send(CopyDone).await?;
293-
let cc: CommandComplete = match conn.inner.stream.recv_expect().await {
295+
let mut pipe = conn.pipe(|buf| buf.write_msg(CopyDone))?;
296+
let cc: CommandComplete = match pipe.recv_expect().await {
294297
Ok(cc) => cc,
295298
Err(e) => {
296-
conn.inner.stream.recv().await?;
299+
pipe.recv().await?;
297300
return Err(e);
298301
}
299302
};
300303

301-
conn.inner.stream.recv_expect::<ReadyForQuery>().await?;
304+
pipe.recv_expect::<ReadyForQuery>().await?;
302305

303306
Ok(cc.rows_affected())
304307
}
305308
}
306309

307310
impl<C: DerefMut<Target = PgConnection>> Drop for PgCopyIn<C> {
308311
fn drop(&mut self) {
309-
if let Some(mut conn) = self.conn.take() {
310-
conn.inner
311-
.stream
312-
.write_msg(CopyFail::new(
313-
"PgCopyIn dropped without calling finish() or fail()",
314-
))
315-
.expect("BUG: PgCopyIn abort message should not be too large");
312+
if let Some(conn) = self.conn.take() {
313+
conn.pipe_and_forget(CopyFail::new(
314+
"PgCopyIn dropped without calling finish() or fail()",
315+
))
316+
.expect("BUG: could not send PgCopyIn to background worker");
316317
}
317318
}
318319
}
319320

320321
async fn pg_begin_copy_out<'c, C: DerefMut<Target = PgConnection> + Send + 'c>(
321-
mut conn: C,
322+
conn: C,
322323
statement: &str,
323324
) -> Result<BoxStream<'c, Result<Bytes>>> {
324-
conn.inner.stream.send(Query(statement)).await?;
325+
let mut pipe = conn.queue_simple_query(statement)?;
325326

326-
let _: CopyOutResponse = conn.inner.stream.recv_expect().await?;
327+
let _: CopyOutResponse = pipe.recv_expect().await?;
327328

328329
let stream: TryAsyncStream<'c, Bytes> = try_stream! {
329330
loop {
330-
match conn.inner.stream.recv().await {
331+
match pipe.recv().await {
331332
Err(e) => {
332-
conn.inner.stream.recv_expect::<ReadyForQuery>().await?;
333+
pipe.recv_expect::<ReadyForQuery>().await?;
333334
return Err(e);
334335
},
335336
Ok(msg) => match msg.format {
336337
BackendMessageFormat::CopyData => r#yield!(msg.decode::<CopyData<Bytes>>()?.0),
337338
BackendMessageFormat::CopyDone => {
338339
let _ = msg.decode::<CopyDone>()?;
339-
conn.inner.stream.recv_expect::<CommandComplete>().await?;
340-
conn.inner.stream.recv_expect::<ReadyForQuery>().await?;
340+
pipe.recv_expect::<CommandComplete>().await?;
341+
pipe.recv_expect::<ReadyForQuery>().await?;
341342
return Ok(())
342343
},
343344
_ => return Err(err_protocol!("unexpected message format during copy out: {:?}", msg.format))

0 commit comments

Comments
 (0)