diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala index a98aaf702acef..13b554eb53d4d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala @@ -78,6 +78,7 @@ object AnsiTypeCoercion extends TypeCoercionBase { UnpivotCoercion :: WidenSetOperationTypes :: ProcedureArgumentCoercion :: + DefaultValueExpressionCoercion :: new AnsiCombinedTypeCoercionRule( CollationTypeCasts :: InConversion :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index f68084803fe75..3e5f14810935b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -50,6 +50,7 @@ object TypeCoercion extends TypeCoercionBase { UnpivotCoercion :: WidenSetOperationTypes :: ProcedureArgumentCoercion :: + DefaultValueExpressionCoercion :: new CombinedTypeCoercionRule( CollationTypeCasts :: InConversion :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionBase.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionBase.scala index eae7d5a74dbc2..a8832aada0839 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionBase.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionBase.scala @@ -32,16 +32,22 @@ import org.apache.spark.sql.catalyst.expressions.{ WindowSpecDefinition } import org.apache.spark.sql.catalyst.plans.logical.{ + AddColumns, + AlterColumns, Call, + CreateTable, Except, Intersect, LogicalPlan, Project, + ReplaceTable, Union, Unpivot } import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin +import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper import org.apache.spark.sql.connector.catalog.procedures.BoundProcedure import org.apache.spark.sql.types.DataType @@ -81,6 +87,67 @@ abstract class TypeCoercionBase extends TypeCoercionHelper { } } + /** + * A type coercion rule that implicitly casts default value expression in DDL statements + * to expected types. + */ + object DefaultValueExpressionCoercion extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case createTable @ CreateTable(_, cols, _, _, _) if createTable.resolved && + cols.exists(_.defaultValue.isDefined) => + val newCols = cols.map { c => + c.copy(defaultValue = c.defaultValue.map(d => + d.copy(child = ResolveDefaultColumns.coerceDefaultValue( + d.child, + c.dataType, + "CREATE TABLE", + c.name, + d.originalSQL)))) + } + createTable.copy(columns = newCols) + + case replaceTable @ ReplaceTable(_, cols, _, _, _) if replaceTable.resolved && + cols.exists(_.defaultValue.isDefined) => + val newCols = cols.map { c => + c.copy(defaultValue = c.defaultValue.map(d => + d.copy(child = ResolveDefaultColumns.coerceDefaultValue( + d.child, + c.dataType, + "REPLACE TABLE", + c.name, + d.originalSQL)))) + } + replaceTable.copy(columns = newCols) + + case addColumns @ AddColumns(_, cols) if addColumns.resolved && + cols.exists(_.default.isDefined) => + val newCols = cols.map { c => + c.copy(default = c.default.map(d => + d.copy(child = ResolveDefaultColumns.coerceDefaultValue( + d.child, + c.dataType, + "ALTER TABLE ADD COLUMNS", + c.colName, + d.originalSQL)))) + } + addColumns.copy(columnsToAdd = newCols) + + case alterColumns @ AlterColumns(_, specs) if alterColumns.resolved && + specs.exists(_.newDefaultExpression.isDefined) => + val newSpecs = specs.map { c => + val dataType = c.column.asInstanceOf[ResolvedFieldName].field.dataType + c.copy(newDefaultExpression = c.newDefaultExpression.map(d => + d.copy(child = ResolveDefaultColumns.coerceDefaultValue( + d.child, + dataType, + "ALTER TABLE ALTER COLUMN", + c.column.name.quoted, + d.originalSQL)))) + } + alterColumns.copy(specs = newSpecs) + } + } + /** * Widens the data types of the [[Unpivot]] values. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala index cccf77d02b424..3cfd0676039d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala @@ -18,8 +18,9 @@ package org.apache.spark.sql.catalyst.util import scala.collection.mutable.ArrayBuffer +import scala.util.control.NonFatal -import org.apache.spark.{SparkException, SparkThrowable, SparkUnsupportedOperationException} +import org.apache.spark.{SparkException, SparkUnsupportedOperationException} import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys._ import org.apache.spark.sql.AnalysisException @@ -432,20 +433,20 @@ object ResolveDefaultColumns extends QueryErrorsBase targetType: DataType, colName: String): Option[Expression] = { expr match { - case l: Literal if !Seq(targetType, l.dataType).exists(_ match { + case e if e.foldable && !Seq(targetType, e.dataType).exists(_ match { case _: BooleanType | _: ArrayType | _: StructType | _: MapType => true case _ => false }) => - val casted = Cast(l, targetType, Some(conf.sessionLocalTimeZone), evalMode = EvalMode.TRY) + val casted = Cast(e, targetType, Some(conf.sessionLocalTimeZone), evalMode = EvalMode.TRY) try { Option(casted.eval(EmptyRow)).map(Literal(_, targetType)) } catch { - case e @ ( _: SparkThrowable | _: RuntimeException) => - logWarning(log"Failed to cast default value '${MDC(COLUMN_DEFAULT_VALUE, l)}' " + + case NonFatal(ex) => + logWarning(log"Failed to cast default value '${MDC(COLUMN_DEFAULT_VALUE, e)}' " + log"for column ${MDC(COLUMN_NAME, colName)} " + - log"from ${MDC(COLUMN_DATA_TYPE_SOURCE, l.dataType)} " + + log"from ${MDC(COLUMN_DATA_TYPE_SOURCE, e.dataType)} " + log"to ${MDC(COLUMN_DATA_TYPE_TARGET, targetType)} " + - log"due to ${MDC(ERROR, e.getMessage)}", e) + log"due to ${MDC(ERROR, ex.getMessage)}", ex) None } case _ => None diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala index 2316616729dcd..a262a5e79a1a7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.{QueryExecution, SparkPlan} import org.apache.spark.sql.execution.ExplainUtils.stripAQEPlan import org.apache.spark.sql.execution.datasources.v2.{AlterTableExec, CreateTableExec, DataSourceV2Relation, ReplaceTableExec} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{BooleanType, CalendarIntervalType, IntegerType, StringType} +import org.apache.spark.sql.types.{BooleanType, CalendarIntervalType, DoubleType, IntegerType, StringType, TimestampType} import org.apache.spark.sql.util.QueryExecutionListener import org.apache.spark.unsafe.types.UTF8String @@ -498,43 +498,32 @@ class DataSourceV2DataFrameSuite |""".stripMargin) val alterExecCol1 = executeAndKeepPhysicalPlan[AlterTableExec] { - sql(s"ALTER TABLE $tableName ALTER COLUMN salary SET DEFAULT (123 + 56)") - } - checkDefaultValue( - alterExecCol1.changes.collect { - case u: UpdateColumnDefaultValue => u - }.head, - new DefaultValue( - "(123 + 56)", - new GeneralScalarExpression( - "+", - Array(LiteralValue(123, IntegerType), LiteralValue(56, IntegerType))))) - - val alterExecCol2 = executeAndKeepPhysicalPlan[AlterTableExec] { - sql(s"ALTER TABLE $tableName ALTER COLUMN dep SET DEFAULT ('r' || 'l')") - } - checkDefaultValue( - alterExecCol2.changes.collect { - case u: UpdateColumnDefaultValue => u - }.head, - new DefaultValue( - "('r' || 'l')", - new GeneralScalarExpression( - "CONCAT", - Array( - LiteralValue(UTF8String.fromString("r"), StringType), - LiteralValue(UTF8String.fromString("l"), StringType))))) - - val alterExecCol3 = executeAndKeepPhysicalPlan[AlterTableExec] { - sql(s"ALTER TABLE $tableName ALTER COLUMN active SET DEFAULT CAST(0 AS BOOLEAN)") + sql( + s""" + |ALTER TABLE $tableName ALTER COLUMN + | salary SET DEFAULT (123 + 56), + | dep SET DEFAULT ('r' || 'l'), + | active SET DEFAULT CAST(0 AS BOOLEAN) + |""".stripMargin) } - checkDefaultValue( - alterExecCol3.changes.collect { - case u: UpdateColumnDefaultValue => u - }.head, - new DefaultValue( - "CAST(0 AS BOOLEAN)", - new V2Cast(LiteralValue(0, IntegerType), IntegerType, BooleanType))) + checkDefaultValues( + alterExecCol1.changes.map(_.asInstanceOf[UpdateColumnDefaultValue]).toArray, + Array( + new DefaultValue( + "(123 + 56)", + new GeneralScalarExpression( + "+", + Array(LiteralValue(123, IntegerType), LiteralValue(56, IntegerType)))), + new DefaultValue( + "('r' || 'l')", + new GeneralScalarExpression( + "CONCAT", + Array( + LiteralValue(UTF8String.fromString("r"), StringType), + LiteralValue(UTF8String.fromString("l"), StringType)))), + new DefaultValue( + "CAST(0 AS BOOLEAN)", + new V2Cast(LiteralValue(0, IntegerType), IntegerType, BooleanType)))) } } } @@ -666,13 +655,9 @@ class DataSourceV2DataFrameSuite sql(s"ALTER TABLE $tableName ALTER COLUMN cat SET DEFAULT current_catalog()") } - checkDefaultValue( - alterExec.changes.collect { - case u: UpdateColumnDefaultValue => u - }.head, - new DefaultValue( - "current_catalog()", - null /* No V2 Expression */)) + checkDefaultValues( + alterExec.changes.map(_.asInstanceOf[UpdateColumnDefaultValue]).toArray, + Array(new DefaultValue("current_catalog()", null /* No V2 Expression */))) val df1 = Seq(1).toDF("dummy") df1.writeTo(tableName).append() @@ -683,6 +668,109 @@ class DataSourceV2DataFrameSuite } } + test("create/replace table default value expression should have a cast") { + val tableName = "testcat.ns1.ns2.tbl" + withTable(tableName) { + + val createExec = executeAndKeepPhysicalPlan[CreateTableExec] { + sql( + s""" + |CREATE TABLE $tableName ( + | col1 int, + | col2 timestamp DEFAULT '2018-11-17 13:33:33', + | col3 double DEFAULT 1) + |""".stripMargin) + } + checkDefaultValues( + createExec.columns, + Array( + null, + new ColumnDefaultValue( + "'2018-11-17 13:33:33'", + new LiteralValue(1542490413000000L, TimestampType), + new LiteralValue(1542490413000000L, TimestampType)), + new ColumnDefaultValue( + "1", + new V2Cast(LiteralValue(1, IntegerType), IntegerType, DoubleType), + LiteralValue(1.0, DoubleType)))) + + val replaceExec = executeAndKeepPhysicalPlan[ReplaceTableExec] { + sql( + s""" + |REPLACE TABLE $tableName ( + | col1 int, + | col2 timestamp DEFAULT '2022-02-23 05:55:55', + | col3 double DEFAULT (1 + 1)) + |""".stripMargin) + } + checkDefaultValues( + replaceExec.columns, + Array( + null, + new ColumnDefaultValue( + "'2022-02-23 05:55:55'", + LiteralValue(1645624555000000L, TimestampType), + LiteralValue(1645624555000000L, TimestampType)), + new ColumnDefaultValue( + "(1 + 1)", + new V2Cast( + new GeneralScalarExpression("+", Array(LiteralValue(1, IntegerType), + LiteralValue(1, IntegerType))), + IntegerType, + DoubleType), + LiteralValue(2.0, DoubleType)))) + } + } + + test("alter table default value expression should have a cast") { + val tableName = "testcat.ns1.ns2.tbl" + withTable(tableName) { + + sql(s"CREATE TABLE $tableName (col1 int) using foo") + val alterExec = executeAndKeepPhysicalPlan[AlterTableExec] { + sql( + s""" + |ALTER TABLE $tableName ADD COLUMNS ( + | col2 timestamp DEFAULT '2018-11-17 13:33:33', + | col3 double DEFAULT 1) + |""".stripMargin) + } + + checkDefaultValues( + alterExec.changes.map(_.asInstanceOf[AddColumn]).toArray, + Array( + new ColumnDefaultValue( + "'2018-11-17 13:33:33'", + LiteralValue(1542490413000000L, TimestampType), + LiteralValue(1542490413000000L, TimestampType)), + new ColumnDefaultValue( + "1", + new V2Cast(LiteralValue(1, IntegerType), IntegerType, DoubleType), + LiteralValue(1.0, DoubleType)))) + + val alterCol1 = executeAndKeepPhysicalPlan[AlterTableExec] { + sql( + s""" + |ALTER TABLE $tableName ALTER COLUMN + | col2 SET DEFAULT '2022-02-23 05:55:55', + | col3 SET DEFAULT (1 + 1) + |""".stripMargin) + } + checkDefaultValues( + alterCol1.changes.map(_.asInstanceOf[UpdateColumnDefaultValue]).toArray, + Array( + new DefaultValue("'2022-02-23 05:55:55'", + LiteralValue(1645624555000000L, TimestampType)), + new DefaultValue( + "(1 + 1)", + new V2Cast( + new GeneralScalarExpression("+", Array(LiteralValue(1, IntegerType), + LiteralValue(1, IntegerType))), + IntegerType, + DoubleType)))) + } + } + private def executeAndKeepPhysicalPlan[T <: SparkPlan](func: => Unit): T = { val qe = withQueryExecutionsCaptured(spark) { func @@ -718,13 +806,18 @@ class DataSourceV2DataFrameSuite } } - private def checkDefaultValue( - column: UpdateColumnDefaultValue, - expectedDefault: DefaultValue): Unit = { - assert( - column.newCurrentDefault() == expectedDefault, - s"Default value mismatch for column '${column.toString}': " + - s"expected $expectedDefault but found ${column.newCurrentDefault()}") + private def checkDefaultValues( + columns: Array[UpdateColumnDefaultValue], + expectedDefaultValues: Array[DefaultValue]): Unit = { + assert(columns.length == expectedDefaultValues.length) + + columns.zip(expectedDefaultValues).foreach { + case (column, expectedDefault) => + assert( + column.newCurrentDefault() == expectedDefault, + s"Default value mismatch for column '${column.toString}': " + + s"expected $expectedDefault but found ${column.newCurrentDefault}") + } } private def checkDropDefaultValue(