Skip to content

Commit d88ecba

Browse files
cloud-fanyhuang-db
authored andcommitted
[SPARK-52218][SQL] Make current datetime functions evaluable again
### What changes were proposed in this pull request? Current date-time functions should only be evaluated all together in the rule `FinishAnalysis`, so that they have the same consistent value across the entire query plan. To do that we mark current date-time functions as `Unevaluable` a while ago. Unfortunately, there are still some places that have to evaluate expressions earlier, and the expression may contain current date-time functions. apache#50800 is one such example. I think the move to mark these functions as `Unevaluable` is too aggressive. This PR proposes to revert this restriction to avoid potential regressions. I think the better way forward is to introduce query context to pre-evaluate these current datetime values, and the expressions just get values from the query context and makes sure it produces consistent values during the query life cycle. ### Why are the changes needed? to avoid potential regressions. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? existing tests ### Was this patch authored or co-authored using generative AI tooling? no Closes apache#50936 from cloud-fan/foldable. Lead-authored-by: Wenchen Fan <wenchen@databricks.com> Co-authored-by: Wenchen Fan <cloud0fan@gmail.com> Signed-off-by: Max Gekk <max.gekk@gmail.com>
1 parent 81dc094 commit d88ecba

File tree

7 files changed

+65
-44
lines changed

7 files changed

+65
-44
lines changed

sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkDateTimeUtils.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,16 @@ trait SparkDateTimeUtils {
134134
}
135135

136136
/**
137-
* Gets the number of microseconds since midnight using the session time zone.
137+
* Gets the number of microseconds since midnight using the given time zone.
138138
*/
139139
def instantToMicrosOfDay(instant: Instant, timezone: String): Long = {
140-
val zoneId = getZoneId(timezone)
140+
instantToMicrosOfDay(instant, getZoneId(timezone))
141+
}
142+
143+
/**
144+
* Gets the number of microseconds since midnight using the given time zone.
145+
*/
146+
def instantToMicrosOfDay(instant: Instant, zoneId: ZoneId): Long = {
141147
val localDateTime = LocalDateTime.ofInstant(instant, zoneId)
142148
localDateTime.toLocalTime.getLong(MICRO_OF_DAY)
143149
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -403,12 +403,16 @@ object ExpressionPatternBitMask {
403403
}
404404
}
405405

