Skip to content

Commit 3d4494b

Browse files
committed
z
1 parent 81e1b14 commit 3d4494b

File tree

1 file changed

+63
-52
lines changed

1 file changed

+63
-52
lines changed

driver/src/rest_api.rs

Lines changed: 63 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,7 @@ impl IConnection for RestAPIConnection {
6464

6565
async fn exec(&self, sql: &str) -> Result<i64> {
6666
info!("exec: {}", sql);
67-
let page = self.client.query_all(sql).await?;
68-
let affected_rows = parse_affected_rows_from_page(&page)?;
69-
Ok(affected_rows)
67+
self.calculate_affected_rows_from_iter(sql).await
7068
}
7169

7270
async fn kill_query(&self, query_id: &str) -> Result<()> {
@@ -178,13 +176,13 @@ impl IConnection for RestAPIConnection {
178176
}
179177
}
180178

181-
impl<'o> RestAPIConnection {
179+
impl RestAPIConnection {
182180
pub async fn try_create(dsn: &str, name: String) -> Result<Self> {
183181
let client = APIClient::new(dsn, Some(name)).await?;
184182
Ok(Self { client })
185183
}
186184

187-
fn default_file_format_options() -> BTreeMap<&'o str, &'o str> {
185+
fn default_file_format_options() -> BTreeMap<&'static str, &'static str> {
188186
vec![
189187
("type", "CSV"),
190188
("field_delimiter", ","),
@@ -195,9 +193,68 @@ impl<'o> RestAPIConnection {
195193
.collect()
196194
}
197195

198-
fn default_copy_options() -> BTreeMap<&'o str, &'o str> {
196+
fn default_copy_options() -> BTreeMap<&'static str, &'static str> {
199197
vec![("purge", "true")].into_iter().collect()
200198
}
199+
fn parse_row_count_string(value_str: &str) -> Result<i64, String> {
200+
let trimmed = value_str.trim();
201+
202+
if trimmed.is_empty() {
203+
return Ok(0);
204+
}
205+
206+
if let Ok(count) = trimmed.parse::<i64>() {
207+
return Ok(count);
208+
}
209+
210+
if let Ok(count) = serde_json::from_str::<i64>(trimmed) {
211+
return Ok(count);
212+
}
213+
214+
let unquoted = trimmed.trim_matches('"');
215+
if let Ok(count) = unquoted.parse::<i64>() {
216+
return Ok(count);
217+
}
218+
219+
Err(format!(
220+
"failed to parse affected rows from: '{}'",
221+
value_str
222+
))
223+
}
224+
225+
async fn calculate_affected_rows_from_iter(&self, sql: &str) -> Result<i64> {
226+
let mut rows = IConnection::query_iter(self, sql).await?;
227+
let mut count = 0i64;
228+
229+
use tokio_stream::StreamExt;
230+
// Get the first row to check if it has affected rows info
231+
if let Some(first_row) = rows.next().await {
232+
let row = first_row?;
233+
let schema = row.schema();
234+
235+
// Check if this is an affected rows response
236+
if !schema.fields().is_empty() && schema.fields()[0].name.contains("number of rows") {
237+
let values = row.values();
238+
if !values.is_empty() {
239+
let value = &values[0];
240+
let s: String = value.clone().try_into().map_err(|e| {
241+
Error::InvalidResponse(format!("Failed to convert value to string: {}", e))
242+
})?;
243+
count = Self::parse_row_count_string(&s).map_err(Error::InvalidResponse)?;
244+
}
245+
} else {
246+
// If it's not affected rows info, count normally
247+
count = 1;
248+
// Continue counting the rest
249+
while let Some(row_result) = rows.next().await {
250+
row_result?;
251+
count += 1;
252+
}
253+
}
254+
}
255+
256+
Ok(count)
257+
}
201258
}
202259

203260
pub struct RestAPIRows<T> {
@@ -288,49 +345,3 @@ impl FromRowStats for RawRowWithStats {
288345
Ok(RawRowWithStats::Row(RawRow::new(rows, row)))
289346
}
290347
}
291-
292-
fn parse_affected_rows_from_page(page: &databend_client::Page) -> Result<i64> {
293-
if page.schema.is_empty() {
294-
return Ok(0);
295-
}
296-
297-
let first_field = &page.schema[0];
298-
if !first_field.name.contains("number of rows") {
299-
return Ok(0);
300-
}
301-
302-
if page.data.is_empty() || page.data[0].is_empty() {
303-
return Ok(0);
304-
}
305-
306-
match &page.data[0][0] {
307-
Some(value_str) => parse_row_count_string(value_str).map_err(Error::InvalidResponse),
308-
None => Ok(0),
309-
}
310-
}
311-
312-
fn parse_row_count_string(value_str: &str) -> Result<i64, String> {
313-
let trimmed = value_str.trim();
314-
315-
if trimmed.is_empty() {
316-
return Ok(0);
317-
}
318-
319-
if let Ok(count) = trimmed.parse::<i64>() {
320-
return Ok(count);
321-
}
322-
323-
if let Ok(count) = serde_json::from_str::<i64>(trimmed) {
324-
return Ok(count);
325-
}
326-
327-
let unquoted = trimmed.trim_matches('"');
328-
if let Ok(count) = unquoted.parse::<i64>() {
329-
return Ok(count);
330-
}
331-
332-
Err(format!(
333-
"failed to parse affected rows from: '{}'",
334-
value_str
335-
))
336-
}

0 commit comments

Comments
 (0)