Skip to content

Commit d112bda

Browse files
committed
Add test_arrow_mysql_roundtrip
1 parent 38d1328 commit d112bda

File tree

2 files changed

+106
-2
lines changed

2 files changed

+106
-2
lines changed

tests/mysql/common.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use crate::{
1212
const MYSQL_ROOT_PASSWORD: &str = "integration-test-pw";
1313
const MYSQL_DOCKER_CONTAINER: &str = "runtime-integration-test-mysql";
1414

15-
fn get_mysql_params(port: usize) -> HashMap<String, SecretString> {
15+
pub(super) fn get_mysql_params(port: usize) -> HashMap<String, SecretString> {
1616
let mut params = HashMap::new();
1717
params.insert(
1818
"mysql_host".to_string(),

tests/mysql/mod.rs

Lines changed: 105 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use crate::arrow_record_batch_gen::*;
12
use datafusion::execution::context::SessionContext;
23
use datafusion_table_providers::{
34
mysql::DynMySQLConnectionPool, sql::sql_provider_datafusion::SqlTable,
@@ -9,8 +10,16 @@ use arrow::{
910
array::*,
1011
datatypes::{i256, DataType, Field, Schema, TimeUnit, UInt16Type},
1112
};
12-
13+
use arrow_schema::SchemaRef;
14+
use datafusion::catalog::TableProviderFactory;
15+
use datafusion::common::{Constraints, ToDFSchema};
16+
use datafusion::logical_expr::CreateExternalTable;
17+
use datafusion::physical_plan::collect;
18+
use datafusion::physical_plan::memory::MemoryExec;
19+
use datafusion_federation::schema_cast::record_convert::try_cast_to;
20+
use datafusion_table_providers::mysql::MySQLTableProviderFactory;
1321
use datafusion_table_providers::sql::db_connection_pool::dbconnection::AsyncDbConnection;
22+
use secrecy::ExposeSecret;
1423

1524
use crate::docker::RunningContainer;
1625

@@ -578,6 +587,73 @@ async fn arrow_mysql_one_way(
578587
record_batch
579588
}
580589

590+
async fn arrow_mysql_round_trip(
591+
port: usize,
592+
arrow_record: RecordBatch,
593+
source_schema: SchemaRef,
594+
table_name: &str,
595+
) {
596+
let factory = MySQLTableProviderFactory::new();
597+
let ctx = SessionContext::new();
598+
let cmd = CreateExternalTable {
599+
schema: Arc::new(arrow_record.schema().to_dfschema().expect("to df schema")),
600+
name: table_name.into(),
601+
location: "".to_string(),
602+
file_type: "".to_string(),
603+
table_partition_cols: vec![],
604+
if_not_exists: false,
605+
definition: None,
606+
order_exprs: vec![],
607+
unbounded: false,
608+
options: common::get_mysql_params(port)
609+
.into_iter()
610+
.map(|(k, v)| (k, v.expose_secret().to_string()))
611+
.collect(),
612+
constraints: Constraints::empty(),
613+
column_defaults: Default::default(),
614+
};
615+
let table_provider = factory
616+
.create(&ctx.state(), &cmd)
617+
.await
618+
.expect("table provider created");
619+
620+
let ctx = SessionContext::new();
621+
let mem_exec = MemoryExec::try_new(&[vec![arrow_record.clone()]], arrow_record.schema(), None)
622+
.expect("memory exec created");
623+
let insert_plan = table_provider
624+
.insert_into(&ctx.state(), Arc::new(mem_exec), true)
625+
.await
626+
.expect("insert plan created");
627+
628+
let _ = collect(insert_plan, ctx.task_ctx())
629+
.await
630+
.expect("insert done");
631+
ctx.register_table(table_name, table_provider)
632+
.expect("Table should be registered");
633+
let sql = format!("SELECT * FROM {table_name}");
634+
let df = ctx
635+
.sql(&sql)
636+
.await
637+
.expect("DataFrame should be created from query");
638+
639+
let record_batch = df.collect().await.expect("RecordBatch should be collected");
640+
641+
tracing::debug!("Original Arrow Record Batch: {:?}", arrow_record.columns());
642+
tracing::debug!(
643+
"MySQL returned Record Batch: {:?}",
644+
record_batch[0].columns()
645+
);
646+
647+
let casted_result =
648+
try_cast_to(record_batch[0].clone(), source_schema).expect("Failed to cast record batch");
649+
650+
// Check results
651+
assert_eq!(record_batch.len(), 1);
652+
assert_eq!(record_batch[0].num_rows(), arrow_record.num_rows());
653+
assert_eq!(record_batch[0].num_columns(), arrow_record.num_columns());
654+
assert_eq!(arrow_record, casted_result);
655+
}
656+
581657
async fn start_mysql_container(port: usize) -> RunningContainer {
582658
let running_container = common::start_mysql_docker_container(port)
583659
.await
@@ -588,6 +664,34 @@ async fn start_mysql_container(port: usize) -> RunningContainer {
588664
running_container
589665
}
590666

667+
#[rstest]
668+
#[case::binary(get_arrow_binary_record_batch(), "binary")]
669+
#[case::int(get_arrow_int_record_batch(), "int")]
670+
#[case::float(get_arrow_float_record_batch(), "float")]
671+
#[case::utf8(get_arrow_utf8_record_batch(), "utf8")]
672+
#[case::time(get_arrow_time_record_batch(), "time")]
673+
#[case::timestamp(get_arrow_timestamp_record_batch(), "timestamp")]
674+
#[case::date(get_arrow_date_record_batch(), "date")]
675+
#[case::struct_type(get_arrow_struct_record_batch(), "struct")]
676+
#[case::decimal(get_arrow_decimal_record_batch(), "decimal")]
677+
#[case::interval(get_arrow_interval_record_batch(), "interval")]
678+
#[case::duration(get_arrow_duration_record_batch(), "duration")]
679+
#[case::list(get_arrow_list_record_batch(), "list")]
680+
#[case::null(get_arrow_null_record_batch(), "null")]
681+
#[case::bytea_array(get_arrow_bytea_array_record_batch(), "bytea_array")]
682+
#[test_log::test(tokio::test)]
683+
async fn test_arrow_mysql_roundtrip(
684+
#[case] arrow_result: (RecordBatch, SchemaRef),
685+
#[case] table_name: &str,
686+
) {
687+
let port = crate::get_random_port();
688+
let mysql_container = start_mysql_container(port).await;
689+
690+
arrow_mysql_round_trip(port, arrow_result.0, arrow_result.1, table_name).await;
691+
692+
mysql_container.remove().await.expect("container to stop");
693+
}
694+
591695
#[rstest]
592696
#[test_log::test(tokio::test)]
593697
async fn test_mysql_arrow_oneway() {

0 commit comments

Comments
 (0)