Skip to content

Commit a70e326

Browse files
committed
[SPARK-52235][SQL] Add implicit cast to DefaultValue V2 Expressions passed to DSV2
1 parent 4b3a653 commit a70e326

File tree

5 files changed

+240
-57
lines changed

5 files changed

+240
-57
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ object AnsiTypeCoercion extends TypeCoercionBase {
7878
UnpivotCoercion ::
7979
WidenSetOperationTypes ::
8080
ProcedureArgumentCoercion ::
81+
DefaultValueExpressionCoercion ::
8182
new AnsiCombinedTypeCoercionRule(
8283
CollationTypeCasts ::
8384
InConversion ::

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ object TypeCoercion extends TypeCoercionBase {
5050
UnpivotCoercion ::
5151
WidenSetOperationTypes ::
5252
ProcedureArgumentCoercion ::
53+
DefaultValueExpressionCoercion ::
5354
new CombinedTypeCoercionRule(
5455
CollationTypeCasts ::
5556
InConversion ::

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionBase.scala

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,22 @@ import org.apache.spark.sql.catalyst.expressions.{
3232
WindowSpecDefinition
3333
}
3434
import org.apache.spark.sql.catalyst.plans.logical.{
35+
AddColumns,
36+
AlterColumns,
3537
Call,
38+
CreateTable,
3639
Except,
3740
Intersect,
3841
LogicalPlan,
3942
Project,
43+
ReplaceTable,
4044
Union,
4145
Unpivot
4246
}
4347
import org.apache.spark.sql.catalyst.rules.Rule
4448
import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin
49+
import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns
50+
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper
4551
import org.apache.spark.sql.connector.catalog.procedures.BoundProcedure
4652
import org.apache.spark.sql.types.DataType
4753

@@ -81,6 +87,66 @@ abstract class TypeCoercionBase extends TypeCoercionHelper {
8187
}
8288
}
8389

90+
object DefaultValueExpressionCoercion extends Rule[LogicalPlan] {
91+
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
92+
case createTable @ CreateTable(_, cols, _, _, _) if createTable.resolved &&
93+
cols.exists(_.defaultValue.isDefined) =>
94+
val newCols = cols.map { c =>
95+
c.copy(defaultValue = c.defaultValue.map(d =>
96+
d.copy(child = ResolveDefaultColumns.coerceDefaultValue(
97+
d.child,
98+
c.dataType,
99+
"CREATE TABLE",
100+
c.name,
101+
d.originalSQL,
102+
castWiderOnlyLiterals = false))))
103+
}
104+
createTable.copy(columns = newCols)
105+
106+
case replaceTable @ ReplaceTable(_, cols, _, _, _) if replaceTable.resolved &&
107+
cols.exists(_.defaultValue.isDefined) =>
108+
val newCols = cols.map { c =>
109+
c.copy(defaultValue = c.defaultValue.map(d =>
110+
d.copy(child = ResolveDefaultColumns.coerceDefaultValue(
111+
d.child,
112+
c.dataType,
113+
"REPLACE TABLE",
114+
c.name,
115+
d.originalSQL,
116+
castWiderOnlyLiterals = false))))
117+
}
118+
replaceTable.copy(columns = newCols)
119+
120+
case addColumns @ AddColumns(_, cols) if addColumns.resolved &&
121+
cols.exists(_.default.isDefined) =>
122+
val newCols = cols.map { c =>
123+
c.copy(default = c.default.map(d =>
124+
d.copy(child = ResolveDefaultColumns.coerceDefaultValue(
125+
d.child,
126+
c.dataType,
127+
"ALTER TABLE ADD COLUMNS",
128+
c.colName,
129+
d.originalSQL,
130+
castWiderOnlyLiterals = false))))
131+
}
132+
addColumns.copy(columnsToAdd = newCols)
133+
case alterColumns @ AlterColumns(_, specs) if alterColumns.resolved &&
134+
specs.exists(_.newDefaultExpression.isDefined) =>
135+
val newSpecs = specs.map { c =>
136+
val dataType = c.column.asInstanceOf[ResolvedFieldName].field.dataType
137+
c.copy(newDefaultExpression = c.newDefaultExpression.map(d =>
138+
d.copy(child = ResolveDefaultColumns.coerceDefaultValue(
139+
d.child,
140+
dataType,
141+
"ALTER TABLE ALTER COLUMN",
142+
c.column.name.quoted,
143+
d.originalSQL,
144+
castWiderOnlyLiterals = false))))
145+
}
146+
alterColumns.copy(specs = newSpecs)
147+
}
148+
}
149+
84150
/**
85151
* Widens the data types of the [[Unpivot]] values.
86152
*/

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -432,18 +432,34 @@ object ResolveDefaultColumns extends QueryErrorsBase
432432
targetType: DataType,
433433
colName: String): Option[Expression] = {
434434
expr match {
435-
case l: Literal if !Seq(targetType, l.dataType).exists(_ match {
435+
case l: Literal => defaultValueFromWiderType(l, targetType, colName)
436+
case _ => None
437+
}
438+
}
439+
440+
/**
441+
* If the provided default value is a literal of a wider type than the target column,
442+
* but the literal value fits within the narrower type, just coerce it for convenience.
443+
* Exclude boolean/array/struct/map types from consideration for this type coercion to
444+
* avoid surprising behavior like interpreting "false" as integer zero.
445+
*/
446+
def defaultValueFromWiderType(
447+
expr: Expression,
448+
targetType: DataType,
449+
colName: String): Option[Expression] = {
450+
expr match {
451+
case e if !Seq(targetType, e.dataType).exists(_ match {
436452
case _: BooleanType | _: ArrayType | _: StructType | _: MapType => true
437453
case _ => false
438454
}) =>
439-
val casted = Cast(l, targetType, Some(conf.sessionLocalTimeZone), evalMode = EvalMode.TRY)
455+
val casted = Cast(e, targetType, Some(conf.sessionLocalTimeZone), evalMode = EvalMode.TRY)
440456
try {
441457
Option(casted.eval(EmptyRow)).map(Literal(_, targetType))
442458
} catch {
443459
case e @ ( _: SparkThrowable | _: RuntimeException) =>
444-
logWarning(log"Failed to cast default value '${MDC(COLUMN_DEFAULT_VALUE, l)}' " +
460+
logWarning(log"Failed to cast default value '${MDC(COLUMN_DEFAULT_VALUE, e)}' " +
445461
log"for column ${MDC(COLUMN_NAME, colName)} " +
446-
log"from ${MDC(COLUMN_DATA_TYPE_SOURCE, l.dataType)} " +
462+
log"from ${MDC(COLUMN_DATA_TYPE_SOURCE, expr.dataType)} " +
447463
log"to ${MDC(COLUMN_DATA_TYPE_TARGET, targetType)} " +
448464
log"due to ${MDC(ERROR, e.getMessage)}", e)
449465
None
@@ -461,7 +477,8 @@ object ResolveDefaultColumns extends QueryErrorsBase
461477
dataType: DataType,
462478
statementType: String,
463479
colName: String,
464-
defaultSQL: String): Expression = {
480+
defaultSQL: String,
481+
castWiderOnlyLiterals: Boolean = true): Expression = {
465482
val supplanted = CharVarcharUtils.replaceCharVarcharWithString(dataType)
466483
// Perform implicit coercion from the provided expression type to the required column type.
467484
val ret = analyzed match {
@@ -470,7 +487,12 @@ object ResolveDefaultColumns extends QueryErrorsBase
470487
case canUpCast if Cast.canUpCast(canUpCast.dataType, supplanted) =>
471488
Cast(analyzed, supplanted, Some(conf.sessionLocalTimeZone))
472489
case other =>
473-
defaultValueFromWiderTypeLiteral(other, supplanted, colName).getOrElse(
490+
val casted = if (castWiderOnlyLiterals) {
491+
defaultValueFromWiderTypeLiteral(other, supplanted, colName)
492+
} else {
493+
defaultValueFromWiderType(other, supplanted, colName)
494+
}
495+
casted.getOrElse(
474496
throw QueryCompilationErrors.defaultValuesDataTypeError(
475497
statementType, colName, defaultSQL, dataType, other.dataType))
476498
}

sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala

Lines changed: 144 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.{QueryExecution, SparkPlan}
3030
import org.apache.spark.sql.execution.ExplainUtils.stripAQEPlan
3131
import org.apache.spark.sql.execution.datasources.v2.{AlterTableExec, CreateTableExec, DataSourceV2Relation, ReplaceTableExec}
3232
import org.apache.spark.sql.internal.SQLConf
33-
import org.apache.spark.sql.types.{BooleanType, CalendarIntervalType, IntegerType, StringType}
33+
import org.apache.spark.sql.types.{BooleanType, CalendarIntervalType, DoubleType, IntegerType, StringType, TimestampType}
3434
import org.apache.spark.sql.util.QueryExecutionListener
3535
import org.apache.spark.unsafe.types.UTF8String
3636

@@ -498,43 +498,32 @@ class DataSourceV2DataFrameSuite
498498
|""".stripMargin)
499499

500500
val alterExecCol1 = executeAndKeepPhysicalPlan[AlterTableExec] {
501-
sql(s"ALTER TABLE $tableName ALTER COLUMN salary SET DEFAULT (123 + 56)")
502-
}
503-
checkDefaultValue(
504-
alterExecCol1.changes.collect {
505-
case u: UpdateColumnDefaultValue => u
506-
}.head,
507-
new DefaultValue(
508-
"(123 + 56)",
509-
new GeneralScalarExpression(
510-
"+",
511-
Array(LiteralValue(123, IntegerType), LiteralValue(56, IntegerType)))))
512-
513-
val alterExecCol2 = executeAndKeepPhysicalPlan[AlterTableExec] {
514-
sql(s"ALTER TABLE $tableName ALTER COLUMN salary SET DEFAULT ('r' || 'l')")
515-
}
516-
checkDefaultValue(
517-
alterExecCol2.changes.collect {
518-
case u: UpdateColumnDefaultValue => u
519-
}.head,
520-
new DefaultValue(
521-
"('r' || 'l')",
522-
new GeneralScalarExpression(
523-
"CONCAT",
524-
Array(
525-
LiteralValue(UTF8String.fromString("r"), StringType),
526-
LiteralValue(UTF8String.fromString("l"), StringType)))))
527-
528-
val alterExecCol3 = executeAndKeepPhysicalPlan[AlterTableExec] {
529-
sql(s"ALTER TABLE $tableName ALTER COLUMN salary SET DEFAULT CAST(0 AS BOOLEAN)")
501+
sql(
502+
s"""
503+
|ALTER TABLE $tableName ALTER COLUMN
504+
| salary SET DEFAULT (123 + 56),
505+
| dep SET DEFAULT ('r' || 'l'),
506+
| active SET DEFAULT CAST(0 AS BOOLEAN)
507+
|""".stripMargin)
530508
}
531-
checkDefaultValue(
532-
alterExecCol3.changes.collect {
533-
case u: UpdateColumnDefaultValue => u
534-
}.head,
535-
new DefaultValue(
536-
"CAST(0 AS BOOLEAN)",
537-
new V2Cast(LiteralValue(0, IntegerType), IntegerType, BooleanType)))
509+
checkDefaultValues(
510+
alterExecCol1.changes.map(_.asInstanceOf[UpdateColumnDefaultValue]).toArray,
511+
Array(
512+
new DefaultValue(
513+
"(123 + 56)",
514+
new GeneralScalarExpression(
515+
"+",
516+
Array(LiteralValue(123, IntegerType), LiteralValue(56, IntegerType)))),
517+
new DefaultValue(
518+
"('r' || 'l')",
519+
new GeneralScalarExpression(
520+
"CONCAT",
521+
Array(
522+
LiteralValue(UTF8String.fromString("r"), StringType),
523+
LiteralValue(UTF8String.fromString("l"), StringType)))),
524+
new DefaultValue(
525+
"CAST(0 AS BOOLEAN)",
526+
new V2Cast(LiteralValue(0, IntegerType), IntegerType, BooleanType))))
538527
}
539528
}
540529
}
@@ -666,13 +655,9 @@ class DataSourceV2DataFrameSuite
666655
sql(s"ALTER TABLE $tableName ALTER COLUMN cat SET DEFAULT current_catalog()")
667656
}
668657

669-
checkDefaultValue(
670-
alterExec.changes.collect {
671-
case u: UpdateColumnDefaultValue => u
672-
}.head,
673-
new DefaultValue(
674-
"current_catalog()",
675-
null /* No V2 Expression */))
658+
checkDefaultValues(
659+
alterExec.changes.map(_.asInstanceOf[UpdateColumnDefaultValue]).toArray,
660+
Array(new DefaultValue("current_catalog()", null /* No V2 Expression */)))
676661

677662
val df1 = Seq(1).toDF("dummy")
678663
df1.writeTo(tableName).append()
@@ -683,6 +668,109 @@ class DataSourceV2DataFrameSuite
683668
}
684669
}
685670

671+
test("create/replace table default value expression should have a cast") {
672+
val tableName = "testcat.ns1.ns2.tbl"
673+
withTable(tableName) {
674+
675+
val createExec = executeAndKeepPhysicalPlan[CreateTableExec] {
676+
sql(
677+
s"""
678+
|CREATE TABLE $tableName (
679+
| col1 int,
680+
| col2 timestamp DEFAULT '2018-11-17 13:33:33',
681+
| col3 double DEFAULT 1)
682+
|""".stripMargin)
683+
}
684+
checkDefaultValues(
685+
createExec.columns,
686+
Array(
687+
null,
688+
new ColumnDefaultValue(
689+
"'2018-11-17 13:33:33'",
690+
new LiteralValue(1542490413000000L, TimestampType),
691+
new LiteralValue(1542490413000000L, TimestampType)),
692+
new ColumnDefaultValue(
693+
"1",
694+
new V2Cast(LiteralValue(1, IntegerType), IntegerType, DoubleType),
695+
LiteralValue(1.0, DoubleType))))
696+
697+
val replaceExec = executeAndKeepPhysicalPlan[ReplaceTableExec] {
698+
sql(
699+
s"""
700+
|REPLACE TABLE $tableName (
701+
| col1 int,
702+
| col2 timestamp DEFAULT '2022-02-23 05:55:55',
703+
| col3 double DEFAULT (1 + 1))
704+
|""".stripMargin)
705+
}
706+
checkDefaultValues(
707+
replaceExec.columns,
708+
Array(
709+
null,
710+
new ColumnDefaultValue(
711+
"'2022-02-23 05:55:55'",
712+
LiteralValue(1645624555000000L, TimestampType),
713+
LiteralValue(1645624555000000L, TimestampType)),
714+
new ColumnDefaultValue(
715+
"(1 + 1)",
716+
new V2Cast(
717+
new GeneralScalarExpression("+", Array(LiteralValue(1, IntegerType),
718+
LiteralValue(1, IntegerType))),
719+
IntegerType,
720+
DoubleType),
721+
LiteralValue(2.0, DoubleType))))
722+
}
723+
}
724+
725+
test("alter table default value expression should have a cast") {
726+
val tableName = "testcat.ns1.ns2.tbl"
727+
withTable(tableName) {
728+
729+
sql(s"CREATE TABLE $tableName (col1 int) using foo")
730+
val alterExec = executeAndKeepPhysicalPlan[AlterTableExec] {
731+
sql(
732+
s"""
733+
|ALTER TABLE $tableName ADD COLUMNS (
734+
| col2 timestamp DEFAULT '2018-11-17 13:33:33',
735+
| col3 double DEFAULT 1)
736+
|""".stripMargin)
737+
}
738+
739+
checkDefaultValues(
740+
alterExec.changes.map(_.asInstanceOf[AddColumn]).toArray,
741+
Array(
742+
new ColumnDefaultValue(
743+
"'2018-11-17 13:33:33'",
744+
LiteralValue(1542490413000000L, TimestampType),
745+
LiteralValue(1542490413000000L, TimestampType)),
746+
new ColumnDefaultValue(
747+
"1",
748+
new V2Cast(LiteralValue(1, IntegerType), IntegerType, DoubleType),
749+
LiteralValue(1.0, DoubleType))))
750+
751+
val alterCol1 = executeAndKeepPhysicalPlan[AlterTableExec] {
752+
sql(
753+
s"""
754+
|ALTER TABLE $tableName ALTER COLUMN
755+
| col2 SET DEFAULT '2022-02-23 05:55:55',
756+
| col3 SET DEFAULT (1 + 1)
757+
|""".stripMargin)
758+
}
759+
checkDefaultValues(
760+
alterCol1.changes.map(_.asInstanceOf[UpdateColumnDefaultValue]).toArray,
761+
Array(
762+
new DefaultValue("'2022-02-23 05:55:55'",
763+
LiteralValue(1645624555000000L, TimestampType)),
764+
new DefaultValue(
765+
"(1 + 1)",
766+
new V2Cast(
767+
new GeneralScalarExpression("+", Array(LiteralValue(1, IntegerType),
768+
LiteralValue(1, IntegerType))),
769+
IntegerType,
770+
DoubleType))))
771+
}
772+
}
773+
686774
private def executeAndKeepPhysicalPlan[T <: SparkPlan](func: => Unit): T = {
687775
val qe = withQueryExecutionsCaptured(spark) {
688776
func
@@ -718,13 +806,18 @@ class DataSourceV2DataFrameSuite
718806
}
719807
}
720808

721-
private def checkDefaultValue(
722-
column: UpdateColumnDefaultValue,
723-
expectedDefault: DefaultValue): Unit = {
724-
assert(
725-
column.newCurrentDefault() == expectedDefault,
726-
s"Default value mismatch for column '${column.toString}': " +
727-
s"expected $expectedDefault but found ${column.newCurrentDefault()}")
809+
private def checkDefaultValues(
810+
columns: Array[UpdateColumnDefaultValue],
811+
expectedDefaultValues: Array[DefaultValue]): Unit = {
812+
assert(columns.length == expectedDefaultValues.length)
813+
814+
columns.zip(expectedDefaultValues).foreach {
815+
case (column, expectedDefault) =>
816+
assert(
817+
column.newCurrentDefault() == expectedDefault,
818+
s"Default value mismatch for column '${column.toString}': " +
819+
s"expected $expectedDefault but found ${column.newCurrentDefault}")
820+
}
728821
}
729822

730823
private def checkDropDefaultValue(

0 commit comments

Comments
 (0)