Skip to content

Commit 144ec30

Browse files
aokolnychyiyhuang-db
authored andcommitted
[SPARK-51987][SQL] DSv2 expressions in column defaults on write
### What changes were proposed in this pull request? This PR allows connectors to expose expression-based defaults on write. ### Why are the changes needed? These changes are needed to avoid the requirement of producing Spark SQL dialect in connectors. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? This PR comes with tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes apache#51002 from aokolnychyi/spark-51987. Authored-by: Anton Okolnychyi <aokolnychyi@apache.org> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent a113403 commit 144ec30

File tree

12 files changed

+260
-27
lines changed

12 files changed

+260
-27
lines changed

sql/api/src/main/scala/org/apache/spark/sql/types/Metadata.scala

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,13 @@ import org.apache.spark.util.ArrayImplicits._
4040
* @since 1.3.0
4141
*/
4242
@Stable
43-
sealed class Metadata private[types] (private[types] val map: Map[String, Any])
43+
sealed class Metadata private[types] (
44+
private[types] val map: Map[String, Any],
45+
@transient private[types] val runtimeMap: Map[String, Any])
4446
extends Serializable {
4547

4648
/** No-arg constructor for kryo. */
47-
protected def this() = this(null)
49+
protected def this() = this(null, null)
4850

4951
/** Tests whether this Metadata contains a binding for a key. */
5052
def contains(key: String): Boolean = map.contains(key)
@@ -120,6 +122,12 @@ sealed class Metadata private[types] (private[types] val map: Map[String, Any])
120122
map(key).asInstanceOf[T]
121123
}
122124

125+
private[sql] def getExpression[E](key: String): (String, Option[E]) = {
126+
val sql = getString(key)
127+
val expr = Option(runtimeMap).flatMap(_.get(key).map(_.asInstanceOf[E]))
128+
sql -> expr
129+
}
130+
123131
private[sql] def jsonValue: JValue = Metadata.toJsonValue(this)
124132
}
125133

