diff --git a/bindings/python/src/blocking.rs b/bindings/python/src/blocking.rs index dab536e9..ea790e8a 100644 --- a/bindings/python/src/blocking.rs +++ b/bindings/python/src/blocking.rs @@ -192,6 +192,7 @@ pub struct BlockingDatabendCursor { // buffer is used to store only the first row after execute buffer: Vec, schema: Option, + rowcount: i64, } impl BlockingDatabendCursor { @@ -201,6 +202,7 @@ impl BlockingDatabendCursor { rows: None, buffer: Vec::new(), schema: None, + rowcount: -1, } } } @@ -210,6 +212,7 @@ impl BlockingDatabendCursor { self.rows = None; self.buffer.clear(); self.schema = None; + self.rowcount = -1; } } @@ -247,10 +250,9 @@ impl BlockingDatabendCursor { } } - /// Not supported currently #[getter] pub fn rowcount(&self, _py: Python) -> i64 { - -1 + self.rowcount } pub fn close(&mut self, py: Python) -> PyResult<()> { @@ -277,18 +279,40 @@ impl BlockingDatabendCursor { self.reset(); let conn = self.conn.clone(); - // fetch first row after execute - // then we could finish the query directly if there's no result let params = to_sql_params(params); + + // check if it is DML(INSERT, UPDATE, DELETE) + let sql_trimmed = operation.trim_start().to_lowercase(); + let is_dml = sql_trimmed.starts_with("insert") + || sql_trimmed.starts_with("update") + || sql_trimmed.starts_with("delete") + || sql_trimmed.starts_with("replace"); + + if is_dml { + let affected_rows = wait_for_future(py, async move { + conn.exec(&operation, params) + .await + .map_err(DriverError::new) + })?; + self.rowcount = affected_rows; + return Ok(py.None()); + } + + // for select, use query_iter let (first, rows) = wait_for_future(py, async move { let mut rows = conn.query_iter(&operation, params).await?; let first = rows.next().await.transpose()?; Ok::<_, databend_driver::Error>((first, rows)) }) .map_err(DriverError::new)?; + if let Some(first) = first { self.buffer.push(Row::new(first)); + self.rowcount = 1; + } else { + self.rowcount = 0; } + self.rows = Some(Arc::new(Mutex::new(rows))); self.set_schema(py); Ok(py.None()) @@ -375,9 +399,14 @@ impl BlockingDatabendCursor { for row in fetched { result.push(Row::new(row.map_err(DriverError::new)?)); } + + if self.rowcount == -1 { + self.rowcount = result.len() as i64; + } + Ok(result) } - None => Ok(vec![]), + None => Ok(result), } } diff --git a/driver/src/rest_api.rs b/driver/src/rest_api.rs index b6e7a5c3..87e25e56 100644 --- a/driver/src/rest_api.rs +++ b/driver/src/rest_api.rs @@ -64,8 +64,7 @@ impl IConnection for RestAPIConnection { async fn exec(&self, sql: &str) -> Result { info!("exec: {}", sql); - let page = self.client.query_all(sql).await?; - Ok(page.stats.progresses.write_progress.rows as i64) + self.calculate_affected_rows_from_iter(sql).await } async fn kill_query(&self, query_id: &str) -> Result<()> { @@ -197,6 +196,65 @@ impl<'o> RestAPIConnection { fn default_copy_options() -> BTreeMap<&'o str, &'o str> { vec![("purge", "true")].into_iter().collect() } + fn parse_row_count_string(value_str: &str) -> Result { + let trimmed = value_str.trim(); + + if trimmed.is_empty() { + return Ok(0); + } + + if let Ok(count) = trimmed.parse::() { + return Ok(count); + } + + if let Ok(count) = serde_json::from_str::(trimmed) { + return Ok(count); + } + + let unquoted = trimmed.trim_matches('"'); + if let Ok(count) = unquoted.parse::() { + return Ok(count); + } + + Err(format!( + "failed to parse affected rows from: '{}'", + value_str + )) + } + + async fn calculate_affected_rows_from_iter(&self, sql: &str) -> Result { + let mut rows = IConnection::query_iter(self, sql).await?; + let mut count = 0i64; + + use tokio_stream::StreamExt; + // Get the first row to check if it has affected rows info + if let Some(first_row) = rows.next().await { + let row = first_row?; + let schema = row.schema(); + + // Check if this is an affected rows response + if !schema.fields().is_empty() && schema.fields()[0].name.contains("number of rows") { + let values = row.values(); + if !values.is_empty() { + let value = &values[0]; + let s: String = value.clone().try_into().map_err(|e| { + Error::InvalidResponse(format!("Failed to convert value to string: {}", e)) + })?; + count = Self::parse_row_count_string(&s).map_err(Error::InvalidResponse)?; + } + } else { + // If it's not affected rows info, count normally + count = 1; + // Continue counting the rest + while let Some(row_result) = rows.next().await { + row_result?; + count += 1; + } + } + } + + Ok(count) + } } pub struct RestAPIRows {