@@ -23,7 +23,6 @@ import java.util.Locale
23
23
24
24
import scala .collection .JavaConverters ._
25
25
import scala .collection .mutable .ListBuffer
26
- import scala .math .min
27
26
28
27
import org .apache .spark .internal .Logging
29
28
import org .apache .spark .sql .catalyst .expressions ._
@@ -67,6 +66,12 @@ object QueryPlanSerde extends Logging with CometExprShim {
67
66
* Mapping of Spark expression class to Comet expression handler.
68
67
*/
69
68
private val exprSerdeMap : Map [Class [_], CometExpressionSerde ] = Map (
69
+ classOf [Add ] -> CometAdd ,
70
+ classOf [Subtract ] -> CometSubtract ,
71
+ classOf [Multiply ] -> CometMultiply ,
72
+ classOf [Divide ] -> CometDivide ,
73
+ classOf [IntegralDivide ] -> CometIntegralDivide ,
74
+ classOf [Remainder ] -> CometRemainder ,
70
75
classOf [ArrayAppend ] -> CometArrayAppend ,
71
76
classOf [ArrayContains ] -> CometArrayContains ,
72
77
classOf [ArrayDistinct ] -> CometArrayDistinct ,
@@ -630,141 +635,6 @@ object QueryPlanSerde extends Logging with CometExprShim {
630
635
case c @ Cast (child, dt, timeZoneId, _) =>
631
636
handleCast(expr, child, inputs, binding, dt, timeZoneId, evalMode(c))
632
637
633
- case add @ Add (left, right, _) if supportedDataType(left.dataType) =>
634
- createMathExpression(
635
- expr,
636
- left,
637
- right,
638
- inputs,
639
- binding,
640
- add.dataType,
641
- add.evalMode == EvalMode .ANSI ,
642
- (builder, mathExpr) => builder.setAdd(mathExpr))
643
-
644
- case add @ Add (left, _, _) if ! supportedDataType(left.dataType) =>
645
- withInfo(add, s " Unsupported datatype ${left.dataType}" )
646
- None
647
-
648
- case sub @ Subtract (left, right, _) if supportedDataType(left.dataType) =>
649
- createMathExpression(
650
- expr,
651
- left,
652
- right,
653
- inputs,
654
- binding,
655
- sub.dataType,
656
- sub.evalMode == EvalMode .ANSI ,
657
- (builder, mathExpr) => builder.setSubtract(mathExpr))
658
-
659
- case sub @ Subtract (left, _, _) if ! supportedDataType(left.dataType) =>
660
- withInfo(sub, s " Unsupported datatype ${left.dataType}" )
661
- None
662
-
663
- case mul @ Multiply (left, right, _) if supportedDataType(left.dataType) =>
664
- createMathExpression(
665
- expr,
666
- left,
667
- right,
668
- inputs,
669
- binding,
670
- mul.dataType,
671
- mul.evalMode == EvalMode .ANSI ,
672
- (builder, mathExpr) => builder.setMultiply(mathExpr))
673
-
674
- case mul @ Multiply (left, _, _) =>
675
- if (! supportedDataType(left.dataType)) {
676
- withInfo(mul, s " Unsupported datatype ${left.dataType}" )
677
- }
678
- None
679
-
680
- case div @ Divide (left, right, _) if supportedDataType(left.dataType) =>
681
- // Datafusion now throws an exception for dividing by zero
682
- // See https://github.com/apache/arrow-datafusion/pull/6792
683
- // For now, use NullIf to swap zeros with nulls.
684
- val rightExpr = nullIfWhenPrimitive(right)
685
-
686
- createMathExpression(
687
- expr,
688
- left,
689
- rightExpr,
690
- inputs,
691
- binding,
692
- div.dataType,
693
- div.evalMode == EvalMode .ANSI ,
694
- (builder, mathExpr) => builder.setDivide(mathExpr))
695
-
696
- case div @ Divide (left, _, _) =>
697
- if (! supportedDataType(left.dataType)) {
698
- withInfo(div, s " Unsupported datatype ${left.dataType}" )
699
- }
700
- None
701
-
702
- case div @ IntegralDivide (left, right, _) if supportedDataType(left.dataType) =>
703
- val rightExpr = nullIfWhenPrimitive(right)
704
-
705
- val dataType = (left.dataType, right.dataType) match {
706
- case (l : DecimalType , r : DecimalType ) =>
707
- // copy from IntegralDivide.resultDecimalType
708
- val intDig = l.precision - l.scale + r.scale
709
- DecimalType (min(if (intDig == 0 ) 1 else intDig, DecimalType .MAX_PRECISION ), 0 )
710
- case _ => left.dataType
711
- }
712
-
713
- val divideExpr = createMathExpression(
714
- expr,
715
- left,
716
- rightExpr,
717
- inputs,
718
- binding,
719
- dataType,
720
- div.evalMode == EvalMode .ANSI ,
721
- (builder, mathExpr) => builder.setIntegralDivide(mathExpr))
722
-
723
- if (divideExpr.isDefined) {
724
- val childExpr = if (dataType.isInstanceOf [DecimalType ]) {
725
- // check overflow for decimal type
726
- val builder = ExprOuterClass .CheckOverflow .newBuilder()
727
- builder.setChild(divideExpr.get)
728
- builder.setFailOnError(div.evalMode == EvalMode .ANSI )
729
- builder.setDatatype(serializeDataType(dataType).get)
730
- Some (
731
- ExprOuterClass .Expr
732
- .newBuilder()
733
- .setCheckOverflow(builder)
734
- .build())
735
- } else {
736
- divideExpr
737
- }
738
-
739
- // cast result to long
740
- castToProto(expr, None , LongType , childExpr.get, CometEvalMode .LEGACY )
741
- } else {
742
- None
743
- }
744
-
745
- case div @ IntegralDivide (left, _, _) =>
746
- if (! supportedDataType(left.dataType)) {
747
- withInfo(div, s " Unsupported datatype ${left.dataType}" )
748
- }
749
- None
750
-
751
- case rem @ Remainder (left, right, _) if supportedDataType(left.dataType) =>
752
- createMathExpression(
753
- expr,
754
- left,
755
- right,
756
- inputs,
757
- binding,
758
- rem.dataType,
759
- rem.evalMode == EvalMode .ANSI ,
760
- (builder, mathExpr) => builder.setRemainder(mathExpr))
761
-
762
- case rem @ Remainder (left, _, _) =>
763
- if (! supportedDataType(left.dataType)) {
764
- withInfo(rem, s " Unsupported datatype ${left.dataType}" )
765
- }
766
- None
767
-
768
638
case EqualTo (left, right) =>
769
639
createBinaryExpr(
770
640
expr,
@@ -1962,42 +1832,6 @@ object QueryPlanSerde extends Logging with CometExprShim {
1962
1832
}
1963
1833
}
1964
1834
1965
- private def createMathExpression (
1966
- expr : Expression ,
1967
- left : Expression ,
1968
- right : Expression ,
1969
- inputs : Seq [Attribute ],
1970
- binding : Boolean ,
1971
- dataType : DataType ,
1972
- failOnError : Boolean ,
1973
- f : (ExprOuterClass .Expr .Builder , ExprOuterClass .MathExpr ) => ExprOuterClass .Expr .Builder )
1974
- : Option [ExprOuterClass .Expr ] = {
1975
- val leftExpr = exprToProtoInternal(left, inputs, binding)
1976
- val rightExpr = exprToProtoInternal(right, inputs, binding)
1977
-
1978
- if (leftExpr.isDefined && rightExpr.isDefined) {
1979
- // create the generic MathExpr message
1980
- val builder = ExprOuterClass .MathExpr .newBuilder()
1981
- builder.setLeft(leftExpr.get)
1982
- builder.setRight(rightExpr.get)
1983
- builder.setFailOnError(failOnError)
1984
- serializeDataType(dataType).foreach { t =>
1985
- builder.setReturnType(t)
1986
- }
1987
- val inner = builder.build()
1988
- // call the user-supplied function to wrap MathExpr in a top-level Expr
1989
- // such as Expr.Add or Expr.Divide
1990
- Some (
1991
- f(
1992
- ExprOuterClass .Expr
1993
- .newBuilder(),
1994
- inner).build())
1995
- } else {
1996
- withInfo(expr, left, right)
1997
- None
1998
- }
1999
- }
2000
-
2001
1835
def in (
2002
1836
expr : Expression ,
2003
1837
value : Expression ,
@@ -2053,25 +1887,6 @@ object QueryPlanSerde extends Logging with CometExprShim {
2053
1887
Some (ExprOuterClass .Expr .newBuilder().setScalarFunc(builder).build())
2054
1888
}
2055
1889
2056
- private def isPrimitive (expression : Expression ): Boolean = expression.dataType match {
2057
- case _ : ByteType | _ : ShortType | _ : IntegerType | _ : LongType | _ : FloatType |
2058
- _ : DoubleType | _ : TimestampType | _ : DateType | _ : BooleanType | _ : DecimalType =>
2059
- true
2060
- case _ => false
2061
- }
2062
-
2063
- private def nullIfWhenPrimitive (expression : Expression ): Expression =
2064
- if (isPrimitive(expression)) {
2065
- val zero = Literal .default(expression.dataType)
2066
- expression match {
2067
- case _ : Literal if expression != zero => expression
2068
- case _ =>
2069
- If (EqualTo (expression, zero), Literal .create(null , expression.dataType), expression)
2070
- }
2071
- } else {
2072
- expression
2073
- }
2074
-
2075
1890
private def nullIfNegative (expression : Expression ): Expression = {
2076
1891
val zero = Literal .default(expression.dataType)
2077
1892
If (LessThanOrEqual (expression, zero), Literal .create(null , expression.dataType), expression)
0 commit comments