Skip to content

Commit 8d73424

Browse files
szehon-hoyhuang-db
authored andcommitted
[SPARK-52235][SQL] Add implicit cast to DefaultValue V2 Expressions passed to DSV2
### What changes were proposed in this pull request? Add implicit cast to DefaultValue V2 expressions passed to DSV2, by adding rule to TypeCoercion. ### Why are the changes needed? Now default values are passed as V2 Expressions to DSV2 in Create/Replace/Alter table DDL statements. We should match what the normal default value analysis path does for sql strings (ie, implicit cast). ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Add unit test to DatasourceV2DataFrameSuite ### Was this patch authored or co-authored using generative AI tooling? No Closes apache#50959 from szehon-ho/cast_default_value. Authored-by: Szehon Ho <szehon.apache@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent a9e6831 commit 8d73424

File tree

5 files changed

+221
-58
lines changed

5 files changed

+221
-58
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ object AnsiTypeCoercion extends TypeCoercionBase {
7878
UnpivotCoercion ::
7979
WidenSetOperationTypes ::
8080
ProcedureArgumentCoercion ::
81+
DefaultValueExpressionCoercion ::
8182
new AnsiCombinedTypeCoercionRule(
8283
CollationTypeCasts ::
8384
InConversion ::

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ object TypeCoercion extends TypeCoercionBase {
5050
UnpivotCoercion ::
5151
WidenSetOperationTypes ::
5252
ProcedureArgumentCoercion ::
53+
DefaultValueExpressionCoercion ::
5354
new CombinedTypeCoercionRule(
5455
CollationTypeCasts ::
5556
InConversion ::

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionBase.scala

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,22 @@ import org.apache.spark.sql.catalyst.expressions.{
3232
WindowSpecDefinition
3333
}
3434
import org.apache.spark.sql.catalyst.plans.logical.{
35+
AddColumns,
36+
AlterColumns,
3537
Call,
38+
CreateTable,
3639
Except,
3740
Intersect,
3841
LogicalPlan,
3942
Project,
43+
ReplaceTable,
4044
Union,
4145
Unpivot
4246
}
4347
import org.apache.spark.sql.catalyst.rules.Rule
4448
import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin
49+
import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns
50+
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper
4551
import org.apache.spark.sql.connector.catalog.procedures.BoundProcedure
4652
import org.apache.spark.sql.types.DataType
4753

@@ -81,6 +87,67 @@ abstract class TypeCoercionBase extends TypeCoercionHelper {
8187
}
8288
}
8389

90+
/**
91+
* A type coercion rule that implicitly casts default value expression in DDL statements
92+
* to expected types.
93+
*/
94+
object DefaultValueExpressionCoercion extends Rule[LogicalPlan] {
95+
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
96+
case createTable @ CreateTable(_, cols, _, _, _) if createTable.resolved &&
97+
cols.exists(_.defaultValue.isDefined) =>
98+
val newCols = cols.map { c =>
99+
c.copy(defaultValue = c.defaultValue.map(d =>
100+
d.copy(child = ResolveDefaultColumns.coerceDefaultValue(
101+
d.child,
102+
c.dataType,
103+
"CREATE TABLE",
104+
c.name,
105+
d.originalSQL))))
106+
}
107+
createTable.copy(columns = newCols)
108+
109+
case replaceTable @ ReplaceTable(_, cols, _, _, _) if replaceTable.resolved &&
110+
cols.exists(_.defaultValue.isDefined) =>
111+
val newCols = cols.map { c =>
112+
c.copy(defaultValue = c.defaultValue.map(d =>
113+
d.copy(child = ResolveDefaultColumns.coerceDefaultValue(
114+
d.child,
115+
c.dataType,
116+
"REPLACE TABLE",
117+
c.name,
118+
d.originalSQL))))
119+
}
120+
replaceTable.copy(columns = newCols)
121+
122+
case addColumns @ AddColumns(_, cols) if addColumns.resolved &&
123+
cols.exists(_.default.isDefined) =>
124+
val newCols = cols.map { c =>
125+
c.copy(default = c.default.map(d =>
126+
d.copy(child = ResolveDefaultColumns.coerceDefaultValue(
127+
d.child,
128+
c.dataType,
129+
"ALTER TABLE ADD COLUMNS",
130+
c.colName,
131+
d.originalSQL))))
132+
}
133+
addColumns.copy(columnsToAdd = newCols)
134+
135+
case alterColumns @ AlterColumns(_, specs) if alterColumns.resolved &&
136+
specs.exists(_.newDefaultExpression.isDefined) =>
137+
val newSpecs = specs.map { c =>
138+
val dataType = c.column.asInstanceOf[ResolvedFieldName].field.dataType
139+
c.copy(newDefaultExpression = c.newDefaultExpression.map(d =>
140+
d.copy(child = ResolveDefaultColumns.coerceDefaultValue(
141+
d.child,
142+
dataType,
143+
"ALTER TABLE ALTER COLUMN",
144+
c.column.name.quoted,
145+
d.originalSQL))))
146+
}
147+
alterColumns.copy(specs = newSpecs)
148+
}
149+
}
150+
84151
/**
85152
* Widens the data types of the [[Unpivot]] values.
86153
*/

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@
1818
package org.apache.spark.sql.catalyst.util
1919

2020
import scala.collection.mutable.ArrayBuffer
21+
import scala.util.control.NonFatal
2122

22-
import org.apache.spark.{SparkException, SparkThrowable, SparkUnsupportedOperationException}
23+
import org.apache.spark.{SparkException, SparkUnsupportedOperationException}
2324
import org.apache.spark.internal.{Logging, MDC}
2425
import org.apache.spark.internal.LogKeys._
2526
import org.apache.spark.sql.AnalysisException
@@ -432,20 +433,20 @@ object ResolveDefaultColumns extends QueryErrorsBase
432433
targetType: DataType,
433434
colName: String): Option[Expression] = {
434435
expr match {
435-
case l: Literal if !Seq(targetType, l.dataType).exists(_ match {
436+
case e if e.foldable && !Seq(targetType, e.dataType).exists(_ match {
436437
case _: BooleanType | _: ArrayType | _: StructType | _: MapType => true
437438
case _ => false
438439
}) =>
439-
val casted = Cast(l, targetType, Some(conf.sessionLocalTimeZone), evalMode = EvalMode.TRY)
440+
val casted = Cast(e, targetType, Some(conf.sessionLocalTimeZone), evalMode = EvalMode.TRY)
440441
try {
441442
Option(casted.eval(EmptyRow)).map(Literal(_, targetType))
442443
} catch {
443-
case e @ ( _: SparkThrowable | _: RuntimeException) =>
444-
logWarning(log"Failed to cast default value '${MDC(COLUMN_DEFAULT_VALUE, l)}' " +
444+
case NonFatal(ex) =>
445+
logWarning(log"Failed to cast default value '${MDC(COLUMN_DEFAULT_VALUE, e)}' " +
445446
log"for column ${MDC(COLUMN_NAME, colName)} " +
446-
log"from ${MDC(COLUMN_DATA_TYPE_SOURCE, l.dataType)} " +
447+
log"from ${MDC(COLUMN_DATA_TYPE_SOURCE, e.dataType)} " +
447448
log"to ${MDC(COLUMN_DATA_TYPE_TARGET, targetType)} " +
448-
log"due to ${MDC(ERROR, e.getMessage)}", e)
449+
log"due to ${MDC(ERROR, ex.getMessage)}", ex)
449450
None
450451
}
451452
case _ => None

sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala

Lines changed: 144 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.{QueryExecution, SparkPlan}
3030
import org.apache.spark.sql.execution.ExplainUtils.stripAQEPlan
3131
import org.apache.spark.sql.execution.datasources.v2.{AlterTableExec, CreateTableExec, DataSourceV2Relation, ReplaceTableExec}
3232
import org.apache.spark.sql.internal.SQLConf
33-
import org.apache.spark.sql.types.{BooleanType, CalendarIntervalType, IntegerType, StringType}
33+
import org.apache.spark.sql.types.{BooleanType, CalendarIntervalType, DoubleType, IntegerType, StringType, TimestampType}
3434
import org.apache.spark.sql.util.QueryExecutionListener
3535
import org.apache.spark.unsafe.types.UTF8String
3636

@@ -498,43 +498,32 @@ class DataSourceV2DataFrameSuite
498498
|""".stripMargin)
499499

500500
val alterExecCol1 = executeAndKeepPhysicalPlan[AlterTableExec] {
501-
sql(s"ALTER TABLE $tableName ALTER COLUMN salary SET DEFAULT (123 + 56)")
502-
}
503-
checkDefaultValue(
504-
alterExecCol1.changes.collect {
505-
case u: UpdateColumnDefaultValue => u
506-
}.head,
507-
new DefaultValue(
508-
"(123 + 56)",
509-
new GeneralScalarExpression(
510-
"+",
511-
Array(LiteralValue(123, IntegerType), LiteralValue(56, IntegerType)))))
512-
513-
val alterExecCol2 = executeAndKeepPhysicalPlan[AlterTableExec] {
514-
sql(s"ALTER TABLE $tableName ALTER COLUMN dep SET DEFAULT ('r' || 'l')")
515-
}
516-
checkDefaultValue(
517-
alterExecCol2.changes.collect {
518-
case u: UpdateColumnDefaultValue => u
519-
}.head,
520-
new DefaultValue(
521-
"('r' || 'l')",
522-
new GeneralScalarExpression(
523-
"CONCAT",
524-
Array(
525-
LiteralValue(UTF8String.fromString("r"), StringType),
526-
LiteralValue(UTF8String.fromString("l"), StringType)))))
527-
528-
val alterExecCol3 = executeAndKeepPhysicalPlan[AlterTableExec] {
529-
sql(s"ALTER TABLE $tableName ALTER COLUMN active SET DEFAULT CAST(0 AS BOOLEAN)")
501+
sql(
502+
s"""
503+
|ALTER TABLE $tableName ALTER COLUMN
504+
| salary SET DEFAULT (123 + 56),
505+
| dep SET DEFAULT ('r' || 'l'),
506+
| active SET DEFAULT CAST(0 AS BOOLEAN)
507+
|""".stripMargin)
530508
}
531-
checkDefaultValue(
532-
alterExecCol3.changes.collect {
533-
case u: UpdateColumnDefaultValue => u
534-
}.head,
535-
new DefaultValue(
536-
"CAST(0 AS BOOLEAN)",
537-
new V2Cast(LiteralValue(0, IntegerType), IntegerType, BooleanType)))
509+
checkDefaultValues(
510+
alterExecCol1.changes.map(_.asInstanceOf[UpdateColumnDefaultValue]).toArray,
511+
Array(
512+
new DefaultValue(
513+
"(123 + 56)",
514+
new GeneralScalarExpression(
515+
"+",
516+
Array(LiteralValue(123, IntegerType), LiteralValue(56, IntegerType)))),
517+
new DefaultValue(
518+
"('r' || 'l')",
519+
new GeneralScalarExpression(
520+
"CONCAT",
521+
Array(
522+
LiteralValue(UTF8String.fromString("r"), StringType),
523+
LiteralValue(UTF8String.fromString("l"), StringType)))),
524+
new DefaultValue(
525+
"CAST(0 AS BOOLEAN)",
526+
new V2Cast(LiteralValue(0, IntegerType), IntegerType, BooleanType))))
538527
}
539528
}
540529
}
@@ -666,13 +655,9 @@ class DataSourceV2DataFrameSuite
666655
sql(s"ALTER TABLE $tableName ALTER COLUMN cat SET DEFAULT current_catalog()")
667656
}
668657

669-
checkDefaultValue(
670-
alterExec.changes.collect {
671-
case u: UpdateColumnDefaultValue => u
672-
}.head,
673-
new DefaultValue(
674-
"current_catalog()",
675-
null /* No V2 Expression */))
658+
checkDefaultValues(
659+
alterExec.changes.map(_.asInstanceOf[UpdateColumnDefaultValue]).toArray,
660+
Array(new DefaultValue("current_catalog()", null /* No V2 Expression */)))
676661

677662
val df1 = Seq(1).toDF("dummy")
678663
df1.writeTo(tableName).append()
@@ -683,6 +668,109 @@ class DataSourceV2DataFrameSuite
683668
}
684669
}
685670

671+
test("create/replace table default value expression should have a cast") {
672+
val tableName = "testcat.ns1.ns2.tbl"
673+
withTable(tableName) {
674+
675+
val createExec = executeAndKeepPhysicalPlan[CreateTableExec] {
676+
sql(
677+
s"""
678+
|CREATE TABLE $tableName (
679+
| col1 int,
680+
| col2 timestamp DEFAULT '2018-11-17 13:33:33',
681+
| col3 double DEFAULT 1)
682+
|""".stripMargin)
683+
}
684+
checkDefaultValues(
685+
createExec.columns,
686+
Array(
687+
null,
688+
new ColumnDefaultValue(
689+
"'2018-11-17 13:33:33'",
690+
new LiteralValue(1542490413000000L, TimestampType),
691+
new LiteralValue(1542490413000000L, TimestampType)),
692+
new ColumnDefaultValue(
693+
"1",
694+
new V2Cast(LiteralValue(1, IntegerType), IntegerType, DoubleType),
695+
LiteralValue(1.0, DoubleType))))
696+
697+
val replaceExec = executeAndKeepPhysicalPlan[ReplaceTableExec] {
698+
sql(
699+
s"""
700+
|REPLACE TABLE $tableName (
701+
| col1 int,
702+
| col2 timestamp DEFAULT '2022-02-23 05:55:55',
703+
| col3 double DEFAULT (1 + 1))
704+
|""".stripMargin)
705+
}
706+
checkDefaultValues(
707+
replaceExec.columns,
708+
Array(
709+
null,
710+
new ColumnDefaultValue(
711+
"'2022-02-23 05:55:55'",
712+
LiteralValue(1645624555000000L, TimestampType),
713+
LiteralValue(1645624555000000L, TimestampType)),
714+
new ColumnDefaultValue(
715+
"(1 + 1)",
716+
new V2Cast(
717+
new GeneralScalarExpression("+", Array(LiteralValue(1, IntegerType),
718+
LiteralValue(1, IntegerType))),
719+
IntegerType,
720+
DoubleType),
721+
LiteralValue(2.0, DoubleType))))
722+
}
723+
}
724+
725+
test("alter table default value expression should have a cast") {
726+
val tableName = "testcat.ns1.ns2.tbl"
727+
withTable(tableName) {
728+
729+
sql(s"CREATE TABLE $tableName (col1 int) using foo")
730+
val alterExec = executeAndKeepPhysicalPlan[AlterTableExec] {
731+
sql(
732+
s"""
733+
|ALTER TABLE $tableName ADD COLUMNS (
734+
| col2 timestamp DEFAULT '2018-11-17 13:33:33',
735+
| col3 double DEFAULT 1)
736+
|""".stripMargin)
737+
}
738+
739+
checkDefaultValues(
740+
alterExec.changes.map(_.asInstanceOf[AddColumn]).toArray,
741+
Array(
742+
new ColumnDefaultValue(
743+
"'2018-11-17 13:33:33'",
744+
LiteralValue(1542490413000000L, TimestampType),
745+
LiteralValue(1542490413000000L, TimestampType)),
746+
new ColumnDefaultValue(
747+
"1",
748+
new V2Cast(LiteralValue(1, IntegerType), IntegerType, DoubleType),
749+
LiteralValue(1.0, DoubleType))))
750+
751+
val alterCol1 = executeAndKeepPhysicalPlan[AlterTableExec] {
752+
sql(
753+
s"""
754+
|ALTER TABLE $tableName ALTER COLUMN
755+
| col2 SET DEFAULT '2022-02-23 05:55:55',
756+
| col3 SET DEFAULT (1 + 1)
757+
|""".stripMargin)
758+
}
759+
checkDefaultValues(
760+
alterCol1.changes.map(_.asInstanceOf[UpdateColumnDefaultValue]).toArray,
761+
Array(
762+
new DefaultValue("'2022-02-23 05:55:55'",
763+
LiteralValue(1645624555000000L, TimestampType)),
764+
new DefaultValue(
765+
"(1 + 1)",
766+
new V2Cast(
767+
new GeneralScalarExpression("+", Array(LiteralValue(1, IntegerType),
768+
LiteralValue(1, IntegerType))),
769+
IntegerType,
770+
DoubleType))))
771+
}
772+
}
773+
686774
private def executeAndKeepPhysicalPlan[T <: SparkPlan](func: => Unit): T = {
687775
val qe = withQueryExecutionsCaptured(spark) {
688776
func
@@ -718,13 +806,18 @@ class DataSourceV2DataFrameSuite
718806
}
719807
}
720808

721-
private def checkDefaultValue(
722-
column: UpdateColumnDefaultValue,
723-
expectedDefault: DefaultValue): Unit = {
724-
assert(
725-
column.newCurrentDefault() == expectedDefault,
726-
s"Default value mismatch for column '${column.toString}': " +
727-
s"expected $expectedDefault but found ${column.newCurrentDefault()}")
809+
private def checkDefaultValues(
810+
columns: Array[UpdateColumnDefaultValue],
811+
expectedDefaultValues: Array[DefaultValue]): Unit = {
812+
assert(columns.length == expectedDefaultValues.length)
813+
814+
columns.zip(expectedDefaultValues).foreach {
815+
case (column, expectedDefault) =>
816+
assert(
817+
column.newCurrentDefault() == expectedDefault,
818+
s"Default value mismatch for column '${column.toString}': " +
819+
s"expected $expectedDefault but found ${column.newCurrentDefault}")
820+
}
728821
}
729822

730823
private def checkDropDefaultValue(

0 commit comments

Comments
 (0)