406+
406407
/**
407-
* An expression that cannot be evaluated but is guaranteed to be replaced with a foldable value
408-
* by query optimizer (e.g. CurrentDate).
408+
* An expression that cannot be evaluated. These expressions don't live past analysis or
409+
* optimization time (e.g. Star) and should not be evaluated during query planning and
410+
* execution.
409411
*/
410-
trait FoldableUnevaluable extends Expression {
411-
override def foldable: Boolean = true
412+
trait Unevaluable extends Expression {
413+
414+
/** Unevaluable is not foldable because we don't have an eval for it. */
415+
final override def foldable: Boolean = false
412416

413417
override def eval(input: InternalRow = null): Any =
414418
throw QueryExecutionErrors.cannotEvaluateExpressionError(this)
@@ -417,19 +421,6 @@ trait FoldableUnevaluable extends Expression {
417421
throw QueryExecutionErrors.cannotGenerateCodeForExpressionError(this)
418422
}
419423

420-
/**
421-
* An expression that cannot be evaluated. These expressions don't live past analysis or
422-
* optimization time (e.g. Star) and should not be evaluated during query planning and
423-
* execution.
424-
*/
425-
trait Unevaluable extends Expression with FoldableUnevaluable {
426-
427-
/** Unevaluable is not foldable by default because we don't have an eval for it.
428-
* Exception are expressions that will be replaced by a literal by Optimizer (e.g. CurrentDate).
429-
* Hence we allow overriding overriding of this field in special cases.
430-
*/
431-
final override def foldable: Boolean = false
432-
}
433424

434425
/**
435426
* An expression that gets replaced at runtime (currently by the optimizer) into a different

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,15 +145,18 @@ case class CurrentTimeZone()
145145
since = "1.5.0")
146146
// scalastyle:on line.size.limit
147147
case class CurrentDate(timeZoneId: Option[String] = None)
148-
extends LeafExpression with TimeZoneAwareExpression with FoldableUnevaluable {
148+
extends LeafExpression with TimeZoneAwareExpression with CodegenFallback {
149149
def this() = this(None)
150+
override def foldable: Boolean = true
150151
override def nullable: Boolean = false
151152
override def dataType: DataType = DateType
152153
final override def nodePatternsInternal(): Seq[TreePattern] = Seq(CURRENT_LIKE)
153154
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
154155
copy(timeZoneId = Option(timeZoneId))
155156

156157
override def prettyName: String = "current_date"
158+
159+
override def eval(input: InternalRow): Any = currentDate(zoneId)
157160
}
158161

159162
// scalastyle:off line.size.limit
@@ -180,9 +183,11 @@ object CurDateExpressionBuilder extends ExpressionBuilder {
180183
}
181184
}
182185

183-
abstract class CurrentTimestampLike() extends LeafExpression with FoldableUnevaluable {
186+
abstract class CurrentTimestampLike() extends LeafExpression with CodegenFallback {
187+
override def foldable: Boolean = true
184188
override def nullable: Boolean = false
185189
override def dataType: DataType = TimestampType
190+
override def eval(input: InternalRow): Any = instantToMicros(java.time.Instant.now())
186191
final override val nodePatterns: Seq[TreePattern] = Seq(CURRENT_LIKE)
187192
}
188193

@@ -246,13 +251,15 @@ case class Now() extends CurrentTimestampLike {
246251
group = "datetime_funcs",
247252
since = "3.4.0")
248253
case class LocalTimestamp(timeZoneId: Option[String] = None) extends LeafExpression
249-
with TimeZoneAwareExpression with FoldableUnevaluable {
254+
with TimeZoneAwareExpression with CodegenFallback {
250255
def this() = this(None)
256+
override def foldable: Boolean = true
251257
override def nullable: Boolean = false
252258
override def dataType: DataType = TimestampNTZType
253259
final override def nodePatternsInternal(): Seq[TreePattern] = Seq(CURRENT_LIKE)
254260
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
255261
copy(timeZoneId = Option(timeZoneId))
262+
override def eval(input: InternalRow): Any = localDateTimeToMicros(LocalDateTime.now(zoneId))
256263
override def prettyName: String = "localtimestamp"
257264
}
258265

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/timeExpressions.scala

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,16 @@ package org.apache.spark.sql.catalyst.expressions
2020
import java.time.DateTimeException
2121
import java.util.Locale
2222

23+
import org.apache.spark.sql.catalyst.InternalRow
2324
import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, TypeCheckResult}
2425
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess}
2526
import org.apache.spark.sql.catalyst.expressions.Cast.{toSQLExpr, toSQLId, toSQLType, toSQLValue}
27+
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
2628
import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke}
2729
import org.apache.spark.sql.catalyst.trees.TreePattern.{CURRENT_LIKE, TreePattern}
2830
import org.apache.spark.sql.catalyst.util.DateTimeUtils
2931
import org.apache.spark.sql.catalyst.util.TimeFormatter
30-
import org.apache.spark.sql.catalyst.util.TypeUtils.{ordinalNumber}
32+
import org.apache.spark.sql.catalyst.util.TypeUtils.ordinalNumber
3133
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
3234
import org.apache.spark.sql.internal.types.StringTypeWithCollation
3335
import org.apache.spark.sql.types.{AbstractDataType, DataType, DecimalType, IntegerType, ObjectType, TimeType, TypeCollection}
@@ -429,17 +431,25 @@ object SecondExpressionBuilder extends ExpressionBuilder {
429431
group = "datetime_funcs",
430432
since = "4.1.0"
431433
)
432-
case class CurrentTime(child: Expression = Literal(TimeType.MICROS_PRECISION))
433-
extends UnaryExpression with FoldableUnevaluable with ImplicitCastInputTypes {
434+
case class CurrentTime(
435+
child: Expression = Literal(TimeType.MICROS_PRECISION),
436+
timeZoneId: Option[String] = None) extends UnaryExpression
437+
with TimeZoneAwareExpression with ImplicitCastInputTypes with CodegenFallback {
434438

435439
def this() = {
436-
this(Literal(TimeType.MICROS_PRECISION))
440+
this(Literal(TimeType.MICROS_PRECISION), None)
437441
}
438442

439-
final override val nodePatterns: Seq[TreePattern] = Seq(CURRENT_LIKE)
443+
def this(child: Expression) = {
444+
this(child, None)
445+
}
446+
447+
final override def nodePatternsInternal(): Seq[TreePattern] = Seq(CURRENT_LIKE)
440448

441449
override def nullable: Boolean = false
442450

451+
override def foldable: Boolean = true
452+
443453
override def checkInputDataTypes(): TypeCheckResult = {
444454
// Check foldability
445455
if (!child.foldable) {
@@ -496,11 +506,19 @@ case class CurrentTime(child: Expression = Literal(TimeType.MICROS_PRECISION))
496506

497507
override def prettyName: String = "current_time"
498508

509+
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
510+
copy(timeZoneId = Option(timeZoneId))
511+
499512
override protected def withNewChildInternal(newChild: Expression): Expression = {
500513
copy(child = newChild)
501514
}
502515

503516
override def inputTypes: Seq[AbstractDataType] = Seq(IntegerType)
517+
518+
override def eval(input: InternalRow): Any = {
519+
val currentTimeOfDayMicros = DateTimeUtils.instantToMicrosOfDay(java.time.Instant.now(), zoneId)
520+
DateTimeUtils.truncateTimeMicrosToPrecision(currentTimeOfDayMicros, precision)
521+
}
504522
}
505523

506524
// scalastyle:off line.size.limit

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ class ResolveInlineTablesSuite extends AnalysisTest with BeforeAndAfter {
121121
Seq(CurrentTime())
122122
)
123123
)
124-
val resolved = ResolveInlineTables(table)
124+
val resolved = ResolveInlineTables(ResolveTimeZone(table))
125125
assert(resolved.isInstanceOf[ResolvedInlineTable],
126126
"Expected an inline table to be resolved into a ResolvedInlineTable")
127127

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import scala.language.postfixOps
2828
import scala.reflect.ClassTag
2929
import scala.util.Random
3030

31-
import org.apache.spark.{SparkArithmeticException, SparkDateTimeException, SparkException, SparkFunSuite, SparkIllegalArgumentException, SparkUpgradeException}
31+
import org.apache.spark.{SparkArithmeticException, SparkDateTimeException, SparkFunSuite, SparkIllegalArgumentException, SparkUpgradeException}
3232
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
3333
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
3434
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils, TimestampFormatter}
@@ -2140,14 +2140,4 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
21402140
}
21412141
}
21422142
}
2143-
2144-
test("datetime function CurrentDate and localtimestamp are Unevaluable") {
2145-
checkError(exception = intercept[SparkException] { CurrentDate(UTC_OPT).eval(EmptyRow) },
2146-
condition = "INTERNAL_ERROR",
2147-
parameters = Map("message" -> "Cannot evaluate expression: current_date(Some(UTC))"))
2148-
2149-
checkError(exception = intercept[SparkException] { LocalTimestamp(UTC_OPT).eval(EmptyRow) },
2150-
condition = "INTERNAL_ERROR",
2151-
parameters = Map("message" -> "Cannot evaluate expression: localtimestamp(Some(UTC))"))
2152-
}
21532143
}

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -919,17 +919,26 @@ private[hive] trait HiveInspectors {
919919
// We will enumerate all of the possible constant expressions, throw exception if we missed
920920
case Literal(_, dt) =>
921921
throw SparkException.internalError(s"Hive doesn't support the constant type [$dt].")
922-
// FoldableUnevaluable will be replaced with a foldable value in FinishAnalysis rule,
923-
// skip eval() for them.
924-
case _ if expr.collectFirst { case e: FoldableUnevaluable => e }.isDefined =>
925-
toInspector(expr.dataType)
926922
// ideally, we don't test the foldable here(but in optimizer), however, some of the
927923
// Hive UDF / UDAF requires its argument to be constant objectinspector, we do it eagerly.
928-
case _ if expr.foldable => toInspector(Literal.create(expr.eval(), expr.dataType))
924+
case _ if expr.foldable && canEarlyEval(expr) =>
925+
toInspector(Literal.create(expr.eval(), expr.dataType))
929926
// For those non constant expression, map to object inspector according to its data type
930927
case _ => toInspector(expr.dataType)
931928
}
932929

930+
// TODO: hard-coding a list here is not very robust. A better idea is to have some kind of query
931+
// context to pre-evaluate these current datetime values, and evaluating these expressions
932+
// just get the pre-evaluated values from the query context, so that we don't need to wait
933+
// for the rule `FinishAnalysis` to compute the values.
934+
private def canEarlyEval(e: Expression): Boolean = e match {
935+
case _: CurrentDate => false
936+
case _: CurrentTime => false
937+
case _: CurrentTimestampLike => false
938+
case _: LocalTimestamp => false
939+
case _ => e.children.forall(canEarlyEval)
940+
}
941+
933942
def inspectorToDataType(inspector: ObjectInspector): DataType = inspector match {
934943
case s: StructObjectInspector =>
935944
StructType(s.getAllStructFieldRefs.asScala.map(f =>

0 commit comments

Comments
 (0)