Skip to content

Commit 7a689e9

Browse files
fix(spark): remove internal functions MakeDecimal and UnscaledValue (#386)
These two functions are inserted by the catalyst optimizer for queries that involve aggregation (sum & average) of decimal values. Approx 50% of the TPC-DS tests rely on these internal functions which doesn’t make them interchangeable with other query processors. This commit reverses this particular optimisation before conversion to substrait, and removes MakeDecimal and UnscaledValue from the `spark.yaml` file.
1 parent 134c224 commit 7a689e9

File tree

7 files changed

+51
-88
lines changed

7 files changed

+51
-88
lines changed

spark/src/main/resources/spark.yml

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,6 @@
1515
%YAML 1.2
1616
---
1717
scalar_functions:
18-
-
19-
name: unscaled
20-
description: >-
21-
Return the unscaled Long value of a Decimal, assuming it fits in a Long.
22-
Note: this expression is internal and created only by the optimizer,
23-
we don't need to do type check for it.
24-
impls:
25-
- args:
26-
- value: DECIMAL<P,S>
27-
return: i64
28-
-
29-
name: make_decimal
30-
description: >-
31-
Return the Decimal value of an unscaled Long.
32-
Note: this expression is internal and created only by the optimizer,
33-
impls:
34-
- args:
35-
- value: i64
36-
return: DECIMAL<P,S>
3718
- name: add
3819
description: >-
3920
Adds days to a date

spark/src/main/scala/io/substrait/spark/expression/FunctionMappings.scala

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -148,11 +148,7 @@ class FunctionMappings {
148148
ss[Year]("extract"),
149149
ss[Quarter]("extract"),
150150
ss[Month]("extract"),
151-
ss[DayOfMonth]("extract"),
152-
153-
// internal
154-
s[MakeDecimal]("make_decimal"),
155-
s[UnscaledValue]("unscaled")
151+
ss[DayOfMonth]("extract")
156152
)
157153

158154
val AGGREGATE_SIGS: Seq[Sig] = Seq(

spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala

Lines changed: 17 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -235,34 +235,22 @@ class ToSparkExpression(
235235
arg.accept(expr.declaration(), i, this)
236236
}.toList
237237

238-
expr.declaration.name match {
239-
case "make_decimal" if expr.declaration.uri == SparkExtension.uri =>
240-
expr.outputType match {
241-
// Need special case handing of this internal function.
242-
// Because the precision and scale arguments are extracted from the output type,
243-
// we can't use the generic scalar function conversion mechanism here.
244-
case d: Type.Decimal => MakeDecimal(args.head, d.precision, d.scale)
245-
case _ =>
246-
throw new IllegalArgumentException("Output type of MakeDecimal must be a decimal type")
247-
}
248-
case _ =>
249-
scalarFunctionConverter
250-
.getSparkExpressionFromSubstraitFunc(expr.declaration.key, args)
251-
.getOrElse({
252-
val msg = String.format(
253-
"Unable to convert scalar function %s(%s).",
254-
expr.declaration.name,
255-
expr.arguments.asScala
256-
.map {
257-
case ea: exp.EnumArg => ea.value.toString
258-
case e: SExpression => e.getType.accept(new StringTypeVisitor)
259-
case t: Type => t.accept(new StringTypeVisitor)
260-
case a => throw new IllegalStateException("Unexpected value: " + a)
261-
}
262-
.mkString(", ")
263-
)
264-
throw new IllegalArgumentException(msg)
265-
})
266-
}
238+
scalarFunctionConverter
239+
.getSparkExpressionFromSubstraitFunc(expr.declaration.key, args)
240+
.getOrElse({
241+
val msg = String.format(
242+
"Unable to convert scalar function %s(%s).",
243+
expr.declaration.name,
244+
expr.arguments.asScala
245+
.map {
246+
case ea: exp.EnumArg => ea.value.toString
247+
case e: SExpression => e.getType.accept(new StringTypeVisitor)
248+
case t: Type => t.accept(new StringTypeVisitor)
249+
case a => throw new IllegalStateException("Unexpected value: " + a)
250+
}
251+
.mkString(", ")
252+
)
253+
throw new IllegalArgumentException(msg)
254+
})
267255
}
268256
}

