Skip to content

Commit 2b0f6c0

Browse files
authored
feat(driver): add try_collect for rows (#196)
1 parent d767425 commit 2b0f6c0

File tree

9 files changed

+77
-25
lines changed

9 files changed

+77
-25
lines changed

driver/src/conn.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ pub trait Connection: DynClone + Send + Sync {
7676
let row = self.query_row("SELECT version()").await?;
7777
let version = match row {
7878
Some(row) => {
79-
let (version,): (String,) = row.try_into()?;
79+
let (version,): (String,) = row.try_into().map_err(Error::Parsing)?;
8080
version
8181
}
8282
None => "".to_string(),

driver/src/flight_sql.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ impl Connection for FlightSQLConnection {
7878
Ok(_) => None,
7979
Err(err) => Some(Err(err)),
8080
});
81-
Ok(Box::pin(rows))
81+
Ok(RowIterator::new(Box::pin(rows)))
8282
}
8383

8484
async fn query_iter_ext(&self, sql: &str) -> Result<(Schema, RowProgressIterator)> {
@@ -92,7 +92,7 @@ impl Connection for FlightSQLConnection {
9292
.ok_or(Error::Protocol("Ticket is empty".to_string()))?;
9393
let flight_data = client.do_get(ticket.clone()).await?;
9494
let (schema, rows) = FlightSQLRows::try_from_flight_data(flight_data).await?;
95-
Ok((schema, Box::pin(rows)))
95+
Ok((schema, RowProgressIterator::new(Box::pin(rows))))
9696
}
9797

9898
async fn stream_load(

driver/src/rest_api.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,13 @@ impl Connection for RestAPIConnection {
6262
Ok(_) => None,
6363
Err(err) => Some(Err(err)),
6464
});
65-
Ok(Box::pin(rows))
65+
Ok(RowIterator::new(Box::pin(rows)))
6666
}
6767

6868
async fn query_iter_ext(&self, sql: &str) -> Result<(Schema, RowProgressIterator)> {
6969
let resp = self.client.query(sql).await?;
7070
let (schema, rows) = RestAPIRows::from_response(self.client.clone(), resp)?;
71-
Ok((schema, Box::pin(rows)))
71+
Ok((schema, RowProgressIterator::new(Box::pin(rows))))
7272
}
7373

7474
async fn query_row(&self, sql: &str) -> Result<Option<Row>> {

driver/tests/driver/select_iter.rs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -167,19 +167,17 @@ async fn select_iter_struct() {
167167
];
168168

169169
let sql_select = format!("SELECT * FROM `{}`", table);
170-
let mut rows = conn.query_iter(&sql_select).await.unwrap();
171-
let mut row_count = 0;
172-
while let Some(row) = rows.next().await {
173-
let v: RowResult = row.unwrap().try_into().unwrap();
174-
let expected_row = &expected[row_count];
170+
let rows = conn.query_iter(&sql_select).await.unwrap();
171+
let results = rows.try_collect::<RowResult>().await.unwrap();
172+
for (idx, v) in results.iter().enumerate() {
173+
let expected_row = &expected[idx];
175174
assert_eq!(v.i64, expected_row.i64);
176175
assert_eq!(v.u64, expected_row.u64);
177176
assert_eq!(v.f64, expected_row.f64);
178177
assert_eq!(v.s, expected_row.s);
179178
assert_eq!(v.s2, expected_row.s2);
180179
assert_eq!(v.d, expected_row.d);
181180
assert_eq!(v.t, expected_row.t);
182-
row_count += 1;
183181
}
184182
}
185183

macros/src/from_row.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,18 +38,18 @@ pub fn from_row_derive(tokens_input: TokenStream) -> TokenStream {
3838
// it is safe to unwrap
3939
let t = col_value.get_type();
4040
<#field_type>::try_from(col_value)
41-
.map_err(|_| #path::Error::InvalidResponse(format!("failed converting column {} from type({:?}) to type({})", col_ix, t, std::any::type_name::<#field_type>())))?
41+
.map_err(|_| format!("failed converting column {} from type({:?}) to type({})", col_ix, t, std::any::type_name::<#field_type>()))?
4242
},
4343
}
4444
});
4545

4646
let fields_count = struct_fields.named.len();
4747
let generated = quote! {
4848
impl #impl_generics TryFrom<#path::Row> for #struct_name #ty_generics #where_clause {
49-
type Error = #path::Error;
50-
fn try_from(row: #path::Row) -> #path::Result<Self> {
49+
type Error = String;
50+
fn try_from(row: #path::Row) -> #path::Result<Self, String> {
5151
if #fields_count != row.len() {
52-
return Err(#path::Error::InvalidResponse(format!("row size mismatch: expected {} columns, got {}", #fields_count, row.len())));
52+
return Err(format!("row size mismatch: expected {} columns, got {}", #fields_count, row.len()));
5353
}
5454
let mut vals_iter = row.into_iter().enumerate();
5555
Ok(#struct_name {

sql/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ flight-sql = ["dep:arrow-array", "dep:arrow-cast", "dep:arrow-flight", "dep:arro
1616
[dependencies]
1717
databend-client = { workspace = true }
1818

19+
async-trait = "0.1.68"
1920
chrono = { version = "0.4.26", default-features = false }
2021
serde = { version = "1.0.164", default-features = false, features = ["derive"] }
2122
serde_json = { version = "1.0.97", default-features = false, features = ["std"] }

sql/src/from_row.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
use crate::error::{Error, Result};
15+
use crate::error::Result;
1616
use crate::rows::Row;
1717
use crate::value::Value;
1818

@@ -29,15 +29,15 @@ macro_rules! impl_tuple_from_row {
2929
where
3030
$($Ti: TryFrom<Value>),+
3131
{
32-
type Error = Error;
33-
fn try_from(row: Row) -> Result<Self> {
32+
type Error = String;
33+
fn try_from(row: Row) -> Result<Self, String> {
3434
// It is not possible yet to get the number of metavariable repetitions
3535
// ref: https://github.com/rust-lang/lang-team/issues/28#issue-644523674
3636
// This is a workaround
3737
let expected_len = <[()]>::len(&[$(replace_expr!(($Ti) ())),*]);
3838

3939
if expected_len != row.len() {
40-
return Err(Error::InvalidResponse(format!("row size mismatch: expected {} columns, got {}", expected_len, row.len())));
40+
return Err(format!("row size mismatch: expected {} columns, got {}", expected_len, row.len()));
4141
}
4242
let mut vals_iter = row.into_iter().enumerate();
4343

@@ -50,7 +50,7 @@ macro_rules! impl_tuple_from_row {
5050
// so it is safe to unwrap
5151
let t = col_value.get_type();
5252
$Ti::try_from(col_value)
53-
.map_err(|_| Error::InvalidResponse(format!("failed converting column {} from type({:?}) to type({})", col_ix, t, std::any::type_name::<$Ti>())))?
53+
.map_err(|_| format!("failed converting column {} from type({:?}) to type({})", col_ix, t, std::any::type_name::<$Ti>()))?
5454
}
5555
,)+
5656
))

sql/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ pub mod value;
2121
#[doc(hidden)]
2222
pub mod _macro_internal {
2323
pub use crate::error::{Error, Result};
24-
pub use crate::rows::Row;
24+
pub use crate::rows::{Row, RowIterator};
2525
pub use crate::schema::Schema;
2626
pub use crate::value::Value;
2727
}

sql/src/rows.rs

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@
1313
// limitations under the License.
1414

1515
use std::pin::Pin;
16+
use std::task::Context;
17+
use std::task::Poll;
1618

1719
use serde::Deserialize;
18-
use tokio_stream::Stream;
20+
use tokio_stream::{Stream, StreamExt};
1921

2022
#[cfg(feature = "flight-sql")]
2123
use arrow::record_batch::RecordBatch;
@@ -24,9 +26,6 @@ use crate::error::{Error, Result};
2426
use crate::schema::SchemaRef;
2527
use crate::value::Value;
2628

27-
pub type RowIterator = Pin<Box<dyn Stream<Item = Result<Row>> + Send>>;
28-
pub type RowProgressIterator = Pin<Box<dyn Stream<Item = Result<RowWithProgress>> + Send>>;
29-
3029
#[derive(Clone, Debug)]
3130
pub enum RowWithProgress {
3231
Row(Row),
@@ -149,3 +148,57 @@ impl IntoIterator for Rows {
149148
self.0.into_iter()
150149
}
151150
}
151+
152+
pub struct RowIterator(Pin<Box<dyn Stream<Item = Result<Row>> + Send>>);
153+
154+
impl RowIterator {
155+
pub fn new(it: Pin<Box<dyn Stream<Item = Result<Row>> + Send>>) -> Self {
156+
Self(it)
157+
}
158+
159+
pub async fn try_collect<T>(mut self) -> Result<Vec<T>>
160+
where
161+
T: TryFrom<Row>,
162+
T::Error: std::fmt::Display,
163+
{
164+
let mut ret = Vec::new();
165+
while let Some(row) = self.0.next().await {
166+
let v = T::try_from(row?).map_err(|e| Error::Parsing(e.to_string()))?;
167+
ret.push(v)
168+
}
169+
Ok(ret)
170+
}
171+
}
172+
173+
impl Stream for RowIterator {
174+
type Item = Result<Row>;
175+
176+
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
177+
Pin::new(&mut self.0).poll_next(cx)
178+
}
179+
}
180+
181+
pub struct RowProgressIterator(Pin<Box<dyn Stream<Item = Result<RowWithProgress>> + Send>>);
182+
183+
impl RowProgressIterator {
184+
pub fn new(it: Pin<Box<dyn Stream<Item = Result<RowWithProgress>> + Send>>) -> Self {
185+
Self(it)
186+
}
187+
188+
pub async fn filter_rows(self) -> RowIterator {
189+
let rows = self.0.filter_map(|r| match r {
190+
Ok(RowWithProgress::Row(r)) => Some(Ok(r)),
191+
Ok(_) => None,
192+
Err(err) => Some(Err(err)),
193+
});
194+
RowIterator(Box::pin(rows))
195+
}
196+
}
197+
198+
impl Stream for RowProgressIterator {
199+
type Item = Result<RowWithProgress>;
200+
201+
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
202+
Pin::new(&mut self.0).poll_next(cx)
203+
}
204+
}

0 commit comments

Comments
 (0)