Skip to content

Commit 8a3d32b

Browse files
mihailoale-dbyhuang-db
authored andcommitted
[SPARK-52385][SQL] Remove TempResolvedColumns from InheritAnalysisRules name
### What changes were proposed in this pull request? In this issue I propose to remove the `TempResolvedColumn` nodes when computing the name for `InheritAnalysisRules` nodes (they are not removed during the `ResolveAggregateFunctions` rule). This is the right behavior as `TempResolvedColumn` is an internal node and shouldn't be exposed to the users. The following query: ``` SELECT sum(col1) FROM VALUES(1) GROUP BY ALL HAVING sum(ifnull(col1, 1)) = 1 ``` Would have following analyzed plans: Before the change: ``` Project [sum(col1)#2L] +- Filter (sum(ifnull(tempresolvedcolumn(col1), 1))#4L = cast(1 as bigint)) +- Aggregate [sum(col1#0) AS sum(col1)#2L, sum(ifnull(tempresolvedcolumn(col1#0, col1, false), 1)) AS sum(ifnull(tempresolvedcolumn(col1), 1))#4L] +- LocalRelation [col1#0] ``` After the change: ``` Project [sum(col1)#2L] +- Filter (sum(ifnull(col1, 1))#4L = cast(1 as bigint)) +- Aggregate [sum(col1#0) AS sum(col1)#2L, sum(ifnull(tempresolvedcolumn(col1#0, col1, false), 1)) AS sum(ifnull(col1, 1))#4L] +- LocalRelation [col1#0] ``` ### Why are the changes needed? To improve (correct) `Alias` names. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes apache#51071 from mihailoale-db/trimtempresolvedcolumnforparameters. Authored-by: mihailoale-db <mihailo.aleksic@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 34416f7 commit 8a3d32b

File tree

6 files changed

+240
-21
lines changed

6 files changed

+240
-21
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ import org.apache.spark.sql.catalyst.trees.AlwaysProcess
4949
import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin
5050
import org.apache.spark.sql.catalyst.trees.TreePattern._
5151
import org.apache.spark.sql.catalyst.types.DataTypeUtils
52-
import org.apache.spark.sql.catalyst.util.{toPrettySQL, CharVarcharUtils}
52+
import org.apache.spark.sql.catalyst.util.{toPrettySQL, trimTempResolvedColumn, CharVarcharUtils}
5353
import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._
5454
import org.apache.spark.sql.connector.catalog.{View => _, _}
5555
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
@@ -2962,7 +2962,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
29622962
expr match {
29632963
case ae: AggregateExpression =>
29642964
val cleaned = trimTempResolvedColumn(ae)
2965-
val alias = Alias(cleaned, toPrettySQL(cleaned))()
2965+
val alias =
2966+
Alias(cleaned, toPrettySQL(e = cleaned, shouldTrimTempResolvedColumn = true))()
29662967
aggExprList += alias
29672968
alias.toAttribute
29682969
case grouping: Expression if agg.groupingExpressions.exists(grouping.semanticEquals) =>
@@ -2971,7 +2972,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
29712972
aggExprList += ne
29722973
ne.toAttribute
29732974
case other =>
2974-
val alias = Alias(other, toPrettySQL(other))()
2975+
val alias =
2976+
Alias(other, toPrettySQL(e = other, shouldTrimTempResolvedColumn = true))()
29752977
aggExprList += alias
29762978
alias.toAttribute
29772979
}
@@ -3001,10 +3003,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
30013003
}
30023004
}
30033005

