Skip to content

Commit 45cc979

Browse files
committed
feat: add transaction related method
1 parent 4c3041b commit 45cc979

File tree

5 files changed

+94
-4
lines changed

5 files changed

+94
-4
lines changed

driver/src/conn.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,10 @@ pub trait Connection: DynClone + Send + Sync {
190190
))
191191
}
192192

193+
async fn begin(&self) -> Result<()>;
194+
async fn commit(&self) -> Result<()>;
195+
async fn rollback(&self) -> Result<()>;
196+
193197
async fn get_files(&self, stage: &str, local_file: &str) -> Result<RowStatsIterator> {
194198
let mut total_count: usize = 0;
195199
let mut total_size: usize = 0;

driver/src/flight_sql.rs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,21 @@ impl Connection for FlightSQLConnection {
143143
"STREAM LOAD unavailable for FlightSQL".to_string(),
144144
))
145145
}
146+
147+
async fn begin(&self) -> Result<()> {
148+
self.exec("BEGIN").await.unwrap();
149+
Ok(())
150+
}
151+
152+
async fn commit(&self) -> Result<()> {
153+
self.exec("COMMIT").await.unwrap();
154+
Ok(())
155+
}
156+
157+
async fn rollback(&self) -> Result<()> {
158+
self.exec("ROLLBACK").await.unwrap();
159+
Ok(())
160+
}
146161
}
147162

148163
impl FlightSQLConnection {
@@ -273,7 +288,7 @@ impl Args {
273288
return Err(Error::BadArgument(format!(
274289
"Invalid value for sslmode: {}",
275290
v.as_ref()
276-
)))
291+
)));
277292
}
278293
},
279294
"tls_ca_file" => args.tls_ca_file = Some(v.to_string()),

driver/src/rest_api.rs

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,21 @@ impl Connection for RestAPIConnection {
187187
let stats = self.load_data(sql, reader, size, None, None).await?;
188188
Ok(stats)
189189
}
190+
191+
async fn begin(&self) -> Result<()> {
192+
self.exec("BEGIN").await.unwrap();
193+
Ok(())
194+
}
195+
196+
async fn commit(&self) -> Result<()> {
197+
self.exec("COMMIT").await.unwrap();
198+
Ok(())
199+
}
200+
201+
async fn rollback(&self) -> Result<()> {
202+
self.exec("ROLLBACK").await.unwrap();
203+
Ok(())
204+
}
190205
}
191206

192207
impl<'o> RestAPIConnection {
@@ -219,16 +234,16 @@ impl<'o> RestAPIConnection {
219234
("record_delimiter", "\n"),
220235
("skip_header", "0"),
221236
]
222-
.into_iter()
223-
.collect()
237+
.into_iter()
238+
.collect()
224239
}
225240

226241
fn default_copy_options() -> BTreeMap<&'o str, &'o str> {
227242
vec![("purge", "true")].into_iter().collect()
228243
}
229244
}
230245

231-
type PageFut = Pin<Box<dyn Future<Output = Result<QueryResponse>> + Send>>;
246+
type PageFut = Pin<Box<dyn Future<Output=Result<QueryResponse>> + Send>>;
232247

233248
pub struct RestAPIRows {
234249
client: APIClient,

driver/tests/driver/main.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,4 @@ mod load;
1919
mod select_iter;
2020
mod select_simple;
2121
mod session;
22+
mod transaction;

driver/tests/driver/transaction.rs

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
// Copyright 2021 Datafuse Labs
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
use databend_driver::Client;
16+
17+
use crate::common::DEFAULT_DSN;
18+
19+
20+
#[tokio::test]
21+
async fn test_commit() {
22+
let dsn = option_env!("TEST_DATABEND_DSN").unwrap_or(DEFAULT_DSN);
23+
let client = Client::new(dsn.to_string());
24+
let conn = client.get_conn().await.unwrap();
25+
26+
conn.exec("CREATE OR REPLACE TABLE t(c int);").await.unwrap();
27+
conn.begin().await.unwrap();
28+
conn.exec("INSERT INTO t VALUES(1);").await.unwrap();
29+
let row = conn.query_row("SELECT * FROM t").await.unwrap();
30+
let row = row.unwrap();
31+
let (val, ): (i32, ) = row.try_into().unwrap();
32+
assert_eq!(val, 1);
33+
conn.commit().await.unwrap();
34+
}
35+
36+
#[tokio::test]
37+
async fn test_rollback() {
38+
let dsn = option_env!("TEST_DATABEND_DSN").unwrap_or(DEFAULT_DSN);
39+
let client = Client::new(dsn.to_string());
40+
let conn = client.get_conn().await.unwrap();
41+
42+
conn.exec("CREATE OR REPLACE TABLE t(c int);").await.unwrap();
43+
conn.begin().await.unwrap();
44+
conn.exec("INSERT INTO t VALUES(1);").await.unwrap();
45+
let row = conn.query_row("SELECT * FROM t").await.unwrap();
46+
let row = row.unwrap();
47+
let (val, ): (i32, ) = row.try_into().unwrap();
48+
assert_eq!(val, 1);
49+
conn.rollback().await.unwrap();
50+
let row = conn.query_row("SELECT * FROM t").await.unwrap();
51+
let row = row.unwrap();
52+
let (val, ): (Option<i32>, ) = row.try_into().unwrap();
53+
assert_eq!(val, None)
54+
}
55+

0 commit comments

Comments
 (0)