Skip to content

Commit fa1daa9

Browse files
committed
Fix UT
1 parent 8a30b2a commit fa1daa9

File tree

3 files changed

+50
-25
lines changed

3 files changed

+50
-25
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteTimeCastToTimestampNTZ.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
package org.apache.spark.sql.catalyst.analysis
1818

1919
import org.apache.spark.sql.catalyst.expressions.{Cast, CurrentDate, MakeTimestampNTZ}
20-
import org.apache.spark.sql.catalyst.optimizer.ComputeCurrentTime
2120
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2221
import org.apache.spark.sql.catalyst.rules.Rule
2322
import org.apache.spark.sql.types.{TimestampNTZType, TimeType}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/TimezoneAwareExpressionResolver.scala

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,7 @@ class TimezoneAwareExpressionResolver(expressionResolver: ExpressionResolver)
7171
other
7272
}
7373

74-
coercedExpr match {
75-
case c @ Cast(child, TimestampNTZType, _, _)
76-
if child.resolved && child.dataType.isInstanceOf[TimeType] =>
77-
MakeTimestampNTZ(CurrentDate(), child)
78-
case other =>
79-
other
80-
}
74+
rewriteTimeCastToTimestampNTZ(coercedExpr)
8175
}
8276

8377
/**
@@ -127,6 +121,14 @@ class TimezoneAwareExpressionResolver(expressionResolver: ExpressionResolver)
127121
}
128122
}
129123
}
124+
125+
private def rewriteTimeCastToTimestampNTZ(expr: Expression): Expression = expr match {
126+
case Cast(child, TimestampNTZType, _, _)
127+
if child.resolved && child.dataType.isInstanceOf[TimeType] =>
128+
MakeTimestampNTZ(CurrentDate(), child)
129+
case other =>
130+
other
131+
}
130132
}
131133

132134
object TimezoneAwareExpressionResolver {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/resolver/TimezoneAwareExpressionResolverSuite.scala

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,21 @@
1717

1818
package org.apache.spark.sql.catalyst.analysis.resolver
1919

20+
import java.sql.Time
21+
2022
import org.scalatestplus.mockito.MockitoSugar.mock
2123

2224
import org.apache.spark.SparkFunSuite
2325
import org.apache.spark.sql.catalyst.analysis.FunctionResolution
24-
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast, CurrentDate, Expression, Literal, MakeTimestampNTZ, TimeZoneAwareExpression}
26+
import org.apache.spark.sql.catalyst.expressions.{
27+
AttributeReference,
28+
Cast,
29+
CurrentDate,
30+
Expression,
31+
Literal,
32+
MakeTimestampNTZ,
33+
TimeZoneAwareExpression
34+
}
2535
import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation
2636
import org.apache.spark.sql.connector.catalog.CatalogManager
2737
import org.apache.spark.sql.types.{IntegerType, StringType, TimestampNTZType, TimeType}
@@ -38,6 +48,16 @@ class TimezoneAwareExpressionResolverSuite extends SparkFunSuite {
3848
override def resolve(expression: Expression): Expression = resolvedExpression
3949
}
4050

51+
class NewExpressionResolver(catalogManager: CatalogManager)
52+
extends ExpressionResolver(
53+
resolver = new Resolver(catalogManager),
54+
functionResolution =
55+
new FunctionResolution(catalogManager, Resolver.createRelationResolution(catalogManager)),
56+
planLogger = new PlanLogger
57+
) {
58+
override def resolve(expression: Expression): Expression = expression
59+
}
60+
4161
private val unresolvedChild =
4262
AttributeReference(name = "unresolvedChild", dataType = StringType)()
4363
private val resolvedChild = AttributeReference(name = "resolvedChild", dataType = IntegerType)()
@@ -50,6 +70,14 @@ class TimezoneAwareExpressionResolverSuite extends SparkFunSuite {
5070
expressionResolver
5171
)
5272

73+
private val newExpressionResolver = new NewExpressionResolver(
74+
catalogManager = mock[CatalogManager]
75+
)
76+
private val newTimezoneAwareExpressionResolver = new TimezoneAwareExpressionResolver(
77+
newExpressionResolver
78+
)
79+
80+
5381
test("TimeZoneAwareExpression resolution") {
5482
assert(castExpression.children.head == unresolvedChild)
5583
assert(castExpression.timeZoneId.isEmpty)
@@ -68,23 +96,19 @@ class TimezoneAwareExpressionResolverSuite extends SparkFunSuite {
6896
assert(resolvedExpression.getTagValue(Cast.USER_SPECIFIED_CAST).nonEmpty)
6997
}
7098

71-
test("SPARK-52617: rewrite TIME -> TIMESTAMP_NTZ cast to MakeTimestampNTZ") {
72-
// TIME: 15:30:00 -> seconds = 15*3600 + 30*60 = 55800
73-
val nanos = 55800L * 1_000_000_000L
74-
val timeLiteral = Literal(nanos, TimeType(6))
75-
76-
val castExpr = Cast(timeLiteral, TimestampNTZType)
77-
val rewrittenExpr = timezoneAwareExpressionResolver.resolve(castExpr)
99+
test("SPARK-52617: Rewrite Cast(TimeType -> TimestampNTZType) to MakeTimestampNTZ") {
100+
val millis = Time.valueOf("12:34:56").getTime
101+
val timeExpr = Literal(millis * 1000L, TimeType(6)) // microseconds since midnight
78102

79-
val expectedExpr = MakeTimestampNTZ(CurrentDate(), timeLiteral)
103+
val input = Cast(timeExpr, TimestampNTZType)
104+
val resolved =
105+
newExpressionResolver.getExpressionTreeTraversals.withNewTraversal(OneRowRelation()) {
106+
newTimezoneAwareExpressionResolver.resolve(input)
107+
}
108+
assert(resolved.isInstanceOf[MakeTimestampNTZ])
80109

81-
assert(
82-
rewrittenExpr.semanticEquals(expectedExpr),
83-
s"""
84-
|Expected:
85-
| $expectedExpr
86-
|But got:
87-
| $rewrittenExpr
88-
|""".stripMargin)
110+
val makeTs = resolved.asInstanceOf[MakeTimestampNTZ]
111+
assert(makeTs.left.isInstanceOf[CurrentDate])
112+
assert(makeTs.right.semanticEquals(timeExpr))
89113
}
90114
}

0 commit comments

Comments
 (0)