spark/src/main/scala/io/substrait/spark/expression/ToSubstraitExpression.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import io.substrait.spark.{HasOutputStack, ToSubstraitType}
2020

2121
import org.apache.spark.sql.catalyst.expressions._
2222
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Project}
23+
import org.apache.spark.sql.types.LongType
2324
import org.apache.spark.substrait.SparkTypeUtil
2425

2526
import io.substrait.expression.{EnumArg, Expression => SExpression, ExpressionCreator, FieldReference, ImmutableEnumArg, ImmutableExpression}
@@ -172,6 +173,7 @@ abstract class ToSubstraitExpression extends HasOutputStack[Seq[Attribute]] {
172173
translateUp(child)
173174
.map(ExpressionCreator
174175
.cast(ToSubstraitType.apply(dataType, c.nullable), _, FailureBehavior.THROW_EXCEPTION))
176+
case UnscaledValue(value) => translateUp(Cast(value, LongType))
175177
case c @ CheckOverflow(child, dataType, _) =>
176178
// CheckOverflow similar with cast
177179
translateUp(child)

spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import io.substrait.spark.expression._
2222
import org.apache.spark.internal.Logging
2323
import org.apache.spark.sql.catalyst.catalog.HiveTableRelation
2424
import org.apache.spark.sql.catalyst.expressions._
25-
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
25+
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, Sum}
2626
import org.apache.spark.sql.catalyst.plans._
2727
import org.apache.spark.sql.catalyst.plans.logical._
2828
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}
@@ -133,7 +133,30 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging {
133133
*/
134134
override def visitAggregate(agg: Aggregate): relation.Rel = {
135135
val input = visit(agg.child)
136-
val actualResultExprs = agg.aggregateExpressions
136+
val actualResultExprs = agg.aggregateExpressions.map {
137+
// eliminate the internal MakeDecimal and UnscaledValue functions by undoing the spark optimisation:
138+
// https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala#L2223
139+
case Alias(expr, name) =>
140+
Alias(
141+
expr.transform {
142+
case MakeDecimal(
143+
ae @ AggregateExpression(Sum(UnscaledValue(e), _), _, _, _, _),
144+
_,
145+
_,
146+
_) =>
147+
ae.copy(aggregateFunction = Sum(e))
148+
case Cast(
149+
Divide(ae @ AggregateExpression(Average(UnscaledValue(e), _), _, _, _, _), _, _),
150+
_,
151+
_,
152+
_) =>
153+
ae.copy(aggregateFunction = Average(e))
154+
case e => e
155+
},
156+
name
157+
)()
158+
case e => e
159+
}
137160
val actualGroupExprs = agg.groupingExpressions
138161

139162
val aggExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int]
@@ -198,8 +221,7 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging {
198221
override def visitWindow(window: Window): relation.Rel = {
199222
val windowExpressions = window.windowExpressions.map {
200223
case w: WindowExpression => fromWindowCall(w, window.child.output)
201-
case a: Alias if a.child.isInstanceOf[WindowExpression] =>
202-
fromWindowCall(a.child.asInstanceOf[WindowExpression], window.child.output)
224+
case Alias(w: WindowExpression, _) => fromWindowCall(w, window.child.output)
203225
case other =>
204226
throw new UnsupportedOperationException(s"Unsupported window expression: $other")
205227
}.asJava

spark/src/test/scala/io/substrait/spark/NumericSuite.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,9 @@ class NumericSuite extends SparkFunSuite with SharedSparkSession with SubstraitP
4242
"select floor(num), ceil(num), round(num, 0), round(num, 1) from (values (0.5), (-0.5)) as table(num)"
4343
)
4444
}
45+
46+
test("decimal aggregation") {
47+
assertSqlSubstraitRelRoundTrip(
48+
"select sum(num), avg(num)from (values (2.5), (-0.5)) as table(num)")
49+
}
4550
}

spark/src/test/scala/io/substrait/spark/expression/YamlTest.scala

Lines changed: 0 additions & 31 deletions
This file was deleted.

0 commit comments

Comments
 (0)