Skip to content

Commit 1d9c92c

Browse files
committed
Add test_arrow_mysql_roundtrip
1 parent 38d1328 commit 1d9c92c

File tree

2 files changed

+105
-2
lines changed

2 files changed

+105
-2
lines changed

tests/mysql/common.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use crate::{
1212
const MYSQL_ROOT_PASSWORD: &str = "integration-test-pw";
1313
const MYSQL_DOCKER_CONTAINER: &str = "runtime-integration-test-mysql";
1414

15-
fn get_mysql_params(port: usize) -> HashMap<String, SecretString> {
15+
pub(super) fn get_mysql_params(port: usize) -> HashMap<String, SecretString> {
1616
let mut params = HashMap::new();
1717
params.insert(
1818
"mysql_host".to_string(),

tests/mysql/mod.rs

Lines changed: 104 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use crate::arrow_record_batch_gen::*;
12
use datafusion::execution::context::SessionContext;
23
use datafusion_table_providers::{
34
mysql::DynMySQLConnectionPool, sql::sql_provider_datafusion::SqlTable,
@@ -9,8 +10,17 @@ use arrow::{
910
array::*,
1011
datatypes::{i256, DataType, Field, Schema, TimeUnit, UInt16Type},
1112
};
12-
13+
use arrow_schema::SchemaRef;
14+
use datafusion::catalog::TableProviderFactory;
15+
use datafusion::common::{Constraints, ToDFSchema};
16+
use datafusion_expr::CreateExternalTable;
17+
use datafusion_federation::schema_cast::record_convert::try_cast_to;
18+
use datafusion_physical_plan::collect;
19+
use datafusion_physical_plan::memory::MemoryExec;
20+
use datafusion_table_providers::mysql::MySQLTableProviderFactory;
1321
use datafusion_table_providers::sql::db_connection_pool::dbconnection::AsyncDbConnection;
22+
use secrecy::ExposeSecret;
23+
use tokio::sync::Mutex;
1424

1525
use crate::docker::RunningContainer;
1626

@@ -578,6 +588,73 @@ async fn arrow_mysql_one_way(
578588
record_batch
579589
}
580590

591+
async fn arrow_mysql_round_trip(
592+
port: usize,
593+
arrow_record: RecordBatch,
594+
source_schema: SchemaRef,
595+
table_name: &str,
596+
) {
597+
let factory = MySQLTableProviderFactory::new();
598+
let ctx = SessionContext::new();
599+
let cmd = CreateExternalTable {
600+
schema: Arc::new(arrow_record.schema().to_dfschema().expect("to df schema")),
601+
name: table_name.into(),
602+
location: "".to_string(),
603+
file_type: "".to_string(),
604+
table_partition_cols: vec![],
605+
if_not_exists: false,
606+
definition: None,
607+
order_exprs: vec![],
608+
unbounded: false,
609+
options: common::get_mysql_params(port)
610+
.into_iter()
611+
.map(|(k, v)| (k, v.expose_secret().to_string()))
612+
.collect(),
613+
constraints: Constraints::empty(),
614+
column_defaults: Default::default(),
615+
};
616+
let table_provider = factory
617+
.create(&ctx.state(), &cmd)
618+
.await
619+
.expect("table provider created");
620+
621+
let ctx = SessionContext::new();
622+
let mem_exec = MemoryExec::try_new(&[vec![arrow_record.clone()]], arrow_record.schema(), None)
623+
.expect("memory exec created");
624+
let insert_plan = table_provider
625+
.insert_into(&ctx.state(), Arc::new(mem_exec), true)
626+
.await
627+
.expect("insert plan created");
628+
629+
let _ = collect(insert_plan, ctx.task_ctx())
630+
.await
631+
.expect("insert done");
632+
ctx.register_table(table_name, table_provider)
633+
.expect("Table should be registered");
634+
let sql = format!("SELECT * FROM {table_name}");
635+
let df = ctx
636+
.sql(&sql)
637+
.await
638+
.expect("DataFrame should be created from query");
639+
640+
let record_batch = df.collect().await.expect("RecordBatch should be collected");
641+
642+
tracing::debug!("Original Arrow Record Batch: {:?}", arrow_record.columns());
643+
tracing::debug!(
644+
"MySQL returned Record Batch: {:?}",
645+
record_batch[0].columns()
646+
);
647+
648+
let casted_result =
649+
try_cast_to(record_batch[0].clone(), source_schema).expect("Failed to cast record batch");
650+
651+
// Check results
652+
assert_eq!(record_batch.len(), 1);
653+
assert_eq!(record_batch[0].num_rows(), arrow_record.num_rows());
654+
assert_eq!(record_batch[0].num_columns(), arrow_record.num_columns());
655+
assert_eq!(arrow_record, casted_result);
656+
}
657+
581658
async fn start_mysql_container(port: usize) -> RunningContainer {
582659
let running_container = common::start_mysql_docker_container(port)
583660
.await
@@ -588,6 +665,32 @@ async fn start_mysql_container(port: usize) -> RunningContainer {
588665
running_container
589666
}
590667

668+
#[rstest]
669+
#[case::binary(get_arrow_binary_record_batch(), "binary")]
670+
#[case::int(get_arrow_int_record_batch(), "int")]
671+
#[case::float(get_arrow_float_record_batch(), "float")]
672+
#[case::utf8(get_arrow_utf8_record_batch(), "utf8")]
673+
#[case::time(get_arrow_time_record_batch(), "time")]
674+
#[case::timestamp(get_arrow_timestamp_record_batch(), "timestamp")]
675+
#[case::date(get_arrow_date_record_batch(), "date")]
676+
#[case::struct_type(get_arrow_struct_record_batch(), "struct")]
677+
#[case::decimal(get_arrow_decimal_record_batch(), "decimal")]
678+
#[case::interval(get_arrow_interval_record_batch(), "interval")]
679+
#[case::duration(get_arrow_duration_record_batch(), "duration")]
680+
#[case::list(get_arrow_list_record_batch(), "list")]
681+
#[case::null(get_arrow_null_record_batch(), "null")]
682+
#[case::bytea_array(get_arrow_bytea_array_record_batch(), "bytea_array")]
683+
#[test_log::test(tokio::test)]
684+
async fn test_arrow_mysql_roundtrip(
685+
#[case] arrow_result: (RecordBatch, SchemaRef),
686+
#[case] table_name: &str,
687+
) {
688+
let port = crate::get_random_port();
689+
let mysql_container = start_mysql_container(port).await;
690+
691+
arrow_mysql_round_trip(port, arrow_result.0, arrow_result.1, table_name).await;
692+
}
693+
591694
#[rstest]
592695
#[test_log::test(tokio::test)]
593696
async fn test_mysql_arrow_oneway() {

0 commit comments

Comments
 (0)