Skip to content

Commit 134c224

Browse files
authored
feat(isthmus): support more datetime extract variants (#360)
feat(isthmus): support conversion of TimestampTZLiteral feat(isthmus): support conversion of PrecisionTimestampLiteral feat(isthmus): support conversion of PrecisionTimestampTZLiteral
1 parent 0ed821c commit 134c224

File tree

6 files changed

+442
-54
lines changed

6 files changed

+442
-54
lines changed

core/src/main/java/io/substrait/dsl/SubstraitBuilder.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import io.substrait.expression.Expression.IfThen;
99
import io.substrait.expression.Expression.SwitchClause;
1010
import io.substrait.expression.FieldReference;
11+
import io.substrait.expression.FunctionArg;
1112
import io.substrait.expression.ImmutableExpression.Cast;
1213
import io.substrait.expression.ImmutableExpression.SingleOrList;
1314
import io.substrait.expression.ImmutableExpression.Switch;
@@ -640,7 +641,7 @@ public Expression.ScalarFunctionInvocation or(Expression... args) {
640641
}
641642

642643
public Expression.ScalarFunctionInvocation scalarFn(
643-
String namespace, String key, Type outputType, Expression... args) {
644+
String namespace, String key, Type outputType, FunctionArg... args) {
644645
var declaration =
645646
extensions.getScalarFunction(SimpleExtension.FunctionAnchor.of(namespace, key));
646647
return Expression.ScalarFunctionInvocation.builder()

isthmus/src/main/java/io/substrait/isthmus/expression/EnumConverter.java

Lines changed: 57 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
package io.substrait.isthmus.expression;
22

3-
import com.google.common.collect.BiMap;
4-
import com.google.common.collect.HashBiMap;
53
import io.substrait.expression.EnumArg;
64
import io.substrait.extension.DefaultExtensionCatalog;
75
import io.substrait.extension.SimpleExtension;
6+
import io.substrait.extension.SimpleExtension.Argument;
7+
import java.util.HashMap;
8+
import java.util.List;
9+
import java.util.Map;
10+
import java.util.Objects;
811
import java.util.Optional;
912
import java.util.function.Supplier;
1013
import org.apache.calcite.avatica.util.TimeUnitRange;
@@ -25,16 +28,34 @@
2528
*/
2629
public class EnumConverter {
2730

28-
private static final BiMap<Class<? extends Enum>, ArgAnchor> calciteEnumMap = HashBiMap.create();
31+
private static final Map<ArgAnchor, Class<? extends Enum<?>>> calciteEnumMap = new HashMap<>();
2932

3033
static {
34+
// deprecated {@link io.substrait.type.Type.Timestamp}
3135
calciteEnumMap.put(
32-
TimeUnitRange.class,
33-
argAnchor(DefaultExtensionCatalog.FUNCTIONS_DATETIME, "extract:req_ts", 0));
36+
argAnchor(DefaultExtensionCatalog.FUNCTIONS_DATETIME, "extract:req_ts", 0),
37+
TimeUnitRange.class);
38+
// deprecated {@link io.substrait.type.Type.TimestampTZ}
39+
calciteEnumMap.put(
40+
argAnchor(DefaultExtensionCatalog.FUNCTIONS_DATETIME, "extract:req_tstz_str", 0),
41+
TimeUnitRange.class);
42+
43+
calciteEnumMap.put(
44+
argAnchor(DefaultExtensionCatalog.FUNCTIONS_DATETIME, "extract:req_pts", 0),
45+
TimeUnitRange.class);
46+
calciteEnumMap.put(
47+
argAnchor(DefaultExtensionCatalog.FUNCTIONS_DATETIME, "extract:req_ptstz_str", 0),
48+
TimeUnitRange.class);
49+
calciteEnumMap.put(
50+
argAnchor(DefaultExtensionCatalog.FUNCTIONS_DATETIME, "extract:req_date", 0),
51+
TimeUnitRange.class);
52+
calciteEnumMap.put(
53+
argAnchor(DefaultExtensionCatalog.FUNCTIONS_DATETIME, "extract:req_time", 0),
54+
TimeUnitRange.class);
3455
}
3556

36-
private static Optional<Enum> constructValue(
37-
Class<? extends Enum> cls, Supplier<Optional<String>> option) {
57+
private static Optional<Enum<?>> constructValue(
58+
Class<? extends Enum<?>> cls, Supplier<Optional<String>> option) {
3859
if (cls.isAssignableFrom(TimeUnitRange.class)) {
3960
return option.get().map(TimeUnitRange::valueOf);
4061
} else {
@@ -44,8 +65,9 @@ private static Optional<Enum> constructValue(
4465

4566
static Optional<RexLiteral> toRex(
4667
RexBuilder rexBuilder, SimpleExtension.Function fnDef, int argIdx, EnumArg e) {
47-
var aAnch = argAnchor(fnDef, argIdx);
48-
var v = Optional.ofNullable(calciteEnumMap.inverse().getOrDefault(aAnch, null));
68+
ArgAnchor aAnch = argAnchor(fnDef, argIdx);
69+
Optional<Class<? extends Enum<?>>> v =
70+
Optional.ofNullable(calciteEnumMap.getOrDefault(aAnch, null));
4971

5072
Supplier<Optional<String>> sOptionVal =
5173
() -> {
@@ -66,11 +88,11 @@ private static Optional<SimpleExtension.EnumArgument> findEnumArg(
6688
return Optional.empty();
6789
} else {
6890

69-
var args = function.args();
91+
List<Argument> args = function.args();
7092
if (args.size() <= enumAnchor.argIdx) {
7193
return Optional.empty();
7294
}
73-
var arg = args.get(enumAnchor.argIdx);
95+
Argument arg = args.get(enumAnchor.argIdx);
7496
if (arg instanceof SimpleExtension.EnumArgument ea) {
7597
return Optional.of(ea);
7698
} else {
@@ -79,17 +101,15 @@ private static Optional<SimpleExtension.EnumArgument> findEnumArg(
79101
}
80102
}
81103

82-
static Optional<EnumArg> fromRex(SimpleExtension.Function function, RexLiteral literal) {
104+
static Optional<EnumArg> fromRex(
105+
SimpleExtension.Function function, RexLiteral literal, int argIdx) {
83106
return switch (literal.getType().getSqlTypeName()) {
84107
case SYMBOL -> {
85108
Object v = literal.getValue();
86109
if (!literal.isNull() && (v instanceof Enum)) {
87-
Enum value = (Enum) v;
88-
Optional<ArgAnchor> enumAnchor =
89-
Optional.ofNullable(calciteEnumMap.getOrDefault(value.getClass(), null));
90-
yield enumAnchor
91-
.flatMap(en -> findEnumArg(function, en))
92-
.map(ea -> EnumArg.of(ea, value.name()));
110+
Enum<?> value = (Enum<?>) v;
111+
ArgAnchor enumAnchor = argAnchor(function, argIdx);
112+
yield findEnumArg(function, enumAnchor).map(ea -> EnumArg.of(ea, value.name()));
93113
} else {
94114
yield Optional.empty();
95115
}
@@ -98,8 +118,8 @@ static Optional<EnumArg> fromRex(SimpleExtension.Function function, RexLiteral l
98118
};
99119
}
100120

101-
static boolean canConvert(Enum value) {
102-
return value != null && calciteEnumMap.containsKey(value.getClass());
121+
static boolean canConvert(Enum<?> value) {
122+
return value != null && calciteEnumMap.containsValue(value.getClass());
103123
}
104124

105125
static boolean isEnumValue(RexNode value) {
@@ -116,6 +136,23 @@ public ArgAnchor(final SimpleExtension.FunctionAnchor fn, final int argIdx) {
116136
this.fn = fn;
117137
this.argIdx = argIdx;
118138
}
139+
140+
@Override
141+
public int hashCode() {
142+
return Objects.hash(fn, argIdx);
143+
}
144+
145+
@Override
146+
public boolean equals(Object obj) {
147+
if (this == obj) {
148+
return true;
149+
}
150+
if (!(obj instanceof ArgAnchor)) {
151+
return false;
152+
}
153+
ArgAnchor other = (ArgAnchor) obj;
154+
return Objects.equals(fn, other.fn) && argIdx == other.argIdx;
155+
}
119156
}
120157

121158
private static ArgAnchor argAnchor(String fnNS, String fnSig, int argIdx) {

isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java

Lines changed: 61 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,13 @@
55
import io.substrait.expression.EnumArg;
66
import io.substrait.expression.Expression;
77
import io.substrait.expression.Expression.FailureBehavior;
8+
import io.substrait.expression.Expression.PrecisionTimestampLiteral;
9+
import io.substrait.expression.Expression.PrecisionTimestampTZLiteral;
810
import io.substrait.expression.Expression.ScalarSubquery;
911
import io.substrait.expression.Expression.SetPredicate;
1012
import io.substrait.expression.Expression.SingleOrList;
1113
import io.substrait.expression.Expression.Switch;
14+
import io.substrait.expression.Expression.TimestampTZLiteral;
1215
import io.substrait.expression.FieldReference;
1316
import io.substrait.expression.FunctionArg;
1417
import io.substrait.expression.WindowBound;
@@ -209,20 +212,65 @@ public RexNode visit(Expression.DateLiteral expr) throws RuntimeException {
209212

210213
@Override
211214
public RexNode visit(Expression.TimestampLiteral expr) throws RuntimeException {
212-
// Expression.TimestampLiteral is microseconds
213-
// Construct a TimeStampString :
214-
// 1. Truncate microseconds to seconds
215-
// 2. Get the fraction seconds in precision of nanoseconds.
216-
// 3. Construct TimeStampString : seconds + fraction_seconds part.
217-
long microSec = expr.value();
218-
long seconds = TimeUnit.MICROSECONDS.toSeconds(microSec);
219-
int fracSecondsInNano =
220-
(int) (TimeUnit.MICROSECONDS.toNanos(microSec) - TimeUnit.SECONDS.toNanos(seconds));
215+
return rexBuilder.makeLiteral(
216+
getTimestampString(expr.value()), typeConverter.toCalcite(typeFactory, expr.getType()));
217+
}
221218

222-
TimestampString tsString =
223-
TimestampString.fromMillisSinceEpoch(TimeUnit.SECONDS.toMillis(seconds))
224-
.withNanos(fracSecondsInNano);
225-
return rexBuilder.makeLiteral(tsString, typeConverter.toCalcite(typeFactory, expr.getType()));
219+
@Override
220+
public RexNode visit(TimestampTZLiteral expr) throws RuntimeException {
221+
return rexBuilder.makeLiteral(
222+
getTimestampString(expr.value()), typeConverter.toCalcite(typeFactory, expr.getType()));
223+
}
224+
225+
@Override
226+
public RexNode visit(PrecisionTimestampLiteral expr) throws RuntimeException {
227+
return rexBuilder.makeLiteral(
228+
getTimestampString(expr.value(), expr.precision()),
229+
typeConverter.toCalcite(typeFactory, expr.getType()));
230+
}
231+
232+
@Override
233+
public RexNode visit(PrecisionTimestampTZLiteral expr) throws RuntimeException {
234+
return rexBuilder.makeLiteral(
235+
getTimestampString(expr.value(), expr.precision()),
236+
typeConverter.toCalcite(typeFactory, expr.getType()));
237+
}
238+
239+
private TimestampString getTimestampString(long microSec) {
240+
return getTimestampString(microSec, 6);
241+
}
242+
243+
private TimestampString getTimestampString(long value, int precision) {
244+
switch (precision) {
245+
case 0:
246+
return TimestampString.fromMillisSinceEpoch(TimeUnit.SECONDS.toMillis(value));
247+
case 3:
248+
{
249+
long seconds = TimeUnit.MILLISECONDS.toSeconds(value);
250+
int fracSecondsInNano =
251+
(int) (TimeUnit.MILLISECONDS.toNanos(value) - TimeUnit.SECONDS.toNanos(seconds));
252+
return TimestampString.fromMillisSinceEpoch(TimeUnit.SECONDS.toMillis(seconds))
253+
.withNanos(fracSecondsInNano);
254+
}
255+
case 6:
256+
{
257+
long seconds = TimeUnit.MICROSECONDS.toSeconds(value);
258+
int fracSecondsInNano =
259+
(int) (TimeUnit.MICROSECONDS.toNanos(value) - TimeUnit.SECONDS.toNanos(seconds));
260+
return TimestampString.fromMillisSinceEpoch(TimeUnit.SECONDS.toMillis(seconds))
261+
.withNanos(fracSecondsInNano);
262+
}
263+
case 9:
264+
{
265+
long seconds = TimeUnit.NANOSECONDS.toSeconds(value);
266+
int fracSecondsInNano = (int) (value - TimeUnit.SECONDS.toNanos(seconds));
267+
return TimestampString.fromMillisSinceEpoch(TimeUnit.SECONDS.toMillis(seconds))
268+
.withNanos(fracSecondsInNano);
269+
}
270+
default:
271+
throw new UnsupportedOperationException(
272+
String.format("Cannot handle PrecisionTimestamp with precision %d.", precision));
273+
}
226274
}
227275

228276
@Override

isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import java.util.Set;
3030
import java.util.function.Function;
3131
import java.util.stream.Collectors;
32+
import java.util.stream.IntStream;
3233
import java.util.stream.Stream;
3334
import org.apache.calcite.rel.type.RelDataType;
3435
import org.apache.calcite.rel.type.RelDataTypeFactory;
@@ -124,7 +125,7 @@ public Optional<SqlOperator> getSqlOperatorFromSubstraitFunc(String key, Type ou
124125
operator ->
125126
resolver.containsKey(operator)
126127
&& resolver.get(operator).types().contains(outputTypeStr))
127-
.collect(java.util.stream.Collectors.toList());
128+
.collect(Collectors.toList());
128129
// only one SqlOperator is possible
129130
if (resolvedOperators.size() == 1) {
130131
return Optional.of(resolvedOperators.get(0));
@@ -331,7 +332,7 @@ private Stream<String> matchKeys(List<RexNode> rexOperands, List<String> opTypes
331332
}
332333
return isOption ? List.of("req", "opt") : List.of(opType);
333334
})
334-
.collect(java.util.stream.Collectors.toList());
335+
.collect(Collectors.toList());
335336

336337
return Utils.crossProduct(argTypeLists)
337338
.map(typList -> typList.stream().collect(Collectors.joining("_")));
@@ -346,42 +347,43 @@ public Optional<T> attemptMatch(C call, Function<RexNode, Expression> topLevelCo
346347
* Once a FunctionVariant is resolved we can map the String Literal
347348
* to a EnumArg.
348349
*/
349-
var operands =
350-
call.getOperands().map(topLevelConverter).collect(java.util.stream.Collectors.toList());
351-
var opTypes =
352-
operands.stream().map(Expression::getType).collect(java.util.stream.Collectors.toList());
350+
List<RexNode> operandsList = call.getOperands().collect(Collectors.toList());
351+
List<Expression> operands =
352+
call.getOperands().map(topLevelConverter).collect(Collectors.toList());
353+
List<Type> opTypes = operands.stream().map(Expression::getType).collect(Collectors.toList());
353354

354-
var outputType = typeConverter.toSubstrait(call.getType());
355+
Type outputType = typeConverter.toSubstrait(call.getType());
355356

356357
// try to do a direct match
357-
var typeStrings =
358+
List<String> typeStrings =
358359
opTypes.stream().map(t -> t.accept(ToTypeString.INSTANCE)).collect(Collectors.toList());
359-
var possibleKeys =
360-
matchKeys(call.getOperands().collect(java.util.stream.Collectors.toList()), typeStrings);
360+
Stream<String> possibleKeys =
361+
matchKeys(call.getOperands().collect(Collectors.toList()), typeStrings);
361362

362-
var directMatchKey =
363+
Optional<String> directMatchKey =
363364
possibleKeys
364365
.map(argList -> name + ":" + argList)
365366
.filter(k -> directMap.containsKey(k))
366367
.findFirst();
367368

368369
if (directMatchKey.isPresent()) {
369-
var variant = directMap.get(directMatchKey.get());
370+
F variant = directMap.get(directMatchKey.get());
370371
variant.validateOutputType(operands, outputType);
371-
372372
List<FunctionArg> funcArgs =
373-
Streams.zip(
374-
call.getOperands(),
375-
operands.stream(),
376-
(r, o) -> {
373+
IntStream.range(0, operandsList.size())
374+
.mapToObj(
375+
i -> {
376+
RexNode r = operandsList.get(i);
377+
Expression o = operands.get(i);
377378
if (EnumConverter.isEnumValue(r)) {
378-
return EnumConverter.fromRex(variant, (RexLiteral) r).orElseGet(() -> null);
379+
return EnumConverter.fromRex(variant, (RexLiteral) r, i)
380+
.orElseGet(() -> null);
379381
} else {
380382
return o;
381383
}
382384
})
383-
.collect(java.util.stream.Collectors.toList());
384-
var allArgsMapped = funcArgs.stream().filter(e -> e == null).findFirst().isEmpty();
385+
.collect(Collectors.toList());
386+
boolean allArgsMapped = funcArgs.stream().filter(e -> e == null).findFirst().isEmpty();
385387
if (allArgsMapped) {
386388
return Optional.of(generateBinding(call, variant, funcArgs, outputType));
387389
} else {

0 commit comments

Comments
 (0)