Skip to content

[SPARK-51554][SQL] Add the time_trunc() function #51547

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
5 changes: 5 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -3430,6 +3430,11 @@
"expects a string literal, but got <invalidValue>."
]
},
"TIMETRUNC_UNIT" : {
"message" : [
"expects one of the units 'HOUR', 'MINUTE', 'SECOND', 'MILLISECOND', 'MICROSECOND', but got the string literal <invalidValue>."
]
},
"ZERO_INDEX" : {
"message" : [
"expects %1$, %2$ and so on, but got %0$."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,7 @@ object FunctionRegistry {
expression[WindowTime]("window_time"),
expression[MakeDate]("make_date"),
expression[MakeTime]("make_time"),
expression[TimeTrunc]("time_trunc"),
expression[MakeTimestamp]("make_timestamp"),
expression[TryMakeTimestamp]("try_make_timestamp"),
expression[MonthName]("monthname"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -630,3 +630,56 @@ case class SubtractTimes(left: Expression, right: Expression)
newLeft: Expression, newRight: Expression): SubtractTimes =
copy(left = newLeft, right = newRight)
}

// scalastyle:off line.size.limit
@ExpressionDescription(
usage = """
_FUNC_(unit, expr) - Returns time `expr` truncated to the unit specified by the unit `unit`.
""",
arguments = """
Arguments:
* unit - the unit representing the unit to be truncated to
- "HOUR" - zero out the minutes and seconds with fraction part
- "MINUTE" - zero out the seconds with fraction part
- "SECOND" - zero out the seconds with fraction part
- "MILLISECOND" - zero out the microseconds
- "MICROSECOND" - zero out the nanoseconds
* expr - a TIME with a valid time format
""",
examples = """
Examples:
> SELECT _FUNC_('HOUR', TIME'09:32:05.359');
09:00:00
> SELECT _FUNC_('MILLISECOND', TIME'09:32:05.123456');
09:32:05.123
""",
group = "datetime_funcs",
since = "4.1.0")
// scalastyle:on line.size.limit
case class TimeTrunc(unit: Expression, time: Expression)
extends BinaryExpression with RuntimeReplaceable with ImplicitCastInputTypes {

override def left: Expression = unit
override def right: Expression = time

override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeWithCollation(supportsTrimCollation = true), AnyTimeType)

override def dataType: DataType = time.dataType

override def prettyName: String = "time_trunc"

override protected def withNewChildrenInternal(
newUnit: Expression, newTime: Expression): TimeTrunc =
copy(unit = newUnit, time = newTime)

override def replacement: Expression = {
StaticInvoke(
classOf[DateTimeUtils.type],
dataType,
"timeTrunc",
Seq(unit, time),
Seq(unit.dataType, time.dataType)
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,43 @@ object DateTimeUtils extends SparkDateTimeUtils {
}
}

/**
* Returns time truncated to the unit specified by the level.
*/
private def parseTimeTruncLevel(level: UTF8String): ChronoUnit = {
if (level == null) {
throw QueryExecutionErrors.invalidTimeTruncUnitError("time_trunc", "null")
}
level.toString.toUpperCase(Locale.ROOT) match {
case "HOUR" => ChronoUnit.HOURS
case "MINUTE" => ChronoUnit.MINUTES
case "SECOND" => ChronoUnit.SECONDS
case "MILLISECOND" => ChronoUnit.MILLIS
case "MICROSECOND" => ChronoUnit.MICROS
case _ =>
throw QueryExecutionErrors.invalidTimeTruncUnitError("time_trunc", level.toString)
}
}

/**
* Returns time truncated to the unit specified by the level. Trunc level should be generated
* using `parseTruncLevel()`, and should be between TRUNC_TO_HOUR and TRUNC_TO_MICROSECOND.
*/
def timeTrunc(level: UTF8String, nanos: Long): Long = {
localTimeToNanos(nanosToLocalTime(nanos).truncatedTo(parseTimeTruncLevel(level)))
}

/**
* Set of supported time truncation levels for TIME values.
*/
private val supportedTimeTruncLevels = Set(
TRUNC_TO_HOUR,
TRUNC_TO_MINUTE,
TRUNC_TO_SECOND,
TRUNC_TO_MILLISECOND,
TRUNC_TO_MICROSECOND
)

/**
* Returns the truncate level, could be from TRUNC_TO_MICROSECOND to TRUNC_TO_YEAR,
* or TRUNC_INVALID, TRUNC_INVALID means unsupported truncate level.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3067,6 +3067,21 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE
)
}

// Throws a SparkIllegalArgumentException when an invalid time truncation unit is specified.
// Note that the supported units are: HOUR, MINUTE, SECOND, MILLISECOND, MICROSECOND.
def invalidTimeTruncUnitError(
functionName: String,
invalidValue: String): Throwable = {
new SparkIllegalArgumentException(
errorClass = "INVALID_PARAMETER_VALUE.TIMETRUNC_UNIT",
messageParameters = Map(
"functionName" -> toSQLId(functionName),
"parameter" -> toSQLId("unit"),
"invalidValue" -> toSQLValue(invalidValue)
)
)
}

// Throws a SparkRuntimeException when a CHECK constraint is violated, including details of the
// violation. This is a Java-friendly version of the above method.
def checkViolationJava(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ package org.apache.spark.sql.catalyst.expressions

import java.time.{Duration, LocalTime}

import org.apache.spark.{SPARK_DOC_ROOT, SparkDateTimeException, SparkFunSuite}
import org.apache.spark.{SPARK_DOC_ROOT, SparkDateTimeException, SparkFunSuite, SparkIllegalArgumentException}
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.Cast.{toSQLId, toSQLValue}
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._
import org.apache.spark.sql.catalyst.util.SparkDateTimeUtils.localTimeToNanos
import org.apache.spark.sql.types.{DayTimeIntervalType, Decimal, DecimalType, IntegerType, StringType, TimeType}
import org.apache.spark.sql.types.DayTimeIntervalType.{DAY, HOUR, SECOND}

Expand Down Expand Up @@ -418,4 +419,108 @@ class TimeExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}
}

test("SPARK-51554: TimeTrunc") {
// Test cases for different truncation units - 09:32:05.359123.
val testTime = localTime(9, 32, 5, 359123)

// Test HOUR truncation.
checkEvaluation(
TimeTrunc(Literal("HOUR"), Literal(testTime, TimeType())),
localTime(9, 0, 0, 0)
)
// Test MINUTE truncation.
checkEvaluation(
TimeTrunc(Literal("MINUTE"), Literal(testTime, TimeType())),
localTime(9, 32, 0, 0)
)
// Test SECOND truncation.
checkEvaluation(
TimeTrunc(Literal("SECOND"), Literal(testTime, TimeType())),
localTime(9, 32, 5, 0)
)
// Test MILLISECOND truncation.
checkEvaluation(
TimeTrunc(Literal("MILLISECOND"), Literal(testTime, TimeType())),
localTime(9, 32, 5, 359000)
)
// Test MICROSECOND truncation.
checkEvaluation(
TimeTrunc(Literal("MICROSECOND"), Literal(testTime, TimeType())),
testTime
)

// Test case-insensitive units.
checkEvaluation(
TimeTrunc(Literal("hour"), Literal(testTime, TimeType())),
localTime(9, 0, 0, 0)
)
checkEvaluation(
TimeTrunc(Literal("Hour"), Literal(testTime, TimeType())),
localTime(9, 0, 0, 0)
)
checkEvaluation(
TimeTrunc(Literal("hoUR"), Literal(testTime, TimeType())),
localTime(9, 0, 0, 0)
)

// Test invalid units.
val invalidUnits: Seq[String] = Seq("MS", "INVALID", "ABC", "XYZ", " ", "")
invalidUnits.foreach { unit =>
checkError(
exception = intercept[SparkIllegalArgumentException] {
TimeTrunc(Literal(unit), Literal(testTime, TimeType())).eval()
},
condition = "INVALID_PARAMETER_VALUE.TIMETRUNC_UNIT",
parameters = Map(
"functionName" -> "`time_trunc`",
"parameter" -> "`unit`",
"invalidValue" -> s"'$unit'"
)
)
}

// Test null inputs.
checkEvaluation(
TimeTrunc(Literal.create(null, StringType), Literal(testTime, TimeType())),
null
)
checkEvaluation(
TimeTrunc(Literal("HOUR"), Literal.create(null, TimeType())),
null
)
checkEvaluation(
TimeTrunc(Literal.create(null, StringType), Literal.create(null, TimeType())),
null
)

// Test edge cases.
val midnightTime = localTime(0, 0, 0, 0)
val supportedUnits: Seq[String] = Seq("HOUR", "MINUTE", "SECOND", "MILLISECOND", "MICROSECOND")
supportedUnits.foreach { unit =>
checkEvaluation(
TimeTrunc(Literal(unit), Literal(midnightTime, TimeType())),
midnightTime
)
}

val maxTime = localTimeToNanos(LocalTime.of(23, 59, 59, 999999999))
checkEvaluation(
TimeTrunc(Literal("HOUR"), Literal(maxTime, TimeType())),
localTime(23, 0, 0, 0)
)
checkEvaluation(
TimeTrunc(Literal("MICROSECOND"), Literal(maxTime, TimeType())),
localTimeToNanos(LocalTime.of(23, 59, 59, 999999000))
)

// Test precision loss.
val timeWithMicroPrecision = localTime(15, 30, 45, 123456)
val timeTruncMin = TimeTrunc(Literal("MINUTE"), Literal(timeWithMicroPrecision, TimeType(3)))
assert(timeTruncMin.dataType == TimeType(3))
checkEvaluation(timeTruncMin, localTime(15, 30, 0, 0))
val timeTruncSec = TimeTrunc(Literal("SECOND"), Literal(timeWithMicroPrecision, TimeType(3)))
assert(timeTruncSec.dataType == TimeType(3))
checkEvaluation(timeTruncSec, localTime(15, 30, 45, 0))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -766,6 +766,43 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers with SQLHelper {
}
}

test("SPARK-51554: time truncation using timeTrunc") {
// 01:02:03.400500600
val input = localTimeToNanos(LocalTime.of(1, 2, 3, 400500600))
// Truncate the minutes, seconds, and fractions of seconds. Result is: 01:00:00.
assert(DateTimeUtils.timeTrunc(UTF8String.fromString("HOUR"), input) === 3600000000000L)
// Truncate the seconds and fractions of seconds. Result is: 01:02:00.
assert(DateTimeUtils.timeTrunc(UTF8String.fromString("MINUTE"), input) === 3720000000000L)
// Truncate the fractions of seconds. Result is: 01:02:03.
assert(DateTimeUtils.timeTrunc(UTF8String.fromString("SECOND"), input) === 3723000000000L)
// Truncate the milliseconds. Result is: 01:02:03.400.
assert(DateTimeUtils.timeTrunc(UTF8String.fromString("MILLISECOND"), input) === 3723400000000L)
// Truncate the microseconds. Result is: 01:02:03.400500.
assert(DateTimeUtils.timeTrunc(UTF8String.fromString("MICROSECOND"), input) === 3723400500000L)

// 00:00:00
val midnight = localTimeToNanos(LocalTime.MIDNIGHT)
// Truncate the minutes, seconds, and fractions of seconds. Result is: 00:00:00.
assert(DateTimeUtils.timeTrunc(UTF8String.fromString("HOUR"), midnight) === 0)
// Truncate the seconds and fractions of seconds. Result is: 00:00:00.
assert(DateTimeUtils.timeTrunc(UTF8String.fromString("MINUTE"), midnight) === 0)
// Truncate the fractions of seconds. Result is: 00:00:00.
assert(DateTimeUtils.timeTrunc(UTF8String.fromString("SECOND"), midnight) === 0)
// Truncate the milliseconds. Result is: Result is: 00:00:00.
assert(DateTimeUtils.timeTrunc(UTF8String.fromString("MILLISECOND"), midnight) === 0)
// Truncate the microseconds. Result is: Result is: 00:00:00.
assert(DateTimeUtils.timeTrunc(UTF8String.fromString("MICROSECOND"), midnight) === 0)

// Unsupported truncation levels.
Seq("DAY", "WEEK", "MONTH", "QUARTER", "YEAR", "INVALID", "ABC", "XYZ", "MS", " ", "", null).
map(UTF8String.fromString).foreach { level =>
intercept[IllegalArgumentException] {
DateTimeUtils.timeTrunc(level, input)
DateTimeUtils.timeTrunc(level, midnight)
}
}
}

test("SPARK-35664: microseconds to LocalDateTime") {
assert(microsToLocalDateTime(0) == LocalDateTime.parse("1970-01-01T00:00:00"))
assert(microsToLocalDateTime(100) == LocalDateTime.parse("1970-01-01T00:00:00.0001"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@
| org.apache.spark.sql.catalyst.expressions.Subtract | - | SELECT 2 - 1 | struct<(2 - 1):int> |
| org.apache.spark.sql.catalyst.expressions.Tan | tan | SELECT tan(0) | struct<TAN(0):double> |
| org.apache.spark.sql.catalyst.expressions.Tanh | tanh | SELECT tanh(0) | struct<TANH(0):double> |
| org.apache.spark.sql.catalyst.expressions.TimeTrunc | time_trunc | SELECT time_trunc('HOUR', TIME'09:32:05.359') | struct<time_trunc(HOUR, TIME '09:32:05.359'):time(6)> |
| org.apache.spark.sql.catalyst.expressions.TimeWindow | window | SELECT a, window.start, window.end, count(*) as cnt FROM VALUES ('A1', '2021-01-01 00:00:00'), ('A1', '2021-01-01 00:04:30'), ('A1', '2021-01-01 00:06:00'), ('A2', '2021-01-01 00:01:00') AS tab(a, b) GROUP by a, window(b, '5 minutes') ORDER BY a, start | struct<a:string,start:timestamp,end:timestamp,cnt:bigint> |
| org.apache.spark.sql.catalyst.expressions.ToBinary | to_binary | SELECT to_binary('abc', 'utf-8') | struct<to_binary(abc, utf-8):binary> |
| org.apache.spark.sql.catalyst.expressions.ToCharacterBuilder | to_char | SELECT to_char(454, '999') | struct<to_char(454, 999):string> |
Expand Down
Loading