Skip to content

Commit 5e6e8f1

Browse files
eejbyfeldthvanhovell
authored andcommitted
[SPARK-52023][SQL] Fix data corruption/segfault returning Option[Product] from udaf
### What changes were proposed in this pull request? This fixes so defining a udaf returning a `Option[Product]` produces correct results instead of the current behavior. Where it throws an exception, segfaults or produces incorrect results. ### Why are the changes needed? Fix correctness issue. ### Does this PR introduce _any_ user-facing change? Fixes a correctness issue. ### How was this patch tested? Existing and new unittest. ### Was this patch authored or co-authored using generative AI tooling? No Closes #50827 from eejbyfeldt/SPARK-52023. Authored-by: Emil Ejbyfeldt <emil.ejbyfeldt@choreograph.com> Signed-off-by: Herman van Hovell <herman@databricks.com>
1 parent 646d96e commit 5e6e8f1

File tree

2 files changed

+28
-1
lines changed
  • sql

2 files changed

+28
-1
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,7 @@ case class ScalaAggregator[IN, BUF, OUT](
530530

531531
def eval(buffer: BUF): Any = {
532532
val row = outputSerializer(agg.finish(buffer))
533-
if (outputEncoder.isSerializedAsStruct) row else row.get(0, dataType)
533+
if (outputEncoder.isSerializedAsStructForTopLevel) row else row.get(0, dataType)
534534
}
535535

536536
@transient private[this] lazy val bufferRow = new UnsafeRow(bufferEncoder.namedExpressions.length)

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,22 @@ object LongProductSumAgg extends Aggregator[(jlLong, jlLong), Long, jlLong] {
6060
def outputEncoder: Encoder[jlLong] = Encoders.LONG
6161
}
6262

63+
final case class Reduce[T: Encoder](r: (T, T) => T)(implicit i: Encoder[Option[T]])
64+
extends Aggregator[T, Option[T], T] {
65+
def zero: Option[T] = None
66+
def reduce(b: Option[T], a: T): Option[T] = Some(b.fold(a)(r(_, a)))
67+
def merge(b1: Option[T], b2: Option[T]): Option[T] =
68+
(b1, b2) match {
69+
case (Some(a), Some(b)) => Some(r(a, b))
70+
case (Some(a), None) => Some(a)
71+
case (None, Some(b)) => Some(b)
72+
case (None, None) => None
73+
}
74+
def finish(reduction: Option[T]): T = reduction.get
75+
def bufferEncoder: Encoder[Option[T]] = implicitly
76+
def outputEncoder: Encoder[T] = implicitly
77+
}
78+
6379
@SQLUserDefinedType(udt = classOf[CountSerDeUDT])
6480
case class CountSerDeSQL(nSer: Int, nDeSer: Int, sum: Int)
6581

@@ -180,6 +196,9 @@ abstract class UDAQuerySuite extends QueryTest with SQLTestUtils with TestHiveSi
180196
val data4 = Seq[Boolean](true, false, true).toDF("boolvalues")
181197
data4.write.saveAsTable("agg4")
182198

199+
val data5 = Seq[(Int, (Int, Int))]((1, (2, 3))).toDF("key", "value")
200+
data5.write.saveAsTable("agg5")
201+
183202
val emptyDF = spark.createDataFrame(
184203
sparkContext.emptyRDD[Row],
185204
StructType(StructField("key", StringType) :: StructField("value", IntegerType) :: Nil))
@@ -190,6 +209,8 @@ abstract class UDAQuerySuite extends QueryTest with SQLTestUtils with TestHiveSi
190209
spark.udf.register("mydoubleavg", udaf(MyDoubleAvgAgg))
191210
spark.udf.register("longProductSum", udaf(LongProductSumAgg))
192211
spark.udf.register("arraysum", udaf(ArrayDataAgg))
212+
spark.udf.register("reduceOptionPair", udaf(Reduce[Option[(Int, Int)]](
213+
(opt1, opt2) => opt1.zip(opt2).map { case ((a1, b1), (a2, b2)) => (a1 + a2, b1 + b2) })))
193214
}
194215

195216
override def afterAll(): Unit = {
@@ -371,6 +392,12 @@ abstract class UDAQuerySuite extends QueryTest with SQLTestUtils with TestHiveSi
371392
Row(Seq(12.0, 15.0, 18.0)) :: Nil)
372393
}
373394

395+
test("SPARK-52023: Returning Option[Product] from udaf") {
396+
checkAnswer(
397+
spark.sql("SELECT reduceOptionPair(value) FROM agg5 GROUP BY key"),
398+
Row(Row(2, 3)) :: Nil)
399+
}
400+
374401
test("verify aggregator ser/de behavior") {
375402
val data = sparkContext.parallelize((1 to 100).toSeq, 3).toDF("value1")
376403
val agg = udaf(CountSerDeAgg)

0 commit comments

Comments
 (0)