diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 327e57bcfd..c2091346d7 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -229,45 +229,78 @@ impl PhysicalPlanner { input_schema: SchemaRef, ) -> Result, ExecutionError> { match spark_expr.expr_struct.as_ref().unwrap() { - ExprStruct::Add(expr) => self.create_binary_expr( - expr.left.as_ref().unwrap(), - expr.right.as_ref().unwrap(), - expr.return_type.as_ref(), - DataFusionOperator::Plus, - input_schema, - ), - ExprStruct::Subtract(expr) => self.create_binary_expr( - expr.left.as_ref().unwrap(), - expr.right.as_ref().unwrap(), - expr.return_type.as_ref(), - DataFusionOperator::Minus, - input_schema, - ), - ExprStruct::Multiply(expr) => self.create_binary_expr( - expr.left.as_ref().unwrap(), - expr.right.as_ref().unwrap(), - expr.return_type.as_ref(), - DataFusionOperator::Multiply, - input_schema, - ), - ExprStruct::Divide(expr) => self.create_binary_expr( - expr.left.as_ref().unwrap(), - expr.right.as_ref().unwrap(), - expr.return_type.as_ref(), - DataFusionOperator::Divide, - input_schema, - ), - ExprStruct::IntegralDivide(expr) => self.create_binary_expr_with_options( - expr.left.as_ref().unwrap(), - expr.right.as_ref().unwrap(), - expr.return_type.as_ref(), - DataFusionOperator::Divide, - input_schema, - BinaryExprOptions { - is_integral_div: true, - }, - ), + ExprStruct::Add(expr) => { + // TODO respect eval mode + // https://github.com/apache/datafusion-comet/issues/2021 + // https://github.com/apache/datafusion-comet/issues/536 + let _eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; + self.create_binary_expr( + expr.left.as_ref().unwrap(), + expr.right.as_ref().unwrap(), + expr.return_type.as_ref(), + DataFusionOperator::Plus, + input_schema, + ) + } + ExprStruct::Subtract(expr) => { + // TODO respect eval mode + // https://github.com/apache/datafusion-comet/issues/2021 + // https://github.com/apache/datafusion-comet/issues/535 + let _eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; + self.create_binary_expr( + expr.left.as_ref().unwrap(), + expr.right.as_ref().unwrap(), + expr.return_type.as_ref(), + DataFusionOperator::Minus, + input_schema, + ) + } + ExprStruct::Multiply(expr) => { + // TODO respect eval mode + // https://github.com/apache/datafusion-comet/issues/2021 + // https://github.com/apache/datafusion-comet/issues/534 + let _eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; + self.create_binary_expr( + expr.left.as_ref().unwrap(), + expr.right.as_ref().unwrap(), + expr.return_type.as_ref(), + DataFusionOperator::Multiply, + input_schema, + ) + } + ExprStruct::Divide(expr) => { + // TODO respect eval mode + // https://github.com/apache/datafusion-comet/issues/2021 + // https://github.com/apache/datafusion-comet/issues/533 + let _eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; + self.create_binary_expr( + expr.left.as_ref().unwrap(), + expr.right.as_ref().unwrap(), + expr.return_type.as_ref(), + DataFusionOperator::Divide, + input_schema, + ) + } + ExprStruct::IntegralDivide(expr) => { + // TODO respect eval mode + // https://github.com/apache/datafusion-comet/issues/2021 + // https://github.com/apache/datafusion-comet/issues/533 + let _eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; + self.create_binary_expr_with_options( + expr.left.as_ref().unwrap(), + expr.right.as_ref().unwrap(), + expr.return_type.as_ref(), + DataFusionOperator::Divide, + input_schema, + BinaryExprOptions { + is_integral_div: true, + }, + ) + } ExprStruct::Remainder(expr) => { + let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; + // TODO add support for EvalMode::TRY + // https://github.com/apache/datafusion-comet/issues/2021 let left = self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?; let right = @@ -278,7 +311,7 @@ impl PhysicalPlanner { right, expr.return_type.as_ref().map(to_arrow_datatype).unwrap(), input_schema, - expr.fail_on_error, + eval_mode == EvalMode::Ansi, &self.session_ctx.state(), ); result.map_err(|e| GeneralError(e.to_string())) diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 8f4c875eec..3daa304736 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -219,19 +219,19 @@ message Literal { bool is_null = 12; } -message MathExpr { - Expr left = 1; - Expr right = 2; - bool fail_on_error = 3; - DataType return_type = 4; -} - enum EvalMode { LEGACY = 0; TRY = 1; ANSI = 2; } +message MathExpr { + Expr left = 1; + Expr right = 2; + DataType return_type = 4; + EvalMode eval_mode = 5; +} + message Cast { Expr child = 1; DataType datatype = 2; diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 970329b28f..a97df41a24 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -23,7 +23,6 @@ import java.util.Locale import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer -import scala.math.min import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ @@ -67,6 +66,12 @@ object QueryPlanSerde extends Logging with CometExprShim { * Mapping of Spark expression class to Comet expression handler. */ private val exprSerdeMap: Map[Class[_], CometExpressionSerde] = Map( + classOf[Add] -> CometAdd, + classOf[Subtract] -> CometSubtract, + classOf[Multiply] -> CometMultiply, + classOf[Divide] -> CometDivide, + classOf[IntegralDivide] -> CometIntegralDivide, + classOf[Remainder] -> CometRemainder, classOf[ArrayAppend] -> CometArrayAppend, classOf[ArrayContains] -> CometArrayContains, classOf[ArrayDistinct] -> CometArrayDistinct, @@ -630,141 +635,6 @@ object QueryPlanSerde extends Logging with CometExprShim { case c @ Cast(child, dt, timeZoneId, _) => handleCast(expr, child, inputs, binding, dt, timeZoneId, evalMode(c)) - case add @ Add(left, right, _) if supportedDataType(left.dataType) => - createMathExpression( - expr, - left, - right, - inputs, - binding, - add.dataType, - add.evalMode == EvalMode.ANSI, - (builder, mathExpr) => builder.setAdd(mathExpr)) - - case add @ Add(left, _, _) if !supportedDataType(left.dataType) => - withInfo(add, s"Unsupported datatype ${left.dataType}") - None - - case sub @ Subtract(left, right, _) if supportedDataType(left.dataType) => - createMathExpression( - expr, - left, - right, - inputs, - binding, - sub.dataType, - sub.evalMode == EvalMode.ANSI, - (builder, mathExpr) => builder.setSubtract(mathExpr)) - - case sub @ Subtract(left, _, _) if !supportedDataType(left.dataType) => - withInfo(sub, s"Unsupported datatype ${left.dataType}") - None - - case mul @ Multiply(left, right, _) if supportedDataType(left.dataType) => - createMathExpression( - expr, - left, - right, - inputs, - binding, - mul.dataType, - mul.evalMode == EvalMode.ANSI, - (builder, mathExpr) => builder.setMultiply(mathExpr)) - - case mul @ Multiply(left, _, _) => - if (!supportedDataType(left.dataType)) { - withInfo(mul, s"Unsupported datatype ${left.dataType}") - } - None - - case div @ Divide(left, right, _) if supportedDataType(left.dataType) => - // Datafusion now throws an exception for dividing by zero - // See https://github.com/apache/arrow-datafusion/pull/6792 - // For now, use NullIf to swap zeros with nulls. - val rightExpr = nullIfWhenPrimitive(right) - - createMathExpression( - expr, - left, - rightExpr, - inputs, - binding, - div.dataType, - div.evalMode == EvalMode.ANSI, - (builder, mathExpr) => builder.setDivide(mathExpr)) - - case div @ Divide(left, _, _) => - if (!supportedDataType(left.dataType)) { - withInfo(div, s"Unsupported datatype ${left.dataType}") - } - None - - case div @ IntegralDivide(left, right, _) if supportedDataType(left.dataType) => - val rightExpr = nullIfWhenPrimitive(right) - - val dataType = (left.dataType, right.dataType) match { - case (l: DecimalType, r: DecimalType) => - // copy from IntegralDivide.resultDecimalType - val intDig = l.precision - l.scale + r.scale - DecimalType(min(if (intDig == 0) 1 else intDig, DecimalType.MAX_PRECISION), 0) - case _ => left.dataType - } - - val divideExpr = createMathExpression( - expr, - left, - rightExpr, - inputs, - binding, - dataType, - div.evalMode == EvalMode.ANSI, - (builder, mathExpr) => builder.setIntegralDivide(mathExpr)) - - if (divideExpr.isDefined) { - val childExpr = if (dataType.isInstanceOf[DecimalType]) { - // check overflow for decimal type - val builder = ExprOuterClass.CheckOverflow.newBuilder() - builder.setChild(divideExpr.get) - builder.setFailOnError(div.evalMode == EvalMode.ANSI) - builder.setDatatype(serializeDataType(dataType).get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setCheckOverflow(builder) - .build()) - } else { - divideExpr - } - - // cast result to long - castToProto(expr, None, LongType, childExpr.get, CometEvalMode.LEGACY) - } else { - None - } - - case div @ IntegralDivide(left, _, _) => - if (!supportedDataType(left.dataType)) { - withInfo(div, s"Unsupported datatype ${left.dataType}") - } - None - - case rem @ Remainder(left, right, _) if supportedDataType(left.dataType) => - createMathExpression( - expr, - left, - right, - inputs, - binding, - rem.dataType, - rem.evalMode == EvalMode.ANSI, - (builder, mathExpr) => builder.setRemainder(mathExpr)) - - case rem @ Remainder(left, _, _) => - if (!supportedDataType(left.dataType)) { - withInfo(rem, s"Unsupported datatype ${left.dataType}") - } - None - case EqualTo(left, right) => createBinaryExpr( expr, @@ -1947,42 +1817,6 @@ object QueryPlanSerde extends Logging with CometExprShim { } } - private def createMathExpression( - expr: Expression, - left: Expression, - right: Expression, - inputs: Seq[Attribute], - binding: Boolean, - dataType: DataType, - failOnError: Boolean, - f: (ExprOuterClass.Expr.Builder, ExprOuterClass.MathExpr) => ExprOuterClass.Expr.Builder) - : Option[ExprOuterClass.Expr] = { - val leftExpr = exprToProtoInternal(left, inputs, binding) - val rightExpr = exprToProtoInternal(right, inputs, binding) - - if (leftExpr.isDefined && rightExpr.isDefined) { - // create the generic MathExpr message - val builder = ExprOuterClass.MathExpr.newBuilder() - builder.setLeft(leftExpr.get) - builder.setRight(rightExpr.get) - builder.setFailOnError(failOnError) - serializeDataType(dataType).foreach { t => - builder.setReturnType(t) - } - val inner = builder.build() - // call the user-supplied function to wrap MathExpr in a top-level Expr - // such as Expr.Add or Expr.Divide - Some( - f( - ExprOuterClass.Expr - .newBuilder(), - inner).build()) - } else { - withInfo(expr, left, right) - None - } - } - def in( expr: Expression, value: Expression, @@ -2038,25 +1872,6 @@ object QueryPlanSerde extends Logging with CometExprShim { Some(ExprOuterClass.Expr.newBuilder().setScalarFunc(builder).build()) } - private def isPrimitive(expression: Expression): Boolean = expression.dataType match { - case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType | - _: DoubleType | _: TimestampType | _: DateType | _: BooleanType | _: DecimalType => - true - case _ => false - } - - private def nullIfWhenPrimitive(expression: Expression): Expression = - if (isPrimitive(expression)) { - val zero = Literal.default(expression.dataType) - expression match { - case _: Literal if expression != zero => expression - case _ => - If(EqualTo(expression, zero), Literal.create(null, expression.dataType), expression) - } - } else { - expression - } - private def nullIfNegative(expression: Expression): Expression = { val zero = Literal.default(expression.dataType) If(LessThanOrEqual(expression, zero), Literal.create(null, expression.dataType), expression) diff --git a/spark/src/main/scala/org/apache/comet/serde/arithmetic.scala b/spark/src/main/scala/org/apache/comet/serde/arithmetic.scala new file mode 100644 index 0000000000..3a7a9f8fb5 --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/serde/arithmetic.scala @@ -0,0 +1,282 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import scala.math.min + +import org.apache.spark.sql.catalyst.expressions.{Add, Attribute, Divide, EqualTo, EvalMode, Expression, If, IntegralDivide, Literal, Multiply, Remainder, Subtract} +import org.apache.spark.sql.types.{ByteType, DataType, DecimalType, DoubleType, FloatType, IntegerType, LongType, ShortType} + +import org.apache.comet.CometSparkSessionExtensions.withInfo +import org.apache.comet.expressions.CometEvalMode +import org.apache.comet.serde.QueryPlanSerde.{castToProto, evalModeToProto, exprToProtoInternal, serializeDataType} +import org.apache.comet.shims.CometEvalModeUtil + +trait MathBase { + def createMathExpression( + expr: Expression, + left: Expression, + right: Expression, + inputs: Seq[Attribute], + binding: Boolean, + dataType: DataType, + evalMode: EvalMode.Value, + f: (ExprOuterClass.Expr.Builder, ExprOuterClass.MathExpr) => ExprOuterClass.Expr.Builder) + : Option[ExprOuterClass.Expr] = { + val leftExpr = exprToProtoInternal(left, inputs, binding) + val rightExpr = exprToProtoInternal(right, inputs, binding) + + if (leftExpr.isDefined && rightExpr.isDefined) { + // create the generic MathExpr message + val builder = ExprOuterClass.MathExpr.newBuilder() + builder.setLeft(leftExpr.get) + builder.setRight(rightExpr.get) + builder.setEvalMode(evalModeToProto(CometEvalModeUtil.fromSparkEvalMode(evalMode))) + serializeDataType(dataType).foreach { t => + builder.setReturnType(t) + } + val inner = builder.build() + // call the user-supplied function to wrap MathExpr in a top-level Expr + // such as Expr.Add or Expr.Divide + Some( + f( + ExprOuterClass.Expr + .newBuilder(), + inner).build()) + } else { + withInfo(expr, left, right) + None + } + } + + def nullIfWhenPrimitive(expression: Expression): Expression = { + val zero = Literal.default(expression.dataType) + expression match { + case _: Literal if expression != zero => expression + case _ => + If(EqualTo(expression, zero), Literal.create(null, expression.dataType), expression) + } + } + + def supportedDataType(dt: DataType): Boolean = dt match { + case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType | + _: DoubleType | _: DecimalType => + true + case _ => + false + } + +} + +object CometAdd extends CometExpressionSerde with MathBase { + override def convert( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val add = expr.asInstanceOf[Add] + if (!supportedDataType(add.left.dataType)) { + withInfo(add, s"Unsupported datatype ${add.left.dataType}") + return None + } + if (add.evalMode == EvalMode.TRY) { + withInfo(add, s"Eval mode ${add.evalMode} is not supported") + return None + } + createMathExpression( + expr, + add.left, + add.right, + inputs, + binding, + add.dataType, + add.evalMode, + (builder, mathExpr) => builder.setAdd(mathExpr)) + } +} + +object CometSubtract extends CometExpressionSerde with MathBase { + override def convert( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val sub = expr.asInstanceOf[Subtract] + if (!supportedDataType(sub.left.dataType)) { + withInfo(sub, s"Unsupported datatype ${sub.left.dataType}") + return None + } + if (sub.evalMode == EvalMode.TRY) { + withInfo(sub, s"Eval mode ${sub.evalMode} is not supported") + return None + } + createMathExpression( + expr, + sub.left, + sub.right, + inputs, + binding, + sub.dataType, + sub.evalMode, + (builder, mathExpr) => builder.setSubtract(mathExpr)) + } +} + +object CometMultiply extends CometExpressionSerde with MathBase { + override def convert( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val mul = expr.asInstanceOf[Multiply] + if (!supportedDataType(mul.left.dataType)) { + withInfo(mul, s"Unsupported datatype ${mul.left.dataType}") + return None + } + if (mul.evalMode == EvalMode.TRY) { + withInfo(mul, s"Eval mode ${mul.evalMode} is not supported") + return None + } + createMathExpression( + expr, + mul.left, + mul.right, + inputs, + binding, + mul.dataType, + mul.evalMode, + (builder, mathExpr) => builder.setMultiply(mathExpr)) + } +} + +object CometDivide extends CometExpressionSerde with MathBase { + override def convert( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val div = expr.asInstanceOf[Divide] + + // Datafusion now throws an exception for dividing by zero + // See https://github.com/apache/arrow-datafusion/pull/6792 + // For now, use NullIf to swap zeros with nulls. + val rightExpr = nullIfWhenPrimitive(div.right) + + if (!supportedDataType(div.left.dataType)) { + withInfo(div, s"Unsupported datatype ${div.left.dataType}") + return None + } + if (div.evalMode == EvalMode.TRY) { + withInfo(div, s"Eval mode ${div.evalMode} is not supported") + return None + } + createMathExpression( + expr, + div.left, + rightExpr, + inputs, + binding, + div.dataType, + div.evalMode, + (builder, mathExpr) => builder.setDivide(mathExpr)) + } +} + +object CometIntegralDivide extends CometExpressionSerde with MathBase { + override def convert( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val div = expr.asInstanceOf[IntegralDivide] + val rightExpr = nullIfWhenPrimitive(div.right) + + if (!supportedDataType(div.left.dataType)) { + withInfo(div, s"Unsupported datatype ${div.left.dataType}") + return None + } + if (div.evalMode == EvalMode.TRY) { + withInfo(div, s"Eval mode ${div.evalMode} is not supported") + return None + } + + val dataType = (div.left.dataType, div.right.dataType) match { + case (l: DecimalType, r: DecimalType) => + // copy from IntegralDivide.resultDecimalType + val intDig = l.precision - l.scale + r.scale + DecimalType(min(if (intDig == 0) 1 else intDig, DecimalType.MAX_PRECISION), 0) + case _ => div.left.dataType + } + + val divideExpr = createMathExpression( + expr, + div.left, + rightExpr, + inputs, + binding, + dataType, + div.evalMode, + (builder, mathExpr) => builder.setIntegralDivide(mathExpr)) + + if (divideExpr.isDefined) { + val childExpr = if (dataType.isInstanceOf[DecimalType]) { + // check overflow for decimal type + val builder = ExprOuterClass.CheckOverflow.newBuilder() + builder.setChild(divideExpr.get) + builder.setFailOnError(div.evalMode == EvalMode.ANSI) + builder.setDatatype(serializeDataType(dataType).get) + Some( + ExprOuterClass.Expr + .newBuilder() + .setCheckOverflow(builder) + .build()) + } else { + divideExpr + } + + // cast result to long + castToProto(expr, None, LongType, childExpr.get, CometEvalMode.LEGACY) + } else { + None + } + } +} + +object CometRemainder extends CometExpressionSerde with MathBase { + override def convert( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val remainder = expr.asInstanceOf[Remainder] + if (!supportedDataType(remainder.left.dataType)) { + withInfo(remainder, s"Unsupported datatype ${remainder.left.dataType}") + return None + } + if (remainder.evalMode == EvalMode.TRY) { + withInfo(remainder, s"Eval mode ${remainder.evalMode} is not supported") + return None + } + + createMathExpression( + expr, + remainder.left, + remainder.right, + inputs, + binding, + remainder.dataType, + remainder.evalMode, + (builder, mathExpr) => builder.setRemainder(mathExpr)) + } +} diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 6f29d48878..0b30575be1 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -302,6 +302,15 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("try_add") { + // TODO: we need to implement more comprehensive tests for all try_ arithmetic functions + // https://github.com/apache/datafusion-comet/issues/2021 + val data = Seq((Integer.MAX_VALUE, 1)) + withParquetTable(data, "tbl") { + checkSparkAnswer("SELECT try_add(_1, _2) FROM tbl") + } + } + test("dictionary arithmetic") { // TODO: test ANSI mode withSQLConf(SQLConf.ANSI_ENABLED.key -> "false", "parquet.enable.dictionary" -> "true") {