Skip to content

[SPARK-51554][SQL] Add the time_trunc() function #51547

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -3430,6 +3430,11 @@
"expects a string literal, but got <invalidValue>."
]
},
"TIMETRUNC_UNIT" : {
"message" : [
"expects one of the units without quotes HOUR, MINUTE, SECOND, MILLISECOND, MICROSECOND, but got the string literal <invalidValue>."
]
},
"ZERO_INDEX" : {
"message" : [
"expects %1$, %2$ and so on, but got %0$."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,7 @@ object FunctionRegistry {
expression[WindowTime]("window_time"),
expression[MakeDate]("make_date"),
expression[MakeTime]("make_time"),
expression[TimeTrunc]("time_trunc"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Snowflake has a date_trunc function that supports all datetime types: https://docs.snowflake.com/en/sql-reference/functions/date_trunc

How about other systems? which one should we follow?

expression[MakeTimestamp]("make_timestamp"),
expression[TryMakeTimestamp]("try_make_timestamp"),
expression[MonthName]("monthname"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.util.TimeFormatter
import org.apache.spark.sql.catalyst.util.TypeUtils.ordinalNumber
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.internal.types.StringTypeWithCollation
import org.apache.spark.sql.types.{AbstractDataType, AnyTimeType, ByteType, DataType, DayTimeIntervalType, DecimalType, IntegerType, ObjectType, TimeType}
import org.apache.spark.sql.types.{AbstractDataType, AnyTimeType, ByteType, DataType, DayTimeIntervalType, DecimalType, IntegerType, ObjectType, StringType, TimeType, TypeCollection}
import org.apache.spark.sql.types.DayTimeIntervalType.{HOUR, SECOND}
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -630,3 +630,72 @@ case class SubtractTimes(left: Expression, right: Expression)
newLeft: Expression, newRight: Expression): SubtractTimes =
copy(left = newLeft, right = newRight)
}

/**
* Returns time truncated to the unit specified by the unit.
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = """
_FUNC_(unit, expr) - Returns time `expr` truncated to the unit specified by the unit `unit`.
""",
arguments = """
Arguments:
* unit - the unit representing the unit to be truncated to
- "HOUR" - zero out the minutes and seconds with fraction part
- "MINUTE" - zero out the seconds with fraction part
- "SECOND" - zero out the seconds with fraction part
- "MILLISECOND" - zero out the microseconds
- "MICROSECOND" - zero out the nanoseconds
* expr - a TIME or STRING with a valid time format
""",
examples = """
Examples:
> SELECT _FUNC_('HOUR', '09:32:05.359');
09:00:00
> SELECT _FUNC_('MILLISECOND', '09:32:05.123456');
09:32:05.123
""",
group = "datetime_funcs",
since = "4.1.0")
// scalastyle:on line.size.limit
case class TimeTrunc(unit: Expression, time: Expression)
extends BinaryExpression with RuntimeReplaceable with ImplicitCastInputTypes {

override def left: Expression = unit
override def right: Expression = time

override def inputTypes: Seq[AbstractDataType] =
Seq(
StringTypeWithCollation(supportsTrimCollation = true),
TypeCollection(AnyTimeType, StringTypeWithCollation(supportsTrimCollation = true))
)

override def dataType: DataType = {
time.dataType match {
case TimeType(precision) => TimeType(precision)
case _ => TimeType()
}
}

override def prettyName: String = "time_trunc"

override protected def withNewChildrenInternal(
newUnit: Expression, newTime: Expression): TimeTrunc =
copy(unit = newUnit, time = newTime)

override def replacement: Expression = {
val castedTime = if (time.dataType.isInstanceOf[StringType]) {
Cast(time, TimeType())
} else {
time
}
StaticInvoke(
classOf[DateTimeUtils.type],
dataType,
"timeTrunc",
Seq(unit, castedTime),
Seq(unit.dataType, castedTime.dataType)
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,57 @@ object DateTimeUtils extends SparkDateTimeUtils {
}
}

/**
* Returns time truncated to the unit specified by the level.
*/
private def timeTrunc(level: Int, nanos: Long): Long = {
level match {
case TRUNC_TO_HOUR =>
val truncatedTime = nanosToLocalTime(nanos).
withMinute(0).
withSecond(0).
withNano(0)
localTimeToNanos(truncatedTime)
case TRUNC_TO_MINUTE =>
val localTime = nanosToLocalTime(nanos)
val truncatedTime = localTime.withSecond(0).withNano(0)
localTimeToNanos(truncatedTime)
case TRUNC_TO_SECOND =>
nanos - Math.floorMod(nanos, NANOS_PER_SECOND)
case TRUNC_TO_MILLISECOND =>
nanos - Math.floorMod(nanos, NANOS_PER_MILLIS)
case TRUNC_TO_MICROSECOND =>
nanos - Math.floorMod(nanos, NANOS_PER_MICROS)
case _ =>
throw new IllegalArgumentException(s"Unsupported time truncation level: $level")
}
}

/**
* Set of supported time truncation levels for TIME values.
*/
private val supportedTimeTruncLevels = Set(
TRUNC_TO_HOUR,
TRUNC_TO_MINUTE,
TRUNC_TO_SECOND,
TRUNC_TO_MILLISECOND,
TRUNC_TO_MICROSECOND
)

/**
* Returns time truncated to the unit specified by the level. Trunc level should be generated
* using `parseTruncLevel()`, and should be between TRUNC_TO_HOUR and TRUNC_TO_MICROSECOND.
*/
def timeTrunc(level: UTF8String, nanos: Long): Long = {
require(level != null, "Truncation level must not be null.")
require(nanos >= 0, "Nanoseconds must be non-negative.")
val truncLevel = parseTruncLevel(level)
if (!supportedTimeTruncLevels.contains(truncLevel)) {
throw QueryExecutionErrors.invalidTimeTruncUnitError("time_trunc", level.toString)
}
timeTrunc(truncLevel, nanos)
}

/**
* Returns the truncate level, could be from TRUNC_TO_MICROSECOND to TRUNC_TO_YEAR,
* or TRUNC_INVALID, TRUNC_INVALID means unsupported truncate level.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3067,6 +3067,20 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE
)
}

// Throws a SparkIllegalArgumentException when an invalid time truncation unit is specified.
// Note that the supported units are: HOUR, MINUTE, SECOND, MILLISECOND, MICROSECOND.
def invalidTimeTruncUnitError(
functionName: String,
invalidValue: String): Throwable = {
new SparkIllegalArgumentException(
errorClass = "INVALID_PARAMETER_VALUE.TIMETRUNC_UNIT",
messageParameters = Map(
"functionName" -> toSQLId(functionName),
"parameter" -> toSQLId("unit"),
"invalidValue" -> s"'$invalidValue'")
)
}

// Throws a SparkRuntimeException when a CHECK constraint is violated, including details of the
// violation. This is a Java-friendly version of the above method.
def checkViolationJava(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ package org.apache.spark.sql.catalyst.expressions

import java.time.{Duration, LocalTime}

import org.apache.spark.{SPARK_DOC_ROOT, SparkDateTimeException, SparkFunSuite}
import org.apache.spark.{SPARK_DOC_ROOT, SparkDateTimeException, SparkFunSuite, SparkIllegalArgumentException}
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.Cast.{toSQLId, toSQLValue}
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._
import org.apache.spark.sql.catalyst.util.SparkDateTimeUtils.localTimeToNanos
import org.apache.spark.sql.types.{DayTimeIntervalType, Decimal, DecimalType, IntegerType, StringType, TimeType}
import org.apache.spark.sql.types.DayTimeIntervalType.{DAY, HOUR, SECOND}

Expand Down Expand Up @@ -418,4 +419,122 @@ class TimeExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}
}

test("SPARK-51554: TimeTrunc") {
// Test cases for different truncation units - 09:32:05.359123.
val testTime = localTime(9, 32, 5, 359123)

// Test HOUR truncation.
checkEvaluation(
TimeTrunc(Literal("HOUR"), Literal(testTime, TimeType())),
localTime(9, 0, 0, 0)
)
// Test MINUTE truncation.
checkEvaluation(
TimeTrunc(Literal("MINUTE"), Literal(testTime, TimeType())),
localTime(9, 32, 0, 0)
)
// Test SECOND truncation.
checkEvaluation(
TimeTrunc(Literal("SECOND"), Literal(testTime, TimeType())),
localTime(9, 32, 5, 0)
)
// Test MILLISECOND truncation.
checkEvaluation(
TimeTrunc(Literal("MILLISECOND"), Literal(testTime, TimeType())),
localTime(9, 32, 5, 359000)
)
// Test MICROSECOND truncation.
checkEvaluation(
TimeTrunc(Literal("MICROSECOND"), Literal(testTime, TimeType())),
testTime
)

// Test case-insensitive units.
checkEvaluation(
TimeTrunc(Literal("hour"), Literal(testTime, TimeType())),
localTime(9, 0, 0, 0)
)
checkEvaluation(
TimeTrunc(Literal("Hour"), Literal(testTime, TimeType())),
localTime(9, 0, 0, 0)
)
checkEvaluation(
TimeTrunc(Literal("hoUR"), Literal(testTime, TimeType())),
localTime(9, 0, 0, 0)
)

// Test invalid units.
val invalidUnits: Seq[String] = Seq("MS", "INVALID", "ABC", "XYZ", " ", "")
invalidUnits.foreach { unit =>
checkError(
exception = intercept[SparkIllegalArgumentException] {
TimeTrunc(Literal(unit), Literal(testTime, TimeType())).eval()
},
condition = "INVALID_PARAMETER_VALUE.TIMETRUNC_UNIT",
parameters = Map(
"functionName" -> "`time_trunc`",
"parameter" -> "`unit`",
"invalidValue" -> s"'$unit'"
)
)
}

// Test null inputs.
checkEvaluation(
TimeTrunc(Literal.create(null, StringType), Literal(testTime, TimeType())),
null
)
checkEvaluation(
TimeTrunc(Literal("HOUR"), Literal.create(null, TimeType())),
null
)
checkEvaluation(
TimeTrunc(Literal.create(null, StringType), Literal.create(null, TimeType())),
null
)

// Test edge cases.
val midnightTime = localTime(0, 0, 0, 0)
val supportedUnits: Seq[String] = Seq("HOUR", "MINUTE", "SECOND", "MILLISECOND", "MICROSECOND")
supportedUnits.foreach { unit =>
checkEvaluation(
TimeTrunc(Literal(unit), Literal(midnightTime, TimeType())),
midnightTime
)
}

val maxTime = localTimeToNanos(LocalTime.of(23, 59, 59, 999999999))
checkEvaluation(
TimeTrunc(Literal("HOUR"), Literal(maxTime, TimeType())),
localTime(23, 0, 0, 0)
)
checkEvaluation(
TimeTrunc(Literal("MICROSECOND"), Literal(maxTime, TimeType())),
localTimeToNanos(LocalTime.of(23, 59, 59, 999999000))
)

// Test precision loss.
val timeWithMicroPrecision = localTime(15, 30, 45, 123456)
val timeTruncMin = TimeTrunc(Literal("MINUTE"), Literal(timeWithMicroPrecision, TimeType(3)))
assert(timeTruncMin.dataType == TimeType(3))
checkEvaluation(timeTruncMin, localTime(15, 30, 0, 0))
val timeTruncSec = TimeTrunc(Literal("SECOND"), Literal(timeWithMicroPrecision, TimeType(3)))
assert(timeTruncSec.dataType == TimeType(3))
checkEvaluation(timeTruncSec, localTime(15, 30, 45, 0))

// Test string time overload.
checkEvaluation(
TimeTrunc(Literal("HOUR"), Literal("15:30:45")),
localTime(15, 0, 0, 0)
)
checkEvaluation(
TimeTrunc(Literal("SECOND"), Literal("15:30:45.123")),
localTime(15, 30, 45, 0)
)
checkEvaluation(
TimeTrunc(Literal("MICROSECOND"), Literal("15:30:45.123456789")),
localTime(15, 30, 45, 123456)
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -766,6 +766,43 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers with SQLHelper {
}
}

test("SPARK-51554: time truncation using timeTrunc") {
// 01:02:03.400500600
val input = localTimeToNanos(LocalTime.of(1, 2, 3, 400500600))
// Truncate the minutes, seconds, and fractions of seconds. Result is: 01:00:00.
assert(DateTimeUtils.timeTrunc(UTF8String.fromString("HOUR"), input) === 3600000000000L)
// Truncate the seconds and fractions of seconds. Result is: 01:02:00.
assert(DateTimeUtils.timeTrunc(UTF8String.fromString("MINUTE"), input) === 3720000000000L)
// Truncate the fractions of seconds. Result is: 01:02:03.
assert(DateTimeUtils.timeTrunc(UTF8String.fromString("SECOND"), input) === 3723000000000L)
// Truncate the milliseconds. Result is: 01:02:03.400.
assert(DateTimeUtils.timeTrunc(UTF8String.fromString("MILLISECOND"), input) === 3723400000000L)
// Truncate the microseconds. Result is: 01:02:03.400500.
assert(DateTimeUtils.timeTrunc(UTF8String.fromString("MICROSECOND"), input) === 3723400500000L)

// 00:00:00
val midnight = localTimeToNanos(LocalTime.MIDNIGHT)
// Truncate the minutes, seconds, and fractions of seconds. Result is: 00:00:00.
assert(DateTimeUtils.timeTrunc(UTF8String.fromString("HOUR"), midnight) === 0)
// Truncate the seconds and fractions of seconds. Result is: 00:00:00.
assert(DateTimeUtils.timeTrunc(UTF8String.fromString("MINUTE"), midnight) === 0)
// Truncate the fractions of seconds. Result is: 00:00:00.
assert(DateTimeUtils.timeTrunc(UTF8String.fromString("SECOND"), midnight) === 0)
// Truncate the milliseconds. Result is: Result is: 00:00:00.
assert(DateTimeUtils.timeTrunc(UTF8String.fromString("MILLISECOND"), midnight) === 0)
// Truncate the microseconds. Result is: Result is: 00:00:00.
assert(DateTimeUtils.timeTrunc(UTF8String.fromString("MICROSECOND"), midnight) === 0)

// Unsupported truncation levels.
Seq("DAY", "WEEK", "MONTH", "QUARTER", "YEAR", "INVALID", "ABC", "XYZ", "MS", " ", "", null).
map(UTF8String.fromString).foreach { level =>
intercept[IllegalArgumentException] {
DateTimeUtils.timeTrunc(level, input)
DateTimeUtils.timeTrunc(level, midnight)
}
}
}

test("SPARK-35664: microseconds to LocalDateTime") {
assert(microsToLocalDateTime(0) == LocalDateTime.parse("1970-01-01T00:00:00"))
assert(microsToLocalDateTime(100) == LocalDateTime.parse("1970-01-01T00:00:00.0001"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@
| org.apache.spark.sql.catalyst.expressions.Subtract | - | SELECT 2 - 1 | struct<(2 - 1):int> |
| org.apache.spark.sql.catalyst.expressions.Tan | tan | SELECT tan(0) | struct<TAN(0):double> |
| org.apache.spark.sql.catalyst.expressions.Tanh | tanh | SELECT tanh(0) | struct<TANH(0):double> |
| org.apache.spark.sql.catalyst.expressions.TimeTrunc | time_trunc | SELECT time_trunc('HOUR', '09:32:05.359') | struct<time_trunc(HOUR, 09:32:05.359):time(6)> |
| org.apache.spark.sql.catalyst.expressions.TimeWindow | window | SELECT a, window.start, window.end, count(*) as cnt FROM VALUES ('A1', '2021-01-01 00:00:00'), ('A1', '2021-01-01 00:04:30'), ('A1', '2021-01-01 00:06:00'), ('A2', '2021-01-01 00:01:00') AS tab(a, b) GROUP by a, window(b, '5 minutes') ORDER BY a, start | struct<a:string,start:timestamp,end:timestamp,cnt:bigint> |
| org.apache.spark.sql.catalyst.expressions.ToBinary | to_binary | SELECT to_binary('abc', 'utf-8') | struct<to_binary(abc, utf-8):binary> |
| org.apache.spark.sql.catalyst.expressions.ToCharacterBuilder | to_char | SELECT to_char(454, '999') | struct<to_char(454, 999):string> |
Expand Down Expand Up @@ -480,4 +481,4 @@
| org.apache.spark.sql.catalyst.expressions.xml.XPathList | xpath | SELECT xpath('<a><b>b1</b><b>b2</b><b>b3</b><c>c1</c><c>c2</c></a>','a/b/text()') | struct<xpath(<a><b>b1</b><b>b2</b><b>b3</b><c>c1</c><c>c2</c></a>, a/b/text()):array<string>> |
| org.apache.spark.sql.catalyst.expressions.xml.XPathLong | xpath_long | SELECT xpath_long('<a><b>1</b><b>2</b></a>', 'sum(a/b)') | struct<xpath_long(<a><b>1</b><b>2</b></a>, sum(a/b)):bigint> |
| org.apache.spark.sql.catalyst.expressions.xml.XPathShort | xpath_short | SELECT xpath_short('<a><b>1</b><b>2</b></a>', 'sum(a/b)') | struct<xpath_short(<a><b>1</b><b>2</b></a>, sum(a/b)):smallint> |
| org.apache.spark.sql.catalyst.expressions.xml.XPathString | xpath_string | SELECT xpath_string('<a><b>b</b><c>cc</c></a>','a/c') | struct<xpath_string(<a><b>b</b><c>cc</c></a>, a/c):string> |
| org.apache.spark.sql.catalyst.expressions.xml.XPathString | xpath_string | SELECT xpath_string('<a><b>b</b><c>cc</c></a>','a/c') | struct<xpath_string(<a><b>b</b><c>cc</c></a>, a/c):string> |
Loading