1+ use crate :: arrow_record_batch_gen:: * ;
12use datafusion:: execution:: context:: SessionContext ;
23use datafusion_table_providers:: {
34 mysql:: DynMySQLConnectionPool , sql:: sql_provider_datafusion:: SqlTable ,
@@ -9,8 +10,16 @@ use arrow::{
910 array:: * ,
1011 datatypes:: { i256, DataType , Field , Schema , TimeUnit , UInt16Type } ,
1112} ;
12-
13+ use arrow_schema:: SchemaRef ;
14+ use datafusion:: catalog:: TableProviderFactory ;
15+ use datafusion:: common:: { Constraints , ToDFSchema } ;
16+ use datafusion:: logical_expr:: CreateExternalTable ;
17+ use datafusion:: physical_plan:: collect;
18+ use datafusion:: physical_plan:: memory:: MemoryExec ;
19+ use datafusion_federation:: schema_cast:: record_convert:: try_cast_to;
20+ use datafusion_table_providers:: mysql:: MySQLTableProviderFactory ;
1321use datafusion_table_providers:: sql:: db_connection_pool:: dbconnection:: AsyncDbConnection ;
22+ use secrecy:: ExposeSecret ;
1423
1524use crate :: docker:: RunningContainer ;
1625
@@ -578,6 +587,73 @@ async fn arrow_mysql_one_way(
578587 record_batch
579588}
580589
590+ async fn arrow_mysql_round_trip (
591+ port : usize ,
592+ arrow_record : RecordBatch ,
593+ source_schema : SchemaRef ,
594+ table_name : & str ,
595+ ) {
596+ let factory = MySQLTableProviderFactory :: new ( ) ;
597+ let ctx = SessionContext :: new ( ) ;
598+ let cmd = CreateExternalTable {
599+ schema : Arc :: new ( arrow_record. schema ( ) . to_dfschema ( ) . expect ( "to df schema" ) ) ,
600+ name : table_name. into ( ) ,
601+ location : "" . to_string ( ) ,
602+ file_type : "" . to_string ( ) ,
603+ table_partition_cols : vec ! [ ] ,
604+ if_not_exists : false ,
605+ definition : None ,
606+ order_exprs : vec ! [ ] ,
607+ unbounded : false ,
608+ options : common:: get_mysql_params ( port)
609+ . into_iter ( )
610+ . map ( |( k, v) | ( k, v. expose_secret ( ) . to_string ( ) ) )
611+ . collect ( ) ,
612+ constraints : Constraints :: empty ( ) ,
613+ column_defaults : Default :: default ( ) ,
614+ } ;
615+ let table_provider = factory
616+ . create ( & ctx. state ( ) , & cmd)
617+ . await
618+ . expect ( "table provider created" ) ;
619+
620+ let ctx = SessionContext :: new ( ) ;
621+ let mem_exec = MemoryExec :: try_new ( & [ vec ! [ arrow_record. clone( ) ] ] , arrow_record. schema ( ) , None )
622+ . expect ( "memory exec created" ) ;
623+ let insert_plan = table_provider
624+ . insert_into ( & ctx. state ( ) , Arc :: new ( mem_exec) , true )
625+ . await
626+ . expect ( "insert plan created" ) ;
627+
628+ let _ = collect ( insert_plan, ctx. task_ctx ( ) )
629+ . await
630+ . expect ( "insert done" ) ;
631+ ctx. register_table ( table_name, table_provider)
632+ . expect ( "Table should be registered" ) ;
633+ let sql = format ! ( "SELECT * FROM {table_name}" ) ;
634+ let df = ctx
635+ . sql ( & sql)
636+ . await
637+ . expect ( "DataFrame should be created from query" ) ;
638+
639+ let record_batch = df. collect ( ) . await . expect ( "RecordBatch should be collected" ) ;
640+
641+ tracing:: debug!( "Original Arrow Record Batch: {:?}" , arrow_record. columns( ) ) ;
642+ tracing:: debug!(
643+ "MySQL returned Record Batch: {:?}" ,
644+ record_batch[ 0 ] . columns( )
645+ ) ;
646+
647+ let casted_result =
648+ try_cast_to ( record_batch[ 0 ] . clone ( ) , source_schema) . expect ( "Failed to cast record batch" ) ;
649+
650+ // Check results
651+ assert_eq ! ( record_batch. len( ) , 1 ) ;
652+ assert_eq ! ( record_batch[ 0 ] . num_rows( ) , arrow_record. num_rows( ) ) ;
653+ assert_eq ! ( record_batch[ 0 ] . num_columns( ) , arrow_record. num_columns( ) ) ;
654+ assert_eq ! ( arrow_record, casted_result) ;
655+ }
656+
581657async fn start_mysql_container ( port : usize ) -> RunningContainer {
582658 let running_container = common:: start_mysql_docker_container ( port)
583659 . await
@@ -588,6 +664,34 @@ async fn start_mysql_container(port: usize) -> RunningContainer {
588664 running_container
589665}
590666
667+ #[ rstest]
668+ #[ case:: binary( get_arrow_binary_record_batch( ) , "binary" ) ]
669+ #[ case:: int( get_arrow_int_record_batch( ) , "int" ) ]
670+ #[ case:: float( get_arrow_float_record_batch( ) , "float" ) ]
671+ #[ case:: utf8( get_arrow_utf8_record_batch( ) , "utf8" ) ]
672+ #[ case:: time( get_arrow_time_record_batch( ) , "time" ) ]
673+ #[ case:: timestamp( get_arrow_timestamp_record_batch( ) , "timestamp" ) ]
674+ #[ case:: date( get_arrow_date_record_batch( ) , "date" ) ]
675+ #[ case:: struct_type( get_arrow_struct_record_batch( ) , "struct" ) ]
676+ #[ case:: decimal( get_arrow_decimal_record_batch( ) , "decimal" ) ]
677+ #[ case:: interval( get_arrow_interval_record_batch( ) , "interval" ) ]
678+ #[ case:: duration( get_arrow_duration_record_batch( ) , "duration" ) ]
679+ #[ case:: list( get_arrow_list_record_batch( ) , "list" ) ]
680+ #[ case:: null( get_arrow_null_record_batch( ) , "null" ) ]
681+ #[ case:: bytea_array( get_arrow_bytea_array_record_batch( ) , "bytea_array" ) ]
682+ #[ test_log:: test( tokio:: test) ]
683+ async fn test_arrow_mysql_roundtrip (
684+ #[ case] arrow_result : ( RecordBatch , SchemaRef ) ,
685+ #[ case] table_name : & str ,
686+ ) {
687+ let port = crate :: get_random_port ( ) ;
688+ let mysql_container = start_mysql_container ( port) . await ;
689+
690+ arrow_mysql_round_trip ( port, arrow_result. 0 , arrow_result. 1 , table_name) . await ;
691+
692+ mysql_container. remove ( ) . await . expect ( "container to stop" ) ;
693+ }
694+
591695#[ rstest]
592696#[ test_log:: test( tokio:: test) ]
593697async fn test_mysql_arrow_oneway ( ) {
0 commit comments