3004-
private def trimTempResolvedColumn(input: Expression): Expression = input.transform {
3005-
case t: TempResolvedColumn => t.child
3006-
}
3007-
30083006
def resolveOperatorWithAggregate(
30093007
exprs: Seq[Expression],
30103008
agg: Aggregate,

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/LiteralFunctionResolution.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,12 @@ object LiteralFunctionResolution {
4141
// support CURRENT_DATE, CURRENT_TIMESTAMP, CURRENT_TIME,
4242
// CURRENT_USER, USER, SESSION_USER and grouping__id
4343
private val literalFunctions: Seq[(String, () => Expression, Expression => String)] = Seq(
44-
(CurrentDate().prettyName, () => CurrentDate(), toPrettySQL(_)),
45-
(CurrentTimestamp().prettyName, () => CurrentTimestamp(), toPrettySQL(_)),
46-
(CurrentTime().prettyName, () => CurrentTime(), toPrettySQL(_)),
47-
(CurrentUser().prettyName, () => CurrentUser(), toPrettySQL(_)),
48-
("user", () => CurrentUser(), toPrettySQL(_)),
49-
("session_user", () => CurrentUser(), toPrettySQL(_)),
44+
(CurrentDate().prettyName, () => CurrentDate(), e => toPrettySQL(e)),
45+
(CurrentTimestamp().prettyName, () => CurrentTimestamp(), e => toPrettySQL(e)),
46+
(CurrentTime().prettyName, () => CurrentTime(), e => toPrettySQL(e)),
47+
(CurrentUser().prettyName, () => CurrentUser(), e => toPrettySQL(e)),
48+
("user", () => CurrentUser(), e => toPrettySQL(e)),
49+
("session_user", () => CurrentUser(), e => toPrettySQL(e)),
5050
(VirtualColumn.hiveGroupingIdName, () => GroupingID(Nil), _ => VirtualColumn.hiveGroupingIdName)
5151
)
5252
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import java.nio.charset.StandardCharsets.UTF_8
2424
import com.google.common.io.ByteStreams
2525

2626
import org.apache.spark.internal.Logging
27+
import org.apache.spark.sql.catalyst.analysis.TempResolvedColumn
2728
import org.apache.spark.sql.catalyst.expressions._
2829
import org.apache.spark.sql.connector.catalog.MetadataColumn
2930
import org.apache.spark.sql.types.{MetadataBuilder, NumericType, StringType, StructType}
@@ -91,22 +92,44 @@ package object util extends Logging {
9192

9293
def stackTraceToString(t: Throwable): String = SparkErrorUtils.stackTraceToString(t)
9394

94-
// Replaces attributes, string literals, complex type extractors with their pretty form so that
95-
// generated column names don't contain back-ticks or double-quotes.
96-
def usePrettyExpression(e: Expression): Expression = e transform {
95+
/**
96+
* Replaces attributes, string literals, complex type extractors with their pretty form so that
97+
* generated column names don't contain back-ticks or double-quotes.
98+
* In case value of `shouldTrimTempResolvedColumn` is true, trim [[TempResolvedColumn]]s from the
99+
* expression tree to avoid having it in an [[Alias]] name.
100+
*/
101+
private def usePrettyExpression(
102+
e: Expression,
103+
shouldTrimTempResolvedColumn: Boolean = false): Expression = e transform {
97104
case a: Attribute => new PrettyAttribute(a)
98105
case Literal(s: UTF8String, StringType) => PrettyAttribute(s.toString, StringType)
99106
case Literal(v, t: NumericType) if v != null => PrettyAttribute(v.toString, t)
100107
case Literal(null, dataType) => PrettyAttribute("NULL", dataType)
101108
case e: GetStructField =>
102109
val name = e.name.getOrElse(e.childSchema(e.ordinal).name)
103-
PrettyAttribute(usePrettyExpression(e.child).sql + "." + name, e.dataType)
110+
PrettyAttribute(
111+
usePrettyExpression(e.child, shouldTrimTempResolvedColumn).sql + "." + name,
112+
e.dataType
113+
)
104114
case e: GetArrayStructFields =>
105-
PrettyAttribute(s"${usePrettyExpression(e.child)}.${e.field.name}", e.dataType)
115+
PrettyAttribute(
116+
s"${usePrettyExpression(e.child, shouldTrimTempResolvedColumn)}.${e.field.name}",
117+
e.dataType
118+
)
106119
case r: InheritAnalysisRules =>
107-
PrettyAttribute(r.makeSQLString(r.parameters.map(toPrettySQL)), r.dataType)
120+
val proposedParameters = if (shouldTrimTempResolvedColumn) {
121+
r.parameters.map(trimTempResolvedColumn)
122+
} else {
123+
r.parameters
124+
}
125+
PrettyAttribute(
126+
name = r.makeSQLString(
127+
proposedParameters.map(parameter => toPrettySQL(parameter, shouldTrimTempResolvedColumn))
128+
),
129+
dataType = r.dataType
130+
)
108131
case c: Cast if c.getTagValue(Cast.USER_SPECIFIED_CAST).isEmpty =>
109-
PrettyAttribute(usePrettyExpression(c.child).sql, c.dataType)
132+
PrettyAttribute(usePrettyExpression(c.child, shouldTrimTempResolvedColumn).sql, c.dataType)
110133
case p: PythonFuncExpression => PrettyPythonUDF(p.name, p.dataType, p.children)
111134
}
112135

@@ -122,7 +145,8 @@ package object util extends Logging {
122145
QuotingUtils.quoteIfNeeded(part)
123146
}
124147

125-
def toPrettySQL(e: Expression): String = usePrettyExpression(e).sql
148+
def toPrettySQL(e: Expression, shouldTrimTempResolvedColumn: Boolean = false): String =
149+
usePrettyExpression(e, shouldTrimTempResolvedColumn).sql
126150

127151
def escapeSingleQuotedString(str: String): String = {
128152
QuotingUtils.escapeSingleQuotedString(str)
@@ -148,6 +172,14 @@ package object util extends Logging {
148172
SparkStringUtils.truncatedString(seq, "", sep, "", maxFields)
149173
}
150174

175+
/**
176+
* Helper method used to remove all the [[TempResolvedColumn]]s from the provided expression
177+
* tree.
178+
*/
179+
def trimTempResolvedColumn(input: Expression): Expression = input.transform {
180+
case t: TempResolvedColumn => t.child
181+
}
182+
151183
val METADATA_COL_ATTR_KEY = "__metadata_col"
152184

153185
/**

sql/core/src/test/resources/sql-tests/analyzer-results/having.sql.out

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,3 +208,115 @@ Project [k#x, sum(v)#xL]
208208
+- Project [k#x, v#x]
209209
+- SubqueryAlias hav
210210
+- LocalRelation [k#x, v#x]
211+
212+
213+
-- !query
214+
SELECT sum(v) FROM hav HAVING avg(try_add(v, 1)) = 1
215+
-- !query analysis
216+
Project [sum(v)#xL]
217+
+- Filter (avg(try_add(v, 1))#x = cast(1 as double))
218+
+- Aggregate [sum(v#x) AS sum(v)#xL, avg(try_add(tempresolvedcolumn(v#x, v, false), 1)) AS avg(try_add(v, 1))#x]
219+
+- SubqueryAlias hav
220+
+- View (`hav`, [k#x, v#x])
221+
+- Project [cast(k#x as string) AS k#x, cast(v#x as int) AS v#x]
222+
+- Project [k#x, v#x]
223+
+- SubqueryAlias hav
224+
+- LocalRelation [k#x, v#x]
225+
226+
227+
-- !query
228+
SELECT sum(v) FROM hav HAVING sum(try_add(v, 1)) = 1
229+
-- !query analysis
230+
Project [sum(v)#xL]
231+
+- Filter (sum(try_add(v, 1))#xL = cast(1 as bigint))
232+
+- Aggregate [sum(v#x) AS sum(v)#xL, sum(try_add(tempresolvedcolumn(v#x, v, false), 1)) AS sum(try_add(v, 1))#xL]
233+
+- SubqueryAlias hav
234+
+- View (`hav`, [k#x, v#x])
235+
+- Project [cast(k#x as string) AS k#x, cast(v#x as int) AS v#x]
236+
+- Project [k#x, v#x]
237+
+- SubqueryAlias hav
238+
+- LocalRelation [k#x, v#x]
239+
240+
241+
-- !query
242+
SELECT sum(v) FROM hav HAVING sum(ifnull(v, 1)) = 1
243+
-- !query analysis
244+
Project [sum(v)#xL]
245+
+- Filter (sum(ifnull(v, 1))#xL = cast(1 as bigint))
246+
+- Aggregate [sum(v#x) AS sum(v)#xL, sum(ifnull(tempresolvedcolumn(v#x, v, false), 1)) AS sum(ifnull(v, 1))#xL]
247+
+- SubqueryAlias hav
248+
+- View (`hav`, [k#x, v#x])
249+
+- Project [cast(k#x as string) AS k#x, cast(v#x as int) AS v#x]
250+
+- Project [k#x, v#x]
251+
+- SubqueryAlias hav
252+
+- LocalRelation [k#x, v#x]
253+
254+
255+
-- !query
256+
SELECT sum(v) FROM hav GROUP BY ALL HAVING sum(ifnull(v, 1)) = 1
257+
-- !query analysis
258+
Project [sum(v)#xL]
259+
+- Filter (sum(ifnull(v, 1))#xL = cast(1 as bigint))
260+
+- Aggregate [sum(v#x) AS sum(v)#xL, sum(ifnull(tempresolvedcolumn(v#x, v, false), 1)) AS sum(ifnull(v, 1))#xL]
261+
+- SubqueryAlias hav
262+
+- View (`hav`, [k#x, v#x])
263+
+- Project [cast(k#x as string) AS k#x, cast(v#x as int) AS v#x]
264+
+- Project [k#x, v#x]
265+
+- SubqueryAlias hav
266+
+- LocalRelation [k#x, v#x]
267+
268+
269+
-- !query
270+
SELECT sum(v) FROM hav GROUP BY v HAVING sum(ifnull(v, 1)) = 1
271+
-- !query analysis
272+
Project [sum(v)#xL]
273+
+- Filter (sum(ifnull(v, 1))#xL = cast(1 as bigint))
274+
+- Aggregate [v#x], [sum(v#x) AS sum(v)#xL, sum(ifnull(tempresolvedcolumn(v#x, v, false), 1)) AS sum(ifnull(v, 1))#xL]
275+
+- SubqueryAlias hav
276+
+- View (`hav`, [k#x, v#x])
277+
+- Project [cast(k#x as string) AS k#x, cast(v#x as int) AS v#x]
278+
+- Project [k#x, v#x]
279+
+- SubqueryAlias hav
280+
+- LocalRelation [k#x, v#x]
281+
282+
283+
-- !query
284+
SELECT v + 1 FROM hav GROUP BY ALL HAVING avg(try_add(v, 1)) = 1
285+
-- !query analysis
286+
Project [(v + 1)#x]
287+
+- Filter (avg(try_add(v, 1))#x = cast(1 as double))
288+
+- Aggregate [(v#x + 1)], [(v#x + 1) AS (v + 1)#x, avg(try_add(tempresolvedcolumn(v#x, v, false), 1)) AS avg(try_add(v, 1))#x]
289+
+- SubqueryAlias hav
290+
+- View (`hav`, [k#x, v#x])
291+
+- Project [cast(k#x as string) AS k#x, cast(v#x as int) AS v#x]
292+
+- Project [k#x, v#x]
293+
+- SubqueryAlias hav
294+
+- LocalRelation [k#x, v#x]
295+
296+
297+
-- !query
298+
SELECT v + 1 FROM hav GROUP BY ALL HAVING avg(try_add(v, 1) + 1) = 1
299+
-- !query analysis
300+
Project [(v + 1)#x]
301+
+- Filter (avg((try_add(v, 1) + 1))#x = cast(1 as double))
302+
+- Aggregate [(v#x + 1)], [(v#x + 1) AS (v + 1)#x, avg((try_add(tempresolvedcolumn(v#x, v, false), 1) + 1)) AS avg((try_add(v, 1) + 1))#x]
303+
+- SubqueryAlias hav
304+
+- View (`hav`, [k#x, v#x])
305+
+- Project [cast(k#x as string) AS k#x, cast(v#x as int) AS v#x]
306+
+- Project [k#x, v#x]
307+
+- SubqueryAlias hav
308+
+- LocalRelation [k#x, v#x]
309+
310+
311+
-- !query
312+
SELECT sum(v) FROM hav GROUP BY ifnull(v, 1) + 1 order by ifnull(v, 1) + 1
313+
-- !query analysis
314+
Project [sum(v)#xL]
315+
+- Sort [(ifnull(v, 1) + 1)#x ASC NULLS FIRST], true
316+
+- Aggregate [(ifnull(v#x, 1) + 1)], [sum(v#x) AS sum(v)#xL, (ifnull(tempresolvedcolumn(v#x, v, false), 1) + 1) AS (ifnull(v, 1) + 1)#x]
317+
+- SubqueryAlias hav
318+
+- View (`hav`, [k#x, v#x])
319+
+- Project [cast(k#x as string) AS k#x, cast(v#x as int) AS v#x]
320+
+- Project [k#x, v#x]
321+
+- SubqueryAlias hav
322+
+- LocalRelation [k#x, v#x]

sql/core/src/test/resources/sql-tests/inputs/having.sql

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,13 @@ SELECT k, sum(v) FROM hav GROUP BY k HAVING sum(v) > 2 ORDER BY sum(v);
3939

4040
-- SPARK-28386: Resolve ORDER BY agg function with HAVING clause, while the agg function does not present on SELECT list
4141
SELECT k, sum(v) FROM hav GROUP BY k HAVING sum(v) > 2 ORDER BY avg(v);
42+
43+
-- SPARK-52385: Remove TempResolvedColumns from InheritAnalysisRules name
44+
SELECT sum(v) FROM hav HAVING avg(try_add(v, 1)) = 1;
45+
SELECT sum(v) FROM hav HAVING sum(try_add(v, 1)) = 1;
46+
SELECT sum(v) FROM hav HAVING sum(ifnull(v, 1)) = 1;
47+
SELECT sum(v) FROM hav GROUP BY ALL HAVING sum(ifnull(v, 1)) = 1;
48+
SELECT sum(v) FROM hav GROUP BY v HAVING sum(ifnull(v, 1)) = 1;
49+
SELECT v + 1 FROM hav GROUP BY ALL HAVING avg(try_add(v, 1)) = 1;
50+
SELECT v + 1 FROM hav GROUP BY ALL HAVING avg(try_add(v, 1) + 1) = 1;
51+
SELECT sum(v) FROM hav GROUP BY ifnull(v, 1) + 1 order by ifnull(v, 1) + 1;

sql/core/src/test/resources/sql-tests/results/having.sql.out

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,3 +152,70 @@ struct<k:string,sum(v):bigint>
152152
-- !query output
153153
one 6
154154
three 3
155+
156+
157+
-- !query
158+
SELECT sum(v) FROM hav HAVING avg(try_add(v, 1)) = 1
159+
-- !query schema
160+
struct<sum(v):bigint>
161+
-- !query output
162+
163+
164+
165+
-- !query
166+
SELECT sum(v) FROM hav HAVING sum(try_add(v, 1)) = 1
167+
-- !query schema
168+
struct<sum(v):bigint>
169+
-- !query output
170+
171+
172+
173+
-- !query
174+
SELECT sum(v) FROM hav HAVING sum(ifnull(v, 1)) = 1
175+
-- !query schema
176+
struct<sum(v):bigint>
177+
-- !query output
178+
179+
180+
181+
-- !query
182+
SELECT sum(v) FROM hav GROUP BY ALL HAVING sum(ifnull(v, 1)) = 1
183+
-- !query schema
184+
struct<sum(v):bigint>
185+
-- !query output
186+
187+
188+
189+
-- !query
190+
SELECT sum(v) FROM hav GROUP BY v HAVING sum(ifnull(v, 1)) = 1
191+
-- !query schema
192+
struct<sum(v):bigint>
193+
-- !query output
194+
1
195+
196+
197+
-- !query
198+
SELECT v + 1 FROM hav GROUP BY ALL HAVING avg(try_add(v, 1)) = 1
199+
-- !query schema
200+
struct<(v + 1):int>
201+
-- !query output
202+
203+
204+
205+
-- !query
206+
SELECT v + 1 FROM hav GROUP BY ALL HAVING avg(try_add(v, 1) + 1) = 1
207+
-- !query schema
208+
struct<(v + 1):int>
209+
-- !query output
210+
211+
212+
213+
-- !query
214+
SELECT sum(v) FROM hav GROUP BY ifnull(v, 1) + 1 order by ifnull(v, 1) + 1
215+
-- !query schema
216+
struct<sum(v):bigint>
217+
-- !query output
218+
1
219+
2
220+
3
221+
5

0 commit comments

Comments
 (0)