@@ -12,7 +12,7 @@ use crate::ext::async_stream::TryAsyncStream;
12
12
use crate :: io:: AsyncRead ;
13
13
use crate :: message:: {
14
14
BackendMessageFormat , CommandComplete , CopyData , CopyDone , CopyFail , CopyInResponse ,
15
- CopyOutResponse , CopyResponseData , Query , ReadyForQuery ,
15
+ CopyOutResponse , CopyResponseData , ReadyForQuery ,
16
16
} ;
17
17
use crate :: pool:: { Pool , PoolConnection } ;
18
18
use crate :: Postgres ;
@@ -146,13 +146,13 @@ pub struct PgCopyIn<C: DerefMut<Target = PgConnection>> {
146
146
}
147
147
148
148
impl < C : DerefMut < Target = PgConnection > > PgCopyIn < C > {
149
- async fn begin ( mut conn : C , statement : & str ) -> Result < Self > {
150
- conn. inner . stream . send ( Query ( statement) ) . await ?;
149
+ async fn begin ( conn : C , statement : & str ) -> Result < Self > {
150
+ let mut pipe = conn. queue_simple_query ( statement) ?;
151
151
152
- let response = match conn . inner . stream . recv_expect :: < CopyInResponse > ( ) . await {
152
+ let response = match pipe . recv_expect :: < CopyInResponse > ( ) . await {
153
153
Ok ( res) => res. 0 ,
154
154
Err ( e) => {
155
- conn . inner . stream . recv ( ) . await ?;
155
+ pipe . recv_ready_for_query ( ) . await ?;
156
156
return Err ( e) ;
157
157
}
158
158
} ;
@@ -194,13 +194,11 @@ impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
194
194
/// If you're copying data from an `AsyncRead`, maybe consider [Self::read_from] instead.
195
195
pub async fn send ( & mut self , data : impl Deref < Target = [ u8 ] > ) -> Result < & mut Self > {
196
196
for chunk in data. deref ( ) . chunks ( PG_COPY_MAX_DATA_LEN ) {
197
+ // TODO: We should probably have some kind of back-pressure here
197
198
self . conn
198
199
. as_deref_mut ( )
199
200
. expect ( "send_data: conn taken" )
200
- . inner
201
- . stream
202
- . send ( CopyData ( chunk) )
203
- . await ?;
201
+ . pipe_and_forget ( CopyData ( chunk) ) ?;
204
202
}
205
203
206
204
Ok ( self )
@@ -223,26 +221,31 @@ impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
223
221
pub async fn read_from ( & mut self , mut source : impl AsyncRead + Unpin ) -> Result < & mut Self > {
224
222
let conn: & mut PgConnection = self . conn . as_deref_mut ( ) . expect ( "copy_from: conn taken" ) ;
225
223
loop {
226
- let buf = conn. inner . stream . write_buffer_mut ( ) ;
224
+ let read = conn
225
+ . pipe_and_forget_async ( async |buf| {
226
+ let write_buf = buf. buf_mut ( ) ;
227
227
228
- // Write the CopyData format code and reserve space for the length.
229
- // This may end up sending an empty `CopyData` packet if, after this point,
230
- // we get canceled or read 0 bytes, but that should be fine.
231
- buf . put_slice ( b"d\0 \0 \0 \x04 " ) ;
228
+ // Write the CopyData format code and reserve space for the length.
229
+ // This may end up sending an empty `CopyData` packet if, after this point,
230
+ // we get canceled or read 0 bytes, but that should be fine.
231
+ write_buf . put_slice ( b"d\0 \0 \0 \x04 " ) ;
232
232
233
- let read = buf . read_from ( & mut source) . await ?;
233
+ let read = sqlx_core :: io :: read_from ( & mut source, write_buf ) . await ?;
234
234
235
- if read == 0 {
236
- break ;
237
- }
235
+ // Write the length
236
+ let read32 = i32:: try_from ( read) . map_err ( |_| {
237
+ err_protocol ! ( "number of bytes read exceeds 2^31 - 1: {}" , read)
238
+ } ) ?;
238
239
239
- // Write the length
240
- let read32 = i32:: try_from ( read)
241
- . map_err ( |_| err_protocol ! ( "number of bytes read exceeds 2^31 - 1: {}" , read) ) ?;
240
+ ( & mut write_buf[ 1 ..] ) . put_i32 ( read32 + 4 ) ;
242
241
243
- ( & mut buf. get_mut ( ) [ 1 ..] ) . put_i32 ( read32 + 4 ) ;
242
+ Ok ( read32)
243
+ } )
244
+ . await ?;
244
245
245
- conn. inner . stream . flush ( ) . await ?;
246
+ if read == 0 {
247
+ break ;
248
+ }
246
249
}
247
250
248
251
Ok ( self )
@@ -254,14 +257,14 @@ impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
254
257
///
255
258
/// The server is expected to respond with an error, so only _unexpected_ errors are returned.
256
259
pub async fn abort ( mut self , msg : impl Into < String > ) -> Result < ( ) > {
257
- let mut conn = self
260
+ let conn = self
258
261
. conn
259
262
. take ( )
260
263
. expect ( "PgCopyIn::fail_with: conn taken illegally" ) ;
261
264
262
- conn. inner . stream . send ( CopyFail :: new ( msg) ) . await ?;
265
+ let mut pipe = conn. pipe ( |buf| buf . write_msg ( CopyFail :: new ( msg) ) ) ?;
263
266
264
- match conn . inner . stream . recv ( ) . await {
267
+ match pipe . recv ( ) . await {
265
268
Ok ( msg) => Err ( err_protocol ! (
266
269
"fail_with: expected ErrorResponse, got: {:?}" ,
267
270
msg. format
@@ -270,7 +273,7 @@ impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
270
273
match e. code ( ) {
271
274
Some ( Cow :: Borrowed ( "57014" ) ) => {
272
275
// postgres abort received error code
273
- conn . inner . stream . recv_expect :: < ReadyForQuery > ( ) . await ?;
276
+ pipe . recv_expect :: < ReadyForQuery > ( ) . await ?;
274
277
Ok ( ( ) )
275
278
}
276
279
_ => Err ( Error :: Database ( e) ) ,
@@ -284,60 +287,58 @@ impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
284
287
///
285
288
/// The number of rows affected is returned.
286
289
pub async fn finish ( mut self ) -> Result < u64 > {
287
- let mut conn = self
290
+ let conn = self
288
291
. conn
289
292
. take ( )
290
293
. expect ( "CopyWriter::finish: conn taken illegally" ) ;
291
294
292
- conn. inner . stream . send ( CopyDone ) . await ?;
293
- let cc: CommandComplete = match conn . inner . stream . recv_expect ( ) . await {
295
+ let mut pipe = conn. pipe ( |buf| buf . write_msg ( CopyDone ) ) ?;
296
+ let cc: CommandComplete = match pipe . recv_expect ( ) . await {
294
297
Ok ( cc) => cc,
295
298
Err ( e) => {
296
- conn . inner . stream . recv ( ) . await ?;
299
+ pipe . recv ( ) . await ?;
297
300
return Err ( e) ;
298
301
}
299
302
} ;
300
303
301
- conn . inner . stream . recv_expect :: < ReadyForQuery > ( ) . await ?;
304
+ pipe . recv_expect :: < ReadyForQuery > ( ) . await ?;
302
305
303
306
Ok ( cc. rows_affected ( ) )
304
307
}
305
308
}
306
309
307
310
impl < C : DerefMut < Target = PgConnection > > Drop for PgCopyIn < C > {
308
311
fn drop ( & mut self ) {
309
- if let Some ( mut conn) = self . conn . take ( ) {
310
- conn. inner
311
- . stream
312
- . write_msg ( CopyFail :: new (
313
- "PgCopyIn dropped without calling finish() or fail()" ,
314
- ) )
315
- . expect ( "BUG: PgCopyIn abort message should not be too large" ) ;
312
+ if let Some ( conn) = self . conn . take ( ) {
313
+ conn. pipe_and_forget ( CopyFail :: new (
314
+ "PgCopyIn dropped without calling finish() or fail()" ,
315
+ ) )
316
+ . expect ( "BUG: could not send PgCopyIn to background worker" ) ;
316
317
}
317
318
}
318
319
}
319
320
320
321
async fn pg_begin_copy_out < ' c , C : DerefMut < Target = PgConnection > + Send + ' c > (
321
- mut conn : C ,
322
+ conn : C ,
322
323
statement : & str ,
323
324
) -> Result < BoxStream < ' c , Result < Bytes > > > {
324
- conn. inner . stream . send ( Query ( statement) ) . await ?;
325
+ let mut pipe = conn. queue_simple_query ( statement) ?;
325
326
326
- let _: CopyOutResponse = conn . inner . stream . recv_expect ( ) . await ?;
327
+ let _: CopyOutResponse = pipe . recv_expect ( ) . await ?;
327
328
328
329
let stream: TryAsyncStream < ' c , Bytes > = try_stream ! {
329
330
loop {
330
- match conn . inner . stream . recv( ) . await {
331
+ match pipe . recv( ) . await {
331
332
Err ( e) => {
332
- conn . inner . stream . recv_expect:: <ReadyForQuery >( ) . await ?;
333
+ pipe . recv_expect:: <ReadyForQuery >( ) . await ?;
333
334
return Err ( e) ;
334
335
} ,
335
336
Ok ( msg) => match msg. format {
336
337
BackendMessageFormat :: CopyData => r#yield!( msg. decode:: <CopyData <Bytes >>( ) ?. 0 ) ,
337
338
BackendMessageFormat :: CopyDone => {
338
339
let _ = msg. decode:: <CopyDone >( ) ?;
339
- conn . inner . stream . recv_expect:: <CommandComplete >( ) . await ?;
340
- conn . inner . stream . recv_expect:: <ReadyForQuery >( ) . await ?;
340
+ pipe . recv_expect:: <CommandComplete >( ) . await ?;
341
+ pipe . recv_expect:: <ReadyForQuery >( ) . await ?;
341
342
return Ok ( ( ) )
342
343
} ,
343
344
_ => return Err ( err_protocol!( "unexpected message format during copy out: {:?}" , msg. format) )
0 commit comments