@@ -129,7 +137,7 @@ sealed class Metadata private[types] (private[types] val map: Map[String, Any])
129137
@Stable
130138
object Metadata {
131139

132-
private[this] val _empty = new Metadata(Map.empty)
140+
private[this] val _empty = new Metadata(Map.empty, Map.empty)
133141

134142
/** Returns an empty Metadata. */
135143
def empty: Metadata = _empty
@@ -248,13 +256,17 @@ object Metadata {
248256
class MetadataBuilder {
249257

250258
private val map: mutable.Map[String, Any] = mutable.Map.empty
259+
private val runtimeMap: mutable.Map[String, Any] = mutable.Map.empty
251260

252261
/** Returns the immutable version of this map. Used for java interop. */
253262
protected def getMap = map.toMap
254263

255264
/** Include the content of an existing [[Metadata]] instance. */
256265
def withMetadata(metadata: Metadata): this.type = {
257266
map ++= metadata.map
267+
if (metadata.runtimeMap != null) {
268+
runtimeMap ++= metadata.runtimeMap
269+
}
258270
this
259271
}
260272

@@ -293,16 +305,23 @@ class MetadataBuilder {
293305

294306
/** Builds the [[Metadata]] instance. */
295307
def build(): Metadata = {
296-
new Metadata(map.toMap)
308+
new Metadata(map.toMap, runtimeMap.toMap)
297309
}
298310

299311
private def put(key: String, value: Any): this.type = {
300312
map.put(key, value)
301313
this
302314
}
303315

316+
private[sql] def putExpression[E](key: String, sql: String, expr: Option[E]): this.type = {
317+
map.put(key, sql)
318+
expr.foreach(runtimeMap.put(key, _))
319+
this
320+
}
321+
304322
def remove(key: String): this.type = {
305323
map.remove(key)
324+
runtimeMap.remove(key)
306325
this
307326
}
308327
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,12 @@ object ResolveDefaultColumns extends QueryErrorsBase
263263
field: StructField,
264264
statementType: String,
265265
metadataKey: String = CURRENT_DEFAULT_COLUMN_METADATA_KEY): Expression = {
266-
analyze(field.name, field.dataType, field.metadata.getString(metadataKey), statementType)
266+
field.metadata.getExpression[Expression](metadataKey) match {
267+
case (sql, Some(expr)) =>
268+
analyze(field.name, field.dataType, expr, sql, statementType)
269+
case (sql, _) =>
270+
analyze(field.name, field.dataType, sql, statementType)
271+
}
267272
}
268273

269274
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@ import java.util.{Collections, Locale}
2222

2323
import scala.jdk.CollectionConverters._
2424

25-
import org.apache.spark.SparkIllegalArgumentException
25+
import org.apache.spark.{SparkException, SparkIllegalArgumentException}
2626
import org.apache.spark.sql.AnalysisException
2727
import org.apache.spark.sql.catalyst.CurrentUserContext
2828
import org.apache.spark.sql.catalyst.analysis.{AsOfTimestamp, AsOfVersion, NamedRelation, NoSuchDatabaseException, NoSuchFunctionException, NoSuchTableException, TimeTravelSpec}
2929
import org.apache.spark.sql.catalyst.catalog.ClusterBySpec
30-
import org.apache.spark.sql.catalyst.expressions.Literal
30+
import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, V2ExpressionUtils}
3131
import org.apache.spark.sql.catalyst.plans.logical.{SerdeInfo, TableSpec}
3232
import org.apache.spark.sql.catalyst.util.{GeneratedColumn, IdentityColumn}
3333
import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._
@@ -597,11 +597,31 @@ private[sql] object CatalogV2Util {
597597
// Note: the back-fill here is a logical concept. The data source can keep the existing
598598
// data unchanged and let the data reader to return "exist default" for missing
599599
// columns.
600-
val existingDefault = Literal(default.getValue.value(), default.getValue.dataType()).sql
601-
f.withExistenceDefaultValue(existingDefault).withCurrentDefaultValue(default.getSql)
600+
val existsDefault = extractExistsDefault(default)
601+
val (sql, expr) = extractCurrentDefault(default)
602+
val newMetadata = new MetadataBuilder()
603+
.withMetadata(f.metadata)
604+
.putString(EXISTS_DEFAULT_COLUMN_METADATA_KEY, existsDefault)
605+
.putExpression(CURRENT_DEFAULT_COLUMN_METADATA_KEY, sql, expr)
606+
.build()
607+
f.copy(metadata = newMetadata)
602608
}.getOrElse(f)
603609
}
604610

611+
private def extractExistsDefault(default: ColumnDefaultValue): String = {
612+
Literal(default.getValue.value, default.getValue.dataType).sql
613+
}
614+
615+
private def extractCurrentDefault(default: ColumnDefaultValue): (String, Option[Expression]) = {
616+
val expr = Option(default.getExpression).flatMap(V2ExpressionUtils.toCatalyst)
617+
val sql = Option(default.getSql).orElse(expr.map(_.sql)).getOrElse {
618+
throw SparkException.internalError(
619+
s"Can't generate SQL for $default. The connector expression couldn't be " +
620+
"converted to Catalyst and there is no provided SQL representation.")
621+
}
622+
(sql, expr)
623+
}
624+
605625
/**
606626
* Converts a StructType to DS v2 columns, which decodes the StructField metadata to v2 column
607627
* comment and default value or generation expression. This is mainly used to generate DS v2

sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ import org.apache.spark.util.ArrayImplicits._
5050
*/
5151
abstract class InMemoryBaseTable(
5252
val name: String,
53-
val schema: StructType,
53+
override val columns: Array[Column],
5454
override val partitioning: Array[Transform],
5555
override val properties: util.Map[String, String],
5656
override val constraints: Array[Constraint] = Array.empty,
@@ -114,6 +114,8 @@ abstract class InMemoryBaseTable(
114114
}
115115
}
116116

117+
override val schema: StructType = CatalogV2Util.v2ColumnsToStructType(columns)
118+
117119
// purposely exposes a metadata column that conflicts with a data column in some tests
118120
override val metadataColumns: Array[MetadataColumn] = Array(IndexColumn, PartitionKeyColumn)
119121
private val metadataColumnNames = metadataColumns.map(_.name).toSet

sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,13 @@ class InMemoryRowLevelOperationTable(
3838
partitioning: Array[Transform],
3939
properties: util.Map[String, String],
4040
constraints: Array[Constraint] = Array.empty)
41-
extends InMemoryTable(name, schema, partitioning, properties, constraints)
42-
with SupportsRowLevelOperations {
41+
extends InMemoryTable(
42+
name,
43+
CatalogV2Util.structTypeToV2Columns(schema),
44+
partitioning,
45+
properties,
46+
constraints)
47+
with SupportsRowLevelOperations {
4348

4449
private final val PARTITION_COLUMN_REF = FieldReference(PartitionKeyColumn.name)
4550
private final val INDEX_COLUMN_REF = FieldReference(IndexColumn.name)

sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ import org.apache.spark.util.ArrayImplicits._
3333
*/
3434
class InMemoryTable(
3535
name: String,
36-
schema: StructType,
36+
columns: Array[Column],
3737
override val partitioning: Array[Transform],
3838
override val properties: util.Map[String, String],
3939
override val constraints: Array[Constraint] = Array.empty,
@@ -43,10 +43,22 @@ class InMemoryTable(
4343
advisoryPartitionSize: Option[Long] = None,
4444
isDistributionStrictlyRequired: Boolean = true,
4545
override val numRowsPerSplit: Int = Int.MaxValue)
46-
extends InMemoryBaseTable(name, schema, partitioning, properties, constraints, distribution,
46+
extends InMemoryBaseTable(name, columns, partitioning, properties, constraints, distribution,
4747
ordering, numPartitions, advisoryPartitionSize, isDistributionStrictlyRequired,
4848
numRowsPerSplit) with SupportsDelete {
4949

50+
def this(
51+
name: String,
52+
schema: StructType,
53+
partitioning: Array[Transform],
54+
properties: util.Map[String, String]) = {
55+
this(
56+
name,
57+
CatalogV2Util.structTypeToV2Columns(schema),
58+
partitioning,
59+
properties)
60+
}
61+
5062
override def canDeleteWhere(filters: Array[Filter]): Boolean = {
5163
InMemoryTable.supportsFilters(filters)
5264
}

sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,15 +118,14 @@ class BasicInMemoryTableCatalog extends TableCatalog {
118118
distributionStrictlyRequired: Boolean = true,
119119
numRowsPerSplit: Int = Int.MaxValue): Table = {
120120
// scalastyle:on argcount
121-
val schema = CatalogV2Util.v2ColumnsToStructType(columns)
122121
if (tables.containsKey(ident)) {
123122
throw new TableAlreadyExistsException(ident.asMultipartIdentifier)
124123
}
125124

126125
InMemoryTableCatalog.maybeSimulateFailedTableCreation(properties)
127126

128127
val tableName = s"$name.${ident.quoted}"
129-
val table = new InMemoryTable(tableName, schema, partitions, properties, constraints,
128+
val table = new InMemoryTable(tableName, columns, partitions, properties, constraints,
130129
distribution, ordering, requiredNumPartitions, advisoryPartitionSize,
131130
distributionStrictlyRequired, numRowsPerSplit)
132131
tables.put(ident, table)
@@ -154,7 +153,7 @@ class BasicInMemoryTableCatalog extends TableCatalog {
154153
val currentVersion = table.currentVersion()
155154
val newTable = new InMemoryTable(
156155
name = table.name,
157-
schema = schema,
156+
columns = CatalogV2Util.structTypeToV2Columns(schema),
158157
partitioning = finalPartitioning,
159158
properties = properties,
160159
constraints = constraints)

sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2Filter.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@ import org.apache.spark.util.ArrayImplicits._
3131

3232
class InMemoryTableWithV2Filter(
3333
name: String,
34-
schema: StructType,
34+
columns: Array[Column],
3535
partitioning: Array[Transform],
3636
properties: util.Map[String, String])
37-
extends InMemoryBaseTable(name, schema, partitioning, properties) with SupportsDeleteV2 {
37+
extends InMemoryBaseTable(name, columns, partitioning, properties) with SupportsDeleteV2 {
3838

3939
override def canDeleteWhere(predicates: Array[Predicate]): Boolean = {
4040
InMemoryTableWithV2Filter.supportsPredicates(predicates)

sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2FilterCatalog.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,7 @@ class InMemoryTableWithV2FilterCatalog extends InMemoryTableCatalog {
3737
InMemoryTableCatalog.maybeSimulateFailedTableCreation(properties)
3838

3939
val tableName = s"$name.${ident.quoted}"
40-
val schema = CatalogV2Util.v2ColumnsToStructType(columns)
41-
val table = new InMemoryTableWithV2Filter(tableName, schema, partitions, properties)
40+
val table = new InMemoryTableWithV2Filter(tableName, columns, partitions, properties)
4241
tables.put(ident, table)
4342
namespaces.putIfAbsent(ident.namespace.toList, Map())
4443
table

sql/catalyst/src/test/scala/org/apache/spark/sql/types/MetadataSuite.scala

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717

1818
package org.apache.spark.sql.types
1919

20-
import org.apache.spark.SparkFunSuite
20+
import org.apache.spark.{SparkConf, SparkFunSuite}
21+
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, SerializerInstance}
22+
import org.apache.spark.sql.catalyst.expressions.{Add, Expression, Literal}
2123

2224
class MetadataSuite extends SparkFunSuite {
2325
test("String Metadata") {
@@ -76,4 +78,79 @@ class MetadataSuite extends SparkFunSuite {
7678
assert(meta === Metadata.fromJson(meta.json))
7779
intercept[NoSuchElementException](meta.getLong("no_such_key"))
7880
}
81+
82+
test("Kryo serialization for expressions") {
83+
val conf = new SparkConf()
84+
val serializer = new KryoSerializer(conf).newInstance()
85+
checkMetadataExpressions(serializer)
86+
}
87+
88+
test("Java serialization for expressions") {
89+
val conf = new SparkConf()
90+
val serializer = new JavaSerializer(conf).newInstance()
91+
checkMetadataExpressions(serializer)
92+
}
93+
94+
test("JSON representation with expressions") {
95+
val meta = new MetadataBuilder()
96+
.putString("key", "value")
97+
.putExpression("expr", "1 + 3", Some(Add(Literal(1), Literal(3))))
98+
.build()
99+
assert(meta.json == """{"expr":"1 + 3","key":"value"}""")
100+
}
101+
102+
test("equals and hashCode with expressions") {
103+
val meta1 = new MetadataBuilder()
104+
.putString("key", "value")
105+
.putExpression("expr", "1 + 2", Some(Add(Literal(1), Literal(2))))
106+
.build()
107+
108+
val meta2 = new MetadataBuilder()
109+
.putString("key", "value")
110+
.putExpression("expr", "1 + 2", Some(Add(Literal(1), Literal(2))))
111+
.build()
112+
113+
val meta3 = new MetadataBuilder()
114+
.putString("key", "value")
115+
.putExpression("expr", "2 + 3", Some(Add(Literal(2), Literal(3))))
116+
.build()
117+
118+
val meta4 = new MetadataBuilder()
119+
.putString("key", "value")
120+
.putExpression("expr", "1 + 2", None)
121+
.build()
122+
123+
// meta1 and meta2 are equivalent
124+
assert(meta1 === meta2)
125+
assert(meta1.hashCode === meta2.hashCode)
126+
127+
// meta1 and meta3 are different as they contain different expressions
128+
assert(meta1 !== meta3)
129+
assert(meta1.hashCode !== meta3.hashCode)
130+
131+
// meta1 and meta4 are equivalent even though meta4 only includes the SQL string
132+
assert(meta1 == meta4)
133+
assert(meta1.hashCode == meta4.hashCode)
134+
}
135+
136+
private def checkMetadataExpressions(serializer: SerializerInstance): Unit = {
137+
val meta = new MetadataBuilder()
138+
.putString("key", "value")
139+
.putExpression("tempKey", "1", Some(Literal(1)))
140+
.build()
141+
assert(meta.contains("key"))
142+
assert(meta.getString("key") == "value")
143+
assert(meta.contains("tempKey"))
144+
assert(meta.getExpression[Expression]("tempKey")._1 == "1")
145+
assert(meta.getExpression[Expression]("tempKey")._2.contains(Literal(1)))
146+
147+
val deserializedMeta = serializer.deserialize[Metadata](serializer.serialize(meta))
148+
assert(deserializedMeta == meta)
149+
assert(deserializedMeta.hashCode == meta.hashCode)
150+
assert(deserializedMeta.contains("key"))
151+
assert(deserializedMeta.getString("key") == "value")
152+
assert(deserializedMeta.contains("tempKey"))
153+
assert(deserializedMeta.getExpression[Expression]("tempKey")._1 == "1")
154+
assert(deserializedMeta.getExpression[Expression]("tempKey")._2.isEmpty)
155+
}
79156
}

0 commit comments

Comments
 (0)