Skip to content

Commit 2ece486

Browse files
feat(spark): add some date functions (#373)
The date/time functions in Spark don’t map directly to the Substrait eqivalents. E.g. - `date ± interval-days` are handled by the `DateAdd` & `DateSub` functions in Spark, but as a variant of the arithmetic `add` function in substrait. - The date/time component extraction functions are all handled by different functions in Spark, but by a single `extract` function in Substrait with an `enum` argument to specify which component. Neither of these could be handled using the existing function mapping capabilities in the `spark` module. This commit exends this capability so that it can now handle these two scenarios in (I hope) a generic way. I’ve added a few variants of the `extract` function - more can follow. Adding this will give us 100% pass rate for all the TPC-DS querues. The README is updated accordingly.
1 parent b67599e commit 2ece486

File tree

13 files changed

+218
-73
lines changed

13 files changed

+218
-73
lines changed

core/src/main/java/io/substrait/expression/EnumArg.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,5 +26,9 @@ static EnumArg of(SimpleExtension.EnumArgument enumArg, String option) {
2626
return ImmutableEnumArg.builder().value(Optional.of(option)).build();
2727
}
2828

29+
static EnumArg of(String value) {
30+
return ImmutableEnumArg.builder().value(Optional.of(value)).build();
31+
}
32+
2933
EnumArg UNSPECIFIED_ENUM_ARG = ImmutableEnumArg.builder().value(Optional.empty()).build();
3034
}

readme.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Substrait Java is a project that makes it easier to build [Substrait](https://substrait.io/) plans through Java. The project has two main parts:
44
1) **Core** is the module that supports building Substrait plans directly through Java. This is much easier than manipulating the Substrait protobuf directly. It has no direct support for going from SQL to Substrait (that's covered by the second part)
55
2) **Isthmus** is the module that allows going from SQL to a Substrait plan. Both Java APIs and a top level script for conversion are present. Not all SQL is supported yet by this module, but a lot is. For example, all of the TPC-H queries and all but a few of the TPC-DS queries are translatable.
6-
3) **Spark** is the module that provides an API for translating a Substrait plan to and from a Spark query plan. The most commonly used logical relations are supported, including those generated from all of the TPC-H queries, but there are currently some gaps in support that prevent all of the TPC-DS queries from being translatable.
6+
3) **Spark** is the module that provides an API for translating a Substrait plan to and from a Spark query plan. The most commonly used logical relations and functions are supported, including those generated from all of the TPC-H and TCP-DS queries.
77

88
## Building
99
After you've cloned the project through git, Substrait Java is built with a tool called [Gradle](https://gradle.org/). To build, execute the following:

spark/src/main/resources/spark.yml

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,6 @@
1515
%YAML 1.2
1616
---
1717
scalar_functions:
18-
-
19-
name: year
20-
description: Returns the year component of the date/timestamp
21-
impls:
22-
- args:
23-
- value: date
24-
return: i32
2518
-
2619
name: unscaled
2720
description: >-
@@ -41,6 +34,16 @@ scalar_functions:
4134
- args:
4235
- value: i64
4336
return: DECIMAL<P,S>
37+
- name: add
38+
description: >-
39+
Adds days to a date
40+
impls:
41+
- args:
42+
- name: start_date
43+
value: date
44+
- name: days
45+
value: i32
46+
return: date
4447
- name: shift_right
4548
description: >-
4649
Bitwise (signed) shift right.
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to you under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package io.substrait.spark.expression
19+
20+
import org.apache.spark.sql.catalyst.expressions.{LeafExpression, Unevaluable}
21+
import org.apache.spark.sql.types.{DataType, NullType}
22+
23+
/**
24+
* For internal use only. This represents the equivalent of a Substrait enum parameter type for use
25+
* during conversion. It must not become part of a final Spark logical plan.
26+
*
27+
* @param value
28+
* The enum string value.
29+
*/
30+
case class Enum(value: String) extends LeafExpression with Unevaluable {
31+
override def nullable: Boolean = false
32+
33+
override def dataType: DataType = NullType
34+
35+
override def equals(that: Any): Boolean = that match {
36+
case Enum(other) => other == value
37+
case _ => false
38+
}
39+
}

spark/src/main/scala/io/substrait/spark/expression/FunctionConverter.scala

Lines changed: 36 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import org.apache.spark.sql.types.DataType
2727

