Skip to content

Commit 255249f

Browse files
pavle-martinovic_datayhuang-db
authored andcommitted
[SPARK-52353][SQL] Fix bug with wrong constraints in LogicalRDDs referencing previous iterations in UnionLoop
### What changes were proposed in this pull request? Modify the way that we write statistics and constraints in LogicalRDDs that refer to previous iterations in UnionLoopExec. ### Why are the changes needed? LogicalRDD constraints are currently incorrectly written in the case where we have multiple columns using the same name in recursion. This leads to incorrectly pruning out filters which can lead to infinite recursion. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New Golden file test. ### Was this patch authored or co-authored using generative AI tooling? No. Closes apache#51070 from Pajaraja/pavle-martinovic_data/ConstraintsFixII. Authored-by: pavle-martinovic_data <pavle.martinovic@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 1ea7f11 commit 255249f

File tree

5 files changed

+125
-5
lines changed

5 files changed

+125
-5
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,10 +229,36 @@ object LogicalRDD extends Logging {
229229
}
230230
}
231231

232+
// A version of buildOutputAssocForRewrite which doesn't assume that the names are the same,
233+
// because the new output can have different names. Used when copying the LogicalRDD with a new
234+
// output
235+
private[sql] def buildOutputAssocForRewriteWithNewOutput(
236+
source: Seq[Attribute],
237+
destination: Seq[Attribute]): Option[Map[Attribute, Attribute]] = {
238+
val rewrite = source.zip(destination).flatMap { case (attr1, attr2) =>
239+
if (attr1.dataType == attr2.dataType) {
240+
Some(attr1 -> attr2)
241+
} else {
242+
None
243+
}
244+
}.toMap
245+
246+
if (rewrite.size == source.size) {
247+
Some(rewrite)
248+
} else {
249+
None
250+
}
251+
}
252+
232253
private[sql] def rewriteStatsAndConstraints(
233254
logicalPlan: LogicalPlan,
234-
optimizedPlan: LogicalPlan): (Option[Statistics], Option[ExpressionSet]) = {
235-
val rewrite = buildOutputAssocForRewrite(optimizedPlan.output, logicalPlan.output)
255+
optimizedPlan: LogicalPlan,
256+
sameOutput: Boolean = true): (Option[Statistics], Option[ExpressionSet]) = {
257+
val rewrite = if (sameOutput) {
258+
buildOutputAssocForRewrite(optimizedPlan.output, logicalPlan.output)
259+
} else {
260+
buildOutputAssocForRewriteWithNewOutput(optimizedPlan.output, logicalPlan.output)
261+
}
236262

237263
rewrite.map { rw =>
238264
val rewrittenStatistics = rewriteStatistics(optimizedPlan.stats, rw)

sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLoopExec.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ case class UnionLoopExec(
9696
"numIterations" -> SQLMetrics.createMetric(sparkContext, "number of recursive iterations"),
9797
"numAnchorOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of anchor output rows"))
9898

99-
val localRelationLimit =
99+
private val localRelationLimit =
100100
conf.getConf(SQLConf.CTE_RECURSION_ANCHOR_ROWS_LIMIT_TO_CONVERT_TO_LOCAL_RELATION)
101101

102102
/**
@@ -220,9 +220,9 @@ case class UnionLoopExec(
220220
val logicalRDD = LogicalRDD.fromDataset(prevDF.queryExecution.toRdd, prevDF,
221221
prevDF.isStreaming).newInstance()
222222
prevPlan = logicalRDD
223-
val logicalPlan = prevDF.logicalPlan
224223
val optimizedPlan = prevDF.queryExecution.optimizedPlan
225-
val (stats, constraints) = rewriteStatsAndConstraints(logicalPlan, optimizedPlan)
224+
val (stats, constraints) = rewriteStatsAndConstraints(r, optimizedPlan,
225+
sameOutput = false)
226226
logicalRDD.copy(output = r.output)(prevDF.sparkSession, stats, constraints)
227227
}
228228
}

sql/core/src/test/resources/sql-tests/analyzer-results/cte-recursion.sql.out

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1664,6 +1664,51 @@ WithCTE
16641664
+- CTERelationRef xxxx, true, [x#x, y#x], false, false
16651665

16661666

1667+
-- !query
1668+
SET spark.sql.cteRecursionAnchorRowsLimitToConvertToLocalRelation=0
1669+
-- !query analysis
1670+
SetCommand (spark.sql.cteRecursionAnchorRowsLimitToConvertToLocalRelation,Some(0))
1671+
1672+
1673+
-- !query
1674+
WITH RECURSIVE tmp(x) AS (
1675+
values (1), (2), (3), (4), (5)
1676+
), rcte(x, y) AS (
1677+
SELECT x, x FROM tmp WHERE x = 1
1678+
UNION ALL
1679+
SELECT x + 1, x FROM rcte WHERE x < 5
1680+
)
1681+
SELECT * FROM rcte
1682+
-- !query analysis
1683+
WithCTE
1684+
:- CTERelationDef xxxx, false
1685+
: +- SubqueryAlias tmp
1686+
: +- Project [col1#x AS x#x]
1687+
: +- LocalRelation [col1#x]
1688+
:- CTERelationDef xxxx, false
1689+
: +- SubqueryAlias rcte
1690+
: +- Project [x#x AS x#x, x#x AS y#x]
1691+
: +- UnionLoop xxxx
1692+
: :- Project [x#x, x#x]
1693+
: : +- Filter (x#x = 1)
1694+
: : +- SubqueryAlias tmp
1695+
: : +- CTERelationRef xxxx, true, [x#x], false, false, 5
1696+
: +- Project [(x#x + 1) AS (x + 1)#x, x#x]
1697+
: +- Filter (x#x < 5)
1698+
: +- SubqueryAlias rcte
1699+
: +- Project [x#x AS x#x, x#x AS y#x]
1700+
: +- UnionLoopRef xxxx, [x#x, x#x], false
1701+
+- Project [x#x, y#x]
1702+
+- SubqueryAlias rcte
1703+
+- CTERelationRef xxxx, true, [x#x, y#x], false, false
1704+
1705+
1706+
-- !query
1707+
SET spark.sql.cteRecursionAnchorRowsLimitToConvertToLocalRelation=100
1708+
-- !query analysis
1709+
SetCommand (spark.sql.cteRecursionAnchorRowsLimitToConvertToLocalRelation,Some(100))
1710+
1711+
16671712
-- !query
16681713
WITH RECURSIVE tmp(x) AS (
16691714
values (1), (2), (3), (4), (5)

sql/core/src/test/resources/sql-tests/inputs/cte-recursion.sql

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -598,6 +598,20 @@ WITH RECURSIVE tmp(x) AS (
598598
)
599599
SELECT * FROM rcte;
600600

601+
-- Previous query without converting to local relations
602+
SET spark.sql.cteRecursionAnchorRowsLimitToConvertToLocalRelation=0;
603+
604+
WITH RECURSIVE tmp(x) AS (
605+
values (1), (2), (3), (4), (5)
606+
), rcte(x, y) AS (
607+
SELECT x, x FROM tmp WHERE x = 1
608+
UNION ALL
609+
SELECT x + 1, x FROM rcte WHERE x < 5
610+
)
611+
SELECT * FROM rcte;
612+
613+
SET spark.sql.cteRecursionAnchorRowsLimitToConvertToLocalRelation=100;
614+
601615
-- Recursive CTE with multiple of the same reference in the anchor, which get referenced as different variables in subsequent iterations.
602616
WITH RECURSIVE tmp(x) AS (
603617
values (1), (2), (3), (4), (5)

sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1496,6 +1496,41 @@ struct<x:int,y:int>
14961496
5 4
14971497

14981498

1499+
-- !query
1500+
SET spark.sql.cteRecursionAnchorRowsLimitToConvertToLocalRelation=0
1501+
-- !query schema
1502+
struct<key:string,value:string>
1503+
-- !query output
1504+
spark.sql.cteRecursionAnchorRowsLimitToConvertToLocalRelation 0
1505+
1506+
1507+
-- !query
1508+
WITH RECURSIVE tmp(x) AS (
1509+
values (1), (2), (3), (4), (5)
1510+
), rcte(x, y) AS (
1511+
SELECT x, x FROM tmp WHERE x = 1
1512+
UNION ALL
1513+
SELECT x + 1, x FROM rcte WHERE x < 5
1514+
)
1515+
SELECT * FROM rcte
1516+
-- !query schema
1517+
struct<x:int,y:int>
1518+
-- !query output
1519+
1 1
1520+
2 1
1521+
3 2
1522+
4 3
1523+
5 4
1524+
1525+
1526+
-- !query
1527+
SET spark.sql.cteRecursionAnchorRowsLimitToConvertToLocalRelation=100
1528+
-- !query schema
1529+
struct<key:string,value:string>
1530+
-- !query output
1531+
spark.sql.cteRecursionAnchorRowsLimitToConvertToLocalRelation 100
1532+
1533+
14991534
-- !query
15001535
WITH RECURSIVE tmp(x) AS (
15011536
values (1), (2), (3), (4), (5)

0 commit comments

Comments
 (0)