diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 000b1f524f207..560295c050220 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -114,6 +114,18 @@ ], "sqlState" : "22004" }, + "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED" : { + "message" : [ + "Combining approx_top_k sketches of different types is not allowed. Found sketches of type and ." + ], + "sqlState": "42846" + }, + "APPROX_TOP_K_SKETCH_SIZE_UNMATCHED" : { + "message" : [ + "Combining approx_top_k sketches of different sizes is not allowed. Found sketches of size and ." + ], + "sqlState": "42846" + }, "ARITHMETIC_OVERFLOW" : { "message" : [ ". If necessary set to \"false\" to bypass this error." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index b54ae9082f840..89a8e8a9a6a6a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -528,6 +528,8 @@ object FunctionRegistry { expression[HllSketchAgg]("hll_sketch_agg"), expression[HllUnionAgg]("hll_union_agg"), expression[ApproxTopK]("approx_top_k"), + expression[ApproxTopKAccumulate]("approx_top_k_accumulate"), + expression[ApproxTopKCombine]("approx_top_k_combine"), // string functions expression[Ascii]("ascii"), @@ -786,6 +788,7 @@ object FunctionRegistry { expression[EqualNull]("equal_null"), expression[HllSketchEstimate]("hll_sketch_estimate"), expression[HllUnion]("hll_union"), + expression[ApproxTopKEstimate]("approx_top_k_estimate"), // grouping sets expression[Grouping]("grouping"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala new file mode 100644 index 0000000000000..342a261d77738 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions + +import org.apache.datasketches.frequencies.ItemsSketch +import org.apache.datasketches.memory.Memory + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} +import org.apache.spark.sql.catalyst.expressions.aggregate.ApproxTopK +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.types._ + +/** + * An expression that estimates the top K items from a sketch. + * + * The input is a sketch state that is generated by the ApproxTopKAccumulation function. + * The output is an array of structs, each containing a frequent item and its estimated frequency. + * The items are sorted by their estimated frequency in descending order. + * + * @param state The sketch state, which is a struct containing the serialized sketch data, + * the original data type and the max items tracked of the sketch. + * @param k The number of top items to estimate. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(state, k) - Returns top k items with their frequency. + `k` An optional INTEGER literal greater than 0. If k is not specified, it defaults to 5. + """, + examples = """ + Examples: + > SELECT _FUNC_(approx_top_k_accumulate(expr)) FROM VALUES (0), (0), (1), (1), (2), (3), (4), (4) AS tab(expr); + [{"item":0,"count":2},{"item":4,"count":2},{"item":1,"count":2},{"item":2,"count":1},{"item":3,"count":1}] + + > SELECT _FUNC_(approx_top_k_accumulate(expr), 2) FROM VALUES 'a', 'b', 'c', 'c', 'c', 'c', 'd', 'd' tab(expr); + [{"item":"c","count":4},{"item":"d","count":2}] + """, + group = "misc_funcs", + since = "4.1.0") +// scalastyle:on line.size.limit +case class ApproxTopKEstimate(state: Expression, k: Expression) + extends BinaryExpression + with CodegenFallback + with ImplicitCastInputTypes { + + def this(child: Expression, topK: Int) = this(child, Literal(topK)) + + def this(child: Expression) = this(child, Literal(ApproxTopK.DEFAULT_K)) + + private lazy val itemDataType: DataType = { + // itemDataType is the type of the "ItemTypeNull" field of the output of ACCUMULATE or COMBINE + state.dataType.asInstanceOf[StructType]("ItemTypeNull").dataType + } + + override def left: Expression = state + + override def right: Expression = k + + override def inputTypes: Seq[AbstractDataType] = Seq(StructType, IntegerType) + + override def checkInputDataTypes(): TypeCheckResult = { + val defaultCheck = super.checkInputDataTypes() + if (defaultCheck.isFailure) { + defaultCheck + } else if (!k.foldable) { + TypeCheckFailure("K must be a constant literal") + } else { + TypeCheckSuccess + } + } + + override def dataType: DataType = ApproxTopK.getResultDataType(itemDataType) + + override def eval(input: InternalRow): Any = { + // null check + ApproxTopK.checkExpressionNotNull(k, "k") + // eval + val stateEval = left.eval(input) + val kEval = right.eval(input) + val dataSketchBytes = stateEval.asInstanceOf[InternalRow].getBinary(0) + val maxItemsTrackedVal = stateEval.asInstanceOf[InternalRow].getInt(2) + val kVal = kEval.asInstanceOf[Int] + ApproxTopK.checkK(kVal) + ApproxTopK.checkMaxItemsTracked(maxItemsTrackedVal, kVal) + val itemsSketch = ItemsSketch.getInstance( + Memory.wrap(dataSketchBytes), ApproxTopK.genSketchSerDe(itemDataType)) + ApproxTopK.genEvalResult(itemsSketch, kVal, itemDataType) + } + + override protected def withNewChildrenInternal(newState: Expression, newK: Expression) + : Expression = copy(state = newState, k = newK) + + override def nullable: Boolean = false + + override def prettyName: String = + getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("approx_top_k_estimate") +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala index 974af917ef512..86d74461b63de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala @@ -22,11 +22,12 @@ import org.apache.datasketches.common._ import org.apache.datasketches.frequencies.{ErrorType, ItemsSketch} import org.apache.datasketches.memory.Memory +import org.apache.spark.SparkUnsupportedOperationException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.{ArrayOfDecimalsSerDe, Expression, ExpressionDescription, ImplicitCastInputTypes, Literal} -import org.apache.spark.sql.catalyst.trees.TernaryLike +import org.apache.spark.sql.catalyst.trees.{BinaryLike, TernaryLike} import org.apache.spark.sql.catalyst.util.{CollationFactory, GenericArrayData} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types._ @@ -50,12 +51,14 @@ import org.apache.spark.unsafe.types.UTF8String */ // scalastyle:off line.size.limit @ExpressionDescription( - usage = """ + usage = + """ _FUNC_(expr, k, maxItemsTracked) - Returns top k items with their frequency. `k` An optional INTEGER literal greater than 0. If k is not specified, it defaults to 5. - `maxItemsTracked` An optional INTEGER literal greater than or equal to k. If maxItemsTracked is not specified, it defaults to 10000. - """, - examples = """ + `maxItemsTracked` An optional INTEGER literal greater than or equal to k and has upper limit of 1000000. If maxItemsTracked is not specified, it defaults to 10000. + """, + examples = + """ Examples: > SELECT _FUNC_(expr) FROM VALUES (0), (0), (1), (1), (2), (3), (4), (4) AS tab(expr); [{"item":0,"count":2},{"item":4,"count":2},{"item":1,"count":2},{"item":2,"count":1},{"item":3,"count":1}] @@ -173,40 +176,49 @@ case class ApproxTopK( object ApproxTopK { - private val DEFAULT_K: Int = 5 - private val DEFAULT_MAX_ITEMS_TRACKED: Int = 10000 + val DEFAULT_K: Int = 5 + val DEFAULT_MAX_ITEMS_TRACKED: Int = 10000 private val MAX_ITEMS_TRACKED_LIMIT: Int = 1000000 + val VOID_MAX_ITEMS_TRACKED = -1 + val SKETCH_SIZE_PLACEHOLDER = 8 - private def checkExpressionNotNull(expr: Expression, exprName: String): Unit = { + def checkExpressionNotNull(expr: Expression, exprName: String): Unit = { if (expr == null || expr.eval() == null) { throw QueryExecutionErrors.approxTopKNullArg(exprName) } } - private def checkK(k: Int): Unit = { + def checkK(k: Int): Unit = { if (k <= 0) { throw QueryExecutionErrors.approxTopKNonPositiveValue("k", k) } } - private def checkMaxItemsTracked(maxItemsTracked: Int, k: Int): Unit = { + def checkMaxItemsTracked(maxItemsTracked: Int): Unit = { if (maxItemsTracked > MAX_ITEMS_TRACKED_LIMIT) { throw QueryExecutionErrors.approxTopKMaxItemsTrackedExceedsLimit( maxItemsTracked, MAX_ITEMS_TRACKED_LIMIT) } + if (maxItemsTracked <= 0) { + throw QueryExecutionErrors.approxTopKNonPositiveValue("maxItemsTracked", maxItemsTracked) + } + } + + def checkMaxItemsTracked(maxItemsTracked: Int, k: Int): Unit = { + checkMaxItemsTracked(maxItemsTracked) if (maxItemsTracked < k) { throw QueryExecutionErrors.approxTopKMaxItemsTrackedLessThanK(maxItemsTracked, k) } } - private def getResultDataType(itemDataType: DataType): DataType = { + def getResultDataType(itemDataType: DataType): DataType = { val resultEntryType = StructType( StructField("item", itemDataType, nullable = false) :: StructField("count", LongType, nullable = false) :: Nil) ArrayType(resultEntryType, containsNull = false) } - private def isDataTypeSupported(itemType: DataType): Boolean = { + def isDataTypeSupported(itemType: DataType): Boolean = { itemType match { case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType | _: DoubleType | _: DateType | @@ -216,13 +228,14 @@ object ApproxTopK { } } - private def calMaxMapSize(maxItemsTracked: Int): Int = { + def calMaxMapSize(maxItemsTracked: Int): Int = { // The maximum capacity of this internal hash map has maxMapCap = 0.75 * maxMapSize // Therefore, the maxMapSize must be at least ceil(maxItemsTracked / 0.75) // https://datasketches.apache.org/docs/Frequency/FrequentItemsOverview.html val ceilMaxMapSize = math.ceil(maxItemsTracked / 0.75).toInt // The maxMapSize must be a power of 2 and greater than ceilMaxMapSize - math.pow(2, math.ceil(math.log(ceilMaxMapSize) / math.log(2))).toInt + val maxMapSize = math.pow(2, math.ceil(math.log(ceilMaxMapSize) / math.log(2))).toInt + maxMapSize } def createAggregationBuffer(itemExpression: Expression, maxMapSize: Int): ItemsSketch[Any] = { @@ -242,7 +255,7 @@ object ApproxTopK { } } - private def updateSketchBuffer( + def updateSketchBuffer( itemExpression: Expression, buffer: ItemsSketch[Any], input: InternalRow): ItemsSketch[Any] = { @@ -268,7 +281,7 @@ object ApproxTopK { buffer } - private def genEvalResult( + def genEvalResult( itemsSketch: ItemsSketch[Any], k: Int, itemDataType: DataType): GenericArrayData = { @@ -290,7 +303,7 @@ object ApproxTopK { new GenericArrayData(result) } - private def genSketchSerDe(dataType: DataType): ArrayOfItemsSerDe[Any] = { + def genSketchSerDe(dataType: DataType): ArrayOfItemsSerDe[Any] = { dataType match { case _: BooleanType => new ArrayOfBooleansSerDe().asInstanceOf[ArrayOfItemsSerDe[Any]] case _: ByteType | _: ShortType | _: IntegerType | _: FloatType | _: DateType => @@ -305,4 +318,341 @@ object ApproxTopK { new ArrayOfDecimalsSerDe(dt).asInstanceOf[ArrayOfItemsSerDe[Any]] } } + + def getSketchStateDataType(itemDataType: DataType): StructType = + StructType( + StructField("Sketch", BinaryType, nullable = false) :: + StructField("ItemTypeNull", itemDataType) :: + StructField("MaxItemsTracked", IntegerType, nullable = false) :: + StructField("TypeCode", BinaryType, nullable = false) :: Nil) + + def dataTypeToBytes(dataType: DataType): Array[Byte] = { + dataType match { + case _: BooleanType => Array(0, 0, 0) + case _: ByteType => Array(1, 0, 0) + case _: ShortType => Array(2, 0, 0) + case _: IntegerType => Array(3, 0, 0) + case _: LongType => Array(4, 0, 0) + case _: FloatType => Array(5, 0, 0) + case _: DoubleType => Array(6, 0, 0) + case _: DateType => Array(7, 0, 0) + case _: TimestampType => Array(8, 0, 0) + case _: TimestampNTZType => Array(9, 0, 0) + case _: StringType => Array(10, 0, 0) + case dt: DecimalType => Array(11, dt.precision.toByte, dt.scale.toByte) + } + } + + def bytesToDataType(bytes: Array[Byte]): DataType = { + bytes(0) match { + case 0 => BooleanType + case 1 => ByteType + case 2 => ShortType + case 3 => IntegerType + case 4 => LongType + case 5 => FloatType + case 6 => DoubleType + case 7 => DateType + case 8 => TimestampType + case 9 => TimestampNTZType + case 10 => StringType + case 11 => DecimalType(bytes(1).toInt, bytes(2).toInt) + } + } +} + +/** + * An aggregate function that accumulates items into a sketch, which can then be used + * to combine with other sketches, via ApproxTopKCombine, + * or to estimate the top K items, via ApproxTopKEstimate. + * + * The output of this function is a struct containing the sketch in binary format, + * a null object indicating the type of items in the sketch, + * and the maximum number of items tracked by the sketch. + * + * @param expr the child expression to accumulate items from + * @param maxItemsTracked the maximum number of items to track in the sketch + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = + """ + _FUNC_(expr, maxItemsTracked) - Accumulates items into a sketch. + `maxItemsTracked` An optional positive INTEGER literal with upper limit of 1000000. If maxItemsTracked is not specified, it defaults to 10000. + """, + examples = + """ + Examples: + > SELECT approx_top_k_accumulate(_FUNC_(expr)) FROM VALUES (0), (0), (1), (1), (2), (3), (4), (4) AS tab(expr); + [{"item":0,"count":2},{"item":4,"count":2},{"item":1,"count":2},{"item":2,"count":1},{"item":3,"count":1}] + + > SELECT approx_top_k_accumulate(_FUNC_(expr, 100), 2) FROM VALUES 'a', 'b', 'c', 'c', 'c', 'c', 'd', 'd' AS tab(expr); + [{"item":"c","count":4},{"item":"d","count":2}] + """, + group = "agg_funcs", + since = "4.1.0") +// scalastyle:on line.size.limit +case class ApproxTopKAccumulate( + expr: Expression, + maxItemsTracked: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) + extends TypedImperativeAggregate[ItemsSketch[Any]] + with ImplicitCastInputTypes + with BinaryLike[Expression] { + + def this(child: Expression, maxItemsTracked: Expression) = this(child, maxItemsTracked, 0, 0) + + def this(child: Expression, maxItemsTracked: Int) = this(child, Literal(maxItemsTracked), 0, 0) + + def this(child: Expression) = this(child, Literal(ApproxTopK.DEFAULT_MAX_ITEMS_TRACKED), 0, 0) + + private lazy val itemDataType: DataType = expr.dataType + + private lazy val maxItemsTrackedVal: Int = { + ApproxTopK.checkExpressionNotNull(maxItemsTracked, "maxItemsTracked") + val maxItemsTrackedVal = maxItemsTracked.eval().asInstanceOf[Int] + ApproxTopK.checkMaxItemsTracked(maxItemsTrackedVal) + maxItemsTrackedVal + } + + override def left: Expression = expr + + override def right: Expression = maxItemsTracked + + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegerType) + + override def checkInputDataTypes(): TypeCheckResult = { + val defaultCheck = super.checkInputDataTypes() + if (defaultCheck.isFailure) { + defaultCheck + } else if (!ApproxTopK.isDataTypeSupported(itemDataType)) { + TypeCheckFailure(f"${itemDataType.typeName} columns are not supported") + } else if (!maxItemsTracked.foldable) { + TypeCheckFailure("Number of items tracked must be a constant literal") + } else { + TypeCheckSuccess + } + } + + override def dataType: DataType = ApproxTopK.getSketchStateDataType(itemDataType) + + override def createAggregationBuffer(): ItemsSketch[Any] = { + val maxMapSize = ApproxTopK.calMaxMapSize(maxItemsTrackedVal) + ApproxTopK.createAggregationBuffer(expr, maxMapSize) + } + + override def update(buffer: ItemsSketch[Any], input: InternalRow): ItemsSketch[Any] = + ApproxTopK.updateSketchBuffer(expr, buffer, input) + + override def merge(buffer: ItemsSketch[Any], input: ItemsSketch[Any]): ItemsSketch[Any] = + buffer.merge(input) + + override def eval(buffer: ItemsSketch[Any]): Any = { + val sketchBytes = serialize(buffer) + val typeCode = ApproxTopK.dataTypeToBytes(itemDataType) + InternalRow.apply(sketchBytes, null, maxItemsTrackedVal, typeCode) + } + + override def serialize(buffer: ItemsSketch[Any]): Array[Byte] = + buffer.toByteArray(ApproxTopK.genSketchSerDe(itemDataType)) + + override def deserialize(storageFormat: Array[Byte]): ItemsSketch[Any] = + ItemsSketch.getInstance(Memory.wrap(storageFormat), ApproxTopK.genSketchSerDe(itemDataType)) + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override protected def withNewChildrenInternal( + newLeft: Expression, + newRight: Expression): Expression = + copy(expr = newLeft, maxItemsTracked = newRight) + + override def nullable: Boolean = false + + override def prettyName: String = + getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("approx_top_k_accumulate") +} + +class CombineInternal[T]( + sketch: ItemsSketch[T], + var itemDataType: DataType, + var maxItemsTracked: Int) { + def getSketch: ItemsSketch[T] = sketch + + def getItemDataType: DataType = itemDataType + + def setItemDataType(dataType: DataType): Unit = { + if (this.itemDataType == null) { + this.itemDataType = dataType + } else if (this.itemDataType != dataType) { + throw new SparkUnsupportedOperationException( + errorClass = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED", + messageParameters = Map( + "type1" -> this.itemDataType.typeName, + "type2" -> dataType.typeName)) + } + } + + def getMaxItemsTracked: Int = maxItemsTracked + + def setMaxItemsTracked(maxItemsTracked: Int): Unit = this.maxItemsTracked = maxItemsTracked +} + +case class ApproxTopKCombine( + expr: Expression, + maxItemsTracked: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) + extends TypedImperativeAggregate[CombineInternal[Any]] + with ImplicitCastInputTypes + with BinaryLike[Expression] { + + def this(child: Expression, maxItemsTracked: Expression) = { + this(child, maxItemsTracked, 0, 0) + ApproxTopK.checkExpressionNotNull(maxItemsTracked, "maxItemsTracked") + ApproxTopK.checkMaxItemsTracked(maxItemsTrackedVal) + } + + def this(child: Expression, maxItemsTracked: Int) = this(child, Literal(maxItemsTracked)) + + def this(child: Expression) = this(child, Literal(ApproxTopK.VOID_MAX_ITEMS_TRACKED), 0, 0) + + private lazy val uncheckedItemDataType: DataType = + expr.dataType.asInstanceOf[StructType]("ItemTypeNull").dataType + private lazy val maxItemsTrackedVal: Int = maxItemsTracked.eval().asInstanceOf[Int] + private lazy val combineSizeSpecified: Boolean = + maxItemsTrackedVal != ApproxTopK.VOID_MAX_ITEMS_TRACKED + + override def left: Expression = expr + + override def right: Expression = maxItemsTracked + + override def inputTypes: Seq[AbstractDataType] = Seq(StructType, IntegerType) + + override def checkInputDataTypes(): TypeCheckResult = { + val defaultCheck = super.checkInputDataTypes() + if (defaultCheck.isFailure) { + defaultCheck + } else if (!maxItemsTracked.foldable) { + TypeCheckFailure("Number of items tracked must be a constant literal") + } else { + TypeCheckSuccess + } + } + + override def dataType: DataType = ApproxTopK.getSketchStateDataType(uncheckedItemDataType) + + override def createAggregationBuffer(): CombineInternal[Any] = { + if (combineSizeSpecified) { + val maxMapSize = ApproxTopK.calMaxMapSize(maxItemsTrackedVal) + new CombineInternal[Any]( + new ItemsSketch[Any](maxMapSize), + null, + maxItemsTrackedVal) + } else { + new CombineInternal[Any]( + new ItemsSketch[Any](ApproxTopK.SKETCH_SIZE_PLACEHOLDER), + null, + ApproxTopK.VOID_MAX_ITEMS_TRACKED) + } + } + + override def update(buffer: CombineInternal[Any], input: InternalRow): CombineInternal[Any] = { + val inputSketchBytes = expr.eval(input).asInstanceOf[InternalRow].getBinary(0) + val inputMaxItemsTracked = expr.eval(input).asInstanceOf[InternalRow].getInt(2) + val typeCode = expr.eval(input).asInstanceOf[InternalRow].getBinary(3) + val actualItemDataType = ApproxTopK.bytesToDataType(typeCode) + buffer.setItemDataType(actualItemDataType) + val inputSketch = ItemsSketch.getInstance( + Memory.wrap(inputSketchBytes), ApproxTopK.genSketchSerDe(buffer.getItemDataType)) + buffer.getSketch.merge(inputSketch) + if (!combineSizeSpecified) { + buffer.setMaxItemsTracked(inputMaxItemsTracked) + } + buffer + } + + override def merge(buffer: CombineInternal[Any], input: CombineInternal[Any]) + : CombineInternal[Any] = { + if (!combineSizeSpecified) { + // check size + if (buffer.getMaxItemsTracked == ApproxTopK.VOID_MAX_ITEMS_TRACKED) { + // If buffer is a placeholder sketch, set it to the input sketch's max items tracked + buffer.setMaxItemsTracked(input.getMaxItemsTracked) + } + if (buffer.getMaxItemsTracked != input.getMaxItemsTracked) { + throw new SparkUnsupportedOperationException( + errorClass = "APPROX_TOP_K_SKETCH_SIZE_UNMATCHED", + messageParameters = Map( + "size1" -> buffer.getMaxItemsTracked.toString, + "size2" -> input.getMaxItemsTracked.toString)) + } + } + // check item data type + if (buffer.getItemDataType != null && input.getItemDataType != null && + buffer.getItemDataType != input.getItemDataType) { + throw new SparkUnsupportedOperationException( + errorClass = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED", + messageParameters = Map( + "type1" -> buffer.getItemDataType.typeName, + "type2" -> input.getItemDataType.typeName)) + } else if (buffer.getItemDataType == null) { + // If buffer is a placeholder sketch, set it to the input sketch's item data type + buffer.setItemDataType(input.getItemDataType) + } + buffer.getSketch.merge(input.getSketch) + buffer + } + + override def eval(buffer: CombineInternal[Any]): Any = { + val sketchBytes = try { + buffer.getSketch.toByteArray(ApproxTopK.genSketchSerDe(buffer.getItemDataType)) + } catch { + case _: ArrayStoreException => + throw new SparkUnsupportedOperationException( + errorClass = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED" + ) + } + val maxItemsTracked = buffer.getMaxItemsTracked + val typeCode = ApproxTopK.dataTypeToBytes(buffer.getItemDataType) + InternalRow.apply(sketchBytes, null, maxItemsTracked, typeCode) + } + + override def serialize(buffer: CombineInternal[Any]): Array[Byte] = { + val sketchBytes = buffer.getSketch.toByteArray( + ApproxTopK.genSketchSerDe(buffer.getItemDataType)) + val maxItemsTrackedByte = buffer.getMaxItemsTracked.toByte + val itemDataTypeBytes = ApproxTopK.dataTypeToBytes(buffer.getItemDataType) + val byteArray = new Array[Byte](sketchBytes.length + 4) + byteArray(0) = maxItemsTrackedByte + System.arraycopy(itemDataTypeBytes, 0, byteArray, 1, itemDataTypeBytes.length) + System.arraycopy(sketchBytes, 0, byteArray, 4, sketchBytes.length) + byteArray + } + + override def deserialize(buffer: Array[Byte]): CombineInternal[Any] = { + val maxItemsTracked = buffer(0).toInt + val itemDataTypeBytes = buffer.slice(1, 4) + val actualItemDataType = ApproxTopK.bytesToDataType(itemDataTypeBytes) + val sketchBytes = buffer.slice(4, buffer.length) + val sketch = ItemsSketch.getInstance( + Memory.wrap(sketchBytes), ApproxTopK.genSketchSerDe(actualItemDataType)) + new CombineInternal[Any](sketch, actualItemDataType, maxItemsTracked) + } + + override protected def withNewChildrenInternal( + newLeft: Expression, + newRight: Expression): Expression = copy(expr = newLeft, maxItemsTracked = newRight) + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def nullable: Boolean = false } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKSuite.scala new file mode 100644 index 0000000000000..e2ad477202f43 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKSuite.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.{SparkFunSuite, SparkRuntimeException} +import org.apache.spark.sql.catalyst.expressions.{Abs, ApproxTopKEstimate, BoundReference, Literal} +import org.apache.spark.sql.types.{ArrayType, BinaryType, IntegerType, MapType, StringType, StructField, StructType} + +class ApproxTopKSuite extends SparkFunSuite { + + test("SPARK-52515: Accepts literal and foldable inputs") { + val agg = new ApproxTopK( + expr = BoundReference(0, IntegerType, nullable = true), + k = Abs(Literal(10)), + maxItemsTracked = Abs(Literal(-10)) + ) + assert(agg.checkInputDataTypes().isSuccess) + } + + test("SPARK-52515: Fail if parameters are not foldable") { + val badAgg = new ApproxTopK( + expr = BoundReference(0, IntegerType, nullable = true), + k = Sum(BoundReference(1, IntegerType, nullable = true)), + maxItemsTracked = Literal(10) + ) + assert(badAgg.checkInputDataTypes().isFailure) + + val badAgg2 = new ApproxTopK( + expr = BoundReference(0, IntegerType, nullable = true), + k = Literal(10), + maxItemsTracked = Sum(BoundReference(1, IntegerType, nullable = true)) + ) + assert(badAgg2.checkInputDataTypes().isFailure) + } + + test("SPARK-52515: invalid item types") { + Seq( + ArrayType(IntegerType), + MapType(StringType, IntegerType), + StructType(Seq(StructField("a", IntegerType), StructField("b", StringType))), + BinaryType + ).foreach { unsupportedType => + val agg = new ApproxTopK( + expr = BoundReference(0, unsupportedType, nullable = true), + k = Literal(10), + maxItemsTracked = Literal(10000) + ) + assert(agg.checkInputDataTypes().isFailure) + } + } + + test("SPARK-52588: invalid accumulate if item type is not supported") { + Seq( + ArrayType(IntegerType), + MapType(StringType, IntegerType), + StructType(Seq(StructField("a", IntegerType), StructField("b", StringType))), + BinaryType + ).foreach { + unsupportedType => + val badAccumulate = ApproxTopKAccumulate( + expr = BoundReference(0, unsupportedType, nullable = true), + maxItemsTracked = Literal(10) + ) + assert(badAccumulate.checkInputDataTypes().isFailure) + } + } + + test("SPARK-52588: invalid accumulate if maxItemsTracked are not foldable") { + val badAccumulate = ApproxTopKAccumulate( + expr = BoundReference(0, IntegerType, nullable = true), + maxItemsTracked = Sum(BoundReference(1, IntegerType, nullable = true)) + ) + assert(badAccumulate.checkInputDataTypes().isFailure) + } + + test("SPARK-52588: invalid accumulate if maxItemsTracked less than or equal to 0") { + Seq(0, -1).foreach { invalidInput => + val badAccumulate = ApproxTopKAccumulate( + expr = BoundReference(0, IntegerType, nullable = true), + maxItemsTracked = Literal(invalidInput) + ) + checkError( + exception = intercept[SparkRuntimeException] { + badAccumulate.createAggregationBuffer() + }, + condition = "APPROX_TOP_K_NON_POSITIVE_ARG", + parameters = Map("argName" -> "`maxItemsTracked`", "argValue" -> invalidInput.toString) + ) + } + } + + test("SPARK-52588: invalid estimate if k are not foldable") { + val badEstimate = ApproxTopKEstimate( + state = BoundReference(0, IntegerType, nullable = false), + k = Sum(BoundReference(1, IntegerType, nullable = true)) + ) + assert(badEstimate.checkInputDataTypes().isFailure) + } + + test("SPARK-combine: invalid combine if maxItemsTracked are not foldable") { + val badCombine = ApproxTopKCombine( + expr = BoundReference(0, BinaryType, nullable = false), + maxItemsTracked = Sum(BoundReference(1, IntegerType, nullable = true)) + ) + assert(badCombine.checkInputDataTypes().isFailure) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index b7f129c35c5dd..10ff56cbcbbff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -24,10 +24,8 @@ import scala.util.Random import org.scalatest.matchers.must.Matchers.the -import org.apache.spark.{SparkArithmeticException, SparkRuntimeException} +import org.apache.spark.{SparkArithmeticException, SparkRuntimeException, SparkUnsupportedOperationException} import org.apache.spark.sql.catalyst.ExtendedAnalysisException -import org.apache.spark.sql.catalyst.expressions.{Abs, BoundReference, Literal} -import org.apache.spark.sql.catalyst.expressions.aggregate.{ApproxTopK, Sum} import org.apache.spark.sql.catalyst.plans.logical.Expand import org.apache.spark.sql.catalyst.util.AUTO_GENERATED_ALIAS import org.apache.spark.sql.errors.DataTypeErrors.toSQLId @@ -2880,31 +2878,6 @@ class DataFrameAggregateSuite extends QueryTest checkAnswer(res, Row(Seq(Row("b", 3), Row("a", 2)))) } - test("SPARK-52515: Accepts literal and foldable inputs") { - val agg = new ApproxTopK( - expr = BoundReference(0, LongType, nullable = true), - k = Abs(Literal(10)), - maxItemsTracked = Abs(Literal(-10)) - ) - assert(agg.checkInputDataTypes().isSuccess) - } - - test("SPARK-52515: Fail if parameters are not foldable") { - val badAgg = new ApproxTopK( - expr = BoundReference(0, LongType, nullable = true), - k = Sum(BoundReference(1, LongType, nullable = true)), - maxItemsTracked = Literal(10) - ) - assert(badAgg.checkInputDataTypes().isFailure) - - val badAgg2 = new ApproxTopK( - expr = BoundReference(0, LongType, nullable = true), - k = Literal(10), - maxItemsTracked = Sum(BoundReference(1, LongType, nullable = true)) - ) - assert(badAgg2.checkInputDataTypes().isFailure) - } - test("SPARK-52626: Support group by Time column") { val ts1 = "15:00:00" val ts2 = "22:00:00" @@ -2931,6 +2904,339 @@ class DataFrameAggregateSuite extends QueryTest res, Row(LocalTime.of(22, 1, 0), LocalTime.of(3, 0, 0))) } + + test("SPARK-52588: accumulate and estimate of Integer with default parameters") { + val res = sql("SELECT approx_top_k_estimate(approx_top_k_accumulate(expr)) " + + "FROM VALUES (0), (0), (0), (1), (1), (2), (3), (4) AS tab(expr);") + checkAnswer(res, Row(Seq(Row(0, 3), Row(1, 2), Row(4, 1), Row(2, 1), Row(3, 1)))) + } + + test("SPARK-52588: accumulate and estimate of String") { + val res = sql("SELECT approx_top_k_estimate(approx_top_k_accumulate(expr), 2) " + + "FROM VALUES 'a', 'b', 'c', 'c', 'c', 'c', 'd', 'd' AS tab(expr);") + checkAnswer(res, Row(Seq(Row("c", 4), Row("d", 2)))) + } + + test("SPARK-52588: accumulate and estimate of Decimal(4, 1)") { + val res = sql("SELECT approx_top_k_estimate(approx_top_k_accumulate(expr, 10)) " + + "FROM VALUES CAST(0.0 AS DECIMAL(4, 1)), CAST(0.0 AS DECIMAL(4, 1)), " + + "CAST(0.0 AS DECIMAL(4, 1)), CAST(1.0 AS DECIMAL(4, 1)), " + + "CAST(1.0 AS DECIMAL(4, 1)), CAST(2.0 AS DECIMAL(4, 1)) AS tab(expr);") + checkAnswer(res, Row(Seq( + Row(new java.math.BigDecimal("0.0"), 3), + Row(new java.math.BigDecimal("1.0"), 2), + Row(new java.math.BigDecimal("2.0"), 1)))) + } + + test("SPARK-52588: accumulate and estimate of Decimal(20, 3)") { + val res = sql("SELECT approx_top_k_estimate(approx_top_k_accumulate(expr, 10), 2) " + + "FROM VALUES CAST(0.0 AS DECIMAL(20, 3)), CAST(0.0 AS DECIMAL(20, 3)), " + + "CAST(0.0 AS DECIMAL(20, 3)), CAST(1.0 AS DECIMAL(20, 3)), " + + "CAST(1.0 AS DECIMAL(20, 3)), CAST(2.0 AS DECIMAL(20, 3)) AS tab(expr);") + checkAnswer(res, Row(Seq( + Row(new java.math.BigDecimal("0.000"), 3), + Row(new java.math.BigDecimal("1.000"), 2)))) + } + + test("SPARK-52588: invalid accumulate if maxItemsTracked is null") { + checkError( + exception = intercept[SparkRuntimeException] { + sql("SELECT approx_top_k_accumulate(expr, NULL) FROM VALUES 0, 1, 2 AS tab(expr);") + .collect() + }, + condition = "APPROX_TOP_K_NULL_ARG", + parameters = Map("argName" -> "`maxItemsTracked`") + ) + } + + test("SPARK-52588: invalid accumulate if maxItemsTracked > 1000000") { + checkError( + exception = intercept[SparkRuntimeException] { + sql("SELECT approx_top_k_accumulate(expr, 1000001) FROM VALUES (0) AS tab(expr);").collect() + }, + condition = "APPROX_TOP_K_MAX_ITEMS_TRACKED_EXCEEDS_LIMIT", + parameters = Map("maxItemsTracked" -> "1000001", "limit" -> "1000000") + ) + } + + test("SPARK-52588: invalid estimate if k is null") { + checkError( + exception = intercept[SparkRuntimeException] { + sql("SELECT approx_top_k_estimate(approx_top_k_accumulate(expr), NULL) " + + "FROM VALUES 0, 1, 2 AS tab(expr);").collect() + }, + condition = "APPROX_TOP_K_NULL_ARG", + parameters = Map("argName" -> "`k`") + ) + } + + test("SPARK-52588: invalid estimate if k is invalid") { + checkError( + exception = intercept[SparkRuntimeException] { + sql("SELECT approx_top_k_estimate(approx_top_k_accumulate(expr), 0) " + + "FROM VALUES 0, 1, 2 AS tab(expr);").collect() + }, + condition = "APPROX_TOP_K_NON_POSITIVE_ARG", + parameters = Map("argName" -> "`k`", "argValue" -> "0") + ) + } + + test("SPARK-52588: invalid estimate if k > Int.MaxValue") { + withSQLConf("spark.sql.ansi.enabled" -> true.toString) { + val k: Long = Int.MaxValue + 1L + checkError( + exception = intercept[SparkArithmeticException] { + sql(s"SELECT approx_top_k_estimate(approx_top_k_accumulate(expr), $k) " + + "FROM VALUES 0, 1, 2 AS tab(expr);").collect() + }, + condition = "CAST_OVERFLOW", + parameters = Map( + "value" -> (k.toString + "L"), + "sourceType" -> "\"BIGINT\"", + "targetType" -> "\"INT\"", + "ansiConfig" -> "\"spark.sql.ansi.enabled\"" + ) + ) + } + } + + test("SPARK-52588: invalid estimate if k > maxItemsTracked") { + checkError( + exception = intercept[SparkRuntimeException] { + sql("SELECT approx_top_k_estimate(approx_top_k_accumulate(expr, 5), 10) " + + "FROM VALUES 0, 1, 2 AS tab(expr);").collect() + }, + condition = "APPROX_TOP_K_MAX_ITEMS_TRACKED_LESS_THAN_K", + parameters = Map("maxItemsTracked" -> "5", "k" -> "10") + ) + } + + def setupAccumulations(size1: Int, size2: Int): Unit = { + sql(s"SELECT approx_top_k_accumulate(expr, $size1) as acc " + + "FROM VALUES (0), (0), (0), (1), (1), (2), (2), (3) AS tab(expr);") + .createOrReplaceTempView("accumulation1") + + sql(s"SELECT approx_top_k_accumulate(expr, $size2) as acc " + + "FROM VALUES (1), (1), (2), (2), (3), (3), (4), (4) AS tab(expr);") + .createOrReplaceTempView("accumulation2") + } + + test("SPARK-combine: same type, same size, specified combine size - success") { + setupAccumulations(10, 10) + + sql("SELECT approx_top_k_combine(acc, 30) as com " + + "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2);") + .createOrReplaceTempView("combined") + + val est = sql("SELECT approx_top_k_estimate(com) FROM combined;") + checkAnswer(est, Row(Seq(Row(2, 4), Row(1, 4), Row(0, 3), Row(3, 3), Row(4, 2)))) + } + + test("SPARK-combine: same type, same size, unspecified combine size - success") { + setupAccumulations(10, 10) + + sql("SELECT approx_top_k_combine(acc) as com " + + "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2);") + .createOrReplaceTempView("combined") + + val est = sql("SELECT approx_top_k_estimate(com) FROM combined;") + checkAnswer(est, Row(Seq(Row(2, 4), Row(1, 4), Row(0, 3), Row(3, 3), Row(4, 2)))) + } + + test("SPARK-combine: same type, different size, specified combine size - success") { + setupAccumulations(10, 20) + + sql("SELECT approx_top_k_combine(acc, 30) as com " + + "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2);") + .createOrReplaceTempView("combination") + + val est = sql("SELECT approx_top_k_estimate(com) FROM combination;") + checkAnswer(est, Row(Seq(Row(2, 4), Row(1, 4), Row(0, 3), Row(3, 3), Row(4, 2)))) + } + + test("SPARK-combine: same type, different size, unspecified combine size - fail") { + setupAccumulations(10, 20) + + val comb = sql("SELECT approx_top_k_combine(acc) as com " + + "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2);") + + checkError( + exception = intercept[SparkUnsupportedOperationException] { + comb.collect() + }, + condition = "APPROX_TOP_K_SKETCH_SIZE_UNMATCHED", + parameters = Map("size1" -> "10", "size2" -> "20") + ) + } + + test("SPARK-combine: same type, different size, invalid combine size - fail") { + setupAccumulations(10, 20) + + checkError( + exception = intercept[SparkRuntimeException] { + sql("SELECT approx_top_k_combine(acc, -1) as com " + + "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2);") + }, + condition = "APPROX_TOP_K_NON_POSITIVE_ARG", + parameters = Map("argName" -> "`maxItemsTracked`", "argValue" -> "-1") + ) + } + + def setupMixedTypeAccumulation(seq1: Seq[Any], seq2: Seq[Any]): Unit = { + sql(s"SELECT approx_top_k_accumulate(expr, 10) as acc " + + s"FROM VALUES ${seq1.mkString(", ")} AS tab(expr);") + .createOrReplaceTempView("accumulation1") + + sql(s"SELECT approx_top_k_accumulate(expr, 10) as acc " + + s"FROM VALUES ${seq2.mkString(", ")} AS tab(expr);") + .createOrReplaceTempView("accumulation2") + + sql("SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2") + .createOrReplaceTempView("unioned") + } + + test("SPARK-combine: different types, same size, specified combine size - fail") { + val mixedTypeSeqs = Seq( + (IntegerType.typeName, Seq(0, 0, 0, 1, 1, 2, 2, 3)), + (StringType.typeName, Seq("'a'", "'b'", "'c'", "'c'", "'c'", "'c'", "'d'", "'d'")), + // (BooleanType.typeName, Seq("(true)", "(true)", "(true)", "(false)", "(false)")), + (ByteType.typeName, Seq("cast(0 AS BYTE)", "cast(0 AS BYTE)", "cast(1 AS BYTE)")), + (ShortType.typeName, Seq("cast(0 AS SHORT)", "cast(0 AS SHORT)", "cast(1 AS SHORT)")), + (LongType.typeName, Seq("cast(0 AS LONG)", "cast(0 AS LONG)", "cast(1 AS LONG)")), + (FloatType.typeName, Seq("cast(0 AS FLOAT)", "cast(0 AS FLOAT)", "cast(1 AS FLOAT)")), + (DoubleType.typeName, Seq("cast(0 AS DOUBLE)", "cast(0 AS DOUBLE)", "cast(1 AS DOUBLE)")), + (DecimalType(4, 2).typeName, Seq("cast(0 AS DECIMAL(4, 2))", "cast(0 AS DECIMAL(4, 2))", + "cast(1 AS DECIMAL(4, 2))")), + (DecimalType(10, 2).typeName, Seq("cast(0 AS DECIMAL(10, 2))", "cast(0 AS DECIMAL(10, 2))", + "cast(1 AS DECIMAL(10, 2))")), + (DecimalType(20, 3).typeName, Seq("cast(0 AS DECIMAL(20, 3))", "cast(0 AS DECIMAL(20, 3))", + "cast(1 AS DECIMAL(20, 3))")) + // DateType.typeName -> Seq("DATE'2025-01-01'", "DATE'2025-01-01'", "DATE'2025-01-02'"), + // TimestampType.typeName -> Seq("TIMESTAMP'2025-01-01 00:00:00'", + // "TIMESTAMP'2025-01-01 00:00:00'", "TIMESTAMP'2025-01-02 00:00:00'"), + // TimestampNTZType.typeName -> Seq("TIMESTAMP_NTZ'2025-01-01 00:00:00'", + // "TIMESTAMP_NTZ'2025-01-01 00:00:00'", "TIMESTAMP_NTZ'2025-01-02 00:00:00'") + ) + + for (i <- 0 until mixedTypeSeqs.size - 1) { + for (j <- i + 1 until mixedTypeSeqs.size) { + val (type1, seq1) = mixedTypeSeqs(i) + val (type2, seq2) = mixedTypeSeqs(j) + setupMixedTypeAccumulation(seq1, seq2) + checkError( + exception = intercept[SparkUnsupportedOperationException] { + sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned;").collect() + }, + condition = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED", + parameters = Map("type1" -> type1, "type2" -> type2) + ) + } + } + } + + test("SPARK-combine: different types - fail on UNION") { + val seq1 = Seq(0, 0, 0, 1, 1, 2, 2, 3) + val seq2 = Seq("(true)", "(true)", "(false)", "(false)") + checkError( + exception = intercept[ExtendedAnalysisException] { + setupMixedTypeAccumulation(seq1, seq2) + }, + condition = "INCOMPATIBLE_COLUMN_TYPE", + parameters = Map( + "tableOrdinalNumber" -> "second", + "columnOrdinalNumber" -> "first", + "dataType2" -> ("\"STRUCT\""), + "operator" -> "UNION", + "hint" -> "", + "dataType1" -> ("\"STRUCT\"") + ), + queryContext = Array( + ExpectedContext( + "SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2", 0, 68)) + ) + } + + test("SPARK-combine: different types Date vs Timestamp - fail") { + val (type1, seq1) = (DateType.typeName, + Seq("DATE'2025-01-01'", "DATE'2025-01-01'", "DATE'2025-01-02'")) + val (type2, seq2) = (TimestampType.typeName, Seq("TIMESTAMP'2025-01-01 00:00:00'", + "TIMESTAMP'2025-01-01 00:00:00'", "TIMESTAMP'2025-01-02 00:00:00'")) + setupMixedTypeAccumulation(seq1, seq2) + checkError( + exception = intercept[SparkUnsupportedOperationException] { + sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned;").collect() + }, + condition = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED", + parameters = Map("type1" -> type1, "type2" -> type2) + ) + } + + test("SPARK-combine: different types Timestamp vs TimestampNTZ - fail") { + val (type1, seq1) = (TimestampType.typeName, + Seq("TIMESTAMP'2025-01-01 00:00:00'", "TIMESTAMP'2025-01-01 00:00:00'", + "TIMESTAMP'2025-01-02 00:00:00'")) + val (type2, seq2) = (TimestampNTZType.typeName, + Seq("TIMESTAMP_NTZ'2025-01-01 00:00:00'", "TIMESTAMP_NTZ'2025-01-01 00:00:00'", + "TIMESTAMP_NTZ'2025-01-02 00:00:00'")) + setupMixedTypeAccumulation(seq1, seq2) + checkError( + exception = intercept[SparkUnsupportedOperationException] { + sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned;").collect() + }, + condition = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED", + parameters = Map("type1" -> type1, "type2" -> type2) + ) + } + + test("SPARK-combine: different types Int vs Date - fail on UNION") { + val seq1 = Seq(0, 0, 0, 1, 1, 2, 2, 3) + val seq2 = Seq("DATE'2025-01-01'", "DATE'2025-01-01'", "DATE'2025-01-02'") + checkError( + exception = intercept[ExtendedAnalysisException] { + setupMixedTypeAccumulation(seq1, seq2) + }, + condition = "INCOMPATIBLE_COLUMN_TYPE", + parameters = Map( + "tableOrdinalNumber" -> "second", + "columnOrdinalNumber" -> "first", + "dataType2" -> ("\"STRUCT\""), + "operator" -> "UNION", + "hint" -> "", + "dataType1" -> ("\"STRUCT\"") + ), + queryContext = Array( + ExpectedContext( + "SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2", 0, 68)) + ) + } + + test("SPARK-combine: different types Long vs Timestamp - fail on UNION") { + val seq1 = Seq("cast(0 AS LONG)", "cast(0 AS LONG)", "cast(1 AS LONG)") + val seq2 = Seq("TIMESTAMP'2025-01-01 00:00:00'", "TIMESTAMP'2025-01-01 00:00:00'") + checkError( + exception = intercept[ExtendedAnalysisException] { + setupMixedTypeAccumulation(seq1, seq2) + }, + condition = "INCOMPATIBLE_COLUMN_TYPE", + parameters = Map( + "tableOrdinalNumber" -> "second", + "columnOrdinalNumber" -> "first", + "dataType2" -> ("\"STRUCT\""), + "operator" -> "UNION", + "hint" -> "", + "dataType1" -> ("\"STRUCT\"") + ), + queryContext = Array( + ExpectedContext( + "SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2", 0, 68)) + ) + } } case class B(c: Option[Double])