diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 43c03d6b4f..167da836b5 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -121,7 +121,10 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[MapKeys] -> CometMapKeys, classOf[MapValues] -> CometMapValues, classOf[MapFromArrays] -> CometMapFromArrays, - classOf[GetMapValue] -> CometMapExtract) + classOf[GetMapValue] -> CometMapExtract, + classOf[GetArrayItem] -> CometGetArrayItem, + classOf[ElementAt] -> CometElementAt, + classOf[GetArrayStructFields] -> CometGetArrayStructFields) def emitWarning(reason: String): Unit = { logWarning(s"Comet native execution is disabled due to: $reason") @@ -1790,74 +1793,6 @@ object QueryPlanSerde extends Logging with CometExprShim { .setGetStructField(getStructFieldBuilder) .build() } - - case GetArrayItem(child, ordinal, failOnError) => - val childExpr = exprToProtoInternal(child, inputs, binding) - val ordinalExpr = exprToProtoInternal(ordinal, inputs, binding) - - if (childExpr.isDefined && ordinalExpr.isDefined) { - val listExtractBuilder = ExprOuterClass.ListExtract - .newBuilder() - .setChild(childExpr.get) - .setOrdinal(ordinalExpr.get) - .setOneBased(false) - .setFailOnError(failOnError) - - Some( - ExprOuterClass.Expr - .newBuilder() - .setListExtract(listExtractBuilder) - .build()) - } else { - withInfo(expr, "unsupported arguments for GetArrayItem", child, ordinal) - None - } - - case ElementAt(child, ordinal, defaultValue, failOnError) - if child.dataType.isInstanceOf[ArrayType] => - val childExpr = exprToProtoInternal(child, inputs, binding) - val ordinalExpr = exprToProtoInternal(ordinal, inputs, binding) - val defaultExpr = defaultValue.flatMap(exprToProtoInternal(_, inputs, binding)) - - if (childExpr.isDefined && ordinalExpr.isDefined && - defaultExpr.isDefined == defaultValue.isDefined) { - val arrayExtractBuilder = ExprOuterClass.ListExtract - .newBuilder() - .setChild(childExpr.get) - .setOrdinal(ordinalExpr.get) - .setOneBased(true) - .setFailOnError(failOnError) - - defaultExpr.foreach(arrayExtractBuilder.setDefaultValue(_)) - - Some( - ExprOuterClass.Expr - .newBuilder() - .setListExtract(arrayExtractBuilder) - .build()) - } else { - withInfo(expr, "unsupported arguments for ElementAt", child, ordinal) - None - } - - case GetArrayStructFields(child, _, ordinal, _, _) => - val childExpr = exprToProtoInternal(child, inputs, binding) - - if (childExpr.isDefined) { - val arrayStructFieldsBuilder = ExprOuterClass.GetArrayStructFields - .newBuilder() - .setChild(childExpr.get) - .setOrdinal(ordinal) - - Some( - ExprOuterClass.Expr - .newBuilder() - .setGetArrayStructFields(arrayStructFieldsBuilder) - .build()) - } else { - withInfo(expr, "unsupported arguments for GetArrayStructFields", child) - None - } case _ @ArrayFilter(_, func) if func.children.head.isInstanceOf[IsNotNull] => convert(CometArrayCompact) case _: ArrayExcept => diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index 4dfc590455..fb9fc868c5 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -21,7 +21,7 @@ package org.apache.comet.serde import scala.annotation.tailrec -import org.apache.spark.sql.catalyst.expressions.{ArrayExcept, ArrayJoin, ArrayRemove, Attribute, Expression, Literal} +import org.apache.spark.sql.catalyst.expressions.{ArrayExcept, ArrayJoin, ArrayRemove, Attribute, ElementAt, Expression, GetArrayItem, GetArrayStructFields, Literal} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -378,3 +378,101 @@ object CometCreateArray extends CometExpressionSerde { } } } + +object CometGetArrayItem extends CometExpressionSerde with IncompatExpr { + override def convert( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val getArrayItem = expr.asInstanceOf[GetArrayItem] + val child = getArrayItem.child + val ordinal = getArrayItem.ordinal + val failOnError = getArrayItem.failOnError + val childExpr = exprToProtoInternal(child, inputs, binding) + val ordinalExpr = exprToProtoInternal(ordinal, inputs, binding) + + if (childExpr.isDefined && ordinalExpr.isDefined) { + val listExtractBuilder = ExprOuterClass.ListExtract + .newBuilder() + .setChild(childExpr.get) + .setOrdinal(ordinalExpr.get) + .setOneBased(false) + .setFailOnError(failOnError) + + Some( + ExprOuterClass.Expr + .newBuilder() + .setListExtract(listExtractBuilder) + .build()) + } else { + withInfo(expr, "unsupported arguments for GetArrayItem", child, ordinal) + None + } + } +} + +object CometElementAt extends CometExpressionSerde with IncompatExpr { + override def convert( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val elementAt = expr.asInstanceOf[ElementAt] + val child = elementAt.left + val ordinal = elementAt.right + val defaultValue = elementAt.defaultValueOutOfBound + val failOnError = elementAt.failOnError + + val childExpr = exprToProtoInternal(child, inputs, binding) + val ordinalExpr = exprToProtoInternal(ordinal, inputs, binding) + val defaultExpr = defaultValue.flatMap(exprToProtoInternal(_, inputs, binding)) + + if (childExpr.isDefined && ordinalExpr.isDefined && + defaultExpr.isDefined == defaultValue.isDefined) { + val arrayExtractBuilder = ExprOuterClass.ListExtract + .newBuilder() + .setChild(childExpr.get) + .setOrdinal(ordinalExpr.get) + .setOneBased(true) + .setFailOnError(failOnError) + + defaultExpr.foreach(arrayExtractBuilder.setDefaultValue(_)) + + Some( + ExprOuterClass.Expr + .newBuilder() + .setListExtract(arrayExtractBuilder) + .build()) + } else { + withInfo(expr, "unsupported arguments for ElementAt", child, ordinal) + None + } + } +} + +object CometGetArrayStructFields extends CometExpressionSerde with IncompatExpr { + override def convert( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val getArrayStructFields = expr.asInstanceOf[GetArrayStructFields] + val child = getArrayStructFields.child + val ordinal = getArrayStructFields.ordinal + val childExpr = exprToProtoInternal(child, inputs, binding) + + if (childExpr.isDefined) { + val arrayStructFieldsBuilder = ExprOuterClass.GetArrayStructFields + .newBuilder() + .setChild(childExpr.get) + .setOrdinal(ordinal) + + Some( + ExprOuterClass.Expr + .newBuilder() + .setGetArrayStructFields(arrayStructFieldsBuilder) + .build()) + } else { + withInfo(expr, "unsupported arguments for GetArrayStructFields", child) + None + } + } +}