1+ use crate :: arrow_record_batch_gen:: * ;
12use datafusion:: execution:: context:: SessionContext ;
23use datafusion_table_providers:: {
34 mysql:: DynMySQLConnectionPool , sql:: sql_provider_datafusion:: SqlTable ,
@@ -9,8 +10,17 @@ use arrow::{
910 array:: * ,
1011 datatypes:: { i256, DataType , Field , Schema , TimeUnit , UInt16Type } ,
1112} ;
12-
13+ use arrow_schema:: SchemaRef ;
14+ use datafusion:: catalog:: TableProviderFactory ;
15+ use datafusion:: common:: { Constraints , ToDFSchema } ;
16+ use datafusion_expr:: CreateExternalTable ;
17+ use datafusion_federation:: schema_cast:: record_convert:: try_cast_to;
18+ use datafusion_physical_plan:: collect;
19+ use datafusion_physical_plan:: memory:: MemoryExec ;
20+ use datafusion_table_providers:: mysql:: MySQLTableProviderFactory ;
1321use datafusion_table_providers:: sql:: db_connection_pool:: dbconnection:: AsyncDbConnection ;
22+ use secrecy:: ExposeSecret ;
23+ use tokio:: sync:: Mutex ;
1424
1525use crate :: docker:: RunningContainer ;
1626
@@ -578,6 +588,73 @@ async fn arrow_mysql_one_way(
578588 record_batch
579589}
580590
591+ async fn arrow_mysql_round_trip (
592+ port : usize ,
593+ arrow_record : RecordBatch ,
594+ source_schema : SchemaRef ,
595+ table_name : & str ,
596+ ) {
597+ let factory = MySQLTableProviderFactory :: new ( ) ;
598+ let ctx = SessionContext :: new ( ) ;
599+ let cmd = CreateExternalTable {
600+ schema : Arc :: new ( arrow_record. schema ( ) . to_dfschema ( ) . expect ( "to df schema" ) ) ,
601+ name : table_name. into ( ) ,
602+ location : "" . to_string ( ) ,
603+ file_type : "" . to_string ( ) ,
604+ table_partition_cols : vec ! [ ] ,
605+ if_not_exists : false ,
606+ definition : None ,
607+ order_exprs : vec ! [ ] ,
608+ unbounded : false ,
609+ options : common:: get_mysql_params ( port)
610+ . into_iter ( )
611+ . map ( |( k, v) | ( k, v. expose_secret ( ) . to_string ( ) ) )
612+ . collect ( ) ,
613+ constraints : Constraints :: empty ( ) ,
614+ column_defaults : Default :: default ( ) ,
615+ } ;
616+ let table_provider = factory
617+ . create ( & ctx. state ( ) , & cmd)
618+ . await
619+ . expect ( "table provider created" ) ;
620+
621+ let ctx = SessionContext :: new ( ) ;
622+ let mem_exec = MemoryExec :: try_new ( & [ vec ! [ arrow_record. clone( ) ] ] , arrow_record. schema ( ) , None )
623+ . expect ( "memory exec created" ) ;
624+ let insert_plan = table_provider
625+ . insert_into ( & ctx. state ( ) , Arc :: new ( mem_exec) , true )
626+ . await
627+ . expect ( "insert plan created" ) ;
628+
629+ let _ = collect ( insert_plan, ctx. task_ctx ( ) )
630+ . await
631+ . expect ( "insert done" ) ;
632+ ctx. register_table ( table_name, table_provider)
633+ . expect ( "Table should be registered" ) ;
634+ let sql = format ! ( "SELECT * FROM {table_name}" ) ;
635+ let df = ctx
636+ . sql ( & sql)
637+ . await
638+ . expect ( "DataFrame should be created from query" ) ;
639+
640+ let record_batch = df. collect ( ) . await . expect ( "RecordBatch should be collected" ) ;
641+
642+ tracing:: debug!( "Original Arrow Record Batch: {:?}" , arrow_record. columns( ) ) ;
643+ tracing:: debug!(
644+ "MySQL returned Record Batch: {:?}" ,
645+ record_batch[ 0 ] . columns( )
646+ ) ;
647+
648+ let casted_result =
649+ try_cast_to ( record_batch[ 0 ] . clone ( ) , source_schema) . expect ( "Failed to cast record batch" ) ;
650+
651+ // Check results
652+ assert_eq ! ( record_batch. len( ) , 1 ) ;
653+ assert_eq ! ( record_batch[ 0 ] . num_rows( ) , arrow_record. num_rows( ) ) ;
654+ assert_eq ! ( record_batch[ 0 ] . num_columns( ) , arrow_record. num_columns( ) ) ;
655+ assert_eq ! ( arrow_record, casted_result) ;
656+ }
657+
581658async fn start_mysql_container ( port : usize ) -> RunningContainer {
582659 let running_container = common:: start_mysql_docker_container ( port)
583660 . await
@@ -588,6 +665,32 @@ async fn start_mysql_container(port: usize) -> RunningContainer {
588665 running_container
589666}
590667
668+ #[ rstest]
669+ #[ case:: binary( get_arrow_binary_record_batch( ) , "binary" ) ]
670+ #[ case:: int( get_arrow_int_record_batch( ) , "int" ) ]
671+ #[ case:: float( get_arrow_float_record_batch( ) , "float" ) ]
672+ #[ case:: utf8( get_arrow_utf8_record_batch( ) , "utf8" ) ]
673+ #[ case:: time( get_arrow_time_record_batch( ) , "time" ) ]
674+ #[ case:: timestamp( get_arrow_timestamp_record_batch( ) , "timestamp" ) ]
675+ #[ case:: date( get_arrow_date_record_batch( ) , "date" ) ]
676+ #[ case:: struct_type( get_arrow_struct_record_batch( ) , "struct" ) ]
677+ #[ case:: decimal( get_arrow_decimal_record_batch( ) , "decimal" ) ]
678+ #[ case:: interval( get_arrow_interval_record_batch( ) , "interval" ) ]
679+ #[ case:: duration( get_arrow_duration_record_batch( ) , "duration" ) ]
680+ #[ case:: list( get_arrow_list_record_batch( ) , "list" ) ]
681+ #[ case:: null( get_arrow_null_record_batch( ) , "null" ) ]
682+ #[ case:: bytea_array( get_arrow_bytea_array_record_batch( ) , "bytea_array" ) ]
683+ #[ test_log:: test( tokio:: test) ]
684+ async fn test_arrow_mysql_roundtrip (
685+ #[ case] arrow_result : ( RecordBatch , SchemaRef ) ,
686+ #[ case] table_name : & str ,
687+ ) {
688+ let port = crate :: get_random_port ( ) ;
689+ let mysql_container = start_mysql_container ( port) . await ;
690+
691+ arrow_mysql_round_trip ( port, arrow_result. 0 , arrow_result. 1 , table_name) . await ;
692+ }
693+
591694#[ rstest]
592695#[ test_log:: test( tokio:: test) ]
593696async fn test_mysql_arrow_oneway ( ) {
0 commit comments