From bbd8ae47dc2465489bca024d8a906cd9f70c6eff Mon Sep 17 00:00:00 2001 From: yhuang-db Date: Thu, 26 Jun 2025 11:49:20 -0700 Subject: [PATCH 01/23] init ApproxTopKAccumulate --- .../aggregate/ApproxTopKAggregates.scala | 115 ++++++++++++++++-- 1 file changed, 105 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala index 974af917ef512..d584c17e55b3d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.{ArrayOfDecimalsSerDe, Expression, ExpressionDescription, ImplicitCastInputTypes, Literal} -import org.apache.spark.sql.catalyst.trees.TernaryLike +import org.apache.spark.sql.catalyst.trees.{BinaryLike, TernaryLike} import org.apache.spark.sql.catalyst.util.{CollationFactory, GenericArrayData} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types._ @@ -50,12 +50,14 @@ import org.apache.spark.unsafe.types.UTF8String */ // scalastyle:off line.size.limit @ExpressionDescription( - usage = """ + usage = + """ _FUNC_(expr, k, maxItemsTracked) - Returns top k items with their frequency. `k` An optional INTEGER literal greater than 0. If k is not specified, it defaults to 5. `maxItemsTracked` An optional INTEGER literal greater than or equal to k. If maxItemsTracked is not specified, it defaults to 10000. """, - examples = """ + examples = + """ Examples: > SELECT _FUNC_(expr) FROM VALUES (0), (0), (1), (1), (2), (3), (4), (4) AS tab(expr); [{"item":0,"count":2},{"item":4,"count":2},{"item":1,"count":2},{"item":2,"count":1},{"item":3,"count":1}] @@ -174,10 +176,10 @@ case class ApproxTopK( object ApproxTopK { private val DEFAULT_K: Int = 5 - private val DEFAULT_MAX_ITEMS_TRACKED: Int = 10000 + val DEFAULT_MAX_ITEMS_TRACKED: Int = 10000 private val MAX_ITEMS_TRACKED_LIMIT: Int = 1000000 - private def checkExpressionNotNull(expr: Expression, exprName: String): Unit = { + def checkExpressionNotNull(expr: Expression, exprName: String): Unit = { if (expr == null || expr.eval() == null) { throw QueryExecutionErrors.approxTopKNullArg(exprName) } @@ -189,11 +191,15 @@ object ApproxTopK { } } - private def checkMaxItemsTracked(maxItemsTracked: Int, k: Int): Unit = { + def checkMaxItemsTracked(maxItemsTracked: Int): Unit = { if (maxItemsTracked > MAX_ITEMS_TRACKED_LIMIT) { throw QueryExecutionErrors.approxTopKMaxItemsTrackedExceedsLimit( maxItemsTracked, MAX_ITEMS_TRACKED_LIMIT) } + } + + private def checkMaxItemsTracked(maxItemsTracked: Int, k: Int): Unit = { + checkMaxItemsTracked(maxItemsTracked) if (maxItemsTracked < k) { throw QueryExecutionErrors.approxTopKMaxItemsTrackedLessThanK(maxItemsTracked, k) } @@ -206,7 +212,7 @@ object ApproxTopK { ArrayType(resultEntryType, containsNull = false) } - private def isDataTypeSupported(itemType: DataType): Boolean = { + def isDataTypeSupported(itemType: DataType): Boolean = { itemType match { case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType | _: DoubleType | _: DateType | @@ -216,7 +222,7 @@ object ApproxTopK { } } - private def calMaxMapSize(maxItemsTracked: Int): Int = { + def calMaxMapSize(maxItemsTracked: Int): Int = { // The maximum capacity of this internal hash map has maxMapCap = 0.75 * maxMapSize // Therefore, the maxMapSize must be at least ceil(maxItemsTracked / 0.75) // https://datasketches.apache.org/docs/Frequency/FrequentItemsOverview.html @@ -242,7 +248,7 @@ object ApproxTopK { } } - private def updateSketchBuffer( + def updateSketchBuffer( itemExpression: Expression, buffer: ItemsSketch[Any], input: InternalRow): ItemsSketch[Any] = { @@ -290,7 +296,7 @@ object ApproxTopK { new GenericArrayData(result) } - private def genSketchSerDe(dataType: DataType): ArrayOfItemsSerDe[Any] = { + def genSketchSerDe(dataType: DataType): ArrayOfItemsSerDe[Any] = { dataType match { case _: BooleanType => new ArrayOfBooleansSerDe().asInstanceOf[ArrayOfItemsSerDe[Any]] case _: ByteType | _: ShortType | _: IntegerType | _: FloatType | _: DateType => @@ -305,4 +311,93 @@ object ApproxTopK { new ArrayOfDecimalsSerDe(dt).asInstanceOf[ArrayOfItemsSerDe[Any]] } } + + def getSketchStateDataType(itemDataType: DataType): StructType = + StructType( + StructField("Sketch", BinaryType, nullable = false) :: + StructField("ItemTypeNull", itemDataType) :: + StructField("MaxItemsTracked", IntegerType, nullable = false) :: Nil) +} + +case class ApproxTopKAccumulate( + expr: Expression, + maxItemsTracked: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) + extends TypedImperativeAggregate[ItemsSketch[Any]] + with ImplicitCastInputTypes + with BinaryLike[Expression] { + + def this(child: Expression, maxItemsTracked: Expression) = this(child, maxItemsTracked, 0, 0) + + def this(child: Expression, maxItemsTracked: Int) = this(child, Literal(maxItemsTracked), 0, 0) + + def this(child: Expression) = this(child, Literal(ApproxTopK.DEFAULT_MAX_ITEMS_TRACKED), 0, 0) + + private lazy val itemDataType: DataType = expr.dataType + private lazy val maxItemsTrackedVal: Int = { + ApproxTopK.checkExpressionNotNull(maxItemsTracked, "maxItemsTracked") + val maxItemsTrackedVal = maxItemsTracked.eval().asInstanceOf[Int] + ApproxTopK.checkMaxItemsTracked(maxItemsTrackedVal) + maxItemsTrackedVal + } + + override def left: Expression = expr + + override def right: Expression = maxItemsTracked + + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegerType) + + override def checkInputDataTypes(): TypeCheckResult = { + val defaultCheck = super.checkInputDataTypes() + if (defaultCheck.isFailure) { + defaultCheck + } else if (!ApproxTopK.isDataTypeSupported(itemDataType)) { + TypeCheckFailure(f"${itemDataType.typeName} columns are not supported") + } else if (!maxItemsTracked.foldable) { + TypeCheckFailure("Number of items tracked must be a constant literal") + } else { + TypeCheckSuccess + } + } + + override def dataType: DataType = ApproxTopK.getSketchStateDataType(itemDataType) + + override def createAggregationBuffer(): ItemsSketch[Any] = { + val maxMapSize = ApproxTopK.calMaxMapSize(maxItemsTrackedVal) + ApproxTopK.createAggregationBuffer(expr, maxMapSize) + } + + override def update(buffer: ItemsSketch[Any], input: InternalRow): ItemsSketch[Any] = + ApproxTopK.updateSketchBuffer(expr, buffer, input) + + override def merge(buffer: ItemsSketch[Any], input: ItemsSketch[Any]): ItemsSketch[Any] = + buffer.merge(input) + + override def eval(buffer: ItemsSketch[Any]): Any = { + val sketchBytes = serialize(buffer) + InternalRow.apply(sketchBytes, null, maxItemsTrackedVal) + } + + override def serialize(buffer: ItemsSketch[Any]): Array[Byte] = + buffer.toByteArray(ApproxTopK.genSketchSerDe(itemDataType)) + + override def deserialize(storageFormat: Array[Byte]): ItemsSketch[Any] = + ItemsSketch.getInstance(Memory.wrap(storageFormat), ApproxTopK.genSketchSerDe(itemDataType)) + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override protected def withNewChildrenInternal( + newLeft: Expression, + newRight: Expression): Expression = + copy(expr = newLeft, maxItemsTracked = newRight) + + override def nullable: Boolean = false + + override def prettyName: String = + getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("approx_top_k_accumulate") } From 9015bbbbcdf78ee50fc637b5aa29b5d8c87ba417 Mon Sep 17 00:00:00 2001 From: yhuang-db Date: Fri, 27 Jun 2025 11:27:37 -0700 Subject: [PATCH 02/23] init ApproxTopKCombine, combineSizeSpecified undone --- .../resources/error/error-conditions.json | 12 ++ .../catalyst/analysis/FunctionRegistry.scala | 3 + .../aggregate/ApproxTopKAggregates.scala | 152 +++++++++++++++++- 3 files changed, 166 insertions(+), 1 deletion(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 000b1f524f207..12f6d9f7b0e03 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -114,6 +114,18 @@ ], "sqlState" : "22004" }, + "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED" : { + "message" : [ + "Failed to deserialize from memory. Probably due to a mismatch in sketch types. Combining approx_top_k sketches of different types is not allowed." + ], + "sqlState": "42846" + }, + "APPROX_TOP_K_SKETCH_SIZE_UNMATCHED" : { + "message" : [ + "Combining approx_top_k sketches of different sizes is not allowed. Found sketches of size and ." + ], + "sqlState": "42846" + }, "ARITHMETIC_OVERFLOW" : { "message" : [ ". If necessary set to \"false\" to bypass this error." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index b54ae9082f840..89a8e8a9a6a6a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -528,6 +528,8 @@ object FunctionRegistry { expression[HllSketchAgg]("hll_sketch_agg"), expression[HllUnionAgg]("hll_union_agg"), expression[ApproxTopK]("approx_top_k"), + expression[ApproxTopKAccumulate]("approx_top_k_accumulate"), + expression[ApproxTopKCombine]("approx_top_k_combine"), // string functions expression[Ascii]("ascii"), @@ -786,6 +788,7 @@ object FunctionRegistry { expression[EqualNull]("equal_null"), expression[HllSketchEstimate]("hll_sketch_estimate"), expression[HllUnion]("hll_union"), + expression[ApproxTopKEstimate]("approx_top_k_estimate"), // grouping sets expression[Grouping]("grouping"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala index d584c17e55b3d..eb8586384afe7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala @@ -21,7 +21,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.datasketches.common._ import org.apache.datasketches.frequencies.{ErrorType, ItemsSketch} import org.apache.datasketches.memory.Memory - +import org.apache.spark.SparkUnsupportedOperationException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} @@ -178,6 +178,8 @@ object ApproxTopK { private val DEFAULT_K: Int = 5 val DEFAULT_MAX_ITEMS_TRACKED: Int = 10000 private val MAX_ITEMS_TRACKED_LIMIT: Int = 1000000 + val VOID_MAX_ITEMS_TRACKED: Int = -1 + val SKETCH_SIZE_PLACEHOLDER: Int = 8 def checkExpressionNotNull(expr: Expression, exprName: String): Unit = { if (expr == null || expr.eval() == null) { @@ -196,6 +198,9 @@ object ApproxTopK { throw QueryExecutionErrors.approxTopKMaxItemsTrackedExceedsLimit( maxItemsTracked, MAX_ITEMS_TRACKED_LIMIT) } + if (maxItemsTracked <= 0) { + throw QueryExecutionErrors.approxTopKNonPositiveValue("maxItemsTracked", maxItemsTracked) + } } private def checkMaxItemsTracked(maxItemsTracked: Int, k: Int): Unit = { @@ -401,3 +406,148 @@ case class ApproxTopKAccumulate( override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("approx_top_k_accumulate") } + +class CombineInternal[T](sketch: ItemsSketch[T], itemDataType: DataType, var maxItemsTracked: Int) { + def getSketch: ItemsSketch[T] = sketch + + def getItemDataType: DataType = itemDataType + + def getMaxItemsTracked: Int = maxItemsTracked + + def setMaxItemsTracked(maxItemsTracked: Int): Unit = this.maxItemsTracked = maxItemsTracked +} + +case class ApproxTopKCombine( + expr: Expression, + maxItemsTracked: Expression, + combineSizeSpecified: Boolean, // not open to user, used to determine if the size is specified + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) + extends TypedImperativeAggregate[CombineInternal[Any]] + with ImplicitCastInputTypes + with BinaryLike[Expression] { + + def this(child: Expression, maxItemsTracked: Expression) = + this(child, maxItemsTracked, true, 0, 0) + + def this(child: Expression, maxItemsTracked: Int) = + this(child, Literal(maxItemsTracked), true, 0, 0) + + def this(child: Expression) = + this(child, Literal(ApproxTopK.VOID_MAX_ITEMS_TRACKED), false, 0, 0) + + private lazy val itemDataType: DataType = + expr.dataType.asInstanceOf[StructType]("ItemTypeNull").dataType + + private lazy val maxItemsTrackedVal: Int = { + ApproxTopK.checkExpressionNotNull(maxItemsTracked, "maxItemsTracked") + val maxItemsTrackedVal = maxItemsTracked.eval().asInstanceOf[Int] + ApproxTopK.checkMaxItemsTracked(maxItemsTrackedVal) + maxItemsTrackedVal + } + + override def left: Expression = expr + + override def right: Expression = maxItemsTracked + + override def inputTypes: Seq[AbstractDataType] = Seq(StructType, IntegerType) + + override def dataType: DataType = ApproxTopK.getSketchStateDataType(itemDataType) + + override def createAggregationBuffer(): CombineInternal[Any] = { + if (combineSizeSpecified) { + val maxMapSize = ApproxTopK.calMaxMapSize(maxItemsTrackedVal) + new CombineInternal[Any](new ItemsSketch[Any](maxMapSize), itemDataType, maxItemsTrackedVal) + } else { + new CombineInternal[Any]( + new ItemsSketch[Any](ApproxTopK.SKETCH_SIZE_PLACEHOLDER), + itemDataType, + ApproxTopK.VOID_MAX_ITEMS_TRACKED) + } + } + + override def update(buffer: CombineInternal[Any], input: InternalRow): CombineInternal[Any] = { + val inputSketchBytes = expr.eval(input).asInstanceOf[InternalRow].getBinary(0) + val inputMaxItemsTracked = expr.eval(input).asInstanceOf[InternalRow].getInt(2) + val inputSketch = try { + ItemsSketch.getInstance( + Memory.wrap(inputSketchBytes), ApproxTopK.genSketchSerDe(buffer.getItemDataType)) + } catch { + case _: SketchesArgumentException | _: NumberFormatException => + throw new SparkUnsupportedOperationException( + errorClass = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED" + ) + } + buffer.getSketch.merge(inputSketch) + if (!combineSizeSpecified) { + buffer.setMaxItemsTracked(inputMaxItemsTracked) + } + buffer + } + + override def merge(buffer: CombineInternal[Any], input: CombineInternal[Any]) + : CombineInternal[Any] = { + if (!combineSizeSpecified) { + // check size + if (buffer.getMaxItemsTracked == ApproxTopK.VOID_MAX_ITEMS_TRACKED) { + // If buffer is a placeholder sketch, set it to the input sketch's max items tracked + buffer.setMaxItemsTracked(input.getMaxItemsTracked) + } + if (buffer.getMaxItemsTracked != input.getMaxItemsTracked) { + throw new SparkUnsupportedOperationException( + errorClass = "APPROX_TOP_K_SKETCH_SIZE_UNMATCHED", + messageParameters = Map( + "size1" -> buffer.getMaxItemsTracked.toString, + "size2" -> input.getMaxItemsTracked.toString)) + } + } + buffer.getSketch.merge(input.getSketch) + buffer + } + + override def eval(buffer: CombineInternal[Any]): Any = { + val sketchBytes = try { + buffer.getSketch.toByteArray(ApproxTopK.genSketchSerDe(buffer.getItemDataType)) + } catch { + case _: ArrayStoreException => + throw new SparkUnsupportedOperationException( + errorClass = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED" + ) + } + val maxItemsTracked = buffer.getMaxItemsTracked + InternalRow.apply(sketchBytes, null, maxItemsTracked) + } + + override def serialize(buffer: CombineInternal[Any]): Array[Byte] = { + val sketchBytes = buffer.getSketch.toByteArray( + ApproxTopK.genSketchSerDe(buffer.getItemDataType)) + val maxItemsTrackedByte = buffer.getMaxItemsTracked.toByte + val byteArray = new Array[Byte](sketchBytes.length + 1) + byteArray(0) = maxItemsTrackedByte + System.arraycopy(sketchBytes, 0, byteArray, 1, sketchBytes.length) + byteArray + } + + override def deserialize(buffer: Array[Byte]): CombineInternal[Any] = { + val maxItemsTracked = buffer(0).toInt + val sketchBytes = buffer.slice(1, buffer.length) + val sketch = ItemsSketch.getInstance( + Memory.wrap(sketchBytes), ApproxTopK.genSketchSerDe(itemDataType)) + new CombineInternal[Any](sketch, itemDataType, maxItemsTracked) + } + + override protected def withNewChildrenInternal( + newLeft: Expression, + newRight: Expression): Expression = copy(expr = newLeft, maxItemsTracked = newRight) + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def nullable: Boolean = false + + override def prettyName: String = + getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("approx_top_k_combine") +} \ No newline at end of file From f369de77fe8ae55e559c44a60d0427d96eea61b4 Mon Sep 17 00:00:00 2001 From: yhuang-db Date: Fri, 27 Jun 2025 11:27:59 -0700 Subject: [PATCH 03/23] init ApproxTopKCombine, combineSizeSpecified undone --- .../catalyst/expressions/aggregate/ApproxTopKAggregates.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala index eb8586384afe7..2a23c93d10c86 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala @@ -21,6 +21,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.datasketches.common._ import org.apache.datasketches.frequencies.{ErrorType, ItemsSketch} import org.apache.datasketches.memory.Memory + import org.apache.spark.SparkUnsupportedOperationException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult} From 1f319cf6baf2881efe5e8eba5acaf8cb116817b2 Mon Sep 17 00:00:00 2001 From: yhuang-db Date: Fri, 27 Jun 2025 11:32:04 -0700 Subject: [PATCH 04/23] init ApproxTopKEstimate --- .../expressions/ApproxTopKExpressions.scala | 62 +++++++++++++++++++ .../aggregate/ApproxTopKAggregates.scala | 6 +- 2 files changed, 65 insertions(+), 3 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala new file mode 100644 index 0000000000000..56b55c54daac8 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions + +import org.apache.datasketches.frequencies.ItemsSketch +import org.apache.datasketches.memory.Memory + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry +import org.apache.spark.sql.catalyst.expressions.aggregate.ApproxTopK +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.types._ + +case class ApproxTopKEstimate(left: Expression, right: Expression) + extends BinaryExpression + with CodegenFallback + with ImplicitCastInputTypes { + + def this(child: Expression, topK: Int) = this(child, Literal(topK)) + + def this(child: Expression) = this(child, Literal(ApproxTopK.DEFAULT_K)) + + private lazy val itemDataType: DataType = { + // itemDataType is the type of the "ItemTypeNull" field of the output of ACCUMULATE or COMBINE + left.dataType.asInstanceOf[StructType]("ItemTypeNull").dataType + } + + override def inputTypes: Seq[AbstractDataType] = Seq(StructType, IntegerType) + + override def dataType: DataType = ApproxTopK.getResultDataType(itemDataType) + + override def nullSafeEval(input1: Any, input2: Any): Any = { + val dataSketchBytes = input1.asInstanceOf[InternalRow].getBinary(0) + val topK = input2.asInstanceOf[Int] + val itemsSketch = ItemsSketch.getInstance( + Memory.wrap(dataSketchBytes), ApproxTopK.genSketchSerDe(itemDataType)) + ApproxTopK.genEvalResult(itemsSketch, topK, itemDataType) + } + + override protected def withNewChildrenInternal( + newLeft: Expression, + newRight: Expression): Expression = copy(left = newLeft, right = newRight) + + override def nullIntolerant: Boolean = true + + override def prettyName: String = + getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("approx_top_k_estimate") +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala index 2a23c93d10c86..68eb587e70f89 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala @@ -176,7 +176,7 @@ case class ApproxTopK( object ApproxTopK { - private val DEFAULT_K: Int = 5 + val DEFAULT_K: Int = 5 val DEFAULT_MAX_ITEMS_TRACKED: Int = 10000 private val MAX_ITEMS_TRACKED_LIMIT: Int = 1000000 val VOID_MAX_ITEMS_TRACKED: Int = -1 @@ -211,7 +211,7 @@ object ApproxTopK { } } - private def getResultDataType(itemDataType: DataType): DataType = { + def getResultDataType(itemDataType: DataType): DataType = { val resultEntryType = StructType( StructField("item", itemDataType, nullable = false) :: StructField("count", LongType, nullable = false) :: Nil) @@ -280,7 +280,7 @@ object ApproxTopK { buffer } - private def genEvalResult( + def genEvalResult( itemsSketch: ItemsSketch[Any], k: Int, itemDataType: DataType): GenericArrayData = { From c09fd83f183ef5c2db319691c4acfcb95ff03684 Mon Sep 17 00:00:00 2001 From: yhuang-db Date: Fri, 27 Jun 2025 16:00:13 -0700 Subject: [PATCH 05/23] init accumulate and estimate tests --- .../expressions/ApproxTopKExpressions.scala | 21 ++++++--- .../aggregate/ApproxTopKAggregates.scala | 23 ++++++---- .../spark/sql/DataFrameAggregateSuite.scala | 44 +++++++++++++++++++ 3 files changed, 72 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala index 56b55c54daac8..1fe2a768e401f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.ApproxTopK import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.types._ -case class ApproxTopKEstimate(left: Expression, right: Expression) +case class ApproxTopKEstimate(expr: Expression, k: Expression) extends BinaryExpression with CodegenFallback with ImplicitCastInputTypes { @@ -36,24 +36,31 @@ case class ApproxTopKEstimate(left: Expression, right: Expression) private lazy val itemDataType: DataType = { // itemDataType is the type of the "ItemTypeNull" field of the output of ACCUMULATE or COMBINE - left.dataType.asInstanceOf[StructType]("ItemTypeNull").dataType + expr.dataType.asInstanceOf[StructType]("ItemTypeNull").dataType } + override def left: Expression = expr + + override def right: Expression = k + override def inputTypes: Seq[AbstractDataType] = Seq(StructType, IntegerType) override def dataType: DataType = ApproxTopK.getResultDataType(itemDataType) override def nullSafeEval(input1: Any, input2: Any): Any = { val dataSketchBytes = input1.asInstanceOf[InternalRow].getBinary(0) - val topK = input2.asInstanceOf[Int] + val maxItemsTrackedVal = input1.asInstanceOf[InternalRow].getInt(2) + ApproxTopK.checkExpressionNotNull(k, "k") + val kVal = input2.asInstanceOf[Int] + ApproxTopK.checkK(kVal) + ApproxTopK.checkMaxItemsTracked(maxItemsTrackedVal, kVal) val itemsSketch = ItemsSketch.getInstance( Memory.wrap(dataSketchBytes), ApproxTopK.genSketchSerDe(itemDataType)) - ApproxTopK.genEvalResult(itemsSketch, topK, itemDataType) + ApproxTopK.genEvalResult(itemsSketch, kVal, itemDataType) } - override protected def withNewChildrenInternal( - newLeft: Expression, - newRight: Expression): Expression = copy(left = newLeft, right = newRight) + override protected def withNewChildrenInternal(newExpr: Expression, newK: Expression) + : Expression = copy(expr = newExpr, k = newK) override def nullIntolerant: Boolean = true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala index 68eb587e70f89..df9a9ad0760ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala @@ -188,7 +188,7 @@ object ApproxTopK { } } - private def checkK(k: Int): Unit = { + def checkK(k: Int): Unit = { if (k <= 0) { throw QueryExecutionErrors.approxTopKNonPositiveValue("k", k) } @@ -204,7 +204,7 @@ object ApproxTopK { } } - private def checkMaxItemsTracked(maxItemsTracked: Int, k: Int): Unit = { + def checkMaxItemsTracked(maxItemsTracked: Int, k: Int): Unit = { checkMaxItemsTracked(maxItemsTracked) if (maxItemsTracked < k) { throw QueryExecutionErrors.approxTopKMaxItemsTrackedLessThanK(maxItemsTracked, k) @@ -234,7 +234,8 @@ object ApproxTopK { // https://datasketches.apache.org/docs/Frequency/FrequentItemsOverview.html val ceilMaxMapSize = math.ceil(maxItemsTracked / 0.75).toInt // The maxMapSize must be a power of 2 and greater than ceilMaxMapSize - math.pow(2, math.ceil(math.log(ceilMaxMapSize) / math.log(2))).toInt + val maxMapSize = math.pow(2, math.ceil(math.log(ceilMaxMapSize) / math.log(2))).toInt + maxMapSize } def createAggregationBuffer(itemExpression: Expression, maxMapSize: Int): ItemsSketch[Any] = { @@ -341,6 +342,7 @@ case class ApproxTopKAccumulate( def this(child: Expression) = this(child, Literal(ApproxTopK.DEFAULT_MAX_ITEMS_TRACKED), 0, 0) private lazy val itemDataType: DataType = expr.dataType + private lazy val maxItemsTrackedVal: Int = { ApproxTopK.checkExpressionNotNull(maxItemsTracked, "maxItemsTracked") val maxItemsTrackedVal = maxItemsTracked.eval().asInstanceOf[Int] @@ -421,7 +423,6 @@ class CombineInternal[T](sketch: ItemsSketch[T], itemDataType: DataType, var max case class ApproxTopKCombine( expr: Expression, maxItemsTracked: Expression, - combineSizeSpecified: Boolean, // not open to user, used to determine if the size is specified mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[CombineInternal[Any]] @@ -429,13 +430,17 @@ case class ApproxTopKCombine( with BinaryLike[Expression] { def this(child: Expression, maxItemsTracked: Expression) = - this(child, maxItemsTracked, true, 0, 0) + this(child, maxItemsTracked, 0, 0) def this(child: Expression, maxItemsTracked: Int) = - this(child, Literal(maxItemsTracked), true, 0, 0) + this(child, Literal(maxItemsTracked), 0, 0) - def this(child: Expression) = - this(child, Literal(ApproxTopK.VOID_MAX_ITEMS_TRACKED), false, 0, 0) + def this(child: Expression) = { + this(child, Literal(ApproxTopK.VOID_MAX_ITEMS_TRACKED), 0, 0) + combineSizeSpecified = false + } + + private var combineSizeSpecified: Boolean = true private lazy val itemDataType: DataType = expr.dataType.asInstanceOf[StructType]("ItemTypeNull").dataType @@ -551,4 +556,4 @@ case class ApproxTopKCombine( override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("approx_top_k_combine") -} \ No newline at end of file +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index b7f129c35c5dd..df6e494885594 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -2931,6 +2931,50 @@ class DataFrameAggregateSuite extends QueryTest res, Row(LocalTime.of(22, 1, 0), LocalTime.of(3, 0, 0))) } + + test("SPARK-52588: accumulate and estimate of Integer with default parameters") { + val res = sql("SELECT approx_top_k_estimate(approx_top_k_accumulate(expr)) " + + "FROM VALUES (0), (0), (0), (1), (1), (2), (3), (4) AS tab(expr);") + checkAnswer(res, Row(Seq(Row(0, 3), Row(1, 2), Row(4, 1), Row(2, 1), Row(3, 1)))) + } + + test("SPARK-52588: accumulate and estimate of String") { + val res = sql("SELECT approx_top_k_estimate(approx_top_k_accumulate(expr), 2) " + + "FROM VALUES 'a', 'b', 'c', 'c', 'c', 'c', 'd', 'd' AS tab(expr);") + checkAnswer(res, Row(Seq(Row("c", 4), Row("d", 2)))) + } + + test("SPARK-52588: accumulate and estimate of Decimal(4, 1)") { + val res = sql("SELECT approx_top_k_estimate(approx_top_k_accumulate(expr, 10)) " + + "FROM VALUES CAST(0.0 AS DECIMAL(4, 1)), CAST(0.0 AS DECIMAL(4, 1)), " + + "CAST(0.0 AS DECIMAL(4, 1)), CAST(1.0 AS DECIMAL(4, 1)), " + + "CAST(1.0 AS DECIMAL(4, 1)), CAST(2.0 AS DECIMAL(4, 1)) AS tab(expr);") + checkAnswer(res, Row(Seq( + Row(new java.math.BigDecimal("0.0"), 3), + Row(new java.math.BigDecimal("1.0"), 2), + Row(new java.math.BigDecimal("2.0"), 1)))) + } + + test("SPARK-52588: accumulate and estimate of Decimal(20, 3)") { + val res = sql("SELECT approx_top_k_estimate(approx_top_k_accumulate(expr, 10), 2) " + + "FROM VALUES CAST(0.0 AS DECIMAL(20, 3)), CAST(0.0 AS DECIMAL(20, 3)), " + + "CAST(0.0 AS DECIMAL(20, 3)), CAST(1.0 AS DECIMAL(20, 3)), " + + "CAST(1.0 AS DECIMAL(20, 3)), CAST(2.0 AS DECIMAL(20, 3)) AS tab(expr);") + checkAnswer(res, Row(Seq( + Row(new java.math.BigDecimal("0.000"), 3), + Row(new java.math.BigDecimal("1.000"), 2)))) + } + + test("SPARK-52588: invalid estimate if k > maxItemsTracked") { + checkError( + exception = intercept[SparkRuntimeException] { + sql("SELECT approx_top_k_estimate(approx_top_k_accumulate(expr, 5), 10) " + + "FROM VALUES 0, 1, 2 AS tab(expr);").collect() + }, + condition = "APPROX_TOP_K_MAX_ITEMS_TRACKED_LESS_THAN_K", + parameters = Map("maxItemsTracked" -> "5", "k" -> "10") + ) + } } case class B(c: Option[Double]) From 612f8d8e8990de6808ed8566ea1560c37bd6583b Mon Sep 17 00:00:00 2001 From: yhuang-db Date: Fri, 27 Jun 2025 16:24:21 -0700 Subject: [PATCH 06/23] unfinished estimate null check --- .../expressions/ApproxTopKExpressions.scala | 2 +- .../spark/sql/DataFrameAggregateSuite.scala | 25 +++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala index 1fe2a768e401f..886b2e2c5a253 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala @@ -62,7 +62,7 @@ case class ApproxTopKEstimate(expr: Expression, k: Expression) override protected def withNewChildrenInternal(newExpr: Expression, newK: Expression) : Expression = copy(expr = newExpr, k = newK) - override def nullIntolerant: Boolean = true + override def nullIntolerant: Boolean = false override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("approx_top_k_estimate") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index df6e494885594..6177d260a3b9b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -2975,6 +2975,31 @@ class DataFrameAggregateSuite extends QueryTest parameters = Map("maxItemsTracked" -> "5", "k" -> "10") ) } + + test("SPARK-52588: invalid accumulate if maxItemsTracked is null") { + checkError( + exception = intercept[SparkRuntimeException] { + sql("SELECT approx_top_k_accumulate(expr, NULL) FROM VALUES 0, 1, 2 AS tab(expr);") + .collect() + }, + condition = "APPROX_TOP_K_NULL_ARG", + parameters = Map("argName" -> "`maxItemsTracked`") + ) + } + + test("SPARK-52588null: invalid estimate if k is null") { + sql("SELECT approx_top_k_estimate(approx_top_k_accumulate(expr), NULL) " + + "FROM VALUES 0, 1, 2 AS tab(expr);").show(false) + + checkError( + exception = intercept[SparkRuntimeException] { + sql("SELECT approx_top_k_estimate(approx_top_k_accumulate(expr), NULL) " + + "FROM VALUES 0, 1, 2 AS tab(expr);").collect() + }, + condition = "APPROX_TOP_K_NULL_ARG", + parameters = Map("argName" -> "`k`") + ) + } } case class B(c: Option[Double]) From 03e432def78c6baec39f4477f7f88a0646f3d533 Mon Sep 17 00:00:00 2001 From: yhuang-db Date: Thu, 3 Jul 2025 11:02:34 -0700 Subject: [PATCH 07/23] fix estimate null check --- .../expressions/ApproxTopKExpressions.scala | 37 ++- .../spark/sql/DataFrameAggregateSuite.scala | 210 +++++++++++++++++- 2 files changed, 232 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala index 886b2e2c5a253..f73b5f6c11798 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala @@ -21,11 +21,13 @@ import org.apache.datasketches.memory.Memory import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.FunctionRegistry +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.aggregate.ApproxTopK import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.types._ -case class ApproxTopKEstimate(expr: Expression, k: Expression) +case class ApproxTopKEstimate(state: Expression, k: Expression) extends BinaryExpression with CodegenFallback with ImplicitCastInputTypes { @@ -36,22 +38,37 @@ case class ApproxTopKEstimate(expr: Expression, k: Expression) private lazy val itemDataType: DataType = { // itemDataType is the type of the "ItemTypeNull" field of the output of ACCUMULATE or COMBINE - expr.dataType.asInstanceOf[StructType]("ItemTypeNull").dataType + state.dataType.asInstanceOf[StructType]("ItemTypeNull").dataType } - override def left: Expression = expr + override def left: Expression = state override def right: Expression = k override def inputTypes: Seq[AbstractDataType] = Seq(StructType, IntegerType) + override def checkInputDataTypes(): TypeCheckResult = { + val defaultCheck = super.checkInputDataTypes() + if (defaultCheck.isFailure) { + defaultCheck + } else if (!k.foldable) { + TypeCheckFailure("K must be a constant literal") + } else { + TypeCheckSuccess + } + } + override def dataType: DataType = ApproxTopK.getResultDataType(itemDataType) - override def nullSafeEval(input1: Any, input2: Any): Any = { - val dataSketchBytes = input1.asInstanceOf[InternalRow].getBinary(0) - val maxItemsTrackedVal = input1.asInstanceOf[InternalRow].getInt(2) + override def eval(input: InternalRow): Any = { + // null check ApproxTopK.checkExpressionNotNull(k, "k") - val kVal = input2.asInstanceOf[Int] + // eval + val stateEval = left.eval(input) + val kEval = right.eval(input) + val dataSketchBytes = stateEval.asInstanceOf[InternalRow].getBinary(0) + val maxItemsTrackedVal = stateEval.asInstanceOf[InternalRow].getInt(2) + val kVal = kEval.asInstanceOf[Int] ApproxTopK.checkK(kVal) ApproxTopK.checkMaxItemsTracked(maxItemsTrackedVal, kVal) val itemsSketch = ItemsSketch.getInstance( @@ -59,10 +76,10 @@ case class ApproxTopKEstimate(expr: Expression, k: Expression) ApproxTopK.genEvalResult(itemsSketch, kVal, itemDataType) } - override protected def withNewChildrenInternal(newExpr: Expression, newK: Expression) - : Expression = copy(expr = newExpr, k = newK) + override protected def withNewChildrenInternal(newState: Expression, newK: Expression) + : Expression = copy(state = newState, k = newK) - override def nullIntolerant: Boolean = false + override def nullable: Boolean = false override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("approx_top_k_estimate") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 6177d260a3b9b..51e5e9365e717 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -24,7 +24,7 @@ import scala.util.Random import org.scalatest.matchers.must.Matchers.the -import org.apache.spark.{SparkArithmeticException, SparkRuntimeException} +import org.apache.spark.{SparkArithmeticException, SparkRuntimeException, SparkUnsupportedOperationException} import org.apache.spark.sql.catalyst.ExtendedAnalysisException import org.apache.spark.sql.catalyst.expressions.{Abs, BoundReference, Literal} import org.apache.spark.sql.catalyst.expressions.aggregate.{ApproxTopK, Sum} @@ -2987,10 +2987,7 @@ class DataFrameAggregateSuite extends QueryTest ) } - test("SPARK-52588null: invalid estimate if k is null") { - sql("SELECT approx_top_k_estimate(approx_top_k_accumulate(expr), NULL) " + - "FROM VALUES 0, 1, 2 AS tab(expr);").show(false) - + test("SPARK-52588: invalid estimate if k is null") { checkError( exception = intercept[SparkRuntimeException] { sql("SELECT approx_top_k_estimate(approx_top_k_accumulate(expr), NULL) " + @@ -3000,6 +2997,209 @@ class DataFrameAggregateSuite extends QueryTest parameters = Map("argName" -> "`k`") ) } + + test("SPARK-52588: same type, same size, specified combine size - success") { + val acc1 = sql("SELECT approx_top_k_accumulate(expr, 10) as acc " + + "FROM VALUES (0), (0), (0), (1), (1), (2), (2), (3) AS tab(expr);") + acc1.createOrReplaceTempView("accumulation1") + + val acc2 = sql("SELECT approx_top_k_accumulate(expr, 10) as acc " + + "FROM VALUES (1), (1), (2), (2), (3), (3), (4), (4) AS tab(expr);") + acc2.createOrReplaceTempView("accumulation2") + + val comb = sql("SELECT approx_top_k_combine(acc, 30) as com " + + "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2);") + comb.createOrReplaceTempView("combined") + + val est = sql("SELECT approx_top_k_estimate(com) FROM combined;") + checkAnswer(est, Row(Seq(Row(2, 4), Row(1, 4), Row(0, 3), Row(3, 3), Row(4, 2)))) + } + + test("SPARK-combinedebug: same type, same size, unspecified combine size - success") { + val acc1 = sql("SELECT approx_top_k_accumulate(expr, 10) as acc " + + "FROM VALUES (0), (0), (0), (1), (1), (2), (2), (3) AS tab(expr);") + acc1.createOrReplaceTempView("accumulation1") + + val acc2 = sql("SELECT approx_top_k_accumulate(expr, 10) as acc " + + "FROM VALUES (1), (1), (2), (2), (3), (3), (4), (4) AS tab(expr);") + acc2.createOrReplaceTempView("accumulation2") + + val comb = sql("SELECT approx_top_k_combine(acc) as com " + + "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2);") + comb.createOrReplaceTempView("combined") + + val est = sql("SELECT approx_top_k_estimate(com) FROM combined;") + checkAnswer(est, Row(Seq(Row(2, 4), Row(1, 4), Row(0, 3), Row(3, 3), Row(4, 2)))) + } + + test("SPARK-combine: same type, different size, specified combine size - success") { + val acc1 = sql("SELECT approx_top_k_accumulate(expr, 10) as acc " + + "FROM VALUES (0), (0), (0), (1), (1), (2), (2), (3) AS tab(expr);") + acc1.createOrReplaceTempView("accumulation1") + + val acc2 = sql("SELECT approx_top_k_accumulate(expr, 20) as acc " + + "FROM VALUES (1), (1), (2), (2), (3), (3), (4), (4) AS tab(expr);") + acc2.createOrReplaceTempView("accumulation2") + + val comb = sql("SELECT approx_top_k_combine(acc, 30) as com " + + "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2);") + comb.createOrReplaceTempView("combination") + + val est = sql("SELECT approx_top_k_estimate(com) FROM combination;") + checkAnswer(est, Row(Seq(Row(2, 4), Row(1, 4), Row(0, 3), Row(3, 3), Row(4, 2)))) + } + + test("SPARK-combine: same type, different size, unspecified combine size - fail") { + val acc1 = sql("SELECT approx_top_k_accumulate(expr, 10) as acc " + + "FROM VALUES (0), (0), (0), (1), (1), (2), (2), (3) AS tab(expr);") + acc1.createOrReplaceTempView("accumulation1") + + val acc2 = sql("SELECT approx_top_k_accumulate(expr, 20) as acc " + + "FROM VALUES (1), (1), (2), (2), (3), (3), (4), (4) AS tab(expr);") + acc2.createOrReplaceTempView("accumulation2") + + val comb = sql("SELECT approx_top_k_combine(acc) as com " + + "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2);") + comb.createOrReplaceTempView("combination") + + val est = sql("SELECT approx_top_k_estimate(com) FROM combination;") + + checkError( + exception = intercept[SparkUnsupportedOperationException] { + est.collect() + }, + condition = "APPROX_TOP_K_SKETCH_SIZE_UNMATCHED", + parameters = Map("size1" -> "10", "size2" -> "20") + ) + } + + test("SPARK-combine: different type (int VS string), same size, specified combine size - fail") { + val acc1 = sql("SELECT approx_top_k_accumulate(expr, 10) as acc " + + "FROM VALUES (0), (0), (0), (1), (1), (2), (2), (3) AS tab(expr);") + acc1.createOrReplaceTempView("accumulation1") + + val acc2 = sql("SELECT approx_top_k_accumulate(expr, 10) as acc " + + "FROM VALUES 'a', 'b', 'c', 'c', 'c', 'c', 'd', 'd' AS tab(expr);") + acc2.createOrReplaceTempView("accumulation2") + + val comb = sql("SELECT approx_top_k_combine(acc, 30) as com " + + "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2);") + comb.createOrReplaceTempView("combined") + + val est = sql("SELECT approx_top_k_estimate(com) FROM combined;") + checkError( + exception = intercept[SparkUnsupportedOperationException] { + est.collect() + }, + condition = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED") + } + + test("SPARK-debug1: different type (int VS date), same size, specified combine size - fail") { + val acc1 = sql("SELECT approx_top_k_accumulate(expr, 10) as acc " + + "FROM VALUES (0), (0), (0), (1), (1), (2), (2), (3) AS tab(expr);") + acc1.createOrReplaceTempView("accumulation1") + + val acc2 = sql("SELECT approx_top_k_accumulate(expr, 10) as acc " + + "FROM VALUES cast('2023-01-01' AS DATE), cast('2023-01-01' AS DATE), " + + "cast('2023-01-02' AS DATE), cast('2023-01-02' AS DATE), " + + "cast('2023-01-03' AS DATE), cast('2023-01-04' AS DATE), " + + "cast('2023-01-05' AS DATE), cast('2023-01-05' AS DATE) AS tab(expr);") + acc2.createOrReplaceTempView("accumulation2") + + checkError( + exception = intercept[AnalysisException] { + sql("SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2;") + }, + condition = "INCOMPATIBLE_COLUMN_TYPE", + parameters = Map( + "tableOrdinalNumber" -> "second", + "columnOrdinalNumber" -> "first", + "dataType2" -> ("\"STRUCT\""), + "operator" -> "UNION", + "hint" -> "", + "dataType1" -> ("\"STRUCT\"") + ) + // TODO: what is the query context for the error? + ) + } + + test("SPARK-combine: different type (int VS float), same size, specified combine size - fail") { + val acc1 = sql("SELECT approx_top_k_accumulate(expr, 10) as acc " + + "FROM VALUES (0), (0), (0), (1), (1), (2), (2), (3) AS tab(expr);") + acc1.createOrReplaceTempView("accumulation1") + + val acc2 = sql("SELECT approx_top_k_accumulate(expr, 10) as acc " + + "FROM VALUES cast(0.0 AS FLOAT), cast(0.0 AS FLOAT), " + + "cast(1.0 AS FLOAT), cast(1.0 AS FLOAT), " + + "cast(2.0 AS FLOAT), cast(3.0 AS FLOAT), " + + "cast(4.0 AS FLOAT), cast(4.0 AS FLOAT) AS tab(expr);") + acc2.createOrReplaceTempView("accumulation2") + + val comb = sql("SELECT approx_top_k_combine(acc, 30) as com " + + "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2);") + comb.createOrReplaceTempView("combined") + + val est = sql("SELECT approx_top_k_estimate(com) FROM combined;") + checkError( + exception = intercept[SparkUnsupportedOperationException] { + est.collect() + }, + condition = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED") + } + + test("SPARK-combine: different type (byte VS short), same size, specified combine size - fail") { + val acc1 = sql("SELECT approx_top_k_accumulate(expr, 10) as acc " + + "FROM VALUES cast(0 AS BYTE), cast(0 AS BYTE), cast(1 AS BYTE), " + + "cast(1 AS BYTE), cast(2 AS BYTE), cast(3 AS BYTE), " + + "cast(4 AS BYTE), cast(4 AS BYTE) AS tab(expr);") + acc1.createOrReplaceTempView("accumulation1") + + val acc2 = sql("SELECT approx_top_k_accumulate(expr, 10) as acc " + + "FROM VALUES cast(0 AS SHORT), cast(0 AS SHORT), cast(1 AS SHORT), " + + "cast(1 AS SHORT), cast(2 AS SHORT), cast(3 AS SHORT), " + + "cast(4 AS SHORT), cast(4 AS SHORT) AS tab(expr);") + acc2.createOrReplaceTempView("accumulation2") + + val comb = sql("SELECT approx_top_k_combine(acc, 30) as com " + + "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2);") + comb.createOrReplaceTempView("combined") + + val est = sql("SELECT approx_top_k_estimate(com) FROM combined;") + checkError( + exception = intercept[SparkUnsupportedOperationException] { + est.collect() + }, + condition = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED") + } + + test("SPARK-combine: different type (decimal(10, 2) VS decimal(20, 3)), same size - fail") { + val acc1 = sql("SELECT approx_top_k_accumulate(expr, 10) as acc " + + "FROM VALUES CAST(0.0 AS DECIMAL(10, 2)), CAST(0.0 AS DECIMAL(10, 2)), " + + "CAST(1.0 AS DECIMAL(10, 2)), CAST(1.0 AS DECIMAL(10, 2)), " + + "CAST(2.0 AS DECIMAL(10, 2)), CAST(3.0 AS DECIMAL(10, 2)), " + + "CAST(4.0 AS DECIMAL(10, 2)), CAST(4.0 AS DECIMAL(10, 2)) AS tab(expr);") + acc1.createOrReplaceTempView("accumulation1") + + val acc2 = sql("SELECT approx_top_k_accumulate(expr, 10) as acc " + + "FROM VALUES CAST(0.0 AS DECIMAL(20, 3)), CAST(0.0 AS DECIMAL(20, 3)), " + + "CAST(1.0 AS DECIMAL(20, 3)), CAST(1.0 AS DECIMAL(20, 3)), " + + "CAST(2.0 AS DECIMAL(20, 3)), CAST(3.0 AS DECIMAL(20, 3)), " + + "CAST(4.0 AS DECIMAL(20, 3)), CAST(4.0 AS DECIMAL(20, 3)) AS tab(expr);") + acc2.createOrReplaceTempView("accumulation2") + + val comb = sql("SELECT approx_top_k_combine(acc, 30) as com " + + "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2);") + comb.createOrReplaceTempView("combined") + + val est = sql("SELECT approx_top_k_estimate(com) FROM combined;") + checkError( + exception = intercept[SparkUnsupportedOperationException] { + est.collect() + }, + condition = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED") + } } case class B(c: Option[Double]) From 88cae204e455509a037f50530f685adfdded254c Mon Sep 17 00:00:00 2001 From: yhuang-db Date: Thu, 3 Jul 2025 11:43:13 -0700 Subject: [PATCH 08/23] estimate and accumulate invalid parameter test --- .../spark/sql/DataFrameAggregateSuite.scala | 68 ++++++++++++++++--- 1 file changed, 58 insertions(+), 10 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 51e5e9365e717..4f001e310a3f6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -26,8 +26,8 @@ import org.scalatest.matchers.must.Matchers.the import org.apache.spark.{SparkArithmeticException, SparkRuntimeException, SparkUnsupportedOperationException} import org.apache.spark.sql.catalyst.ExtendedAnalysisException -import org.apache.spark.sql.catalyst.expressions.{Abs, BoundReference, Literal} -import org.apache.spark.sql.catalyst.expressions.aggregate.{ApproxTopK, Sum} +import org.apache.spark.sql.catalyst.expressions.{Abs, ApproxTopKEstimate, BoundReference, Literal} +import org.apache.spark.sql.catalyst.expressions.aggregate.{ApproxTopK, ApproxTopKAccumulate, Sum} import org.apache.spark.sql.catalyst.plans.logical.Expand import org.apache.spark.sql.catalyst.util.AUTO_GENERATED_ALIAS import org.apache.spark.sql.errors.DataTypeErrors.toSQLId @@ -2965,15 +2965,44 @@ class DataFrameAggregateSuite extends QueryTest Row(new java.math.BigDecimal("1.000"), 2)))) } - test("SPARK-52588: invalid estimate if k > maxItemsTracked") { - checkError( - exception = intercept[SparkRuntimeException] { - sql("SELECT approx_top_k_estimate(approx_top_k_accumulate(expr, 5), 10) " + - "FROM VALUES 0, 1, 2 AS tab(expr);").collect() - }, - condition = "APPROX_TOP_K_MAX_ITEMS_TRACKED_LESS_THAN_K", - parameters = Map("maxItemsTracked" -> "5", "k" -> "10") + test("SPARK-52588type: invalid accumulate if item type is not supported") { + Seq( + ArrayType(IntegerType), + MapType(StringType, IntegerType), + StructType(Seq(StructField("a", IntegerType), StructField("b", StringType))), + BinaryType + ).foreach { + unsupportedType => + val badAccumulate = ApproxTopKAccumulate( + expr = BoundReference(0, unsupportedType, nullable = true), + maxItemsTracked = Literal(10) + ) + assert(badAccumulate.checkInputDataTypes().isFailure) + } + } + + test("SPARK-52588: invalid accumulate if maxItemsTracked are not foldable") { + val badAccumulate = ApproxTopKAccumulate( + expr = BoundReference(0, LongType, nullable = true), + maxItemsTracked = Sum(BoundReference(1, LongType, nullable = true)) ) + assert(badAccumulate.checkInputDataTypes().isFailure) + } + + test("SPARK-52588: invalid accumulate if maxItemsTracked less than or equal to 0") { + Seq(0, -1).foreach { invalidInput => + val badAccumulate = ApproxTopKAccumulate( + expr = BoundReference(0, LongType, nullable = true), + maxItemsTracked = Literal(invalidInput) + ) + checkError( + exception = intercept[SparkRuntimeException] { + badAccumulate.createAggregationBuffer() + }, + condition = "APPROX_TOP_K_NON_POSITIVE_ARG", + parameters = Map("argName" -> "`maxItemsTracked`", "argValue" -> invalidInput.toString) + ) + } } test("SPARK-52588: invalid accumulate if maxItemsTracked is null") { @@ -2987,6 +3016,14 @@ class DataFrameAggregateSuite extends QueryTest ) } + test("SPARK-52588: invalid estimate if k are not foldable") { + val badEstimate = ApproxTopKEstimate( + state = BoundReference(0, LongType, nullable = false), + k = Sum(BoundReference(1, LongType, nullable = true)) + ) + assert(badEstimate.checkInputDataTypes().isFailure) + } + test("SPARK-52588: invalid estimate if k is null") { checkError( exception = intercept[SparkRuntimeException] { @@ -2998,6 +3035,17 @@ class DataFrameAggregateSuite extends QueryTest ) } + test("SPARK-52588: invalid estimate if k > maxItemsTracked") { + checkError( + exception = intercept[SparkRuntimeException] { + sql("SELECT approx_top_k_estimate(approx_top_k_accumulate(expr, 5), 10) " + + "FROM VALUES 0, 1, 2 AS tab(expr);").collect() + }, + condition = "APPROX_TOP_K_MAX_ITEMS_TRACKED_LESS_THAN_K", + parameters = Map("maxItemsTracked" -> "5", "k" -> "10") + ) + } + test("SPARK-52588: same type, same size, specified combine size - success") { val acc1 = sql("SELECT approx_top_k_accumulate(expr, 10) as acc " + "FROM VALUES (0), (0), (0), (1), (1), (2), (2), (3) AS tab(expr);") From ccfc661681d6715f5527cf70bf687179d5046327 Mon Sep 17 00:00:00 2001 From: yhuang-db Date: Thu, 3 Jul 2025 12:46:11 -0700 Subject: [PATCH 09/23] remove combine for PR --- .../catalyst/analysis/FunctionRegistry.scala | 1 - .../expressions/ApproxTopKExpressions.scala | 2 +- .../aggregate/ApproxTopKAggregates.scala | 151 ------------- .../spark/sql/DataFrameAggregateSuite.scala | 207 +----------------- 4 files changed, 3 insertions(+), 358 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 89a8e8a9a6a6a..76c3b1d80b294 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -529,7 +529,6 @@ object FunctionRegistry { expression[HllUnionAgg]("hll_union_agg"), expression[ApproxTopK]("approx_top_k"), expression[ApproxTopKAccumulate]("approx_top_k_accumulate"), - expression[ApproxTopKCombine]("approx_top_k_combine"), // string functions expression[Ascii]("ascii"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala index f73b5f6c11798..bccc47711ff13 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala @@ -77,7 +77,7 @@ case class ApproxTopKEstimate(state: Expression, k: Expression) } override protected def withNewChildrenInternal(newState: Expression, newK: Expression) - : Expression = copy(state = newState, k = newK) + : Expression = copy(state = newState, k = newK) override def nullable: Boolean = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala index df9a9ad0760ac..261907bf38635 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala @@ -22,7 +22,6 @@ import org.apache.datasketches.common._ import org.apache.datasketches.frequencies.{ErrorType, ItemsSketch} import org.apache.datasketches.memory.Memory -import org.apache.spark.SparkUnsupportedOperationException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} @@ -179,8 +178,6 @@ object ApproxTopK { val DEFAULT_K: Int = 5 val DEFAULT_MAX_ITEMS_TRACKED: Int = 10000 private val MAX_ITEMS_TRACKED_LIMIT: Int = 1000000 - val VOID_MAX_ITEMS_TRACKED: Int = -1 - val SKETCH_SIZE_PLACEHOLDER: Int = 8 def checkExpressionNotNull(expr: Expression, exprName: String): Unit = { if (expr == null || expr.eval() == null) { @@ -409,151 +406,3 @@ case class ApproxTopKAccumulate( override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("approx_top_k_accumulate") } - -class CombineInternal[T](sketch: ItemsSketch[T], itemDataType: DataType, var maxItemsTracked: Int) { - def getSketch: ItemsSketch[T] = sketch - - def getItemDataType: DataType = itemDataType - - def getMaxItemsTracked: Int = maxItemsTracked - - def setMaxItemsTracked(maxItemsTracked: Int): Unit = this.maxItemsTracked = maxItemsTracked -} - -case class ApproxTopKCombine( - expr: Expression, - maxItemsTracked: Expression, - mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) - extends TypedImperativeAggregate[CombineInternal[Any]] - with ImplicitCastInputTypes - with BinaryLike[Expression] { - - def this(child: Expression, maxItemsTracked: Expression) = - this(child, maxItemsTracked, 0, 0) - - def this(child: Expression, maxItemsTracked: Int) = - this(child, Literal(maxItemsTracked), 0, 0) - - def this(child: Expression) = { - this(child, Literal(ApproxTopK.VOID_MAX_ITEMS_TRACKED), 0, 0) - combineSizeSpecified = false - } - - private var combineSizeSpecified: Boolean = true - - private lazy val itemDataType: DataType = - expr.dataType.asInstanceOf[StructType]("ItemTypeNull").dataType - - private lazy val maxItemsTrackedVal: Int = { - ApproxTopK.checkExpressionNotNull(maxItemsTracked, "maxItemsTracked") - val maxItemsTrackedVal = maxItemsTracked.eval().asInstanceOf[Int] - ApproxTopK.checkMaxItemsTracked(maxItemsTrackedVal) - maxItemsTrackedVal - } - - override def left: Expression = expr - - override def right: Expression = maxItemsTracked - - override def inputTypes: Seq[AbstractDataType] = Seq(StructType, IntegerType) - - override def dataType: DataType = ApproxTopK.getSketchStateDataType(itemDataType) - - override def createAggregationBuffer(): CombineInternal[Any] = { - if (combineSizeSpecified) { - val maxMapSize = ApproxTopK.calMaxMapSize(maxItemsTrackedVal) - new CombineInternal[Any](new ItemsSketch[Any](maxMapSize), itemDataType, maxItemsTrackedVal) - } else { - new CombineInternal[Any]( - new ItemsSketch[Any](ApproxTopK.SKETCH_SIZE_PLACEHOLDER), - itemDataType, - ApproxTopK.VOID_MAX_ITEMS_TRACKED) - } - } - - override def update(buffer: CombineInternal[Any], input: InternalRow): CombineInternal[Any] = { - val inputSketchBytes = expr.eval(input).asInstanceOf[InternalRow].getBinary(0) - val inputMaxItemsTracked = expr.eval(input).asInstanceOf[InternalRow].getInt(2) - val inputSketch = try { - ItemsSketch.getInstance( - Memory.wrap(inputSketchBytes), ApproxTopK.genSketchSerDe(buffer.getItemDataType)) - } catch { - case _: SketchesArgumentException | _: NumberFormatException => - throw new SparkUnsupportedOperationException( - errorClass = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED" - ) - } - buffer.getSketch.merge(inputSketch) - if (!combineSizeSpecified) { - buffer.setMaxItemsTracked(inputMaxItemsTracked) - } - buffer - } - - override def merge(buffer: CombineInternal[Any], input: CombineInternal[Any]) - : CombineInternal[Any] = { - if (!combineSizeSpecified) { - // check size - if (buffer.getMaxItemsTracked == ApproxTopK.VOID_MAX_ITEMS_TRACKED) { - // If buffer is a placeholder sketch, set it to the input sketch's max items tracked - buffer.setMaxItemsTracked(input.getMaxItemsTracked) - } - if (buffer.getMaxItemsTracked != input.getMaxItemsTracked) { - throw new SparkUnsupportedOperationException( - errorClass = "APPROX_TOP_K_SKETCH_SIZE_UNMATCHED", - messageParameters = Map( - "size1" -> buffer.getMaxItemsTracked.toString, - "size2" -> input.getMaxItemsTracked.toString)) - } - } - buffer.getSketch.merge(input.getSketch) - buffer - } - - override def eval(buffer: CombineInternal[Any]): Any = { - val sketchBytes = try { - buffer.getSketch.toByteArray(ApproxTopK.genSketchSerDe(buffer.getItemDataType)) - } catch { - case _: ArrayStoreException => - throw new SparkUnsupportedOperationException( - errorClass = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED" - ) - } - val maxItemsTracked = buffer.getMaxItemsTracked - InternalRow.apply(sketchBytes, null, maxItemsTracked) - } - - override def serialize(buffer: CombineInternal[Any]): Array[Byte] = { - val sketchBytes = buffer.getSketch.toByteArray( - ApproxTopK.genSketchSerDe(buffer.getItemDataType)) - val maxItemsTrackedByte = buffer.getMaxItemsTracked.toByte - val byteArray = new Array[Byte](sketchBytes.length + 1) - byteArray(0) = maxItemsTrackedByte - System.arraycopy(sketchBytes, 0, byteArray, 1, sketchBytes.length) - byteArray - } - - override def deserialize(buffer: Array[Byte]): CombineInternal[Any] = { - val maxItemsTracked = buffer(0).toInt - val sketchBytes = buffer.slice(1, buffer.length) - val sketch = ItemsSketch.getInstance( - Memory.wrap(sketchBytes), ApproxTopK.genSketchSerDe(itemDataType)) - new CombineInternal[Any](sketch, itemDataType, maxItemsTracked) - } - - override protected def withNewChildrenInternal( - newLeft: Expression, - newRight: Expression): Expression = copy(expr = newLeft, maxItemsTracked = newRight) - - override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = - copy(mutableAggBufferOffset = newMutableAggBufferOffset) - - override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = - copy(inputAggBufferOffset = newInputAggBufferOffset) - - override def nullable: Boolean = false - - override def prettyName: String = - getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("approx_top_k_combine") -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 4f001e310a3f6..03f3293747121 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -24,7 +24,7 @@ import scala.util.Random import org.scalatest.matchers.must.Matchers.the -import org.apache.spark.{SparkArithmeticException, SparkRuntimeException, SparkUnsupportedOperationException} +import org.apache.spark.{SparkArithmeticException, SparkRuntimeException} import org.apache.spark.sql.catalyst.ExtendedAnalysisException import org.apache.spark.sql.catalyst.expressions.{Abs, ApproxTopKEstimate, BoundReference, Literal} import org.apache.spark.sql.catalyst.expressions.aggregate.{ApproxTopK, ApproxTopKAccumulate, Sum} @@ -2965,7 +2965,7 @@ class DataFrameAggregateSuite extends QueryTest Row(new java.math.BigDecimal("1.000"), 2)))) } - test("SPARK-52588type: invalid accumulate if item type is not supported") { + test("SPARK-52588: invalid accumulate if item type is not supported") { Seq( ArrayType(IntegerType), MapType(StringType, IntegerType), @@ -3045,209 +3045,6 @@ class DataFrameAggregateSuite extends QueryTest parameters = Map("maxItemsTracked" -> "5", "k" -> "10") ) } - - test("SPARK-52588: same type, same size, specified combine size - success") { - val acc1 = sql("SELECT approx_top_k_accumulate(expr, 10) as acc " + - "FROM VALUES (0), (0), (0), (1), (1), (2), (2), (3) AS tab(expr);") - acc1.createOrReplaceTempView("accumulation1") - - val acc2 = sql("SELECT approx_top_k_accumulate(expr, 10) as acc " + - "FROM VALUES (1), (1), (2), (2), (3), (3), (4), (4) AS tab(expr);") - acc2.createOrReplaceTempView("accumulation2") - - val comb = sql("SELECT approx_top_k_combine(acc, 30) as com " + - "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2);") - comb.createOrReplaceTempView("combined") - - val est = sql("SELECT approx_top_k_estimate(com) FROM combined;") - checkAnswer(est, Row(Seq(Row(2, 4), Row(1, 4), Row(0, 3), Row(3, 3), Row(4, 2)))) - } - - test("SPARK-combinedebug: same type, same size, unspecified combine size - success") { - val acc1 = sql("SELECT approx_top_k_accumulate(expr, 10) as acc " + - "FROM VALUES (0), (0), (0), (1), (1), (2), (2), (3) AS tab(expr);") - acc1.createOrReplaceTempView("accumulation1") - - val acc2 = sql("SELECT approx_top_k_accumulate(expr, 10) as acc " + - "FROM VALUES (1), (1), (2), (2), (3), (3), (4), (4) AS tab(expr);") - acc2.createOrReplaceTempView("accumulation2") - - val comb = sql("SELECT approx_top_k_combine(acc) as com " + - "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2);") - comb.createOrReplaceTempView("combined") - - val est = sql("SELECT approx_top_k_estimate(com) FROM combined;") - checkAnswer(est, Row(Seq(Row(2, 4), Row(1, 4), Row(0, 3), Row(3, 3), Row(4, 2)))) - } - - test("SPARK-combine: same type, different size, specified combine size - success") { - val acc1 = sql("SELECT approx_top_k_accumulate(expr, 10) as acc " + - "FROM VALUES (0), (0), (0), (1), (1), (2), (2), (3) AS tab(expr);") - acc1.createOrReplaceTempView("accumulation1") - - val acc2 = sql("SELECT approx_top_k_accumulate(expr, 20) as acc " + - "FROM VALUES (1), (1), (2), (2), (3), (3), (4), (4) AS tab(expr);") - acc2.createOrReplaceTempView("accumulation2") - - val comb = sql("SELECT approx_top_k_combine(acc, 30) as com " + - "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2);") - comb.createOrReplaceTempView("combination") - - val est = sql("SELECT approx_top_k_estimate(com) FROM combination;") - checkAnswer(est, Row(Seq(Row(2, 4), Row(1, 4), Row(0, 3), Row(3, 3), Row(4, 2)))) - } - - test("SPARK-combine: same type, different size, unspecified combine size - fail") { - val acc1 = sql("SELECT approx_top_k_accumulate(expr, 10) as acc " + - "FROM VALUES (0), (0), (0), (1), (1), (2), (2), (3) AS tab(expr);") - acc1.createOrReplaceTempView("accumulation1") - - val acc2 = sql("SELECT approx_top_k_accumulate(expr, 20) as acc " + - "FROM VALUES (1), (1), (2), (2), (3), (3), (4), (4) AS tab(expr);") - acc2.createOrReplaceTempView("accumulation2") - - val comb = sql("SELECT approx_top_k_combine(acc) as com " + - "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2);") - comb.createOrReplaceTempView("combination") - - val est = sql("SELECT approx_top_k_estimate(com) FROM combination;") - - checkError( - exception = intercept[SparkUnsupportedOperationException] { - est.collect() - }, - condition = "APPROX_TOP_K_SKETCH_SIZE_UNMATCHED", - parameters = Map("size1" -> "10", "size2" -> "20") - ) - } - - test("SPARK-combine: different type (int VS string), same size, specified combine size - fail") { - val acc1 = sql("SELECT approx_top_k_accumulate(expr, 10) as acc " + - "FROM VALUES (0), (0), (0), (1), (1), (2), (2), (3) AS tab(expr);") - acc1.createOrReplaceTempView("accumulation1") - - val acc2 = sql("SELECT approx_top_k_accumulate(expr, 10) as acc " + - "FROM VALUES 'a', 'b', 'c', 'c', 'c', 'c', 'd', 'd' AS tab(expr);") - acc2.createOrReplaceTempView("accumulation2") - - val comb = sql("SELECT approx_top_k_combine(acc, 30) as com " + - "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2);") - comb.createOrReplaceTempView("combined") - - val est = sql("SELECT approx_top_k_estimate(com) FROM combined;") - checkError( - exception = intercept[SparkUnsupportedOperationException] { - est.collect() - }, - condition = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED") - } - - test("SPARK-debug1: different type (int VS date), same size, specified combine size - fail") { - val acc1 = sql("SELECT approx_top_k_accumulate(expr, 10) as acc " + - "FROM VALUES (0), (0), (0), (1), (1), (2), (2), (3) AS tab(expr);") - acc1.createOrReplaceTempView("accumulation1") - - val acc2 = sql("SELECT approx_top_k_accumulate(expr, 10) as acc " + - "FROM VALUES cast('2023-01-01' AS DATE), cast('2023-01-01' AS DATE), " + - "cast('2023-01-02' AS DATE), cast('2023-01-02' AS DATE), " + - "cast('2023-01-03' AS DATE), cast('2023-01-04' AS DATE), " + - "cast('2023-01-05' AS DATE), cast('2023-01-05' AS DATE) AS tab(expr);") - acc2.createOrReplaceTempView("accumulation2") - - checkError( - exception = intercept[AnalysisException] { - sql("SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2;") - }, - condition = "INCOMPATIBLE_COLUMN_TYPE", - parameters = Map( - "tableOrdinalNumber" -> "second", - "columnOrdinalNumber" -> "first", - "dataType2" -> ("\"STRUCT\""), - "operator" -> "UNION", - "hint" -> "", - "dataType1" -> ("\"STRUCT\"") - ) - // TODO: what is the query context for the error? - ) - } - - test("SPARK-combine: different type (int VS float), same size, specified combine size - fail") { - val acc1 = sql("SELECT approx_top_k_accumulate(expr, 10) as acc " + - "FROM VALUES (0), (0), (0), (1), (1), (2), (2), (3) AS tab(expr);") - acc1.createOrReplaceTempView("accumulation1") - - val acc2 = sql("SELECT approx_top_k_accumulate(expr, 10) as acc " + - "FROM VALUES cast(0.0 AS FLOAT), cast(0.0 AS FLOAT), " + - "cast(1.0 AS FLOAT), cast(1.0 AS FLOAT), " + - "cast(2.0 AS FLOAT), cast(3.0 AS FLOAT), " + - "cast(4.0 AS FLOAT), cast(4.0 AS FLOAT) AS tab(expr);") - acc2.createOrReplaceTempView("accumulation2") - - val comb = sql("SELECT approx_top_k_combine(acc, 30) as com " + - "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2);") - comb.createOrReplaceTempView("combined") - - val est = sql("SELECT approx_top_k_estimate(com) FROM combined;") - checkError( - exception = intercept[SparkUnsupportedOperationException] { - est.collect() - }, - condition = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED") - } - - test("SPARK-combine: different type (byte VS short), same size, specified combine size - fail") { - val acc1 = sql("SELECT approx_top_k_accumulate(expr, 10) as acc " + - "FROM VALUES cast(0 AS BYTE), cast(0 AS BYTE), cast(1 AS BYTE), " + - "cast(1 AS BYTE), cast(2 AS BYTE), cast(3 AS BYTE), " + - "cast(4 AS BYTE), cast(4 AS BYTE) AS tab(expr);") - acc1.createOrReplaceTempView("accumulation1") - - val acc2 = sql("SELECT approx_top_k_accumulate(expr, 10) as acc " + - "FROM VALUES cast(0 AS SHORT), cast(0 AS SHORT), cast(1 AS SHORT), " + - "cast(1 AS SHORT), cast(2 AS SHORT), cast(3 AS SHORT), " + - "cast(4 AS SHORT), cast(4 AS SHORT) AS tab(expr);") - acc2.createOrReplaceTempView("accumulation2") - - val comb = sql("SELECT approx_top_k_combine(acc, 30) as com " + - "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2);") - comb.createOrReplaceTempView("combined") - - val est = sql("SELECT approx_top_k_estimate(com) FROM combined;") - checkError( - exception = intercept[SparkUnsupportedOperationException] { - est.collect() - }, - condition = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED") - } - - test("SPARK-combine: different type (decimal(10, 2) VS decimal(20, 3)), same size - fail") { - val acc1 = sql("SELECT approx_top_k_accumulate(expr, 10) as acc " + - "FROM VALUES CAST(0.0 AS DECIMAL(10, 2)), CAST(0.0 AS DECIMAL(10, 2)), " + - "CAST(1.0 AS DECIMAL(10, 2)), CAST(1.0 AS DECIMAL(10, 2)), " + - "CAST(2.0 AS DECIMAL(10, 2)), CAST(3.0 AS DECIMAL(10, 2)), " + - "CAST(4.0 AS DECIMAL(10, 2)), CAST(4.0 AS DECIMAL(10, 2)) AS tab(expr);") - acc1.createOrReplaceTempView("accumulation1") - - val acc2 = sql("SELECT approx_top_k_accumulate(expr, 10) as acc " + - "FROM VALUES CAST(0.0 AS DECIMAL(20, 3)), CAST(0.0 AS DECIMAL(20, 3)), " + - "CAST(1.0 AS DECIMAL(20, 3)), CAST(1.0 AS DECIMAL(20, 3)), " + - "CAST(2.0 AS DECIMAL(20, 3)), CAST(3.0 AS DECIMAL(20, 3)), " + - "CAST(4.0 AS DECIMAL(20, 3)), CAST(4.0 AS DECIMAL(20, 3)) AS tab(expr);") - acc2.createOrReplaceTempView("accumulation2") - - val comb = sql("SELECT approx_top_k_combine(acc, 30) as com " + - "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2);") - comb.createOrReplaceTempView("combined") - - val est = sql("SELECT approx_top_k_estimate(com) FROM combined;") - checkError( - exception = intercept[SparkUnsupportedOperationException] { - est.collect() - }, - condition = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED") - } } case class B(c: Option[Double]) From d89c12131b00f4540982066a5d5981d6c95e7467 Mon Sep 17 00:00:00 2001 From: yhuang-db Date: Thu, 3 Jul 2025 12:53:01 -0700 Subject: [PATCH 10/23] remove combine for PR --- .../src/main/resources/error/error-conditions.json | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 12f6d9f7b0e03..000b1f524f207 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -114,18 +114,6 @@ ], "sqlState" : "22004" }, - "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED" : { - "message" : [ - "Failed to deserialize from memory. Probably due to a mismatch in sketch types. Combining approx_top_k sketches of different types is not allowed." - ], - "sqlState": "42846" - }, - "APPROX_TOP_K_SKETCH_SIZE_UNMATCHED" : { - "message" : [ - "Combining approx_top_k sketches of different sizes is not allowed. Found sketches of size and ." - ], - "sqlState": "42846" - }, "ARITHMETIC_OVERFLOW" : { "message" : [ ". If necessary set to \"false\" to bypass this error." From b95ff7a820993059e901a6d25babf3a938f1e56c Mon Sep 17 00:00:00 2001 From: yhuang-db Date: Thu, 3 Jul 2025 13:18:58 -0700 Subject: [PATCH 11/23] separate expression suite and query suite --- .../aggregate/ApproxTopKSuite.scala | 98 +++++++++++++++++++ .../spark/sql/DataFrameAggregateSuite.scala | 75 -------------- 2 files changed, 98 insertions(+), 75 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKSuite.scala diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKSuite.scala new file mode 100644 index 0000000000000..70d9aa69aa3f5 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKSuite.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.{SparkFunSuite, SparkRuntimeException} +import org.apache.spark.sql.catalyst.expressions.{Abs, ApproxTopKEstimate, BoundReference, Literal} +import org.apache.spark.sql.types.{ArrayType, BinaryType, IntegerType, LongType, MapType, StringType, StructField, StructType} + +class ApproxTopKSuite extends SparkFunSuite { + + test("SPARK-52515: Accepts literal and foldable inputs") { + val agg = new ApproxTopK( + expr = BoundReference(0, LongType, nullable = true), + k = Abs(Literal(10)), + maxItemsTracked = Abs(Literal(-10)) + ) + assert(agg.checkInputDataTypes().isSuccess) + } + + test("SPARK-52515: Fail if parameters are not foldable") { + val badAgg = new ApproxTopK( + expr = BoundReference(0, LongType, nullable = true), + k = Sum(BoundReference(1, LongType, nullable = true)), + maxItemsTracked = Literal(10) + ) + assert(badAgg.checkInputDataTypes().isFailure) + + val badAgg2 = new ApproxTopK( + expr = BoundReference(0, LongType, nullable = true), + k = Literal(10), + maxItemsTracked = Sum(BoundReference(1, LongType, nullable = true)) + ) + assert(badAgg2.checkInputDataTypes().isFailure) + } + + test("SPARK-52588: invalid accumulate if item type is not supported") { + Seq( + ArrayType(IntegerType), + MapType(StringType, IntegerType), + StructType(Seq(StructField("a", IntegerType), StructField("b", StringType))), + BinaryType + ).foreach { + unsupportedType => + val badAccumulate = ApproxTopKAccumulate( + expr = BoundReference(0, unsupportedType, nullable = true), + maxItemsTracked = Literal(10) + ) + assert(badAccumulate.checkInputDataTypes().isFailure) + } + } + + test("SPARK-52588: invalid accumulate if maxItemsTracked are not foldable") { + val badAccumulate = ApproxTopKAccumulate( + expr = BoundReference(0, LongType, nullable = true), + maxItemsTracked = Sum(BoundReference(1, LongType, nullable = true)) + ) + assert(badAccumulate.checkInputDataTypes().isFailure) + } + + test("SPARK-52588: invalid accumulate if maxItemsTracked less than or equal to 0") { + Seq(0, -1).foreach { invalidInput => + val badAccumulate = ApproxTopKAccumulate( + expr = BoundReference(0, LongType, nullable = true), + maxItemsTracked = Literal(invalidInput) + ) + checkError( + exception = intercept[SparkRuntimeException] { + badAccumulate.createAggregationBuffer() + }, + condition = "APPROX_TOP_K_NON_POSITIVE_ARG", + parameters = Map("argName" -> "`maxItemsTracked`", "argValue" -> invalidInput.toString) + ) + } + } + + test("SPARK-52588: invalid estimate if k are not foldable") { + val badEstimate = ApproxTopKEstimate( + state = BoundReference(0, LongType, nullable = false), + k = Sum(BoundReference(1, LongType, nullable = true)) + ) + assert(badEstimate.checkInputDataTypes().isFailure) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 03f3293747121..9002e20428dcc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -26,8 +26,6 @@ import org.scalatest.matchers.must.Matchers.the import org.apache.spark.{SparkArithmeticException, SparkRuntimeException} import org.apache.spark.sql.catalyst.ExtendedAnalysisException -import org.apache.spark.sql.catalyst.expressions.{Abs, ApproxTopKEstimate, BoundReference, Literal} -import org.apache.spark.sql.catalyst.expressions.aggregate.{ApproxTopK, ApproxTopKAccumulate, Sum} import org.apache.spark.sql.catalyst.plans.logical.Expand import org.apache.spark.sql.catalyst.util.AUTO_GENERATED_ALIAS import org.apache.spark.sql.errors.DataTypeErrors.toSQLId @@ -2880,31 +2878,6 @@ class DataFrameAggregateSuite extends QueryTest checkAnswer(res, Row(Seq(Row("b", 3), Row("a", 2)))) } - test("SPARK-52515: Accepts literal and foldable inputs") { - val agg = new ApproxTopK( - expr = BoundReference(0, LongType, nullable = true), - k = Abs(Literal(10)), - maxItemsTracked = Abs(Literal(-10)) - ) - assert(agg.checkInputDataTypes().isSuccess) - } - - test("SPARK-52515: Fail if parameters are not foldable") { - val badAgg = new ApproxTopK( - expr = BoundReference(0, LongType, nullable = true), - k = Sum(BoundReference(1, LongType, nullable = true)), - maxItemsTracked = Literal(10) - ) - assert(badAgg.checkInputDataTypes().isFailure) - - val badAgg2 = new ApproxTopK( - expr = BoundReference(0, LongType, nullable = true), - k = Literal(10), - maxItemsTracked = Sum(BoundReference(1, LongType, nullable = true)) - ) - assert(badAgg2.checkInputDataTypes().isFailure) - } - test("SPARK-52626: Support group by Time column") { val ts1 = "15:00:00" val ts2 = "22:00:00" @@ -2965,46 +2938,6 @@ class DataFrameAggregateSuite extends QueryTest Row(new java.math.BigDecimal("1.000"), 2)))) } - test("SPARK-52588: invalid accumulate if item type is not supported") { - Seq( - ArrayType(IntegerType), - MapType(StringType, IntegerType), - StructType(Seq(StructField("a", IntegerType), StructField("b", StringType))), - BinaryType - ).foreach { - unsupportedType => - val badAccumulate = ApproxTopKAccumulate( - expr = BoundReference(0, unsupportedType, nullable = true), - maxItemsTracked = Literal(10) - ) - assert(badAccumulate.checkInputDataTypes().isFailure) - } - } - - test("SPARK-52588: invalid accumulate if maxItemsTracked are not foldable") { - val badAccumulate = ApproxTopKAccumulate( - expr = BoundReference(0, LongType, nullable = true), - maxItemsTracked = Sum(BoundReference(1, LongType, nullable = true)) - ) - assert(badAccumulate.checkInputDataTypes().isFailure) - } - - test("SPARK-52588: invalid accumulate if maxItemsTracked less than or equal to 0") { - Seq(0, -1).foreach { invalidInput => - val badAccumulate = ApproxTopKAccumulate( - expr = BoundReference(0, LongType, nullable = true), - maxItemsTracked = Literal(invalidInput) - ) - checkError( - exception = intercept[SparkRuntimeException] { - badAccumulate.createAggregationBuffer() - }, - condition = "APPROX_TOP_K_NON_POSITIVE_ARG", - parameters = Map("argName" -> "`maxItemsTracked`", "argValue" -> invalidInput.toString) - ) - } - } - test("SPARK-52588: invalid accumulate if maxItemsTracked is null") { checkError( exception = intercept[SparkRuntimeException] { @@ -3016,14 +2949,6 @@ class DataFrameAggregateSuite extends QueryTest ) } - test("SPARK-52588: invalid estimate if k are not foldable") { - val badEstimate = ApproxTopKEstimate( - state = BoundReference(0, LongType, nullable = false), - k = Sum(BoundReference(1, LongType, nullable = true)) - ) - assert(badEstimate.checkInputDataTypes().isFailure) - } - test("SPARK-52588: invalid estimate if k is null") { checkError( exception = intercept[SparkRuntimeException] { From edaf18e8d7632a89c5c1eac6d1e35c6a48c4c9a1 Mon Sep 17 00:00:00 2001 From: yhuang-db Date: Thu, 3 Jul 2025 13:32:29 -0700 Subject: [PATCH 12/23] add expression doc --- .../expressions/ApproxTopKExpressions.scala | 30 ++++++++++++++++++- .../aggregate/ApproxTopKAggregates.scala | 8 ++--- 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala index bccc47711ff13..342a261d77738 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala @@ -27,6 +27,34 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.ApproxTopK import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.types._ +/** + * An expression that estimates the top K items from a sketch. + * + * The input is a sketch state that is generated by the ApproxTopKAccumulation function. + * The output is an array of structs, each containing a frequent item and its estimated frequency. + * The items are sorted by their estimated frequency in descending order. + * + * @param state The sketch state, which is a struct containing the serialized sketch data, + * the original data type and the max items tracked of the sketch. + * @param k The number of top items to estimate. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(state, k) - Returns top k items with their frequency. + `k` An optional INTEGER literal greater than 0. If k is not specified, it defaults to 5. + """, + examples = """ + Examples: + > SELECT _FUNC_(approx_top_k_accumulate(expr)) FROM VALUES (0), (0), (1), (1), (2), (3), (4), (4) AS tab(expr); + [{"item":0,"count":2},{"item":4,"count":2},{"item":1,"count":2},{"item":2,"count":1},{"item":3,"count":1}] + + > SELECT _FUNC_(approx_top_k_accumulate(expr), 2) FROM VALUES 'a', 'b', 'c', 'c', 'c', 'c', 'd', 'd' tab(expr); + [{"item":"c","count":4},{"item":"d","count":2}] + """, + group = "misc_funcs", + since = "4.1.0") +// scalastyle:on line.size.limit case class ApproxTopKEstimate(state: Expression, k: Expression) extends BinaryExpression with CodegenFallback @@ -77,7 +105,7 @@ case class ApproxTopKEstimate(state: Expression, k: Expression) } override protected def withNewChildrenInternal(newState: Expression, newK: Expression) - : Expression = copy(state = newState, k = newK) + : Expression = copy(state = newState, k = newK) override def nullable: Boolean = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala index 261907bf38635..377df75455537 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala @@ -50,14 +50,12 @@ import org.apache.spark.unsafe.types.UTF8String */ // scalastyle:off line.size.limit @ExpressionDescription( - usage = - """ + usage = """ _FUNC_(expr, k, maxItemsTracked) - Returns top k items with their frequency. `k` An optional INTEGER literal greater than 0. If k is not specified, it defaults to 5. `maxItemsTracked` An optional INTEGER literal greater than or equal to k. If maxItemsTracked is not specified, it defaults to 10000. - """, - examples = - """ + """, + examples = """ Examples: > SELECT _FUNC_(expr) FROM VALUES (0), (0), (1), (1), (2), (3), (4), (4) AS tab(expr); [{"item":0,"count":2},{"item":4,"count":2},{"item":1,"count":2},{"item":2,"count":1},{"item":3,"count":1}] From b3811a7df6c93ef6a9686fd7d13a2ceb183cdb1f Mon Sep 17 00:00:00 2001 From: yhuang-db Date: Thu, 3 Jul 2025 13:47:24 -0700 Subject: [PATCH 13/23] add accumulation doc --- .../aggregate/ApproxTopKAggregates.scala | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala index 377df75455537..fa2f7114062c1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala @@ -321,6 +321,36 @@ object ApproxTopK { StructField("MaxItemsTracked", IntegerType, nullable = false) :: Nil) } +/** + * An aggregate function that accumulates items into a sketch, which can then be used + * to combine with other sketches, via ApproxTopKCombine, + * or to estimate the top K items, via ApproxTopKEstimate. + * + * The output of this function is a struct containing the sketch in binary format, + * a null object indicating the type of items in the sketch, + * and the maximum number of items tracked by the sketch. + * + * @param expr the child expression to accumulate items from + * @param maxItemsTracked the maximum number of items to track in the sketch + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(expr, maxItemsTracked) - Accumulates items into a sketch. + `maxItemsTracked` An optional positive INTEGER literal with upper limit of 1000000. + If maxItemsTracked is not specified, it defaults to 10000. + """, + examples = """ + Examples: + > SELECT approx_top_k_accumulate(_FUNC_(expr)) FROM VALUES (0), (0), (1), (1), (2), (3), (4), (4) AS tab(expr); + [{"item":0,"count":2},{"item":4,"count":2},{"item":1,"count":2},{"item":2,"count":1},{"item":3,"count":1}] + + > SELECT approx_top_k_accumulate(_FUNC_(expr, 100), 2) FROM VALUES 'a', 'b', 'c', 'c', 'c', 'c', 'd', 'd' AS tab(expr); + [{"item":"c","count":4},{"item":"d","count":2}] + """, + group = "agg_funcs", + since = "4.1.0") +// scalastyle:on line.size.limit case class ApproxTopKAccumulate( expr: Expression, maxItemsTracked: Expression, From 7e1e5198b5ecd588f94d58cf7fa38b2c6de7bbc9 Mon Sep 17 00:00:00 2001 From: yhuang-db Date: Thu, 3 Jul 2025 13:49:37 -0700 Subject: [PATCH 14/23] nit doc --- .../expressions/aggregate/ApproxTopKAggregates.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala index fa2f7114062c1..4ebc194f354d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala @@ -53,7 +53,7 @@ import org.apache.spark.unsafe.types.UTF8String usage = """ _FUNC_(expr, k, maxItemsTracked) - Returns top k items with their frequency. `k` An optional INTEGER literal greater than 0. If k is not specified, it defaults to 5. - `maxItemsTracked` An optional INTEGER literal greater than or equal to k. If maxItemsTracked is not specified, it defaults to 10000. + `maxItemsTracked` An optional INTEGER literal greater than or equal to k and has upper limit of 1000000. If maxItemsTracked is not specified, it defaults to 10000. """, examples = """ Examples: @@ -337,8 +337,7 @@ object ApproxTopK { @ExpressionDescription( usage = """ _FUNC_(expr, maxItemsTracked) - Accumulates items into a sketch. - `maxItemsTracked` An optional positive INTEGER literal with upper limit of 1000000. - If maxItemsTracked is not specified, it defaults to 10000. + `maxItemsTracked` An optional positive INTEGER literal with upper limit of 1000000. If maxItemsTracked is not specified, it defaults to 10000. """, examples = """ Examples: From a9153aad84853d70320f20b01a3586e7026858ba Mon Sep 17 00:00:00 2001 From: yhuang-db Date: Thu, 3 Jul 2025 17:25:17 -0700 Subject: [PATCH 15/23] update expression type check test --- .../aggregate/ApproxTopKSuite.scala | 64 +++++++++++++++---- 1 file changed, 53 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKSuite.scala index 70d9aa69aa3f5..00838026f4f9b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKSuite.scala @@ -19,13 +19,13 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.{SparkFunSuite, SparkRuntimeException} import org.apache.spark.sql.catalyst.expressions.{Abs, ApproxTopKEstimate, BoundReference, Literal} -import org.apache.spark.sql.types.{ArrayType, BinaryType, IntegerType, LongType, MapType, StringType, StructField, StructType} +import org.apache.spark.sql.types.{ArrayType, BinaryType, IntegerType, MapType, StringType, StructField, StructType} class ApproxTopKSuite extends SparkFunSuite { test("SPARK-52515: Accepts literal and foldable inputs") { val agg = new ApproxTopK( - expr = BoundReference(0, LongType, nullable = true), + expr = BoundReference(0, IntegerType, nullable = true), k = Abs(Literal(10)), maxItemsTracked = Abs(Literal(-10)) ) @@ -34,20 +34,62 @@ class ApproxTopKSuite extends SparkFunSuite { test("SPARK-52515: Fail if parameters are not foldable") { val badAgg = new ApproxTopK( - expr = BoundReference(0, LongType, nullable = true), - k = Sum(BoundReference(1, LongType, nullable = true)), + expr = BoundReference(0, IntegerType, nullable = true), + k = Sum(BoundReference(1, IntegerType, nullable = true)), maxItemsTracked = Literal(10) ) assert(badAgg.checkInputDataTypes().isFailure) val badAgg2 = new ApproxTopK( - expr = BoundReference(0, LongType, nullable = true), + expr = BoundReference(0, IntegerType, nullable = true), k = Literal(10), - maxItemsTracked = Sum(BoundReference(1, LongType, nullable = true)) + maxItemsTracked = Sum(BoundReference(1, IntegerType, nullable = true)) ) assert(badAgg2.checkInputDataTypes().isFailure) } + test("SPARK-52515: invalid item types") { + Seq( + ArrayType(IntegerType), + MapType(StringType, IntegerType), + StructType(Seq(StructField("a", IntegerType), StructField("b", StringType))), + BinaryType + ).foreach { unsupportedType => + val agg = new ApproxTopK( + expr = BoundReference(0, unsupportedType, nullable = true), + k = Literal(10), + maxItemsTracked = Literal(10000) + ) + assert(agg.checkInputDataTypes().isFailure) + } + } + + test("SPARK-52515: invalid k types") { + val invalidNumberTypes: Seq[Any] = Seq( + 10.0, 10.0f, BigDecimal("10.0"), 10.toByte, 10.toShort, 10L, 2147483648L, true, "10") + invalidNumberTypes.foreach { invalidK => + val agg = new ApproxTopK( + expr = BoundReference(0, IntegerType, nullable = true), + k = Literal(invalidK), + maxItemsTracked = Literal(10000) + ) + assert(agg.checkInputDataTypes().isFailure) + } + } + + test("SPARK-52515: invalid maxItemsTracked types") { + val invalidNumberTypes: Seq[Any] = Seq( + 10.0, 10.0f, BigDecimal("10.0"), 10.toByte, 10.toShort, 10L, 2147483648L, true, "10") + invalidNumberTypes.foreach { invalidMaxItems => + val agg = new ApproxTopK( + expr = BoundReference(0, IntegerType, nullable = true), + k = Literal(10), + maxItemsTracked = Literal(invalidMaxItems) + ) + assert(agg.checkInputDataTypes().isFailure) + } + } + test("SPARK-52588: invalid accumulate if item type is not supported") { Seq( ArrayType(IntegerType), @@ -66,8 +108,8 @@ class ApproxTopKSuite extends SparkFunSuite { test("SPARK-52588: invalid accumulate if maxItemsTracked are not foldable") { val badAccumulate = ApproxTopKAccumulate( - expr = BoundReference(0, LongType, nullable = true), - maxItemsTracked = Sum(BoundReference(1, LongType, nullable = true)) + expr = BoundReference(0, IntegerType, nullable = true), + maxItemsTracked = Sum(BoundReference(1, IntegerType, nullable = true)) ) assert(badAccumulate.checkInputDataTypes().isFailure) } @@ -75,7 +117,7 @@ class ApproxTopKSuite extends SparkFunSuite { test("SPARK-52588: invalid accumulate if maxItemsTracked less than or equal to 0") { Seq(0, -1).foreach { invalidInput => val badAccumulate = ApproxTopKAccumulate( - expr = BoundReference(0, LongType, nullable = true), + expr = BoundReference(0, IntegerType, nullable = true), maxItemsTracked = Literal(invalidInput) ) checkError( @@ -90,8 +132,8 @@ class ApproxTopKSuite extends SparkFunSuite { test("SPARK-52588: invalid estimate if k are not foldable") { val badEstimate = ApproxTopKEstimate( - state = BoundReference(0, LongType, nullable = false), - k = Sum(BoundReference(1, LongType, nullable = true)) + state = BoundReference(0, IntegerType, nullable = false), + k = Sum(BoundReference(1, IntegerType, nullable = true)) ) assert(badEstimate.checkInputDataTypes().isFailure) } From ae6cc81308526ed13a84400813a2aee124dc178c Mon Sep 17 00:00:00 2001 From: yhuang-db Date: Thu, 3 Jul 2025 17:34:50 -0700 Subject: [PATCH 16/23] remove k and max type check test --- .../aggregate/ApproxTopKSuite.scala | 26 ------------------- 1 file changed, 26 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKSuite.scala index 00838026f4f9b..14672dbc7f207 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKSuite.scala @@ -64,32 +64,6 @@ class ApproxTopKSuite extends SparkFunSuite { } } - test("SPARK-52515: invalid k types") { - val invalidNumberTypes: Seq[Any] = Seq( - 10.0, 10.0f, BigDecimal("10.0"), 10.toByte, 10.toShort, 10L, 2147483648L, true, "10") - invalidNumberTypes.foreach { invalidK => - val agg = new ApproxTopK( - expr = BoundReference(0, IntegerType, nullable = true), - k = Literal(invalidK), - maxItemsTracked = Literal(10000) - ) - assert(agg.checkInputDataTypes().isFailure) - } - } - - test("SPARK-52515: invalid maxItemsTracked types") { - val invalidNumberTypes: Seq[Any] = Seq( - 10.0, 10.0f, BigDecimal("10.0"), 10.toByte, 10.toShort, 10L, 2147483648L, true, "10") - invalidNumberTypes.foreach { invalidMaxItems => - val agg = new ApproxTopK( - expr = BoundReference(0, IntegerType, nullable = true), - k = Literal(10), - maxItemsTracked = Literal(invalidMaxItems) - ) - assert(agg.checkInputDataTypes().isFailure) - } - } - test("SPARK-52588: invalid accumulate if item type is not supported") { Seq( ArrayType(IntegerType), From d60702d02e43bc2c2e761ecf8552138eb553a215 Mon Sep 17 00:00:00 2001 From: yhuang-db Date: Thu, 3 Jul 2025 17:57:04 -0700 Subject: [PATCH 17/23] add upper limit test for accumulate --- .../org/apache/spark/sql/DataFrameAggregateSuite.scala | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 9002e20428dcc..22f402c27f62e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -2949,6 +2949,16 @@ class DataFrameAggregateSuite extends QueryTest ) } + test("SPARK-52588: invalid accumulate if maxItemsTracked > 1000000") { + checkError( + exception = intercept[SparkRuntimeException] { + sql("SELECT approx_top_k_accumulate(expr, 1000001) FROM VALUES (0) AS tab(expr);").collect() + }, + condition = "APPROX_TOP_K_MAX_ITEMS_TRACKED_EXCEEDS_LIMIT", + parameters = Map("maxItemsTracked" -> "1000001", "limit" -> "1000000") + ) + } + test("SPARK-52588: invalid estimate if k is null") { checkError( exception = intercept[SparkRuntimeException] { From b616e7c479e99d58ad980b6fc9c58543efb576f8 Mon Sep 17 00:00:00 2001 From: yhuang-db Date: Thu, 3 Jul 2025 18:02:05 -0700 Subject: [PATCH 18/23] add invalid value tests --- .../spark/sql/DataFrameAggregateSuite.scala | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 22f402c27f62e..0f4e4be1bd960 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -2970,6 +2970,36 @@ class DataFrameAggregateSuite extends QueryTest ) } + test("SPARK-52588: invalid estimate if k is invalid") { + checkError( + exception = intercept[SparkRuntimeException] { + sql("SELECT approx_top_k_estimate(approx_top_k_accumulate(expr), 0) " + + "FROM VALUES 0, 1, 2 AS tab(expr);").collect() + }, + condition = "APPROX_TOP_K_NON_POSITIVE_ARG", + parameters = Map("argName" -> "`k`", "argValue" -> "0") + ) + } + + test("SPARK-52588: invalid estimate if k > Int.MaxValue") { + withSQLConf("spark.sql.ansi.enabled" -> true.toString) { + val k: Long = Int.MaxValue + 1L + checkError( + exception = intercept[SparkArithmeticException] { + sql(s"SELECT approx_top_k_estimate(approx_top_k_accumulate(expr), $k) " + + "FROM VALUES 0, 1, 2 AS tab(expr);").collect() + }, + condition = "CAST_OVERFLOW", + parameters = Map( + "value" -> (k.toString + "L"), + "sourceType" -> "\"BIGINT\"", + "targetType" -> "\"INT\"", + "ansiConfig" -> "\"spark.sql.ansi.enabled\"" + ) + ) + } + } + test("SPARK-52588: invalid estimate if k > maxItemsTracked") { checkError( exception = intercept[SparkRuntimeException] { From 1c006051f0737120c33956db2aa17e039466161e Mon Sep 17 00:00:00 2001 From: yhuang-db Date: Thu, 3 Jul 2025 16:51:11 -0700 Subject: [PATCH 19/23] combine implement and test --- .../resources/error/error-conditions.json | 12 ++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../aggregate/ApproxTopKAggregates.scala | 164 +++++++++++++++++- .../aggregate/ApproxTopKSuite.scala | 8 + .../spark/sql/DataFrameAggregateSuite.scala | 76 +++++++- 5 files changed, 256 insertions(+), 5 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 000b1f524f207..12f6d9f7b0e03 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -114,6 +114,18 @@ ], "sqlState" : "22004" }, + "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED" : { + "message" : [ + "Failed to deserialize from memory. Probably due to a mismatch in sketch types. Combining approx_top_k sketches of different types is not allowed." + ], + "sqlState": "42846" + }, + "APPROX_TOP_K_SKETCH_SIZE_UNMATCHED" : { + "message" : [ + "Combining approx_top_k sketches of different sizes is not allowed. Found sketches of size and ." + ], + "sqlState": "42846" + }, "ARITHMETIC_OVERFLOW" : { "message" : [ ". If necessary set to \"false\" to bypass this error." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 76c3b1d80b294..89a8e8a9a6a6a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -529,6 +529,7 @@ object FunctionRegistry { expression[HllUnionAgg]("hll_union_agg"), expression[ApproxTopK]("approx_top_k"), expression[ApproxTopKAccumulate]("approx_top_k_accumulate"), + expression[ApproxTopKCombine]("approx_top_k_combine"), // string functions expression[Ascii]("ascii"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala index 4ebc194f354d2..05bd79f92c577 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala @@ -22,6 +22,7 @@ import org.apache.datasketches.common._ import org.apache.datasketches.frequencies.{ErrorType, ItemsSketch} import org.apache.datasketches.memory.Memory +import org.apache.spark.SparkUnsupportedOperationException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} @@ -50,12 +51,14 @@ import org.apache.spark.unsafe.types.UTF8String */ // scalastyle:off line.size.limit @ExpressionDescription( - usage = """ + usage = + """ _FUNC_(expr, k, maxItemsTracked) - Returns top k items with their frequency. `k` An optional INTEGER literal greater than 0. If k is not specified, it defaults to 5. `maxItemsTracked` An optional INTEGER literal greater than or equal to k and has upper limit of 1000000. If maxItemsTracked is not specified, it defaults to 10000. """, - examples = """ + examples = + """ Examples: > SELECT _FUNC_(expr) FROM VALUES (0), (0), (1), (1), (2), (3), (4), (4) AS tab(expr); [{"item":0,"count":2},{"item":4,"count":2},{"item":1,"count":2},{"item":2,"count":1},{"item":3,"count":1}] @@ -176,6 +179,8 @@ object ApproxTopK { val DEFAULT_K: Int = 5 val DEFAULT_MAX_ITEMS_TRACKED: Int = 10000 private val MAX_ITEMS_TRACKED_LIMIT: Int = 1000000 + val VOID_MAX_ITEMS_TRACKED = -1 + val SKETCH_SIZE_PLACEHOLDER = 8 def checkExpressionNotNull(expr: Expression, exprName: String): Unit = { if (expr == null || expr.eval() == null) { @@ -335,11 +340,13 @@ object ApproxTopK { */ // scalastyle:off line.size.limit @ExpressionDescription( - usage = """ + usage = + """ _FUNC_(expr, maxItemsTracked) - Accumulates items into a sketch. `maxItemsTracked` An optional positive INTEGER literal with upper limit of 1000000. If maxItemsTracked is not specified, it defaults to 10000. """, - examples = """ + examples = + """ Examples: > SELECT approx_top_k_accumulate(_FUNC_(expr)) FROM VALUES (0), (0), (1), (1), (2), (3), (4), (4) AS tab(expr); [{"item":0,"count":2},{"item":4,"count":2},{"item":1,"count":2},{"item":2,"count":1},{"item":3,"count":1}] @@ -433,3 +440,152 @@ case class ApproxTopKAccumulate( override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("approx_top_k_accumulate") } + +class CombineInternal[T](sketch: ItemsSketch[T], itemDataType: DataType, var maxItemsTracked: Int) { + def getSketch: ItemsSketch[T] = sketch + + def getItemDataType: DataType = itemDataType + + def getMaxItemsTracked: Int = maxItemsTracked + + def setMaxItemsTracked(maxItemsTracked: Int): Unit = this.maxItemsTracked = maxItemsTracked +} + +case class ApproxTopKCombine( + expr: Expression, + maxItemsTracked: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) + extends TypedImperativeAggregate[CombineInternal[Any]] + with ImplicitCastInputTypes + with BinaryLike[Expression] { + + def this(child: Expression, maxItemsTracked: Expression) = { + this(child, maxItemsTracked, 0, 0) + ApproxTopK.checkExpressionNotNull(maxItemsTracked, "maxItemsTracked") + ApproxTopK.checkMaxItemsTracked(maxItemsTrackedVal) + } + + def this(child: Expression, maxItemsTracked: Int) = this(child, Literal(maxItemsTracked)) + + def this(child: Expression) = this(child, Literal(ApproxTopK.VOID_MAX_ITEMS_TRACKED), 0, 0) + + private lazy val itemDataType: DataType = + expr.dataType.asInstanceOf[StructType]("ItemTypeNull").dataType + private lazy val maxItemsTrackedVal: Int = maxItemsTracked.eval().asInstanceOf[Int] + private lazy val combineSizeSpecified: Boolean = + maxItemsTrackedVal != ApproxTopK.VOID_MAX_ITEMS_TRACKED + + override def left: Expression = expr + + override def right: Expression = maxItemsTracked + + override def inputTypes: Seq[AbstractDataType] = Seq(StructType, IntegerType) + + override def checkInputDataTypes(): TypeCheckResult = { + val defaultCheck = super.checkInputDataTypes() + if (defaultCheck.isFailure) { + defaultCheck + } else if (!maxItemsTracked.foldable) { + TypeCheckFailure("Number of items tracked must be a constant literal") + } else { + TypeCheckSuccess + } + } + + override def dataType: DataType = ApproxTopK.getSketchStateDataType(itemDataType) + + override def createAggregationBuffer(): CombineInternal[Any] = { + if (combineSizeSpecified) { + val maxMapSize = ApproxTopK.calMaxMapSize(maxItemsTrackedVal) + new CombineInternal[Any](new ItemsSketch[Any](maxMapSize), itemDataType, maxItemsTrackedVal) + } else { + new CombineInternal[Any]( + new ItemsSketch[Any](ApproxTopK.SKETCH_SIZE_PLACEHOLDER), + itemDataType, + ApproxTopK.VOID_MAX_ITEMS_TRACKED) + } + } + + override def update(buffer: CombineInternal[Any], input: InternalRow): CombineInternal[Any] = { + val inputSketchBytes = expr.eval(input).asInstanceOf[InternalRow].getBinary(0) + val inputMaxItemsTracked = expr.eval(input).asInstanceOf[InternalRow].getInt(2) + val inputSketch = try { + ItemsSketch.getInstance( + Memory.wrap(inputSketchBytes), ApproxTopK.genSketchSerDe(buffer.getItemDataType)) + } catch { + case _: SketchesArgumentException | _: NumberFormatException => + throw new SparkUnsupportedOperationException( + errorClass = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED" + ) + } + buffer.getSketch.merge(inputSketch) + if (!combineSizeSpecified) { + buffer.setMaxItemsTracked(inputMaxItemsTracked) + } + buffer + } + + override def merge(buffer: CombineInternal[Any], input: CombineInternal[Any]) + : CombineInternal[Any] = { + if (!combineSizeSpecified) { + // check size + if (buffer.getMaxItemsTracked == ApproxTopK.VOID_MAX_ITEMS_TRACKED) { + // If buffer is a placeholder sketch, set it to the input sketch's max items tracked + buffer.setMaxItemsTracked(input.getMaxItemsTracked) + } + if (buffer.getMaxItemsTracked != input.getMaxItemsTracked) { + throw new SparkUnsupportedOperationException( + errorClass = "APPROX_TOP_K_SKETCH_SIZE_UNMATCHED", + messageParameters = Map( + "size1" -> buffer.getMaxItemsTracked.toString, + "size2" -> input.getMaxItemsTracked.toString)) + } + } + buffer.getSketch.merge(input.getSketch) + buffer + } + + override def eval(buffer: CombineInternal[Any]): Any = { + val sketchBytes = try { + buffer.getSketch.toByteArray(ApproxTopK.genSketchSerDe(buffer.getItemDataType)) + } catch { + case _: ArrayStoreException => + throw new SparkUnsupportedOperationException( + errorClass = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED" + ) + } + val maxItemsTracked = buffer.getMaxItemsTracked + InternalRow.apply(sketchBytes, null, maxItemsTracked) + } + + override def serialize(buffer: CombineInternal[Any]): Array[Byte] = { + val sketchBytes = buffer.getSketch.toByteArray( + ApproxTopK.genSketchSerDe(buffer.getItemDataType)) + val maxItemsTrackedByte = buffer.getMaxItemsTracked.toByte + val byteArray = new Array[Byte](sketchBytes.length + 1) + byteArray(0) = maxItemsTrackedByte + System.arraycopy(sketchBytes, 0, byteArray, 1, sketchBytes.length) + byteArray + } + + override def deserialize(buffer: Array[Byte]): CombineInternal[Any] = { + val maxItemsTracked = buffer(0).toInt + val sketchBytes = buffer.slice(1, buffer.length) + val sketch = ItemsSketch.getInstance( + Memory.wrap(sketchBytes), ApproxTopK.genSketchSerDe(itemDataType)) + new CombineInternal[Any](sketch, itemDataType, maxItemsTracked) + } + + override protected def withNewChildrenInternal( + newLeft: Expression, + newRight: Expression): Expression = copy(expr = newLeft, maxItemsTracked = newRight) + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def nullable: Boolean = false +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKSuite.scala index 14672dbc7f207..e2ad477202f43 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKSuite.scala @@ -111,4 +111,12 @@ class ApproxTopKSuite extends SparkFunSuite { ) assert(badEstimate.checkInputDataTypes().isFailure) } + + test("SPARK-combine: invalid combine if maxItemsTracked are not foldable") { + val badCombine = ApproxTopKCombine( + expr = BoundReference(0, BinaryType, nullable = false), + maxItemsTracked = Sum(BoundReference(1, IntegerType, nullable = true)) + ) + assert(badCombine.checkInputDataTypes().isFailure) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 0f4e4be1bd960..86808fa46bfba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -24,7 +24,7 @@ import scala.util.Random import org.scalatest.matchers.must.Matchers.the -import org.apache.spark.{SparkArithmeticException, SparkRuntimeException} +import org.apache.spark.{SparkArithmeticException, SparkRuntimeException, SparkUnsupportedOperationException} import org.apache.spark.sql.catalyst.ExtendedAnalysisException import org.apache.spark.sql.catalyst.plans.logical.Expand import org.apache.spark.sql.catalyst.util.AUTO_GENERATED_ALIAS @@ -3010,6 +3010,80 @@ class DataFrameAggregateSuite extends QueryTest parameters = Map("maxItemsTracked" -> "5", "k" -> "10") ) } + + test("SPARK-combine: same type, same size, unspecified combine size - success") { + sql("SELECT approx_top_k_accumulate(expr, 10) as acc " + + "FROM VALUES (0), (0), (0), (1), (1), (2), (2), (3) AS tab(expr);") + .createOrReplaceTempView("accumulation1") + + sql("SELECT approx_top_k_accumulate(expr, 10) as acc " + + "FROM VALUES (1), (1), (2), (2), (3), (3), (4), (4) AS tab(expr);") + .createOrReplaceTempView("accumulation2") + + sql("SELECT approx_top_k_combine(acc) as com " + + "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2);") + .createOrReplaceTempView("combined") + + val est = sql("SELECT approx_top_k_estimate(com) FROM combined;") + checkAnswer(est, Row(Seq(Row(2, 4), Row(1, 4), Row(0, 3), Row(3, 3), Row(4, 2)))) + } + + test("SPARK-combine: same type, different size, specified combine size - success") { + sql("SELECT approx_top_k_accumulate(expr, 10) as acc " + + "FROM VALUES (0), (0), (0), (1), (1), (2), (2), (3) AS tab(expr);") + .createOrReplaceTempView("accumulation1") + + sql("SELECT approx_top_k_accumulate(expr, 20) as acc " + + "FROM VALUES (1), (1), (2), (2), (3), (3), (4), (4) AS tab(expr);") + .createOrReplaceTempView("accumulation2") + + sql("SELECT approx_top_k_combine(acc, 30) as com " + + "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2);") + .createOrReplaceTempView("combination") + + val est = sql("SELECT approx_top_k_estimate(com) FROM combination;") + checkAnswer(est, Row(Seq(Row(2, 4), Row(1, 4), Row(0, 3), Row(3, 3), Row(4, 2)))) + } + + test("SPARK-combine: same type, different size, unspecified combine size - fail") { + sql("SELECT approx_top_k_accumulate(expr, 10) as acc " + + "FROM VALUES (0), (0), (0), (1), (1), (2), (2), (3) AS tab(expr);") + .createOrReplaceTempView("accumulation1") + + sql("SELECT approx_top_k_accumulate(expr, 20) as acc " + + "FROM VALUES (1), (1), (2), (2), (3), (3), (4), (4) AS tab(expr);") + .createOrReplaceTempView("accumulation2") + + val comb = sql("SELECT approx_top_k_combine(acc) as com " + + "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2);") + + checkError( + exception = intercept[SparkUnsupportedOperationException] { + comb.collect() + }, + condition = "APPROX_TOP_K_SKETCH_SIZE_UNMATCHED", + parameters = Map("size1" -> "10", "size2" -> "20") + ) + } + + test("SPARK-combine: same type, different size, invalid combine size - fail") { + sql("SELECT approx_top_k_accumulate(expr, 10) as acc " + + "FROM VALUES (0), (0), (0), (1), (1), (2), (2), (3) AS tab(expr);") + .createOrReplaceTempView("accumulation1") + + sql("SELECT approx_top_k_accumulate(expr, 20) as acc " + + "FROM VALUES (1), (1), (2), (2), (3), (3), (4), (4) AS tab(expr);") + .createOrReplaceTempView("accumulation2") + + checkError( + exception = intercept[SparkRuntimeException] { + sql("SELECT approx_top_k_combine(acc, -1) as com " + + "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2);") + }, + condition = "APPROX_TOP_K_NON_POSITIVE_ARG", + parameters = Map("argName" -> "`maxItemsTracked`", "argValue" -> "-1") + ) + } } case class B(c: Option[Double]) From 149ce9056bd3fa6d2f948f9c15a16646d8413edc Mon Sep 17 00:00:00 2001 From: yhuang-db Date: Thu, 3 Jul 2025 17:49:42 -0700 Subject: [PATCH 20/23] update combine tests --- .../spark/sql/DataFrameAggregateSuite.scala | 45 +++++++++---------- 1 file changed, 21 insertions(+), 24 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 86808fa46bfba..a8efaa5eb82eb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -3011,14 +3011,29 @@ class DataFrameAggregateSuite extends QueryTest ) } - test("SPARK-combine: same type, same size, unspecified combine size - success") { - sql("SELECT approx_top_k_accumulate(expr, 10) as acc " + + def setupAccumulations(size1: Int, size2: Int): Unit = { + sql(s"SELECT approx_top_k_accumulate(expr, $size1) as acc " + "FROM VALUES (0), (0), (0), (1), (1), (2), (2), (3) AS tab(expr);") .createOrReplaceTempView("accumulation1") - sql("SELECT approx_top_k_accumulate(expr, 10) as acc " + + sql(s"SELECT approx_top_k_accumulate(expr, $size2) as acc " + "FROM VALUES (1), (1), (2), (2), (3), (3), (4), (4) AS tab(expr);") .createOrReplaceTempView("accumulation2") + } + + test("SPARK-combine: same type, same size, specified combine size - success") { + setupAccumulations(10, 10) + + sql("SELECT approx_top_k_combine(acc, 30) as com " + + "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2);") + .createOrReplaceTempView("combined") + + val est = sql("SELECT approx_top_k_estimate(com) FROM combined;") + checkAnswer(est, Row(Seq(Row(2, 4), Row(1, 4), Row(0, 3), Row(3, 3), Row(4, 2)))) + } + + test("SPARK-combine: same type, same size, unspecified combine size - success") { + setupAccumulations(10, 10) sql("SELECT approx_top_k_combine(acc) as com " + "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2);") @@ -3029,13 +3044,7 @@ class DataFrameAggregateSuite extends QueryTest } test("SPARK-combine: same type, different size, specified combine size - success") { - sql("SELECT approx_top_k_accumulate(expr, 10) as acc " + - "FROM VALUES (0), (0), (0), (1), (1), (2), (2), (3) AS tab(expr);") - .createOrReplaceTempView("accumulation1") - - sql("SELECT approx_top_k_accumulate(expr, 20) as acc " + - "FROM VALUES (1), (1), (2), (2), (3), (3), (4), (4) AS tab(expr);") - .createOrReplaceTempView("accumulation2") + setupAccumulations(10, 20) sql("SELECT approx_top_k_combine(acc, 30) as com " + "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2);") @@ -3046,13 +3055,7 @@ class DataFrameAggregateSuite extends QueryTest } test("SPARK-combine: same type, different size, unspecified combine size - fail") { - sql("SELECT approx_top_k_accumulate(expr, 10) as acc " + - "FROM VALUES (0), (0), (0), (1), (1), (2), (2), (3) AS tab(expr);") - .createOrReplaceTempView("accumulation1") - - sql("SELECT approx_top_k_accumulate(expr, 20) as acc " + - "FROM VALUES (1), (1), (2), (2), (3), (3), (4), (4) AS tab(expr);") - .createOrReplaceTempView("accumulation2") + setupAccumulations(10, 20) val comb = sql("SELECT approx_top_k_combine(acc) as com " + "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2);") @@ -3067,13 +3070,7 @@ class DataFrameAggregateSuite extends QueryTest } test("SPARK-combine: same type, different size, invalid combine size - fail") { - sql("SELECT approx_top_k_accumulate(expr, 10) as acc " + - "FROM VALUES (0), (0), (0), (1), (1), (2), (2), (3) AS tab(expr);") - .createOrReplaceTempView("accumulation1") - - sql("SELECT approx_top_k_accumulate(expr, 20) as acc " + - "FROM VALUES (1), (1), (2), (2), (3), (3), (4), (4) AS tab(expr);") - .createOrReplaceTempView("accumulation2") + setupAccumulations(10, 20) checkError( exception = intercept[SparkRuntimeException] { From 874fc7d5dbf4962b86990f0d0ad516480b3a4468 Mon Sep 17 00:00:00 2001 From: yhuang-db Date: Sun, 6 Jul 2025 22:28:38 -0700 Subject: [PATCH 21/23] test combine with mixed types --- .../aggregate/ApproxTopKAggregates.scala | 4 +- .../spark/sql/DataFrameAggregateSuite.scala | 113 ++++++++++++++++++ 2 files changed, 116 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala index 05bd79f92c577..18596b3880777 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala @@ -514,7 +514,9 @@ case class ApproxTopKCombine( ItemsSketch.getInstance( Memory.wrap(inputSketchBytes), ApproxTopK.genSketchSerDe(buffer.getItemDataType)) } catch { - case _: SketchesArgumentException | _: NumberFormatException => + case _: SketchesArgumentException | + _: NumberFormatException | + _: java.lang.ClassCastException => throw new SparkUnsupportedOperationException( errorClass = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED" ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index a8efaa5eb82eb..a9e933c1b772c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -3081,6 +3081,119 @@ class DataFrameAggregateSuite extends QueryTest parameters = Map("argName" -> "`maxItemsTracked`", "argValue" -> "-1") ) } + + def setupMixedTypeAccumulation(seq1: Seq[Any], seq2: Seq[Any]): Unit = { + sql(s"SELECT approx_top_k_accumulate(expr, 10) as acc " + + s"FROM VALUES ${seq1.mkString(", ")} AS tab(expr);").createOrReplaceTempView("accumulation1") + + sql(s"SELECT approx_top_k_accumulate(expr, 10) as acc " + + s"FROM VALUES ${seq2.mkString(", ")} AS tab(expr);").createOrReplaceTempView("accumulation2") + + sql("SELECT approx_top_k_combine(acc, 30) as com " + + "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2);") + .createOrReplaceTempView("combined") + sql("SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2;").show(false) + } + + test("SPARK-combine: different types, same size, specified combine size - fail") { +// val acc2 = sql("SELECT approx_top_k_accumulate(expr, 10) as acc " + +// "FROM VALUES cast('2023-01-01' AS DATE), cast('2023-01-01' AS DATE), " + +// "cast('2023-01-02' AS DATE), cast('2023-01-02' AS DATE), " + +// "cast('2023-01-03' AS DATE), cast('2023-01-04' AS DATE), " + +// "cast('2023-01-05' AS DATE), cast('2023-01-05' AS DATE) AS tab(expr);") +// acc2.createOrReplaceTempView("accumulation2") + +// val acc2 = sql("SELECT approx_top_k_accumulate(expr, 10) as acc " + +// "FROM VALUES cast(0.0 AS FLOAT), cast(0.0 AS FLOAT), " + +// "cast(1.0 AS FLOAT), cast(1.0 AS FLOAT), " + +// "cast(2.0 AS FLOAT), cast(3.0 AS FLOAT), " + +// "cast(4.0 AS FLOAT), cast(4.0 AS FLOAT) AS tab(expr);") +// acc2.createOrReplaceTempView("accumulation2") + +// val acc2 = sql("SELECT approx_top_k_accumulate(expr, 10) as acc " + +// "FROM VALUES CAST(0.0 AS DECIMAL(20, 3)), CAST(0.0 AS DECIMAL(20, 3)), " + +// "CAST(1.0 AS DECIMAL(20, 3)), CAST(1.0 AS DECIMAL(20, 3)), " + +// "CAST(2.0 AS DECIMAL(20, 3)), CAST(3.0 AS DECIMAL(20, 3)), " + +// "CAST(4.0 AS DECIMAL(20, 3)), CAST(4.0 AS DECIMAL(20, 3)) AS tab(expr);") +// acc2.createOrReplaceTempView("accumulation2") + // + // val comb = sql("SELECT approx_top_k_combine(acc, 30) as com " + + // "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2);") + // comb.createOrReplaceTempView("combined") + +// Seq("DATE'2023-01-01'", "DATE'2023-01-01'", "DATE'2023-01-02'", "DATE'2023-01-03'") + + val mixedTypeSeqs = Seq( + Seq(0, 0, 0, 1, 1, 2, 2, 3), + Seq("\'a\'", "\'b\'", "\'c\'", "\'c\'", "\'c\'", "\'c\'", "\'d\'", "\'d\'"), + Seq("cast(0 AS BYTE)", "cast(0 AS BYTE)", "cast(1 AS BYTE)", "cast(2 AS BYTE)"), + Seq("cast(0 AS SHORT)", "cast(0 AS SHORT)", "cast(1 AS SHORT)", "cast(2 AS SHORT)"), + Seq("cast(0 AS FLOAT)", "cast(0 AS FLOAT)", "cast(1 AS FLOAT)", "cast(2 AS FLOAT)"), + Seq("cast(0 AS DOUBLE)", "cast(0 AS DOUBLE)", "cast(1 AS DOUBLE)", "cast(2 AS DOUBLE)"), + Seq("cast(0 AS DECIMAL(4, 2))", "cast(0 AS DECIMAL(4, 2))", + "cast(1 AS DECIMAL(4, 2))", "cast(2 AS DECIMAL(4, 2))"), + Seq("cast(0 AS DECIMAL(10, 2))", "cast(0 AS DECIMAL(10, 2))", + "cast(1 AS DECIMAL(10, 2))", "cast(2 AS DECIMAL(10, 2))"), + Seq("cast(0 AS DECIMAL(20, 3))", "cast(0 AS DECIMAL(20, 3))", + "cast(1 AS DECIMAL(20, 3))", "cast(2 AS DECIMAL(20, 3))") + ) + + // for any two sequences of different types, the combine should fail + mixedTypeSeqs.foreach { seq1 => + mixedTypeSeqs.foreach { seq2 => + if (seq1 != seq2) { + // scalastyle:off + println(seq1.mkString(", ")) + println(seq2.mkString(", ")) + // scalastyle:on + setupMixedTypeAccumulation(seq1, seq2) + checkError( + exception = intercept[SparkUnsupportedOperationException] { + sql("SELECT approx_top_k_estimate(com) FROM combined;").collect() + }, + condition = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED" + ) + } + } + } + } + + test("SPARK-decimal: different type (decimal(10, 2) VS decimal(20, 3)), same size - fail") { + val acc1 = sql("SELECT approx_top_k_accumulate(expr, 10) as acc " + + "FROM VALUES CAST(0.0 AS DECIMAL(10, 2)), CAST(0.0 AS DECIMAL(10, 2)), " + + "CAST(1.0 AS DECIMAL(10, 2)), CAST(1.0 AS DECIMAL(10, 2)), " + + "CAST(2.0 AS DECIMAL(10, 2)), CAST(3.0 AS DECIMAL(10, 2)), " + + "CAST(4.0 AS DECIMAL(10, 2)), CAST(4.0 AS DECIMAL(10, 2)) AS tab(expr);") + acc1.createOrReplaceTempView("accumulation1") + + val acc2 = sql("SELECT approx_top_k_accumulate(expr, 10) as acc " + + "FROM VALUES CAST(0.0 AS DECIMAL(20, 3)), CAST(0.0 AS DECIMAL(20, 3)), " + + "CAST(1.0 AS DECIMAL(20, 3)), CAST(1.0 AS DECIMAL(20, 3)), " + + "CAST(2.0 AS DECIMAL(20, 3)), CAST(3.0 AS DECIMAL(20, 3)), " + + "CAST(4.0 AS DECIMAL(20, 3)), CAST(4.0 AS DECIMAL(20, 3)) AS tab(expr);") + acc2.createOrReplaceTempView("accumulation2") + + val comb = sql("SELECT approx_top_k_combine(acc, 30) as com " + + "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2);") + comb.createOrReplaceTempView("combined") + + val est = sql("SELECT approx_top_k_estimate(com) FROM combined;") + checkError( + exception = intercept[SparkUnsupportedOperationException] { + est.collect() + }, + condition = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED") + } + + test("SPARK-dd: different type (decimal(10, 2) VS decimal(20, 3)), same size - fail") { + val ddouble = Seq("cast(0 AS DOUBLE)", "cast(0 AS DOUBLE)", + "cast(1 AS DOUBLE)", "cast(2 AS DOUBLE)") + val ddecimal = Seq("cast(0 AS DECIMAL(10, 2))", "cast(0 AS DECIMAL(10, 2))", + "cast(1 AS DECIMAL(10, 2))", "cast(2 AS DECIMAL(10, 2))") + + setupMixedTypeAccumulation(ddouble, ddecimal) + sql("SELECT approx_top_k_estimate(com) FROM combined;").show(false) + } } case class B(c: Option[Double]) From e5e40c21f44ac94548cb46712ba82644ab225d3e Mon Sep 17 00:00:00 2001 From: yhuang-db Date: Mon, 7 Jul 2025 10:33:49 -0700 Subject: [PATCH 22/23] combine with type code --- .../resources/error/error-conditions.json | 2 +- .../aggregate/ApproxTopKAggregates.scala | 113 +++++++++++--- .../spark/sql/DataFrameAggregateSuite.scala | 139 +++++------------- 3 files changed, 129 insertions(+), 125 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 12f6d9f7b0e03..560295c050220 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -116,7 +116,7 @@ }, "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED" : { "message" : [ - "Failed to deserialize from memory. Probably due to a mismatch in sketch types. Combining approx_top_k sketches of different types is not allowed." + "Combining approx_top_k sketches of different types is not allowed. Found sketches of type and ." ], "sqlState": "42846" }, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala index 18596b3880777..86d74461b63de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala @@ -323,7 +323,42 @@ object ApproxTopK { StructType( StructField("Sketch", BinaryType, nullable = false) :: StructField("ItemTypeNull", itemDataType) :: - StructField("MaxItemsTracked", IntegerType, nullable = false) :: Nil) + StructField("MaxItemsTracked", IntegerType, nullable = false) :: + StructField("TypeCode", BinaryType, nullable = false) :: Nil) + + def dataTypeToBytes(dataType: DataType): Array[Byte] = { + dataType match { + case _: BooleanType => Array(0, 0, 0) + case _: ByteType => Array(1, 0, 0) + case _: ShortType => Array(2, 0, 0) + case _: IntegerType => Array(3, 0, 0) + case _: LongType => Array(4, 0, 0) + case _: FloatType => Array(5, 0, 0) + case _: DoubleType => Array(6, 0, 0) + case _: DateType => Array(7, 0, 0) + case _: TimestampType => Array(8, 0, 0) + case _: TimestampNTZType => Array(9, 0, 0) + case _: StringType => Array(10, 0, 0) + case dt: DecimalType => Array(11, dt.precision.toByte, dt.scale.toByte) + } + } + + def bytesToDataType(bytes: Array[Byte]): DataType = { + bytes(0) match { + case 0 => BooleanType + case 1 => ByteType + case 2 => ShortType + case 3 => IntegerType + case 4 => LongType + case 5 => FloatType + case 6 => DoubleType + case 7 => DateType + case 8 => TimestampType + case 9 => TimestampNTZType + case 10 => StringType + case 11 => DecimalType(bytes(1).toInt, bytes(2).toInt) + } + } } /** @@ -415,7 +450,8 @@ case class ApproxTopKAccumulate( override def eval(buffer: ItemsSketch[Any]): Any = { val sketchBytes = serialize(buffer) - InternalRow.apply(sketchBytes, null, maxItemsTrackedVal) + val typeCode = ApproxTopK.dataTypeToBytes(itemDataType) + InternalRow.apply(sketchBytes, null, maxItemsTrackedVal, typeCode) } override def serialize(buffer: ItemsSketch[Any]): Array[Byte] = @@ -441,11 +477,26 @@ case class ApproxTopKAccumulate( getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("approx_top_k_accumulate") } -class CombineInternal[T](sketch: ItemsSketch[T], itemDataType: DataType, var maxItemsTracked: Int) { +class CombineInternal[T]( + sketch: ItemsSketch[T], + var itemDataType: DataType, + var maxItemsTracked: Int) { def getSketch: ItemsSketch[T] = sketch def getItemDataType: DataType = itemDataType + def setItemDataType(dataType: DataType): Unit = { + if (this.itemDataType == null) { + this.itemDataType = dataType + } else if (this.itemDataType != dataType) { + throw new SparkUnsupportedOperationException( + errorClass = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED", + messageParameters = Map( + "type1" -> this.itemDataType.typeName, + "type2" -> dataType.typeName)) + } + } + def getMaxItemsTracked: Int = maxItemsTracked def setMaxItemsTracked(maxItemsTracked: Int): Unit = this.maxItemsTracked = maxItemsTracked @@ -470,7 +521,7 @@ case class ApproxTopKCombine( def this(child: Expression) = this(child, Literal(ApproxTopK.VOID_MAX_ITEMS_TRACKED), 0, 0) - private lazy val itemDataType: DataType = + private lazy val uncheckedItemDataType: DataType = expr.dataType.asInstanceOf[StructType]("ItemTypeNull").dataType private lazy val maxItemsTrackedVal: Int = maxItemsTracked.eval().asInstanceOf[Int] private lazy val combineSizeSpecified: Boolean = @@ -493,16 +544,19 @@ case class ApproxTopKCombine( } } - override def dataType: DataType = ApproxTopK.getSketchStateDataType(itemDataType) + override def dataType: DataType = ApproxTopK.getSketchStateDataType(uncheckedItemDataType) override def createAggregationBuffer(): CombineInternal[Any] = { if (combineSizeSpecified) { val maxMapSize = ApproxTopK.calMaxMapSize(maxItemsTrackedVal) - new CombineInternal[Any](new ItemsSketch[Any](maxMapSize), itemDataType, maxItemsTrackedVal) + new CombineInternal[Any]( + new ItemsSketch[Any](maxMapSize), + null, + maxItemsTrackedVal) } else { new CombineInternal[Any]( new ItemsSketch[Any](ApproxTopK.SKETCH_SIZE_PLACEHOLDER), - itemDataType, + null, ApproxTopK.VOID_MAX_ITEMS_TRACKED) } } @@ -510,17 +564,11 @@ case class ApproxTopKCombine( override def update(buffer: CombineInternal[Any], input: InternalRow): CombineInternal[Any] = { val inputSketchBytes = expr.eval(input).asInstanceOf[InternalRow].getBinary(0) val inputMaxItemsTracked = expr.eval(input).asInstanceOf[InternalRow].getInt(2) - val inputSketch = try { - ItemsSketch.getInstance( - Memory.wrap(inputSketchBytes), ApproxTopK.genSketchSerDe(buffer.getItemDataType)) - } catch { - case _: SketchesArgumentException | - _: NumberFormatException | - _: java.lang.ClassCastException => - throw new SparkUnsupportedOperationException( - errorClass = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED" - ) - } + val typeCode = expr.eval(input).asInstanceOf[InternalRow].getBinary(3) + val actualItemDataType = ApproxTopK.bytesToDataType(typeCode) + buffer.setItemDataType(actualItemDataType) + val inputSketch = ItemsSketch.getInstance( + Memory.wrap(inputSketchBytes), ApproxTopK.genSketchSerDe(buffer.getItemDataType)) buffer.getSketch.merge(inputSketch) if (!combineSizeSpecified) { buffer.setMaxItemsTracked(inputMaxItemsTracked) @@ -544,6 +592,18 @@ case class ApproxTopKCombine( "size2" -> input.getMaxItemsTracked.toString)) } } + // check item data type + if (buffer.getItemDataType != null && input.getItemDataType != null && + buffer.getItemDataType != input.getItemDataType) { + throw new SparkUnsupportedOperationException( + errorClass = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED", + messageParameters = Map( + "type1" -> buffer.getItemDataType.typeName, + "type2" -> input.getItemDataType.typeName)) + } else if (buffer.getItemDataType == null) { + // If buffer is a placeholder sketch, set it to the input sketch's item data type + buffer.setItemDataType(input.getItemDataType) + } buffer.getSketch.merge(input.getSketch) buffer } @@ -558,25 +618,30 @@ case class ApproxTopKCombine( ) } val maxItemsTracked = buffer.getMaxItemsTracked - InternalRow.apply(sketchBytes, null, maxItemsTracked) + val typeCode = ApproxTopK.dataTypeToBytes(buffer.getItemDataType) + InternalRow.apply(sketchBytes, null, maxItemsTracked, typeCode) } override def serialize(buffer: CombineInternal[Any]): Array[Byte] = { val sketchBytes = buffer.getSketch.toByteArray( ApproxTopK.genSketchSerDe(buffer.getItemDataType)) val maxItemsTrackedByte = buffer.getMaxItemsTracked.toByte - val byteArray = new Array[Byte](sketchBytes.length + 1) + val itemDataTypeBytes = ApproxTopK.dataTypeToBytes(buffer.getItemDataType) + val byteArray = new Array[Byte](sketchBytes.length + 4) byteArray(0) = maxItemsTrackedByte - System.arraycopy(sketchBytes, 0, byteArray, 1, sketchBytes.length) + System.arraycopy(itemDataTypeBytes, 0, byteArray, 1, itemDataTypeBytes.length) + System.arraycopy(sketchBytes, 0, byteArray, 4, sketchBytes.length) byteArray } override def deserialize(buffer: Array[Byte]): CombineInternal[Any] = { val maxItemsTracked = buffer(0).toInt - val sketchBytes = buffer.slice(1, buffer.length) + val itemDataTypeBytes = buffer.slice(1, 4) + val actualItemDataType = ApproxTopK.bytesToDataType(itemDataTypeBytes) + val sketchBytes = buffer.slice(4, buffer.length) val sketch = ItemsSketch.getInstance( - Memory.wrap(sketchBytes), ApproxTopK.genSketchSerDe(itemDataType)) - new CombineInternal[Any](sketch, itemDataType, maxItemsTracked) + Memory.wrap(sketchBytes), ApproxTopK.genSketchSerDe(actualItemDataType)) + new CombineInternal[Any](sketch, actualItemDataType, maxItemsTracked) } override protected def withNewChildrenInternal( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index a9e933c1b772c..5fe20e823c78f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -3084,116 +3084,55 @@ class DataFrameAggregateSuite extends QueryTest def setupMixedTypeAccumulation(seq1: Seq[Any], seq2: Seq[Any]): Unit = { sql(s"SELECT approx_top_k_accumulate(expr, 10) as acc " + - s"FROM VALUES ${seq1.mkString(", ")} AS tab(expr);").createOrReplaceTempView("accumulation1") + s"FROM VALUES ${seq1.mkString(", ")} AS tab(expr);") + .createOrReplaceTempView("accumulation1") sql(s"SELECT approx_top_k_accumulate(expr, 10) as acc " + - s"FROM VALUES ${seq2.mkString(", ")} AS tab(expr);").createOrReplaceTempView("accumulation2") + s"FROM VALUES ${seq2.mkString(", ")} AS tab(expr);") + .createOrReplaceTempView("accumulation2") - sql("SELECT approx_top_k_combine(acc, 30) as com " + - "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2);") - .createOrReplaceTempView("combined") - sql("SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2;").show(false) + sql("SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2") + .createOrReplaceTempView("unioned") } test("SPARK-combine: different types, same size, specified combine size - fail") { -// val acc2 = sql("SELECT approx_top_k_accumulate(expr, 10) as acc " + -// "FROM VALUES cast('2023-01-01' AS DATE), cast('2023-01-01' AS DATE), " + -// "cast('2023-01-02' AS DATE), cast('2023-01-02' AS DATE), " + -// "cast('2023-01-03' AS DATE), cast('2023-01-04' AS DATE), " + -// "cast('2023-01-05' AS DATE), cast('2023-01-05' AS DATE) AS tab(expr);") -// acc2.createOrReplaceTempView("accumulation2") - -// val acc2 = sql("SELECT approx_top_k_accumulate(expr, 10) as acc " + -// "FROM VALUES cast(0.0 AS FLOAT), cast(0.0 AS FLOAT), " + -// "cast(1.0 AS FLOAT), cast(1.0 AS FLOAT), " + -// "cast(2.0 AS FLOAT), cast(3.0 AS FLOAT), " + -// "cast(4.0 AS FLOAT), cast(4.0 AS FLOAT) AS tab(expr);") -// acc2.createOrReplaceTempView("accumulation2") - -// val acc2 = sql("SELECT approx_top_k_accumulate(expr, 10) as acc " + -// "FROM VALUES CAST(0.0 AS DECIMAL(20, 3)), CAST(0.0 AS DECIMAL(20, 3)), " + -// "CAST(1.0 AS DECIMAL(20, 3)), CAST(1.0 AS DECIMAL(20, 3)), " + -// "CAST(2.0 AS DECIMAL(20, 3)), CAST(3.0 AS DECIMAL(20, 3)), " + -// "CAST(4.0 AS DECIMAL(20, 3)), CAST(4.0 AS DECIMAL(20, 3)) AS tab(expr);") -// acc2.createOrReplaceTempView("accumulation2") - // - // val comb = sql("SELECT approx_top_k_combine(acc, 30) as com " + - // "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2);") - // comb.createOrReplaceTempView("combined") - -// Seq("DATE'2023-01-01'", "DATE'2023-01-01'", "DATE'2023-01-02'", "DATE'2023-01-03'") - val mixedTypeSeqs = Seq( - Seq(0, 0, 0, 1, 1, 2, 2, 3), - Seq("\'a\'", "\'b\'", "\'c\'", "\'c\'", "\'c\'", "\'c\'", "\'d\'", "\'d\'"), - Seq("cast(0 AS BYTE)", "cast(0 AS BYTE)", "cast(1 AS BYTE)", "cast(2 AS BYTE)"), - Seq("cast(0 AS SHORT)", "cast(0 AS SHORT)", "cast(1 AS SHORT)", "cast(2 AS SHORT)"), - Seq("cast(0 AS FLOAT)", "cast(0 AS FLOAT)", "cast(1 AS FLOAT)", "cast(2 AS FLOAT)"), - Seq("cast(0 AS DOUBLE)", "cast(0 AS DOUBLE)", "cast(1 AS DOUBLE)", "cast(2 AS DOUBLE)"), - Seq("cast(0 AS DECIMAL(4, 2))", "cast(0 AS DECIMAL(4, 2))", - "cast(1 AS DECIMAL(4, 2))", "cast(2 AS DECIMAL(4, 2))"), - Seq("cast(0 AS DECIMAL(10, 2))", "cast(0 AS DECIMAL(10, 2))", - "cast(1 AS DECIMAL(10, 2))", "cast(2 AS DECIMAL(10, 2))"), - Seq("cast(0 AS DECIMAL(20, 3))", "cast(0 AS DECIMAL(20, 3))", - "cast(1 AS DECIMAL(20, 3))", "cast(2 AS DECIMAL(20, 3))") - ) - - // for any two sequences of different types, the combine should fail - mixedTypeSeqs.foreach { seq1 => - mixedTypeSeqs.foreach { seq2 => - if (seq1 != seq2) { - // scalastyle:off - println(seq1.mkString(", ")) - println(seq2.mkString(", ")) - // scalastyle:on - setupMixedTypeAccumulation(seq1, seq2) - checkError( - exception = intercept[SparkUnsupportedOperationException] { - sql("SELECT approx_top_k_estimate(com) FROM combined;").collect() - }, - condition = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED" - ) - } + (IntegerType.typeName, Seq(0, 0, 0, 1, 1, 2, 2, 3)), + (StringType.typeName, Seq("'a'", "'b'", "'c'", "'c'", "'c'", "'c'", "'d'", "'d'")), + // (BooleanType.typeName, Seq("(true)", "(true)", "(true)", "(false)", "(false)")), + (ByteType.typeName, Seq("cast(0 AS BYTE)", "cast(0 AS BYTE)", "cast(1 AS BYTE)")), + (ShortType.typeName, Seq("cast(0 AS SHORT)", "cast(0 AS SHORT)", "cast(1 AS SHORT)")), + (LongType.typeName, Seq("cast(0 AS LONG)", "cast(0 AS LONG)", "cast(1 AS LONG)")), + (FloatType.typeName, Seq("cast(0 AS FLOAT)", "cast(0 AS FLOAT)", "cast(1 AS FLOAT)")), + (DoubleType.typeName, Seq("cast(0 AS DOUBLE)", "cast(0 AS DOUBLE)", "cast(1 AS DOUBLE)")), + (DecimalType(4, 2).typeName, Seq("cast(0 AS DECIMAL(4, 2))", "cast(0 AS DECIMAL(4, 2))", + "cast(1 AS DECIMAL(4, 2))")), + (DecimalType(10, 2).typeName, Seq("cast(0 AS DECIMAL(10, 2))", "cast(0 AS DECIMAL(10, 2))", + "cast(1 AS DECIMAL(10, 2))")), + (DecimalType(20, 3).typeName, Seq("cast(0 AS DECIMAL(20, 3))", "cast(0 AS DECIMAL(20, 3))", + "cast(1 AS DECIMAL(20, 3))")) + // DateType.typeName -> Seq("DATE'2025-01-01'", "DATE'2025-01-01'", "DATE'2025-01-02'"), + // TimestampType.typeName -> Seq("TIMESTAMP'2025-01-01 00:00:00'", + // "TIMESTAMP'2025-01-01 00:00:00'", "TIMESTAMP'2025-01-02 00:00:00'"), + // TimestampNTZType.typeName -> Seq("TIMESTAMP_NTZ'2025-01-01 00:00:00'", + // "TIMESTAMP_NTZ'2025-01-01 00:00:00'", "TIMESTAMP_NTZ'2025-01-02 00:00:00'") + ) + + for (i <- 0 until mixedTypeSeqs.size - 1) { + for (j <- i + 1 until mixedTypeSeqs.size) { + val (type1, seq1) = mixedTypeSeqs(i) + val (type2, seq2) = mixedTypeSeqs(j) + setupMixedTypeAccumulation(seq1, seq2) + checkError( + exception = intercept[SparkUnsupportedOperationException] { + sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned;").collect() + }, + condition = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED", + parameters = Map("type1" -> type1, "type2" -> type2) + ) } } } - - test("SPARK-decimal: different type (decimal(10, 2) VS decimal(20, 3)), same size - fail") { - val acc1 = sql("SELECT approx_top_k_accumulate(expr, 10) as acc " + - "FROM VALUES CAST(0.0 AS DECIMAL(10, 2)), CAST(0.0 AS DECIMAL(10, 2)), " + - "CAST(1.0 AS DECIMAL(10, 2)), CAST(1.0 AS DECIMAL(10, 2)), " + - "CAST(2.0 AS DECIMAL(10, 2)), CAST(3.0 AS DECIMAL(10, 2)), " + - "CAST(4.0 AS DECIMAL(10, 2)), CAST(4.0 AS DECIMAL(10, 2)) AS tab(expr);") - acc1.createOrReplaceTempView("accumulation1") - - val acc2 = sql("SELECT approx_top_k_accumulate(expr, 10) as acc " + - "FROM VALUES CAST(0.0 AS DECIMAL(20, 3)), CAST(0.0 AS DECIMAL(20, 3)), " + - "CAST(1.0 AS DECIMAL(20, 3)), CAST(1.0 AS DECIMAL(20, 3)), " + - "CAST(2.0 AS DECIMAL(20, 3)), CAST(3.0 AS DECIMAL(20, 3)), " + - "CAST(4.0 AS DECIMAL(20, 3)), CAST(4.0 AS DECIMAL(20, 3)) AS tab(expr);") - acc2.createOrReplaceTempView("accumulation2") - - val comb = sql("SELECT approx_top_k_combine(acc, 30) as com " + - "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2);") - comb.createOrReplaceTempView("combined") - - val est = sql("SELECT approx_top_k_estimate(com) FROM combined;") - checkError( - exception = intercept[SparkUnsupportedOperationException] { - est.collect() - }, - condition = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED") - } - - test("SPARK-dd: different type (decimal(10, 2) VS decimal(20, 3)), same size - fail") { - val ddouble = Seq("cast(0 AS DOUBLE)", "cast(0 AS DOUBLE)", - "cast(1 AS DOUBLE)", "cast(2 AS DOUBLE)") - val ddecimal = Seq("cast(0 AS DECIMAL(10, 2))", "cast(0 AS DECIMAL(10, 2))", - "cast(1 AS DECIMAL(10, 2))", "cast(2 AS DECIMAL(10, 2))") - - setupMixedTypeAccumulation(ddouble, ddecimal) - sql("SELECT approx_top_k_estimate(com) FROM combined;").show(false) - } } case class B(c: Option[Double]) From de82104c60b6ac9671595d080779f91f39aa4d97 Mon Sep 17 00:00:00 2001 From: yhuang-db Date: Mon, 7 Jul 2025 11:11:29 -0700 Subject: [PATCH 23/23] finish combine test with fixed types --- .../spark/sql/DataFrameAggregateSuite.scala | 104 ++++++++++++++++++ 1 file changed, 104 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 5fe20e823c78f..10ff56cbcbbff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -3133,6 +3133,110 @@ class DataFrameAggregateSuite extends QueryTest } } } + + test("SPARK-combine: different types - fail on UNION") { + val seq1 = Seq(0, 0, 0, 1, 1, 2, 2, 3) + val seq2 = Seq("(true)", "(true)", "(false)", "(false)") + checkError( + exception = intercept[ExtendedAnalysisException] { + setupMixedTypeAccumulation(seq1, seq2) + }, + condition = "INCOMPATIBLE_COLUMN_TYPE", + parameters = Map( + "tableOrdinalNumber" -> "second", + "columnOrdinalNumber" -> "first", + "dataType2" -> ("\"STRUCT\""), + "operator" -> "UNION", + "hint" -> "", + "dataType1" -> ("\"STRUCT\"") + ), + queryContext = Array( + ExpectedContext( + "SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2", 0, 68)) + ) + } + + test("SPARK-combine: different types Date vs Timestamp - fail") { + val (type1, seq1) = (DateType.typeName, + Seq("DATE'2025-01-01'", "DATE'2025-01-01'", "DATE'2025-01-02'")) + val (type2, seq2) = (TimestampType.typeName, Seq("TIMESTAMP'2025-01-01 00:00:00'", + "TIMESTAMP'2025-01-01 00:00:00'", "TIMESTAMP'2025-01-02 00:00:00'")) + setupMixedTypeAccumulation(seq1, seq2) + checkError( + exception = intercept[SparkUnsupportedOperationException] { + sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned;").collect() + }, + condition = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED", + parameters = Map("type1" -> type1, "type2" -> type2) + ) + } + + test("SPARK-combine: different types Timestamp vs TimestampNTZ - fail") { + val (type1, seq1) = (TimestampType.typeName, + Seq("TIMESTAMP'2025-01-01 00:00:00'", "TIMESTAMP'2025-01-01 00:00:00'", + "TIMESTAMP'2025-01-02 00:00:00'")) + val (type2, seq2) = (TimestampNTZType.typeName, + Seq("TIMESTAMP_NTZ'2025-01-01 00:00:00'", "TIMESTAMP_NTZ'2025-01-01 00:00:00'", + "TIMESTAMP_NTZ'2025-01-02 00:00:00'")) + setupMixedTypeAccumulation(seq1, seq2) + checkError( + exception = intercept[SparkUnsupportedOperationException] { + sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned;").collect() + }, + condition = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED", + parameters = Map("type1" -> type1, "type2" -> type2) + ) + } + + test("SPARK-combine: different types Int vs Date - fail on UNION") { + val seq1 = Seq(0, 0, 0, 1, 1, 2, 2, 3) + val seq2 = Seq("DATE'2025-01-01'", "DATE'2025-01-01'", "DATE'2025-01-02'") + checkError( + exception = intercept[ExtendedAnalysisException] { + setupMixedTypeAccumulation(seq1, seq2) + }, + condition = "INCOMPATIBLE_COLUMN_TYPE", + parameters = Map( + "tableOrdinalNumber" -> "second", + "columnOrdinalNumber" -> "first", + "dataType2" -> ("\"STRUCT\""), + "operator" -> "UNION", + "hint" -> "", + "dataType1" -> ("\"STRUCT\"") + ), + queryContext = Array( + ExpectedContext( + "SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2", 0, 68)) + ) + } + + test("SPARK-combine: different types Long vs Timestamp - fail on UNION") { + val seq1 = Seq("cast(0 AS LONG)", "cast(0 AS LONG)", "cast(1 AS LONG)") + val seq2 = Seq("TIMESTAMP'2025-01-01 00:00:00'", "TIMESTAMP'2025-01-01 00:00:00'") + checkError( + exception = intercept[ExtendedAnalysisException] { + setupMixedTypeAccumulation(seq1, seq2) + }, + condition = "INCOMPATIBLE_COLUMN_TYPE", + parameters = Map( + "tableOrdinalNumber" -> "second", + "columnOrdinalNumber" -> "first", + "dataType2" -> ("\"STRUCT\""), + "operator" -> "UNION", + "hint" -> "", + "dataType1" -> ("\"STRUCT\"") + ), + queryContext = Array( + ExpectedContext( + "SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2", 0, 68)) + ) + } } case class B(c: Option[Double])