Skip to content

Commit ba3c82c

Browse files
authored
fix: Refactor arithmetic serde and fix correctness issues with EvalMode::TRY (apache#2018)
1 parent b256458 commit ba3c82c

File tree

5 files changed

+376
-237
lines changed

5 files changed

+376
-237
lines changed

native/core/src/execution/planner.rs

Lines changed: 72 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -229,45 +229,78 @@ impl PhysicalPlanner {
229229
input_schema: SchemaRef,
230230
) -> Result<Arc<dyn PhysicalExpr>, ExecutionError> {
231231
match spark_expr.expr_struct.as_ref().unwrap() {
232-
ExprStruct::Add(expr) => self.create_binary_expr(
233-
expr.left.as_ref().unwrap(),
234-
expr.right.as_ref().unwrap(),
235-
expr.return_type.as_ref(),
236-
DataFusionOperator::Plus,
237-
input_schema,
238-
),
239-
ExprStruct::Subtract(expr) => self.create_binary_expr(
240-
expr.left.as_ref().unwrap(),
241-
expr.right.as_ref().unwrap(),
242-
expr.return_type.as_ref(),
243-
DataFusionOperator::Minus,
244-
input_schema,
245-
),
246-
ExprStruct::Multiply(expr) => self.create_binary_expr(
247-
expr.left.as_ref().unwrap(),
248-
expr.right.as_ref().unwrap(),
249-
expr.return_type.as_ref(),
250-
DataFusionOperator::Multiply,
251-
input_schema,
252-
),
253-
ExprStruct::Divide(expr) => self.create_binary_expr(
254-
expr.left.as_ref().unwrap(),
255-
expr.right.as_ref().unwrap(),
256-
expr.return_type.as_ref(),
257-
DataFusionOperator::Divide,
258-
input_schema,
259-
),
260-
ExprStruct::IntegralDivide(expr) => self.create_binary_expr_with_options(
261-
expr.left.as_ref().unwrap(),
262-
expr.right.as_ref().unwrap(),
263-
expr.return_type.as_ref(),
264-
DataFusionOperator::Divide,
265-
input_schema,
266-
BinaryExprOptions {
267-
is_integral_div: true,
268-
},
269-
),
232+
ExprStruct::Add(expr) => {
233+
// TODO respect eval mode
234+
// https://github.com/apache/datafusion-comet/issues/2021
235+
// https://github.com/apache/datafusion-comet/issues/536
236+
let _eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
237+
self.create_binary_expr(
238+
expr.left.as_ref().unwrap(),
239+
expr.right.as_ref().unwrap(),
240+
expr.return_type.as_ref(),
241+
DataFusionOperator::Plus,
242+
input_schema,
243+
)
244+
}
245+
ExprStruct::Subtract(expr) => {
246+
// TODO respect eval mode
247+
// https://github.com/apache/datafusion-comet/issues/2021
248+
// https://github.com/apache/datafusion-comet/issues/535
249+
let _eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
250+
self.create_binary_expr(
251+
expr.left.as_ref().unwrap(),
252+
expr.right.as_ref().unwrap(),
253+
expr.return_type.as_ref(),
254+
DataFusionOperator::Minus,
255+
input_schema,
256+
)
257+
}
258+
ExprStruct::Multiply(expr) => {
259+
// TODO respect eval mode
260+
// https://github.com/apache/datafusion-comet/issues/2021
261+
// https://github.com/apache/datafusion-comet/issues/534
262+
let _eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
263+
self.create_binary_expr(
264+
expr.left.as_ref().unwrap(),
265+
expr.right.as_ref().unwrap(),
266+
expr.return_type.as_ref(),
267+
DataFusionOperator::Multiply,
268+
input_schema,
269+
)
270+
}
271+
ExprStruct::Divide(expr) => {
272+
// TODO respect eval mode
273+
// https://github.com/apache/datafusion-comet/issues/2021
274+
// https://github.com/apache/datafusion-comet/issues/533
275+
let _eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
276+
self.create_binary_expr(
277+
expr.left.as_ref().unwrap(),
278+
expr.right.as_ref().unwrap(),
279+
expr.return_type.as_ref(),
280+
DataFusionOperator::Divide,
281+
input_schema,
282+
)
283+
}
284+
ExprStruct::IntegralDivide(expr) => {
285+
// TODO respect eval mode
286+
// https://github.com/apache/datafusion-comet/issues/2021
287+
// https://github.com/apache/datafusion-comet/issues/533
288+
let _eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
289+
self.create_binary_expr_with_options(
290+
expr.left.as_ref().unwrap(),
291+
expr.right.as_ref().unwrap(),
292+
expr.return_type.as_ref(),
293+
DataFusionOperator::Divide,
294+
input_schema,
295+
BinaryExprOptions {
296+
is_integral_div: true,
297+
},
298+
)
299+
}
270300
ExprStruct::Remainder(expr) => {
301+
let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
302+
// TODO add support for EvalMode::TRY
303+
// https://github.com/apache/datafusion-comet/issues/2021
271304
let left =
272305
self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?;
273306
let right =
@@ -278,7 +311,7 @@ impl PhysicalPlanner {
278311
right,
279312
expr.return_type.as_ref().map(to_arrow_datatype).unwrap(),
280313
input_schema,
281-
expr.fail_on_error,
314+
eval_mode == EvalMode::Ansi,
282315
&self.session_ctx.state(),
283316
);
284317
result.map_err(|e| GeneralError(e.to_string()))

native/proto/src/proto/expr.proto

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -220,19 +220,19 @@ message Literal {
220220
bool is_null = 12;
221221
}
222222

223-
message MathExpr {
224-
Expr left = 1;
225-
Expr right = 2;
226-
bool fail_on_error = 3;
227-
DataType return_type = 4;
228-
}
229-
230223
enum EvalMode {
231224
LEGACY = 0;
232225
TRY = 1;
233226
ANSI = 2;
234227
}
235228

229+
message MathExpr {
230+
Expr left = 1;
231+
Expr right = 2;
232+
DataType return_type = 4;
233+
EvalMode eval_mode = 5;
234+
}
235+
236236
message Cast {
237237
Expr child = 1;
238238
DataType datatype = 2;

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 6 additions & 191 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ import java.util.Locale
2323

2424
import scala.collection.JavaConverters._
2525
import scala.collection.mutable.ListBuffer
26-
import scala.math.min
2726

2827
import org.apache.spark.internal.Logging
2928
import org.apache.spark.sql.catalyst.expressions._
@@ -67,6 +66,12 @@ object QueryPlanSerde extends Logging with CometExprShim {
6766
* Mapping of Spark expression class to Comet expression handler.
6867
*/
6968
private val exprSerdeMap: Map[Class[_], CometExpressionSerde] = Map(
69+
classOf[Add] -> CometAdd,
70+
classOf[Subtract] -> CometSubtract,
71+
classOf[Multiply] -> CometMultiply,
72+
classOf[Divide] -> CometDivide,
73+
classOf[IntegralDivide] -> CometIntegralDivide,
74+
classOf[Remainder] -> CometRemainder,
7075
classOf[ArrayAppend] -> CometArrayAppend,
7176
classOf[ArrayContains] -> CometArrayContains,
7277
classOf[ArrayDistinct] -> CometArrayDistinct,
@@ -630,141 +635,6 @@ object QueryPlanSerde extends Logging with CometExprShim {
630635
case c @ Cast(child, dt, timeZoneId, _) =>
631636
handleCast(expr, child, inputs, binding, dt, timeZoneId, evalMode(c))
632637

633-
case add @ Add(left, right, _) if supportedDataType(left.dataType) =>
634-
createMathExpression(
635-
expr,
636-
left,
637-
right,
638-
inputs,
639-
binding,
640-
add.dataType,
641-
add.evalMode == EvalMode.ANSI,
642-
(builder, mathExpr) => builder.setAdd(mathExpr))
643-
644-
case add @ Add(left, _, _) if !supportedDataType(left.dataType) =>
645-
withInfo(add, s"Unsupported datatype ${left.dataType}")
646-
None
647-
648-
case sub @ Subtract(left, right, _) if supportedDataType(left.dataType) =>
649-
createMathExpression(
650-
expr,
651-
left,
652-
right,
653-
inputs,
654-
binding,
655-
sub.dataType,
656-
sub.evalMode == EvalMode.ANSI,
657-
(builder, mathExpr) => builder.setSubtract(mathExpr))
658-
659-
case sub @ Subtract(left, _, _) if !supportedDataType(left.dataType) =>
660-
withInfo(sub, s"Unsupported datatype ${left.dataType}")
661-
None
662-
663-
case mul @ Multiply(left, right, _) if supportedDataType(left.dataType) =>
664-
createMathExpression(
665-
expr,
666-
left,
667-
right,
668-
inputs,
669-
binding,
670-
mul.dataType,
671-
mul.evalMode == EvalMode.ANSI,
672-
(builder, mathExpr) => builder.setMultiply(mathExpr))
673-
674-
case mul @ Multiply(left, _, _) =>
675-
if (!supportedDataType(left.dataType)) {
676-
withInfo(mul, s"Unsupported datatype ${left.dataType}")
677-
}
678-
None
679-
680-
case div @ Divide(left, right, _) if supportedDataType(left.dataType) =>
681-
// Datafusion now throws an exception for dividing by zero
682-
// See https://github.com/apache/arrow-datafusion/pull/6792
683-
// For now, use NullIf to swap zeros with nulls.
684-
val rightExpr = nullIfWhenPrimitive(right)
685-
686-
createMathExpression(
687-
expr,
688-
left,
689-
rightExpr,
690-
inputs,
691-
binding,
692-
div.dataType,
693-
div.evalMode == EvalMode.ANSI,
694-
(builder, mathExpr) => builder.setDivide(mathExpr))
695-
696-
case div @ Divide(left, _, _) =>
697-
if (!supportedDataType(left.dataType)) {
698-
withInfo(div, s"Unsupported datatype ${left.dataType}")
699-
}
700-
None
701-
702-
case div @ IntegralDivide(left, right, _) if supportedDataType(left.dataType) =>
703-
val rightExpr = nullIfWhenPrimitive(right)
704-
705-
val dataType = (left.dataType, right.dataType) match {
706-
case (l: DecimalType, r: DecimalType) =>
707-
// copy from IntegralDivide.resultDecimalType
708-
val intDig = l.precision - l.scale + r.scale
709-
DecimalType(min(if (intDig == 0) 1 else intDig, DecimalType.MAX_PRECISION), 0)
710-
case _ => left.dataType
711-
}
712-
713-
val divideExpr = createMathExpression(
714-
expr,
715-
left,
716-
rightExpr,
717-
inputs,
718-
binding,
719-
dataType,
720-
div.evalMode == EvalMode.ANSI,
721-
(builder, mathExpr) => builder.setIntegralDivide(mathExpr))
722-
723-
if (divideExpr.isDefined) {
724-
val childExpr = if (dataType.isInstanceOf[DecimalType]) {
725-
// check overflow for decimal type
726-
val builder = ExprOuterClass.CheckOverflow.newBuilder()
727-
builder.setChild(divideExpr.get)
728-
builder.setFailOnError(div.evalMode == EvalMode.ANSI)
729-
builder.setDatatype(serializeDataType(dataType).get)
730-
Some(
731-
ExprOuterClass.Expr
732-
.newBuilder()
733-
.setCheckOverflow(builder)
734-
.build())
735-
} else {
736-
divideExpr
737-
}
738-
739-
// cast result to long
740-
castToProto(expr, None, LongType, childExpr.get, CometEvalMode.LEGACY)
741-
} else {
742-
None
743-
}
744-
745-
case div @ IntegralDivide(left, _, _) =>
746-
if (!supportedDataType(left.dataType)) {
747-
withInfo(div, s"Unsupported datatype ${left.dataType}")
748-
}
749-
None
750-
751-
case rem @ Remainder(left, right, _) if supportedDataType(left.dataType) =>
752-
createMathExpression(
753-
expr,
754-
left,
755-
right,
756-
inputs,
757-
binding,
758-
rem.dataType,
759-
rem.evalMode == EvalMode.ANSI,
760-
(builder, mathExpr) => builder.setRemainder(mathExpr))
761-
762-
case rem @ Remainder(left, _, _) =>
763-
if (!supportedDataType(left.dataType)) {
764-
withInfo(rem, s"Unsupported datatype ${left.dataType}")
765-
}
766-
None
767-
768638
case EqualTo(left, right) =>
769639
createBinaryExpr(
770640
expr,
@@ -1962,42 +1832,6 @@ object QueryPlanSerde extends Logging with CometExprShim {
19621832
}
19631833
}
19641834

1965-
private def createMathExpression(
1966-
expr: Expression,
1967-
left: Expression,
1968-
right: Expression,
1969-
inputs: Seq[Attribute],
1970-
binding: Boolean,
1971-
dataType: DataType,
1972-
failOnError: Boolean,
1973-
f: (ExprOuterClass.Expr.Builder, ExprOuterClass.MathExpr) => ExprOuterClass.Expr.Builder)
1974-
: Option[ExprOuterClass.Expr] = {
1975-
val leftExpr = exprToProtoInternal(left, inputs, binding)
1976-
val rightExpr = exprToProtoInternal(right, inputs, binding)
1977-
1978-
if (leftExpr.isDefined && rightExpr.isDefined) {
1979-
// create the generic MathExpr message
1980-
val builder = ExprOuterClass.MathExpr.newBuilder()
1981-
builder.setLeft(leftExpr.get)
1982-
builder.setRight(rightExpr.get)
1983-
builder.setFailOnError(failOnError)
1984-
serializeDataType(dataType).foreach { t =>
1985-
builder.setReturnType(t)
1986-
}
1987-
val inner = builder.build()
1988-
// call the user-supplied function to wrap MathExpr in a top-level Expr
1989-
// such as Expr.Add or Expr.Divide
1990-
Some(
1991-
f(
1992-
ExprOuterClass.Expr
1993-
.newBuilder(),
1994-
inner).build())
1995-
} else {
1996-
withInfo(expr, left, right)
1997-
None
1998-
}
1999-
}
2000-
20011835
def in(
20021836
expr: Expression,
20031837
value: Expression,
@@ -2053,25 +1887,6 @@ object QueryPlanSerde extends Logging with CometExprShim {
20531887
Some(ExprOuterClass.Expr.newBuilder().setScalarFunc(builder).build())
20541888
}
20551889

2056-
private def isPrimitive(expression: Expression): Boolean = expression.dataType match {
2057-
case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType |
2058-
_: DoubleType | _: TimestampType | _: DateType | _: BooleanType | _: DecimalType =>
2059-
true
2060-
case _ => false
2061-
}
2062-
2063-
private def nullIfWhenPrimitive(expression: Expression): Expression =
2064-
if (isPrimitive(expression)) {
2065-
val zero = Literal.default(expression.dataType)
2066-
expression match {
2067-
case _: Literal if expression != zero => expression
2068-
case _ =>
2069-
If(EqualTo(expression, zero), Literal.create(null, expression.dataType), expression)
2070-
}
2071-
} else {
2072-
expression
2073-
}
2074-
20751890
private def nullIfNegative(expression: Expression): Expression = {
20761891
val zero = Literal.default(expression.dataType)
20771892
If(LessThanOrEqual(expression, zero), Literal.create(null, expression.dataType), expression)

0 commit comments

Comments
 (0)