@@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.{QueryExecution, SparkPlan}
30
30
import org .apache .spark .sql .execution .ExplainUtils .stripAQEPlan
31
31
import org .apache .spark .sql .execution .datasources .v2 .{AlterTableExec , CreateTableExec , DataSourceV2Relation , ReplaceTableExec }
32
32
import org .apache .spark .sql .internal .SQLConf
33
- import org .apache .spark .sql .types .{BooleanType , CalendarIntervalType , IntegerType , StringType }
33
+ import org .apache .spark .sql .types .{BooleanType , CalendarIntervalType , DoubleType , IntegerType , StringType , TimestampType }
34
34
import org .apache .spark .sql .util .QueryExecutionListener
35
35
import org .apache .spark .unsafe .types .UTF8String
36
36
@@ -498,43 +498,32 @@ class DataSourceV2DataFrameSuite
498
498
| """ .stripMargin)
499
499
500
500
val alterExecCol1 = executeAndKeepPhysicalPlan[AlterTableExec ] {
501
- sql(s " ALTER TABLE $tableName ALTER COLUMN salary SET DEFAULT (123 + 56) " )
502
- }
503
- checkDefaultValue(
504
- alterExecCol1.changes.collect {
505
- case u : UpdateColumnDefaultValue => u
506
- }.head,
507
- new DefaultValue (
508
- " (123 + 56)" ,
509
- new GeneralScalarExpression (
510
- " +" ,
511
- Array (LiteralValue (123 , IntegerType ), LiteralValue (56 , IntegerType )))))
512
-
513
- val alterExecCol2 = executeAndKeepPhysicalPlan[AlterTableExec ] {
514
- sql(s " ALTER TABLE $tableName ALTER COLUMN dep SET DEFAULT ('r' || 'l') " )
515
- }
516
- checkDefaultValue(
517
- alterExecCol2.changes.collect {
518
- case u : UpdateColumnDefaultValue => u
519
- }.head,
520
- new DefaultValue (
521
- " ('r' || 'l')" ,
522
- new GeneralScalarExpression (
523
- " CONCAT" ,
524
- Array (
525
- LiteralValue (UTF8String .fromString(" r" ), StringType ),
526
- LiteralValue (UTF8String .fromString(" l" ), StringType )))))
527
-
528
- val alterExecCol3 = executeAndKeepPhysicalPlan[AlterTableExec ] {
529
- sql(s " ALTER TABLE $tableName ALTER COLUMN active SET DEFAULT CAST(0 AS BOOLEAN) " )
501
+ sql(
502
+ s """
503
+ |ALTER TABLE $tableName ALTER COLUMN
504
+ | salary SET DEFAULT (123 + 56),
505
+ | dep SET DEFAULT ('r' || 'l'),
506
+ | active SET DEFAULT CAST(0 AS BOOLEAN)
507
+ | """ .stripMargin)
530
508
}
531
- checkDefaultValue(
532
- alterExecCol3.changes.collect {
533
- case u : UpdateColumnDefaultValue => u
534
- }.head,
535
- new DefaultValue (
536
- " CAST(0 AS BOOLEAN)" ,
537
- new V2Cast (LiteralValue (0 , IntegerType ), IntegerType , BooleanType )))
509
+ checkDefaultValues(
510
+ alterExecCol1.changes.map(_.asInstanceOf [UpdateColumnDefaultValue ]).toArray,
511
+ Array (
512
+ new DefaultValue (
513
+ " (123 + 56)" ,
514
+ new GeneralScalarExpression (
515
+ " +" ,
516
+ Array (LiteralValue (123 , IntegerType ), LiteralValue (56 , IntegerType )))),
517
+ new DefaultValue (
518
+ " ('r' || 'l')" ,
519
+ new GeneralScalarExpression (
520
+ " CONCAT" ,
521
+ Array (
522
+ LiteralValue (UTF8String .fromString(" r" ), StringType ),
523
+ LiteralValue (UTF8String .fromString(" l" ), StringType )))),
524
+ new DefaultValue (
525
+ " CAST(0 AS BOOLEAN)" ,
526
+ new V2Cast (LiteralValue (0 , IntegerType ), IntegerType , BooleanType ))))
538
527
}
539
528
}
540
529
}
@@ -666,13 +655,9 @@ class DataSourceV2DataFrameSuite
666
655
sql(s " ALTER TABLE $tableName ALTER COLUMN cat SET DEFAULT current_catalog() " )
667
656
}
668
657
669
- checkDefaultValue(
670
- alterExec.changes.collect {
671
- case u : UpdateColumnDefaultValue => u
672
- }.head,
673
- new DefaultValue (
674
- " current_catalog()" ,
675
- null /* No V2 Expression */ ))
658
+ checkDefaultValues(
659
+ alterExec.changes.map(_.asInstanceOf [UpdateColumnDefaultValue ]).toArray,
660
+ Array (new DefaultValue (" current_catalog()" , null /* No V2 Expression */ )))
676
661
677
662
val df1 = Seq (1 ).toDF(" dummy" )
678
663
df1.writeTo(tableName).append()
@@ -683,6 +668,109 @@ class DataSourceV2DataFrameSuite
683
668
}
684
669
}
685
670
671
+ test(" create/replace table default value expression should have a cast" ) {
672
+ val tableName = " testcat.ns1.ns2.tbl"
673
+ withTable(tableName) {
674
+
675
+ val createExec = executeAndKeepPhysicalPlan[CreateTableExec ] {
676
+ sql(
677
+ s """
678
+ |CREATE TABLE $tableName (
679
+ | col1 int,
680
+ | col2 timestamp DEFAULT '2018-11-17 13:33:33',
681
+ | col3 double DEFAULT 1)
682
+ | """ .stripMargin)
683
+ }
684
+ checkDefaultValues(
685
+ createExec.columns,
686
+ Array (
687
+ null ,
688
+ new ColumnDefaultValue (
689
+ " '2018-11-17 13:33:33'" ,
690
+ new LiteralValue (1542490413000000L , TimestampType ),
691
+ new LiteralValue (1542490413000000L , TimestampType )),
692
+ new ColumnDefaultValue (
693
+ " 1" ,
694
+ new V2Cast (LiteralValue (1 , IntegerType ), IntegerType , DoubleType ),
695
+ LiteralValue (1.0 , DoubleType ))))
696
+
697
+ val replaceExec = executeAndKeepPhysicalPlan[ReplaceTableExec ] {
698
+ sql(
699
+ s """
700
+ |REPLACE TABLE $tableName (
701
+ | col1 int,
702
+ | col2 timestamp DEFAULT '2022-02-23 05:55:55',
703
+ | col3 double DEFAULT (1 + 1))
704
+ | """ .stripMargin)
705
+ }
706
+ checkDefaultValues(
707
+ replaceExec.columns,
708
+ Array (
709
+ null ,
710
+ new ColumnDefaultValue (
711
+ " '2022-02-23 05:55:55'" ,
712
+ LiteralValue (1645624555000000L , TimestampType ),
713
+ LiteralValue (1645624555000000L , TimestampType )),
714
+ new ColumnDefaultValue (
715
+ " (1 + 1)" ,
716
+ new V2Cast (
717
+ new GeneralScalarExpression (" +" , Array (LiteralValue (1 , IntegerType ),
718
+ LiteralValue (1 , IntegerType ))),
719
+ IntegerType ,
720
+ DoubleType ),
721
+ LiteralValue (2.0 , DoubleType ))))
722
+ }
723
+ }
724
+
725
+ test(" alter table default value expression should have a cast" ) {
726
+ val tableName = " testcat.ns1.ns2.tbl"
727
+ withTable(tableName) {
728
+
729
+ sql(s " CREATE TABLE $tableName (col1 int) using foo " )
730
+ val alterExec = executeAndKeepPhysicalPlan[AlterTableExec ] {
731
+ sql(
732
+ s """
733
+ |ALTER TABLE $tableName ADD COLUMNS (
734
+ | col2 timestamp DEFAULT '2018-11-17 13:33:33',
735
+ | col3 double DEFAULT 1)
736
+ | """ .stripMargin)
737
+ }
738
+
739
+ checkDefaultValues(
740
+ alterExec.changes.map(_.asInstanceOf [AddColumn ]).toArray,
741
+ Array (
742
+ new ColumnDefaultValue (
743
+ " '2018-11-17 13:33:33'" ,
744
+ LiteralValue (1542490413000000L , TimestampType ),
745
+ LiteralValue (1542490413000000L , TimestampType )),
746
+ new ColumnDefaultValue (
747
+ " 1" ,
748
+ new V2Cast (LiteralValue (1 , IntegerType ), IntegerType , DoubleType ),
749
+ LiteralValue (1.0 , DoubleType ))))
750
+
751
+ val alterCol1 = executeAndKeepPhysicalPlan[AlterTableExec ] {
752
+ sql(
753
+ s """
754
+ |ALTER TABLE $tableName ALTER COLUMN
755
+ | col2 SET DEFAULT '2022-02-23 05:55:55',
756
+ | col3 SET DEFAULT (1 + 1)
757
+ | """ .stripMargin)
758
+ }
759
+ checkDefaultValues(
760
+ alterCol1.changes.map(_.asInstanceOf [UpdateColumnDefaultValue ]).toArray,
761
+ Array (
762
+ new DefaultValue (" '2022-02-23 05:55:55'" ,
763
+ LiteralValue (1645624555000000L , TimestampType )),
764
+ new DefaultValue (
765
+ " (1 + 1)" ,
766
+ new V2Cast (
767
+ new GeneralScalarExpression (" +" , Array (LiteralValue (1 , IntegerType ),
768
+ LiteralValue (1 , IntegerType ))),
769
+ IntegerType ,
770
+ DoubleType ))))
771
+ }
772
+ }
773
+
686
774
private def executeAndKeepPhysicalPlan [T <: SparkPlan ](func : => Unit ): T = {
687
775
val qe = withQueryExecutionsCaptured(spark) {
688
776
func
@@ -718,13 +806,18 @@ class DataSourceV2DataFrameSuite
718
806
}
719
807
}
720
808
721
- private def checkDefaultValue (
722
- column : UpdateColumnDefaultValue ,
723
- expectedDefault : DefaultValue ): Unit = {
724
- assert(
725
- column.newCurrentDefault() == expectedDefault,
726
- s " Default value mismatch for column ' ${column.toString}': " +
727
- s " expected $expectedDefault but found ${column.newCurrentDefault()}" )
809
+ private def checkDefaultValues (
810
+ columns : Array [UpdateColumnDefaultValue ],
811
+ expectedDefaultValues : Array [DefaultValue ]): Unit = {
812
+ assert(columns.length == expectedDefaultValues.length)
813
+
814
+ columns.zip(expectedDefaultValues).foreach {
815
+ case (column, expectedDefault) =>
816
+ assert(
817
+ column.newCurrentDefault() == expectedDefault,
818
+ s " Default value mismatch for column ' ${column.toString}': " +
819
+ s " expected $expectedDefault but found ${column.newCurrentDefault}" )
820
+ }
728
821
}
729
822
730
823
private def checkDropDefaultValue (
0 commit comments