2828
import com.google.common.collect.{ArrayListMultimap, Multimap}
2929
import io.substrait.`type`.Type
30-
import io.substrait.expression.{Expression => SExpression, ExpressionCreator, FunctionArg}
30+
import io.substrait.expression.{EnumArg, Expression => SExpression, ExpressionCreator, FunctionArg}
3131
import io.substrait.expression.Expression.FailureBehavior
3232
import io.substrait.extension.SimpleExtension
3333
import io.substrait.function.{ParameterizedType, ToTypeString}
@@ -93,14 +93,28 @@ abstract class FunctionConverter[F <: SimpleExtension.Function, T](functions: Se
9393
(matcherMap, keyMap)
9494
}
9595

96-
def getSparkExpressionFromSubstraitFunc(key: String, outputType: Type): Option[Sig] = {
97-
val sigs = substraitFuncKeyToSig.get(key)
98-
sigs.size() match {
99-
case 0 => None
100-
case 1 => Some(sigs.iterator().next())
101-
case _ => None
96+
def getSparkExpressionFromSubstraitFunc(
97+
key: String,
98+
args: Seq[Expression]): Option[Expression] = {
99+
val candidates = substraitFuncKeyToSig.get(key).asScala.toList
100+
val sigs = if (candidates.length > 1) {
101+
// attempt to disambiguate with the key (if it's been set)
102+
val specific = candidates.filter {
103+
case SpecialSig(_, _, Some(sig), _) if sig == key => true
104+
case _ => false
105+
}
106+
if (specific.nonEmpty) {
107+
specific
108+
} else {
109+
// no matching signature, so select the generic one(s)
110+
candidates
111+
}
112+
} else {
113+
candidates
102114
}
115+
sigs.headOption.map(sig => sig.makeCall(args))
103116
}
117+
104118
private def createFinder(name: String, functions: Seq[F]): FunctionFinder[F, T] = {
105119
new FunctionFinder[F, T](
106120
name,
@@ -237,10 +251,14 @@ class FunctionFinder[F <: SimpleExtension.Function, T](
237251
val singularInputType: Option[SingularArgumentMatcher[F]],
238252
val parent: FunctionConverter[F, T]) {
239253

240-
def attemptMatch(expression: Expression, operands: Seq[SExpression]): Option[T] = {
241-
val opTypes = operands.map(_.getType)
254+
def attemptMatch(expression: Expression, operands: Seq[FunctionArg]): Option[T] = {
242255
val outputType = ToSubstraitType.apply(expression.dataType, expression.nullable)
243-
val opTypesStr = opTypes.map(t => t.accept(ToTypeString.INSTANCE))
256+
257+
val opTypesStr = operands.map {
258+
case e: SExpression => e.getType.accept(ToTypeString.INSTANCE)
259+
case t: Type => t.accept(ToTypeString.INSTANCE)
260+
case _: EnumArg => "req"
261+
}
244262

245263
val possibleKeys =
246264
Util.crossProduct(opTypesStr.map(s => Seq(s))).map(list => list.mkString("_"))
@@ -251,11 +269,11 @@ class FunctionFinder[F <: SimpleExtension.Function, T](
251269

252270
if (operands.isEmpty) {
253271
val variant = directMap(name + ":")
254-
variant.validateOutputType(JavaConverters.bufferAsJavaList(operands.toBuffer), outputType)
272+
// TODO validate the output type
255273
Option(parent.generateBinding(expression, variant, operands, outputType))
256274
} else if (directMatchKey.isDefined) {
257275
val variant = directMap(directMatchKey.get)
258-
variant.validateOutputType(JavaConverters.bufferAsJavaList(operands.toBuffer), outputType)
276+
// TODO validate the output type
259277
val funcArgs: Seq[FunctionArg] = operands
260278
Option(parent.generateBinding(expression, variant, funcArgs, outputType))
261279
} else if (singularInputType.isDefined) {
@@ -277,9 +295,7 @@ class FunctionFinder[F <: SimpleExtension.Function, T](
277295
.map(
278296
declaration => {
279297
val coercedArgs = coerceArguments(operands, leastRestrictiveSubstraitT)
280-
declaration.validateOutputType(
281-
JavaConverters.bufferAsJavaList(coercedArgs.toBuffer),
282-
outputType)
298+
// TODO validate the output type
283299
val funcArgs: Seq[FunctionArg] = coercedArgs
284300
parent.generateBinding(expression, declaration, funcArgs, outputType)
285301
})
@@ -293,14 +309,15 @@ class FunctionFinder[F <: SimpleExtension.Function, T](
293309
* Coerced types according to an expected output type. Coercion is only done for type mismatches,
294310
* not for nullability or parameter mismatches.
295311
*/
296-
private def coerceArguments(arguments: Seq[SExpression], t: Type): Seq[SExpression] = {
297-
arguments.map(
298-
a => {
312+
private def coerceArguments(arguments: Seq[FunctionArg], t: Type): Seq[FunctionArg] = {
313+
arguments.map {
314+
case a: SExpression =>
299315
if (FunctionFinder.isMatch(t, a.getType)) {
300316
a
301317
} else {
302318
ExpressionCreator.cast(t, a, FailureBehavior.THROW_EXCEPTION)
303319
}
304-
})
320+
case other => other
321+
}
305322
}
306323
}

spark/src/main/scala/io/substrait/spark/expression/FunctionMappings.scala

Lines changed: 67 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,75 @@ package io.substrait.spark.expression
1919
import org.apache.spark.sql.catalyst.analysis.FunctionRegistryBase
2020
import org.apache.spark.sql.catalyst.expressions._
2121
import org.apache.spark.sql.catalyst.expressions.aggregate._
22+
import org.apache.spark.sql.types.{DayTimeIntervalType, IntegerType}
23+
24+
import io.substrait.utils.Util
2225

2326
import scala.reflect.ClassTag
2427

25-
case class Sig(expClass: Class[_], name: String, builder: Seq[Expression] => Expression) {
26-
def makeCall(args: Seq[Expression]): Expression =
28+
trait Sig {
29+
def name: String
30+
def expClass: Class[_]
31+
def makeCall(args: Seq[Expression]): Expression
32+
}
33+
34+
case class GenericSig(expClass: Class[_], name: String, builder: Seq[Expression] => Expression)
35+
extends Sig {
36+
override def makeCall(args: Seq[Expression]): Expression = {
2737
builder(args)
38+
}
39+
}
40+
41+
case class SpecialSig(
42+
expClass: Class[_],
43+
name: String,
44+
key: Option[String],
45+
builder: Seq[Expression] => Expression)
46+
extends Sig {
47+
override def makeCall(args: Seq[Expression]): Expression = {
48+
builder(args)
49+
}
50+
}
51+
52+
object DateFunction {
53+
def unapply(e: Expression): Option[Seq[Expression]] = e match {
54+
case DateAdd(startDate, days) => Some(Seq(startDate, days))
55+
// The following map to the Substrait `extract` function.
56+
case Year(date) => Some(Seq(Enum("YEAR"), date))
57+
case Quarter(date) => Some(Seq(Enum("QUARTER"), Enum("ONE"), date))
58+
case Month(date) => Some(Seq(Enum("MONTH"), Enum("ONE"), date))
59+
case DayOfMonth(date) => Some(Seq(Enum("DAY"), Enum("ONE"), date))
60+
case _ => None
61+
}
62+
63+
def unapply(name_args: (String, Seq[Expression])): Option[Expression] = name_args match {
64+
case ("add:date_i32", Seq(startDate, days)) => Some(DateAdd(startDate, days))
65+
case ("extract", Seq(Enum("YEAR"), date)) => Some(Year(date))
66+
case ("extract", Seq(Enum("QUARTER"), Enum("ONE"), date)) => Some(Quarter(date))
67+
case ("extract", Seq(Enum("MONTH"), Enum("ONE"), date)) => Some(Month(date))
68+
case ("extract", Seq(Enum("DAY"), Enum("ONE"), date)) => Some(DayOfMonth(date))
69+
case _ => None
70+
}
2871
}
2972

3073
class FunctionMappings {
3174

32-
private def s[T <: Expression: ClassTag](name: String): Sig = {
75+
private def s[T <: Expression: ClassTag](name: String): GenericSig = {
3376
val builder = FunctionRegistryBase.build[T](name, None)._2
34-
Sig(scala.reflect.classTag[T].runtimeClass, name, builder)
77+
GenericSig(scala.reflect.classTag[T].runtimeClass, name, builder)
78+
}
79+
80+
private def ss[T <: Expression: ClassTag](signature: String): SpecialSig = {
81+
val (name, key) = if (signature.contains(":")) {
82+
(signature.split(':').head, Some(signature))
83+
} else {
84+
(signature, None)
85+
}
86+
val builder = (args: Seq[Expression]) =>
87+
(signature, args) match {
88+
case DateFunction(expr) => expr
89+
}
90+
SpecialSig(scala.reflect.classTag[T].runtimeClass, name, key, builder)
3591
}
3692

3793
val SCALAR_SIGS: Seq[Sig] = Seq(
@@ -82,12 +138,18 @@ class FunctionMappings {
82138
s[Lower]("lower"),
83139
s[Concat]("concat"),
84140
s[Coalesce]("coalesce"),
85-
s[Year]("year"),
86141
s[ShiftRight]("shift_right"),
87142
s[BitwiseAnd]("bitwise_and"),
88143
s[BitwiseOr]("bitwise_or"),
89144
s[BitwiseXor]("bitwise_xor"),
90145

146+
// date/time functions require special handling
147+
ss[DateAdd]("add:date_i32"),
148+
ss[Year]("extract"),
149+
ss[Quarter]("extract"),
150+
ss[Month]("extract"),
151+
ss[DayOfMonth]("extract"),
152+
91153
// internal
92154
s[MakeDecimal]("make_decimal"),
93155
s[UnscaledValue]("unscaled")
@@ -115,11 +177,6 @@ class FunctionMappings {
115177
s[Lag]("lag"),
116178
s[NthValue]("nth_value")
117179
)
118-
119-
lazy val scalar_functions_map: Map[Class[_], Sig] = SCALAR_SIGS.map(s => (s.expClass, s)).toMap
120-
lazy val aggregate_functions_map: Map[Class[_], Sig] =
121-
AGGREGATE_SIGS.map(s => (s.expClass, s)).toMap
122-
lazy val window_functions_map: Map[Class[_], Sig] = WINDOW_SIGS.map(s => (s.expClass, s)).toMap
123180
}
124181

125182
object FunctionMappings extends FunctionMappings

spark/src/main/scala/io/substrait/spark/expression/ToScalarFunction.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ abstract class ToScalarFunction(functions: Seq[SimpleExtension.ScalarFunctionVar
4040
.build()
4141
}
4242

43-
def convert(expression: Expression, operands: Seq[SExpression]): Option[SExpression] = {
43+
def convert(expression: Expression, operands: Seq[FunctionArg]): Option[SExpression] = {
4444
Option(signatures.get(expression.getClass))
4545
.flatMap(m => m.attemptMatch(expression, operands))
4646
}

spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,14 @@ import io.substrait.spark.logical.ToLogicalPlan
2222
import org.apache.spark.sql.Row
2323
import org.apache.spark.sql.catalyst.expressions.{CaseWhen, Cast, Expression, In, InSubquery, ListQuery, Literal, MakeDecimal, NamedExpression, ScalarSubquery}
2424
import org.apache.spark.sql.internal.SQLConf
25-
import org.apache.spark.sql.types.{DateType, Decimal}
25+
import org.apache.spark.sql.types.Decimal
2626
import org.apache.spark.substrait.SparkTypeUtil
2727
import org.apache.spark.unsafe.types.UTF8String
2828

2929
import io.substrait.`type`.{StringTypeVisitor, Type}
3030
import io.substrait.{expression => exp}
31-
import io.substrait.expression.{Expression => SExpression}
31+
import io.substrait.expression.{EnumArg, Expression => SExpression}
32+
import io.substrait.extension.SimpleExtension
3233
import io.substrait.util.DecimalUtil
3334
import io.substrait.utils.Util
3435

@@ -162,10 +163,11 @@ class ToSparkExpression(
162163
override def visit(expr: SExpression.Cast): Expression = {
163164
val childExp = expr.input().accept(this)
164165
val tt = ToSparkType.convert(expr.getType)
165-
val tz = childExp.dataType match {
166-
case DateType => Some(SQLConf.get.getConf(SQLConf.SESSION_LOCAL_TIMEZONE))
167-
case _ => None
168-
}
166+
val tz =
167+
if (Cast.needsTimeZone(childExp.dataType, tt))
168+
Some(SQLConf.get.getConf(SQLConf.SESSION_LOCAL_TIMEZONE))
169+
else
170+
None
169171
Cast(childExp, tt, tz)
170172
}
171173

@@ -219,12 +221,19 @@ class ToSparkExpression(
219221
}
220222
}
221223

224+
override def visitEnumArg(
225+
fnDef: SimpleExtension.Function,
226+
argIdx: Int,
227+
e: EnumArg): Expression = {
228+
Enum(e.value.orElse(""))
229+
}
230+
222231
override def visit(expr: SExpression.ScalarFunctionInvocation): Expression = {
223232
val eArgs = expr.arguments().asScala
224233
val args = eArgs.zipWithIndex.map {
225234
case (arg, i) =>
226235
arg.accept(expr.declaration(), i, this)
227-
}
236+
}.toList
228237

229238
expr.declaration.name match {
230239
case "make_decimal" if expr.declaration.uri == SparkExtension.uri =>
@@ -238,8 +247,7 @@ class ToSparkExpression(
238247
}
239248
case _ =>
240249
scalarFunctionConverter
241-
.getSparkExpressionFromSubstraitFunc(expr.declaration().key(), expr.outputType())
242-
.flatMap(sig => Option(sig.makeCall(args)))
250+
.getSparkExpressionFromSubstraitFunc(expr.declaration.key, args)
243251
.getOrElse({
244252
val msg = String.format(
245253
"Unable to convert scalar function %s(%s).",

0 commit comments

Comments
 (0)