Skip to content

Commit a34c7d1

Browse files
committed
feat: Allow for async row streams to be Send
1 parent f246df6 commit a34c7d1

File tree

6 files changed

+99
-10
lines changed

6 files changed

+99
-10
lines changed

Cargo.lock

Lines changed: 41 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

odbc-api/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,8 @@ tempfile = "3.21.0"
120120
criterion = { version = "0.7.0", features = ["html_reports"] }
121121
tokio = { version = "1.47.1", features = ["rt", "macros", "time"] }
122122
stdext = "0.3.3" # Used for function_name macro to generate unique table names for tests
123+
tokio-stream = "0.1.17"
124+
async-stream = "0.3.6"
123125

124126

125127
[[bench]]

odbc-api/src/execute.rs

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::mem::transmute;
22

33
use crate::{
44
CursorImpl, CursorPolling, Error, ParameterCollectionRef, Sleep,
5-
handles::{AsStatementRef, SqlText, Statement},
5+
handles::{AsStatementRef, SqlText, Statement, StatementRef},
66
parameter::Blob,
77
sleep::wait_for,
88
};
@@ -165,10 +165,7 @@ where
165165
if need_data {
166166
// Check if any delayed parameters have been bound which stream data to the database at
167167
// statement execution time. Loops over each bound stream.
168-
while let Some(blob_ptr) = stmt.param_data().into_result(&stmt)? {
169-
// The safe interfaces currently exclusively bind pointers to `Blob` trait objects
170-
let blob_ptr: *mut &mut dyn Blob = transmute(blob_ptr);
171-
let blob_ref = &mut *blob_ptr;
168+
while let Some(blob_ref) = next_blob_param(&mut stmt)? {
172169
// Loop over all batches within each blob
173170
while let Some(batch) = blob_ref.next_batch().map_err(Error::FailedReadingInput)? {
174171
let result = wait_for(|| stmt.put_binary_batch(batch), &mut sleep).await;
@@ -191,6 +188,20 @@ where
191188
}
192189
}
193190

191+
unsafe fn next_blob_param<'a>(
192+
stmt: &mut StatementRef<'a>,
193+
) -> Result<Option<&'a mut dyn Blob>, Error> {
194+
let maybe_ptr = stmt.param_data().into_result(stmt)?;
195+
if let Some(blob_ptr) = maybe_ptr {
196+
// The safe interfaces currently exclusively bind pointers to `Blob` trait objects
197+
let blob_ptr: *mut &mut dyn Blob = unsafe { std::mem::transmute(blob_ptr) };
198+
let blob_ref = unsafe { &mut *blob_ptr };
199+
Ok(Some(*blob_ref))
200+
} else {
201+
Ok(None)
202+
}
203+
}
204+
194205
/// Shared implementation for executing a columns query between [`crate::Connection`] and
195206
/// [`crate::Preallocated`].
196207
pub fn execute_columns<S>(

odbc-api/src/parameter.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@
244244
//! fn insert_image_to_db(
245245
//! conn: &Connection<'_>,
246246
//! id: &str,
247-
//! image_data: impl BufRead) -> Result<(), Error>
247+
//! image_data: impl BufRead + Send) -> Result<(), Error>
248248
//! {
249249
//! const MAX_IMAGE_SIZE: usize = 4 * 1024 * 1024;
250250
//! let mut blob = BlobRead::with_upper_bound(image_data, MAX_IMAGE_SIZE);

odbc-api/src/parameter/blob.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@ use std::{
1212
path::Path,
1313
};
1414

15-
/// A `Blob` can stream its contents to the database batch by batch and may therefore be used to
15+
/// A [`Blob`] can stream its contents to the database batch by batch and may therefore be used to
1616
/// transfer large amounts of data, exceeding the drivers capabilities for normal input parameters.
1717
///
1818
/// # Safety
1919
///
2020
/// If a hint is implemented for `blob_size` it must be accurate before the first call to
2121
/// `next_batch`.
22-
pub unsafe trait Blob: HasDataType {
22+
pub unsafe trait Blob: HasDataType + Send {
2323
/// CData type of the binary data returned in the batches. Likely to be either
2424
/// [`crate::sys::CDataType::Binary`], [`crate::sys::CDataType::Char`] or
2525
/// [`crate::sys::CDataType::WChar`].
@@ -266,7 +266,7 @@ impl<R> BlobRead<R> {
266266
/// fn insert_image_to_db(
267267
/// conn: &Connection<'_>,
268268
/// id: &str,
269-
/// image_data: impl BufRead) -> Result<(), Error>
269+
/// image_data: impl BufRead + Send) -> Result<(), Error>
270270
/// {
271271
/// const MAX_IMAGE_SIZE: usize = 4 * 1024 * 1024;
272272
/// let mut blob = BlobRead::with_upper_bound(image_data, MAX_IMAGE_SIZE);
@@ -360,7 +360,7 @@ where
360360

361361
unsafe impl<R> Blob for BlobRead<R>
362362
where
363-
R: BufRead,
363+
R: BufRead + Send,
364364
{
365365
fn c_data_type(&self) -> CDataType {
366366
CDataType::Binary

odbc-api/tests/integration.rs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
mod common;
22
mod connection_strings;
33

4+
use async_stream::stream;
45
use odbc_sys::{Date, Time};
56
use stdext::function_name;
67
use sys::{CDataType, NULL_DATA, Numeric, Pointer, SqlDataType, Timestamp};
@@ -34,6 +35,7 @@ use odbc_api::{
3435
},
3536
sys,
3637
};
38+
use tokio_stream::{Stream, StreamExt as _};
3739
use widestring::Utf16String;
3840

3941
use std::{
@@ -6063,6 +6065,39 @@ fn fetch_decimal_as_numeric_struct_using_bind_col(profile: &Profile) {
60636065
assert_eq!(0, target.val[2]);
60646066
}
60656067

6068+
#[test_case(MSSQL; "Microsoft SQL Server")]
6069+
#[test_case(MARIADB; "Maria DB")]
6070+
#[test_case(SQLITE_3; "SQLite 3")]
6071+
#[test_case(POSTGRES; "PostgreSQL")]
6072+
#[tokio::test]
6073+
async fn async_stream_of_rows_from_other_thread(profile: &Profile) {
6074+
let table_name = table_name!();
6075+
let (conn, table) = Given::new(&table_name)
6076+
.column_types(&["INT"])
6077+
.values_by_column(&[&[Some("42")]])
6078+
.build(profile)
6079+
.unwrap();
6080+
6081+
// When
6082+
fn stream_of_send_rows(
6083+
connection: Connection<'static>,
6084+
query: String,
6085+
) -> impl Stream<Item = (i32,)> + Send {
6086+
let stmt = connection.into_preallocated().unwrap();
6087+
let mut stmt = stmt.into_polling().unwrap();
6088+
stream! {
6089+
let sleep = || tokio::time::sleep(Duration::from_millis(10));
6090+
let _ = stmt.execute(&query, (), sleep).await;
6091+
yield (42, )
6092+
}
6093+
}
6094+
6095+
// Then
6096+
let stream = stream_of_send_rows(conn, table.sql_all_ordered_by_id());
6097+
let rows = stream.collect::<Vec<_>>().await;
6098+
assert_eq!([(42i32,)].as_slice(), rows)
6099+
}
6100+
60666101
/// Learning test to see how scrolling cursors behave
60676102
#[test_case(MSSQL; "Microsoft SQL Server")]
60686103
#[test_case(MARIADB; "Maria DB")]

0 commit comments

Comments
 (0)