Skip to content

Commit e629133

Browse files
committed
impl affect_rows in rest_api.rs and use query_iter
1 parent f88ef88 commit e629133

File tree

3 files changed

+59
-58
lines changed

3 files changed

+59
-58
lines changed

bindings/python/src/blocking.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,11 +294,12 @@ impl BlockingDatabendCursor {
294294
.await
295295
.map_err(DriverError::new)
296296
})?;
297+
297298
self.rowcount = affected_rows;
298299
return Ok(py.None());
299300
}
300301

301-
// for select, use query_iter
302+
// For SELECT, use query_iter as before
302303
let (first, rows) = wait_for_future(py, async move {
303304
let mut rows = conn.query_iter(&operation, params).await?;
304305
let first = rows.next().await.transpose()?;

core/src/pages.rs

Lines changed: 0 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -47,58 +47,6 @@ impl Page {
4747
}
4848
self.stats = p.stats;
4949
}
50-
51-
pub fn affected_rows(&self) -> Result<i64, String> {
52-
if self.schema.is_empty() {
53-
return Ok(0);
54-
}
55-
56-
let first_field = &self.schema[0];
57-
58-
if !first_field.name.contains("number of rows") {
59-
return Ok(0);
60-
}
61-
62-
if self.data.is_empty() || self.data[0].is_empty() {
63-
return Ok(0);
64-
}
65-
66-
match &self.data[0][0] {
67-
Some(value_str) => self.parse_row_count_string(value_str),
68-
None => Ok(0),
69-
}
70-
}
71-
72-
fn parse_row_count_string(&self, value_str: &str) -> Result<i64, String> {
73-
let trimmed = value_str.trim();
74-
75-
if trimmed.is_empty() {
76-
return Ok(0);
77-
}
78-
79-
if let Ok(count) = trimmed.parse::<i64>() {
80-
return Ok(count);
81-
}
82-
83-
if let Ok(count) = serde_json::from_str::<i64>(trimmed) {
84-
return Ok(count);
85-
}
86-
87-
let unquoted = trimmed.trim_matches('"');
88-
if let Ok(count) = unquoted.parse::<i64>() {
89-
return Ok(count);
90-
}
91-
92-
Err(format!(
93-
"failed to parse affected rows from: '{}'",
94-
value_str
95-
))
96-
}
97-
98-
///the schema can be `number of rows inserted`, `number of rows deleted`, `number of rows updated` when sql start with `insert`, `delete`, `update`
99-
pub fn has_affected_rows(&self) -> bool {
100-
!self.schema.is_empty() && self.schema[0].name.contains("number of rows")
101-
}
10250
}
10351

10452
type PageFut = Pin<Box<dyn Future<Output = Result<QueryResponse>> + Send>>;

driver/src/rest_api.rs

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ use log::info;
2424
use tokio::fs::File;
2525
use tokio::io::BufReader;
2626
use tokio_stream::Stream;
27+
use tokio_stream::StreamExt;
2728

2829
use databend_client::APIClient;
2930
use databend_client::Pages;
@@ -64,11 +65,8 @@ impl IConnection for RestAPIConnection {
6465

6566
async fn exec(&self, sql: &str) -> Result<i64> {
6667
info!("exec: {}", sql);
67-
let page = self.client.query_all(sql).await?;
68-
69-
let affected_rows = page.affected_rows().map_err(Error::InvalidResponse)?;
70-
71-
Ok(affected_rows)
68+
// Use the new affected_rows method that internally uses query_iter
69+
self.calculate_affected_rows_from_iter(sql).await
7270
}
7371

7472
async fn kill_query(&self, query_id: &str) -> Result<()> {
@@ -200,6 +198,60 @@ impl<'o> RestAPIConnection {
200198
fn default_copy_options() -> BTreeMap<&'o str, &'o str> {
201199
vec![("purge", "true")].into_iter().collect()
202200
}
201+
202+
fn parse_row_count_string(value_str: &str) -> Result<i64, String> {
203+
let trimmed = value_str.trim();
204+
205+
if trimmed.is_empty() {
206+
return Ok(0);
207+
}
208+
209+
if let Ok(count) = trimmed.parse::<i64>() {
210+
return Ok(count);
211+
}
212+
213+
if let Ok(count) = serde_json::from_str::<i64>(trimmed) {
214+
return Ok(count);
215+
}
216+
217+
let unquoted = trimmed.trim_matches('"');
218+
if let Ok(count) = unquoted.parse::<i64>() {
219+
return Ok(count);
220+
}
221+
222+
Err(format!(
223+
"failed to parse affected rows from: '{}'",
224+
value_str
225+
))
226+
}
227+
228+
async fn calculate_affected_rows_from_iter(&self, sql: &str) -> Result<i64> {
229+
let mut rows = IConnection::query_iter(self, sql).await?;
230+
let mut count = 0i64;
231+
232+
// Get the first row to check if it has affected rows info
233+
if let Some(first_row) = rows.next().await {
234+
let row = first_row?;
235+
let schema = row.schema();
236+
237+
// Check if this is an affected rows response
238+
if !schema.fields().is_empty() && schema.fields()[0].name.contains("number of rows") {
239+
let values = row.values();
240+
if !values.is_empty() {
241+
let value = &values[0];
242+
let s: String = value.clone().try_into().map_err(|e| {
243+
Error::InvalidResponse(format!("Failed to convert value to string: {}", e))
244+
})?;
245+
count = Self::parse_row_count_string(&s).map_err(Error::InvalidResponse)?;
246+
}
247+
} else {
248+
// If it's not affected rows info, count normally
249+
count = -1;
250+
}
251+
}
252+
253+
Ok(count)
254+
}
203255
}
204256

205257
pub struct RestAPIRows<T> {

0 commit comments

Comments
 (0)