From da5d901d8df6e0137ee389c4f95895a6bf9649ad Mon Sep 17 00:00:00 2001 From: Peter Nguyen Date: Sat, 12 Jul 2025 10:35:06 -0700 Subject: [PATCH 1/2] refactor: Refactor GetArrayItem, ElementAt, GetArrayStructFields out of QueryPlanSerde --- .../apache/comet/serde/QueryPlanSerde.scala | 91 +--------------- .../scala/org/apache/comet/serde/arrays.scala | 100 +++++++++++++++++- 2 files changed, 103 insertions(+), 88 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 43c03d6b4f..2b6551f495 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -121,7 +121,10 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[MapKeys] -> CometMapKeys, classOf[MapValues] -> CometMapValues, classOf[MapFromArrays] -> CometMapFromArrays, - classOf[GetMapValue] -> CometMapExtract) + classOf[GetMapValue] -> CometMapExtract, + classOf[GetArrayItem] -> CometGetArrayItem, + classOf[ElementAt] -> CometElementAt, + classOf[GetArrayStructFields] -> CometGetArrayStructFields) def emitWarning(reason: String): Unit = { logWarning(s"Comet native execution is disabled due to: $reason") @@ -1790,92 +1793,6 @@ object QueryPlanSerde extends Logging with CometExprShim { .setGetStructField(getStructFieldBuilder) .build() } - - case GetArrayItem(child, ordinal, failOnError) => - val childExpr = exprToProtoInternal(child, inputs, binding) - val ordinalExpr = exprToProtoInternal(ordinal, inputs, binding) - - if (childExpr.isDefined && ordinalExpr.isDefined) { - val listExtractBuilder = ExprOuterClass.ListExtract - .newBuilder() - .setChild(childExpr.get) - .setOrdinal(ordinalExpr.get) - .setOneBased(false) - .setFailOnError(failOnError) - - Some( - ExprOuterClass.Expr - .newBuilder() - .setListExtract(listExtractBuilder) - .build()) - } else { - withInfo(expr, "unsupported arguments for GetArrayItem", child, ordinal) - None - } - - case ElementAt(child, ordinal, defaultValue, failOnError) - if child.dataType.isInstanceOf[ArrayType] => - val childExpr = exprToProtoInternal(child, inputs, binding) - val ordinalExpr = exprToProtoInternal(ordinal, inputs, binding) - val defaultExpr = defaultValue.flatMap(exprToProtoInternal(_, inputs, binding)) - - if (childExpr.isDefined && ordinalExpr.isDefined && - defaultExpr.isDefined == defaultValue.isDefined) { - val arrayExtractBuilder = ExprOuterClass.ListExtract - .newBuilder() - .setChild(childExpr.get) - .setOrdinal(ordinalExpr.get) - .setOneBased(true) - .setFailOnError(failOnError) - - defaultExpr.foreach(arrayExtractBuilder.setDefaultValue(_)) - - Some( - ExprOuterClass.Expr - .newBuilder() - .setListExtract(arrayExtractBuilder) - .build()) - } else { - withInfo(expr, "unsupported arguments for ElementAt", child, ordinal) - None - } - - case GetArrayStructFields(child, _, ordinal, _, _) => - val childExpr = exprToProtoInternal(child, inputs, binding) - - if (childExpr.isDefined) { - val arrayStructFieldsBuilder = ExprOuterClass.GetArrayStructFields - .newBuilder() - .setChild(childExpr.get) - .setOrdinal(ordinal) - - Some( - ExprOuterClass.Expr - .newBuilder() - .setGetArrayStructFields(arrayStructFieldsBuilder) - .build()) - } else { - withInfo(expr, "unsupported arguments for GetArrayStructFields", child) - None - } - case _ @ArrayFilter(_, func) if func.children.head.isInstanceOf[IsNotNull] => - convert(CometArrayCompact) - case _: ArrayExcept => - convert(CometArrayExcept) - case Rand(child, _) => - createUnaryExpr( - expr, - child, - inputs, - binding, - (builder, unaryExpr) => builder.setRand(unaryExpr)) - case expr => - QueryPlanSerde.exprSerdeMap.get(expr.getClass) match { - case Some(handler) => convert(handler) - case _ => - withInfo(expr, s"${expr.prettyName} is not supported", expr.children: _*) - None - } } } diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index 4dfc590455..fb9fc868c5 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -21,7 +21,7 @@ package org.apache.comet.serde import scala.annotation.tailrec -import org.apache.spark.sql.catalyst.expressions.{ArrayExcept, ArrayJoin, ArrayRemove, Attribute, Expression, Literal} +import org.apache.spark.sql.catalyst.expressions.{ArrayExcept, ArrayJoin, ArrayRemove, Attribute, ElementAt, Expression, GetArrayItem, GetArrayStructFields, Literal} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -378,3 +378,101 @@ object CometCreateArray extends CometExpressionSerde { } } } + +object CometGetArrayItem extends CometExpressionSerde with IncompatExpr { + override def convert( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val getArrayItem = expr.asInstanceOf[GetArrayItem] + val child = getArrayItem.child + val ordinal = getArrayItem.ordinal + val failOnError = getArrayItem.failOnError + val childExpr = exprToProtoInternal(child, inputs, binding) + val ordinalExpr = exprToProtoInternal(ordinal, inputs, binding) + + if (childExpr.isDefined && ordinalExpr.isDefined) { + val listExtractBuilder = ExprOuterClass.ListExtract + .newBuilder() + .setChild(childExpr.get) + .setOrdinal(ordinalExpr.get) + .setOneBased(false) + .setFailOnError(failOnError) + + Some( + ExprOuterClass.Expr + .newBuilder() + .setListExtract(listExtractBuilder) + .build()) + } else { + withInfo(expr, "unsupported arguments for GetArrayItem", child, ordinal) + None + } + } +} + +object CometElementAt extends CometExpressionSerde with IncompatExpr { + override def convert( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val elementAt = expr.asInstanceOf[ElementAt] + val child = elementAt.left + val ordinal = elementAt.right + val defaultValue = elementAt.defaultValueOutOfBound + val failOnError = elementAt.failOnError + + val childExpr = exprToProtoInternal(child, inputs, binding) + val ordinalExpr = exprToProtoInternal(ordinal, inputs, binding) + val defaultExpr = defaultValue.flatMap(exprToProtoInternal(_, inputs, binding)) + + if (childExpr.isDefined && ordinalExpr.isDefined && + defaultExpr.isDefined == defaultValue.isDefined) { + val arrayExtractBuilder = ExprOuterClass.ListExtract + .newBuilder() + .setChild(childExpr.get) + .setOrdinal(ordinalExpr.get) + .setOneBased(true) + .setFailOnError(failOnError) + + defaultExpr.foreach(arrayExtractBuilder.setDefaultValue(_)) + + Some( + ExprOuterClass.Expr + .newBuilder() + .setListExtract(arrayExtractBuilder) + .build()) + } else { + withInfo(expr, "unsupported arguments for ElementAt", child, ordinal) + None + } + } +} + +object CometGetArrayStructFields extends CometExpressionSerde with IncompatExpr { + override def convert( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val getArrayStructFields = expr.asInstanceOf[GetArrayStructFields] + val child = getArrayStructFields.child + val ordinal = getArrayStructFields.ordinal + val childExpr = exprToProtoInternal(child, inputs, binding) + + if (childExpr.isDefined) { + val arrayStructFieldsBuilder = ExprOuterClass.GetArrayStructFields + .newBuilder() + .setChild(childExpr.get) + .setOrdinal(ordinal) + + Some( + ExprOuterClass.Expr + .newBuilder() + .setGetArrayStructFields(arrayStructFieldsBuilder) + .build()) + } else { + withInfo(expr, "unsupported arguments for GetArrayStructFields", child) + None + } + } +} From 091732cf295c0b469b2802b3a5f988051fbbf777 Mon Sep 17 00:00:00 2001 From: Peter Nguyen Date: Sat, 12 Jul 2025 11:25:36 -0700 Subject: [PATCH 2/2] Add back cases improperly removed --- .../apache/comet/serde/QueryPlanSerde.scala | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 2b6551f495..167da836b5 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -1793,6 +1793,24 @@ object QueryPlanSerde extends Logging with CometExprShim { .setGetStructField(getStructFieldBuilder) .build() } + case _ @ArrayFilter(_, func) if func.children.head.isInstanceOf[IsNotNull] => + convert(CometArrayCompact) + case _: ArrayExcept => + convert(CometArrayExcept) + case Rand(child, _) => + createUnaryExpr( + expr, + child, + inputs, + binding, + (builder, unaryExpr) => builder.setRand(unaryExpr)) + case expr => + QueryPlanSerde.exprSerdeMap.get(expr.getClass) match { + case Some(handler) => convert(handler) + case _ => + withInfo(expr, s"${expr.prettyName} is not supported", expr.children: _*) + None + } } }