Skip to content

chore: Refactor GetArrayItem, ElementAt, GetArrayStructFields out of QueryPlanSerde #2026

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 4 additions & 69 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,10 @@ object QueryPlanSerde extends Logging with CometExprShim {
classOf[MapKeys] -> CometMapKeys,
classOf[MapValues] -> CometMapValues,
classOf[MapFromArrays] -> CometMapFromArrays,
classOf[GetMapValue] -> CometMapExtract)
classOf[GetMapValue] -> CometMapExtract,
classOf[GetArrayItem] -> CometGetArrayItem,
classOf[ElementAt] -> CometElementAt,
classOf[GetArrayStructFields] -> CometGetArrayStructFields)

def emitWarning(reason: String): Unit = {
logWarning(s"Comet native execution is disabled due to: $reason")
Expand Down Expand Up @@ -1790,74 +1793,6 @@ object QueryPlanSerde extends Logging with CometExprShim {
.setGetStructField(getStructFieldBuilder)
.build()
}

case GetArrayItem(child, ordinal, failOnError) =>
val childExpr = exprToProtoInternal(child, inputs, binding)
val ordinalExpr = exprToProtoInternal(ordinal, inputs, binding)

if (childExpr.isDefined && ordinalExpr.isDefined) {
val listExtractBuilder = ExprOuterClass.ListExtract
.newBuilder()
.setChild(childExpr.get)
.setOrdinal(ordinalExpr.get)
.setOneBased(false)
.setFailOnError(failOnError)

Some(
ExprOuterClass.Expr
.newBuilder()
.setListExtract(listExtractBuilder)
.build())
} else {
withInfo(expr, "unsupported arguments for GetArrayItem", child, ordinal)
None
}

case ElementAt(child, ordinal, defaultValue, failOnError)
if child.dataType.isInstanceOf[ArrayType] =>
val childExpr = exprToProtoInternal(child, inputs, binding)
val ordinalExpr = exprToProtoInternal(ordinal, inputs, binding)
val defaultExpr = defaultValue.flatMap(exprToProtoInternal(_, inputs, binding))

if (childExpr.isDefined && ordinalExpr.isDefined &&
defaultExpr.isDefined == defaultValue.isDefined) {
val arrayExtractBuilder = ExprOuterClass.ListExtract
.newBuilder()
.setChild(childExpr.get)
.setOrdinal(ordinalExpr.get)
.setOneBased(true)
.setFailOnError(failOnError)

defaultExpr.foreach(arrayExtractBuilder.setDefaultValue(_))

Some(
ExprOuterClass.Expr
.newBuilder()
.setListExtract(arrayExtractBuilder)
.build())
} else {
withInfo(expr, "unsupported arguments for ElementAt", child, ordinal)
None
}

case GetArrayStructFields(child, _, ordinal, _, _) =>
val childExpr = exprToProtoInternal(child, inputs, binding)

if (childExpr.isDefined) {
val arrayStructFieldsBuilder = ExprOuterClass.GetArrayStructFields
.newBuilder()
.setChild(childExpr.get)
.setOrdinal(ordinal)

Some(
ExprOuterClass.Expr
.newBuilder()
.setGetArrayStructFields(arrayStructFieldsBuilder)
.build())
} else {
withInfo(expr, "unsupported arguments for GetArrayStructFields", child)
None
}
case _ @ArrayFilter(_, func) if func.children.head.isInstanceOf[IsNotNull] =>
convert(CometArrayCompact)
case _: ArrayExcept =>
Expand Down
100 changes: 99 additions & 1 deletion spark/src/main/scala/org/apache/comet/serde/arrays.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ package org.apache.comet.serde

import scala.annotation.tailrec

import org.apache.spark.sql.catalyst.expressions.{ArrayExcept, ArrayJoin, ArrayRemove, Attribute, Expression, Literal}
import org.apache.spark.sql.catalyst.expressions.{ArrayExcept, ArrayJoin, ArrayRemove, Attribute, ElementAt, Expression, GetArrayItem, GetArrayStructFields, Literal}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -378,3 +378,101 @@ object CometCreateArray extends CometExpressionSerde {
}
}
}

object CometGetArrayItem extends CometExpressionSerde with IncompatExpr {
override def convert(
expr: Expression,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
val getArrayItem = expr.asInstanceOf[GetArrayItem]
val child = getArrayItem.child
val ordinal = getArrayItem.ordinal
val failOnError = getArrayItem.failOnError
val childExpr = exprToProtoInternal(child, inputs, binding)
val ordinalExpr = exprToProtoInternal(ordinal, inputs, binding)

if (childExpr.isDefined && ordinalExpr.isDefined) {
val listExtractBuilder = ExprOuterClass.ListExtract
.newBuilder()
.setChild(childExpr.get)
.setOrdinal(ordinalExpr.get)
.setOneBased(false)
.setFailOnError(failOnError)

Some(
ExprOuterClass.Expr
.newBuilder()
.setListExtract(listExtractBuilder)
.build())
} else {
withInfo(expr, "unsupported arguments for GetArrayItem", child, ordinal)
None
}
}
}

object CometElementAt extends CometExpressionSerde with IncompatExpr {
override def convert(
expr: Expression,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
val elementAt = expr.asInstanceOf[ElementAt]
val child = elementAt.left
val ordinal = elementAt.right
val defaultValue = elementAt.defaultValueOutOfBound
val failOnError = elementAt.failOnError

val childExpr = exprToProtoInternal(child, inputs, binding)
val ordinalExpr = exprToProtoInternal(ordinal, inputs, binding)
val defaultExpr = defaultValue.flatMap(exprToProtoInternal(_, inputs, binding))

if (childExpr.isDefined && ordinalExpr.isDefined &&
defaultExpr.isDefined == defaultValue.isDefined) {
val arrayExtractBuilder = ExprOuterClass.ListExtract
.newBuilder()
.setChild(childExpr.get)
.setOrdinal(ordinalExpr.get)
.setOneBased(true)
.setFailOnError(failOnError)

defaultExpr.foreach(arrayExtractBuilder.setDefaultValue(_))

Some(
ExprOuterClass.Expr
.newBuilder()
.setListExtract(arrayExtractBuilder)
.build())
} else {
withInfo(expr, "unsupported arguments for ElementAt", child, ordinal)
None
}
}
}

object CometGetArrayStructFields extends CometExpressionSerde with IncompatExpr {
override def convert(
expr: Expression,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
val getArrayStructFields = expr.asInstanceOf[GetArrayStructFields]
val child = getArrayStructFields.child
val ordinal = getArrayStructFields.ordinal
val childExpr = exprToProtoInternal(child, inputs, binding)

if (childExpr.isDefined) {
val arrayStructFieldsBuilder = ExprOuterClass.GetArrayStructFields
.newBuilder()
.setChild(childExpr.get)
.setOrdinal(ordinal)

Some(
ExprOuterClass.Expr
.newBuilder()
.setGetArrayStructFields(arrayStructFieldsBuilder)
.build())
} else {
withInfo(expr, "unsupported arguments for GetArrayStructFields", child)
None
}
}
}
Loading