From 202556d1690de6e47d4bf93bd8ea6286e65d440b Mon Sep 17 00:00:00 2001 From: Yu-Chuan Hung Date: Wed, 16 Jul 2025 00:20:33 +0800 Subject: [PATCH 1/2] chore: refactor GreaterThan out of QueryPlanSerde - Create comparisons.scala following the pattern from math/array expressions. - Implements CometGreaterThan as proof of concept for issue #2019. --- .../apache/comet/serde/QueryPlanSerde.scala | 12 +- .../org/apache/comet/serde/comparisons.scala | 116 ++++++++++++++++++ 2 files changed, 118 insertions(+), 10 deletions(-) create mode 100644 spark/src/main/scala/org/apache/comet/serde/comparisons.scala diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 970329b28..702336452 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -120,7 +120,8 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[MapKeys] -> CometMapKeys, classOf[MapValues] -> CometMapValues, classOf[MapFromArrays] -> CometMapFromArrays, - classOf[GetMapValue] -> CometMapExtract) + classOf[GetMapValue] -> CometMapExtract, + classOf[GreaterThan] -> CometGreaterThan) def emitWarning(reason: String): Unit = { logWarning(s"Comet native execution is disabled due to: $reason") @@ -801,15 +802,6 @@ object QueryPlanSerde extends Logging with CometExprShim { binding, (builder, binaryExpr) => builder.setNeqNullSafe(binaryExpr)) - case GreaterThan(left, right) => - createBinaryExpr( - expr, - left, - right, - inputs, - binding, - (builder, binaryExpr) => builder.setGt(binaryExpr)) - case GreaterThanOrEqual(left, right) => createBinaryExpr( expr, diff --git a/spark/src/main/scala/org/apache/comet/serde/comparisons.scala b/spark/src/main/scala/org/apache/comet/serde/comparisons.scala new file mode 100644 index 000000000..018b013cd --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/serde/comparisons.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde +import org.apache.comet.CometSparkSessionExtensions.withInfo +import org.apache.comet.serde.QueryPlanSerde.exprToProtoInternal +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, GreaterThan} + +import scala.collection.JavaConverters._ + +object CometGreaterThan extends CometExpressionSerde with ComparisonBase { + override def convert( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val greaterThan = expr.asInstanceOf[GreaterThan] + + createBinaryExpr( + expr, + greaterThan.left, + greaterThan.right, + inputs, + binding, + (builder, binaryExpr) => builder.setGt(binaryExpr)) + } +} + +sealed trait ComparisonBase { + + /** + * Creates a UnaryExpr by calling exprToProtoInternal for the provided child expression and then + * invokes the supplied function to wrap this UnaryExpr in a top-level Expr. + * + * @param child + * Spark expression + * @param inputs + * Inputs to the expression + * @param f + * Function that accepts an Expr.Builder and a UnaryExpr and builds the specific top-level + * Expr + * @return + * Some(Expr) or None if not supported + */ + def createUnaryExpr( + expr: Expression, + child: Expression, + inputs: Seq[Attribute], + binding: Boolean, + f: (ExprOuterClass.Expr.Builder, ExprOuterClass.UnaryExpr) => ExprOuterClass.Expr.Builder) + : Option[ExprOuterClass.Expr] = { + val childExpr = exprToProtoInternal(child, inputs, binding) // TODO review + if (childExpr.isDefined) { + // create the generic UnaryExpr message + val inner = ExprOuterClass.UnaryExpr + .newBuilder() + .setChild(childExpr.get) + .build() + // call the user-supplied function to wrap UnaryExpr in a top-level Expr + // such as Expr.IsNull or Expr.IsNotNull + Some( + f( + ExprOuterClass.Expr + .newBuilder(), + inner).build()) + } else { + withInfo(expr, child) + None + } + } + + def createBinaryExpr( + expr: Expression, + left: Expression, + right: Expression, + inputs: Seq[Attribute], + binding: Boolean, + f: (ExprOuterClass.Expr.Builder, ExprOuterClass.BinaryExpr) => ExprOuterClass.Expr.Builder) + : Option[ExprOuterClass.Expr] = { + val leftExpr = exprToProtoInternal(left, inputs, binding) + val rightExpr = exprToProtoInternal(right, inputs, binding) + if (leftExpr.isDefined && rightExpr.isDefined) { + // create the generic BinaryExpr message + val inner = ExprOuterClass.BinaryExpr + .newBuilder() + .setLeft(leftExpr.get) + .setRight(rightExpr.get) + .build() + // call the user-supplied function to wrap BinaryExpr in a top-level Expr + // such as Expr.And or Expr.Or + Some( + f( + ExprOuterClass.Expr + .newBuilder(), + inner).build()) + } else { + withInfo(expr, left, right) + None + } + } +} From 2606d4fc20bb4333b88286361e24fc3a823db707 Mon Sep 17 00:00:00 2001 From: Yu-Chuan Hung Date: Thu, 17 Jul 2025 22:42:00 +0800 Subject: [PATCH 2/2] Refactor: extract comparison expressions from QueryPlanSerde - Add ComparisonBase trait with reusable helper methods - Move GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual to CometComparison* - Move IsNull, IsNotNull, IsNaN to CometIs* classes - Move In expression to CometIn class - Update exprSerdeMap to use new comparison classes - Remove corresponding case statements from exprToProtoInternal --- .../apache/comet/serde/QueryPlanSerde.scala | 62 +------ .../org/apache/comet/serde/comparisons.scala | 155 +++++++++++++++++- 2 files changed, 160 insertions(+), 57 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 702336452..ed3aa8f20 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -121,7 +121,14 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[MapValues] -> CometMapValues, classOf[MapFromArrays] -> CometMapFromArrays, classOf[GetMapValue] -> CometMapExtract, - classOf[GreaterThan] -> CometGreaterThan) + classOf[GreaterThan] -> CometGreaterThan, + classOf[GreaterThanOrEqual] -> CometGreaterThanOrEqual, + classOf[LessThan] -> CometLessThan, + classOf[LessThanOrEqual] -> CometLessThanOrEqual, + classOf[IsNull] -> CometIsNull, + classOf[IsNotNull] -> CometIsNotNull, + classOf[IsNaN] -> CometIsNaN, + classOf[In] -> CometIn) def emitWarning(reason: String): Unit = { logWarning(s"Comet native execution is disabled due to: $reason") @@ -802,33 +809,6 @@ object QueryPlanSerde extends Logging with CometExprShim { binding, (builder, binaryExpr) => builder.setNeqNullSafe(binaryExpr)) - case GreaterThanOrEqual(left, right) => - createBinaryExpr( - expr, - left, - right, - inputs, - binding, - (builder, binaryExpr) => builder.setGtEq(binaryExpr)) - - case LessThan(left, right) => - createBinaryExpr( - expr, - left, - right, - inputs, - binding, - (builder, binaryExpr) => builder.setLt(binaryExpr)) - - case LessThanOrEqual(left, right) => - createBinaryExpr( - expr, - left, - right, - inputs, - binding, - (builder, binaryExpr) => builder.setLtEq(binaryExpr)) - case Literal(value, dataType) if supportedDataType(dataType, allowComplex = value == null) => val exprBuilder = ExprOuterClass.Literal.newBuilder() @@ -1183,29 +1163,6 @@ object QueryPlanSerde extends Logging with CometExprShim { }) optExprWithInfo(optExpr, expr, child) - case IsNull(child) => - createUnaryExpr( - expr, - child, - inputs, - binding, - (builder, unaryExpr) => builder.setIsNull(unaryExpr)) - - case IsNotNull(child) => - createUnaryExpr( - expr, - child, - inputs, - binding, - (builder, unaryExpr) => builder.setIsNotNull(unaryExpr)) - - case IsNaN(child) => - val childExpr = exprToProtoInternal(child, inputs, binding) - val optExpr = - scalarFunctionExprToProtoWithReturnType("isnan", BooleanType, childExpr) - - optExprWithInfo(optExpr, expr, child) - case SortOrder(child, direction, nullOrdering, _) => val childExpr = exprToProtoInternal(child, inputs, binding) @@ -1575,9 +1532,6 @@ object QueryPlanSerde extends Logging with CometExprShim { binding, (builder, binaryExpr) => builder.setBitwiseAnd(binaryExpr)) - case In(value, list) => - in(expr, value, list, inputs, binding, negate = false) - case InSet(value, hset) => val valueDataType = value.dataType val list = hset.map { setVal => diff --git a/spark/src/main/scala/org/apache/comet/serde/comparisons.scala b/spark/src/main/scala/org/apache/comet/serde/comparisons.scala index 018b013cd..650e39e47 100644 --- a/spark/src/main/scala/org/apache/comet/serde/comparisons.scala +++ b/spark/src/main/scala/org/apache/comet/serde/comparisons.scala @@ -18,12 +18,16 @@ */ package org.apache.comet.serde -import org.apache.comet.CometSparkSessionExtensions.withInfo -import org.apache.comet.serde.QueryPlanSerde.exprToProtoInternal -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, GreaterThan} import scala.collection.JavaConverters._ +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, GreaterThan, GreaterThanOrEqual, In, IsNaN, IsNotNull, IsNull, LessThan, LessThanOrEqual} +import org.apache.spark.sql.types.BooleanType + +import org.apache.comet.CometSparkSessionExtensions.withInfo +import org.apache.comet.serde.ExprOuterClass.Expr +import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, scalarFunctionExprToProtoWithReturnType} + object CometGreaterThan extends CometExpressionSerde with ComparisonBase { override def convert( expr: Expression, @@ -41,6 +45,112 @@ object CometGreaterThan extends CometExpressionSerde with ComparisonBase { } } +object CometGreaterThanOrEqual extends CometExpressionSerde with ComparisonBase { + override def convert( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val greaterThanOrEqual = expr.asInstanceOf[GreaterThanOrEqual] + + createBinaryExpr( + expr, + greaterThanOrEqual.left, + greaterThanOrEqual.right, + inputs, + binding, + (builder, binaryExpr) => builder.setGtEq(binaryExpr)) + } +} + +object CometLessThan extends CometExpressionSerde with ComparisonBase { + override def convert( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val lessThan = expr.asInstanceOf[LessThan] + + createBinaryExpr( + expr, + lessThan.left, + lessThan.right, + inputs, + binding, + (builder, binaryExpr) => builder.setLt(binaryExpr)) + } +} + +object CometLessThanOrEqual extends CometExpressionSerde with ComparisonBase { + override def convert( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val lessThanOrEqual = expr.asInstanceOf[LessThanOrEqual] + + createBinaryExpr( + expr, + lessThanOrEqual.left, + lessThanOrEqual.right, + inputs, + binding, + (builder, binaryExpr) => builder.setLtEq(binaryExpr)) + } +} + +object CometIsNull extends CometExpressionSerde with ComparisonBase { + override def convert( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val isNull = expr.asInstanceOf[IsNull] + + createUnaryExpr( + expr, + isNull.child, + inputs, + binding, + (builder, unaryExpr) => builder.setIsNull(unaryExpr)) + } +} + +object CometIsNotNull extends CometExpressionSerde with ComparisonBase { + override def convert( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val isNotNull = expr.asInstanceOf[IsNotNull] + + createUnaryExpr( + expr, + isNotNull.child, + inputs, + binding, + (builder, unaryExpr) => builder.setIsNotNull(unaryExpr)) + } +} + +object CometIsNaN extends CometExpressionSerde with ComparisonBase { + override def convert( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val isNaN = expr.asInstanceOf[IsNaN] + val childExpr = exprToProtoInternal(isNaN.child, inputs, binding) + val optExpr = scalarFunctionExprToProtoWithReturnType("isnan", BooleanType, childExpr) + + optExprWithInfo(optExpr, expr, isNaN.child) + } +} + +object CometIn extends CometExpressionSerde with ComparisonBase { + override def convert( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val inExpr = expr.asInstanceOf[In] + in(expr, inExpr.value, inExpr.list, inputs, binding, negate = false) + } +} + sealed trait ComparisonBase { /** @@ -113,4 +223,43 @@ sealed trait ComparisonBase { None } } + + // Utility method. Adds explain info if the result of calling exprToProto is None + def optExprWithInfo( + optExpr: Option[Expr], + expr: Expression, + childExpr: Expression*): Option[Expr] = { + optExpr match { + case None => + withInfo(expr, childExpr: _*) + None + case o => o + } + } + + def in( + expr: Expression, + value: Expression, + list: Seq[Expression], + inputs: Seq[Attribute], + binding: Boolean, + negate: Boolean): Option[Expr] = { + val valueExpr = exprToProtoInternal(value, inputs, binding) + val listExprs = list.map(exprToProtoInternal(_, inputs, binding)) + if (valueExpr.isDefined && listExprs.forall(_.isDefined)) { + val builder = ExprOuterClass.In.newBuilder() + builder.setInValue(valueExpr.get) + builder.addAllLists(listExprs.map(_.get).asJava) + builder.setNegated(negate) + Some( + ExprOuterClass.Expr + .newBuilder() + .setIn(builder) + .build()) + } else { + val allExprs = list ++ Seq(value) + withInfo(expr, allExprs: _*) + None + } + } }