Skip to content

Commit b078fcc

Browse files
pavle-martinovic_datacloud-fan
authored andcommitted
[SPARK-52232][SQL] Fix non-deterministic queries to produce different results at every step
### What changes were proposed in this pull request? Enable deterministic queries to work with rCTEs. Fix bug where non-deterministic queries produce same result every iteration. ### Why are the changes needed? Currently, recursive CTEs create a new plan for each iteration of the recursion, so the expressions that use randomness "roll-back" to their initial value, causing things like "rand()" to produce the same result every time. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New tests in golden files. ### Was this patch authored or co-authored using generative AI tooling? No. Closes apache#50957 from Pajaraja/pavle-martinovic_data/RandomInrCTEs. Lead-authored-by: pavle-martinovic_data <pavle.martinovic@databricks.com> Co-authored-by: Wenchen Fan <cloud0fan@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent a0f7865 commit b078fcc

File tree

7 files changed

+269
-3
lines changed

7 files changed

+269
-3
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1267,6 +1267,9 @@ case class Shuffle(child: Expression, randomSeed: Option[Long] = None) extends U
12671267

12681268
override def withNewSeed(seed: Long): Shuffle = copy(randomSeed = Some(seed))
12691269

1270+
override def withShiftedSeed(shift: Long): Shuffle =
1271+
copy(randomSeed = randomSeed.map(_ + shift))
1272+
12701273
override lazy val resolved: Boolean =
12711274
childrenResolved && checkInputDataTypes().isSuccess && randomSeed.isDefined
12721275

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,8 @@ case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Non
260260

261261
override def withNewSeed(seed: Long): Uuid = Uuid(Some(seed))
262262

263+
override def withShiftedSeed(shift: Long): Uuid = Uuid(randomSeed.map(_ + shift))
264+
263265
override lazy val resolved: Boolean = randomSeed.isDefined
264266

265267
override def nullable: Boolean = false

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ trait ExpressionWithRandomSeed extends Expression {
7676

7777
def seedExpression: Expression
7878
def withNewSeed(seed: Long): Expression
79+
def withShiftedSeed(shift: Long): Expression
7980
}
8081

8182
private[catalyst] object ExpressionWithRandomSeed {
@@ -114,6 +115,9 @@ case class Rand(child: Expression, hideSeed: Boolean = false) extends Nondetermi
114115

115116
override def withNewSeed(seed: Long): Rand = Rand(Literal(seed, LongType), hideSeed)
116117

118+
override def withShiftedSeed(shift: Long): Rand =
119+
Rand(Add(child, Literal(shift), evalMode = EvalMode.LEGACY), hideSeed)
120+
117121
override protected def evalInternal(input: InternalRow): Double = rng.nextDouble()
118122

119123
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
@@ -165,6 +169,9 @@ case class Randn(child: Expression, hideSeed: Boolean = false) extends Nondeterm
165169

166170
override def withNewSeed(seed: Long): Randn = Randn(Literal(seed, LongType), hideSeed)
167171

172+
override def withShiftedSeed(shift: Long): Randn =
173+
Randn(Add(child, Literal(shift), evalMode = EvalMode.LEGACY), hideSeed)
174+
168175
override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian()
169176

170177
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
@@ -268,6 +275,9 @@ case class Uniform(min: Expression, max: Expression, seedExpression: Expression,
268275
override def withNewSeed(newSeed: Long): Expression =
269276
Uniform(min, max, Literal(newSeed, LongType), hideSeed)
270277

278+
override def withShiftedSeed(shift: Long): Expression =
279+
Uniform(min, max, Literal(seed + shift, LongType), hideSeed)
280+
271281
override def withNewChildrenInternal(
272282
newFirst: Expression, newSecond: Expression, newThird: Expression): Expression =
273283
Uniform(newFirst, newSecond, newThird, hideSeed)
@@ -348,6 +358,10 @@ case class RandStr(
348358

349359
override def withNewSeed(newSeed: Long): Expression =
350360
RandStr(length, Literal(newSeed, LongType), hideSeed)
361+
362+
override def withShiftedSeed(shift: Long): Expression =
363+
RandStr(length, Literal(seed + shift, LongType), hideSeed)
364+
351365
override def withNewChildrenInternal(newFirst: Expression, newSecond: Expression): Expression =
352366
RandStr(newFirst, newSecond, hideSeed)
353367

sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLoopExec.scala

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import scala.collection.mutable
2222
import org.apache.spark.SparkException
2323
import org.apache.spark.rdd.{EmptyRDD, RDD}
2424
import org.apache.spark.sql.catalyst.InternalRow
25-
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, InterpretedMutableProjection, Literal}
25+
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, ExpressionWithRandomSeed, InterpretedMutableProjection, Literal}
2626
import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation.hasUnevaluableExpr
2727
import org.apache.spark.sql.catalyst.plans.QueryPlan
2828
import org.apache.spark.sql.catalyst.plans.logical.{LocalLimit, LocalRelation, LogicalPlan, OneRowRelation, Project, Union, UnionLoopRef}
@@ -183,11 +183,24 @@ case class UnionLoopExec(
183183
// Main loop for obtaining the result of the recursive query.
184184
while (prevCount > 0 && !limitReached) {
185185
var prevPlan: LogicalPlan = null
186+
187+
// If the recursive part contains non-deterministic expressions that depends on a seed, we
188+
// need to create a new seed since the seed for this expression is set in the analysis, and
189+
// we avoid re-triggering the analysis for every iterative step.
190+
val recursionReseeded = if (currentLevel == 1 || recursion.deterministic) {
191+
recursion
192+
} else {
193+
recursion.transformAllExpressionsWithSubqueries {
194+
case e: ExpressionWithRandomSeed =>
195+
e.withShiftedSeed(currentLevel - 1)
196+
}
197+
}
198+
186199
// the current plan is created by substituting UnionLoopRef node with the project node of
187200
// the previous plan.
188201
// This way we support only UNION ALL case. Additional case should be added for UNION case.
189202
// One way of supporting UNION case can be seen at SPARK-24497 PR from Peter Toth.
190-
val newRecursion = recursion.transformWithSubqueries {
203+
val newRecursion = recursionReseeded.transformWithSubqueries {
191204
case r: UnionLoopRef if r.loopId == loopId =>
192205
prevDF.queryExecution.optimizedPlan match {
193206
case l: LocalRelation =>

sql/core/src/test/resources/sql-tests/analyzer-results/cte-recursion.sql.out

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1629,3 +1629,75 @@ WithCTE
16291629
+- Project [n#x]
16301630
+- SubqueryAlias t1
16311631
+- CTERelationRef xxxx, true, [n#x], false, false
1632+
1633+
1634+
-- !query
1635+
WITH RECURSIVE randoms(val) AS (
1636+
SELECT CAST(floor(rand(82374) * 5 + 1) AS INT)
1637+
UNION ALL
1638+
SELECT CAST(floor(rand(237685) * 5 + 1) AS INT)
1639+
FROM randoms
1640+
)
1641+
SELECT val FROM randoms LIMIT 5
1642+
-- !query analysis
1643+
[Analyzer test output redacted due to nondeterminism]
1644+
1645+
1646+
-- !query
1647+
WITH RECURSIVE randoms(val) AS (
1648+
SELECT CAST(UNIFORM(1, 6, 82374) AS INT)
1649+
UNION ALL
1650+
SELECT CAST(UNIFORM(1, 6, 237685) AS INT)
1651+
FROM randoms
1652+
)
1653+
SELECT val FROM randoms LIMIT 5
1654+
-- !query analysis
1655+
[Analyzer test output redacted due to nondeterminism]
1656+
1657+
1658+
-- !query
1659+
WITH RECURSIVE randoms(val) AS (
1660+
SELECT CAST(floor(randn(82374) * 5 + 1) AS INT)
1661+
UNION ALL
1662+
SELECT CAST(floor(randn(237685) * 5 + 1) AS INT)
1663+
FROM randoms
1664+
)
1665+
SELECT val FROM randoms LIMIT 5
1666+
-- !query analysis
1667+
[Analyzer test output redacted due to nondeterminism]
1668+
1669+
1670+
-- !query
1671+
WITH RECURSIVE randoms(val) AS (
1672+
SELECT randstr(10, 82374)
1673+
UNION ALL
1674+
SELECT randstr(10, 237685)
1675+
FROM randoms
1676+
)
1677+
SELECT val FROM randoms LIMIT 5
1678+
-- !query analysis
1679+
[Analyzer test output redacted due to nondeterminism]
1680+
1681+
1682+
-- !query
1683+
WITH RECURSIVE randoms(val) AS (
1684+
SELECT UUID(82374)
1685+
UNION ALL
1686+
SELECT UUID(237685)
1687+
FROM randoms
1688+
)
1689+
SELECT val FROM randoms LIMIT 5
1690+
-- !query analysis
1691+
[Analyzer test output redacted due to nondeterminism]
1692+
1693+
1694+
-- !query
1695+
WITH RECURSIVE randoms(val) AS (
1696+
SELECT ARRAY(1,2,3,4,5)
1697+
UNION ALL
1698+
SELECT SHUFFLE(ARRAY(1,2,3,4,5), 237685)
1699+
FROM randoms
1700+
)
1701+
SELECT val FROM randoms LIMIT 5
1702+
-- !query analysis
1703+
[Analyzer test output redacted due to nondeterminism]

sql/core/src/test/resources/sql-tests/inputs/cte-recursion.sql

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -586,4 +586,58 @@ WITH RECURSIVE t1 AS (
586586
SELECT 1 AS n
587587
UNION ALL
588588
SELECT n+1 FROM t2 WHERE n < 5)
589-
SELECT * FROM t1;
589+
SELECT * FROM t1;
590+
591+
-- Non-deterministic query with rand with seed
592+
WITH RECURSIVE randoms(val) AS (
593+
SELECT CAST(floor(rand(82374) * 5 + 1) AS INT)
594+
UNION ALL
595+
SELECT CAST(floor(rand(237685) * 5 + 1) AS INT)
596+
FROM randoms
597+
)
598+
SELECT val FROM randoms LIMIT 5;
599+
600+
-- Non-deterministic query with uniform with seed
601+
WITH RECURSIVE randoms(val) AS (
602+
SELECT CAST(UNIFORM(1, 6, 82374) AS INT)
603+
UNION ALL
604+
SELECT CAST(UNIFORM(1, 6, 237685) AS INT)
605+
FROM randoms
606+
)
607+
SELECT val FROM randoms LIMIT 5;
608+
609+
-- Non-deterministic query with randn with seed
610+
WITH RECURSIVE randoms(val) AS (
611+
SELECT CAST(floor(randn(82374) * 5 + 1) AS INT)
612+
UNION ALL
613+
SELECT CAST(floor(randn(237685) * 5 + 1) AS INT)
614+
FROM randoms
615+
)
616+
SELECT val FROM randoms LIMIT 5;
617+
618+
-- Non-deterministic query with randstr
619+
WITH RECURSIVE randoms(val) AS (
620+
SELECT randstr(10, 82374)
621+
UNION ALL
622+
SELECT randstr(10, 237685)
623+
FROM randoms
624+
)
625+
SELECT val FROM randoms LIMIT 5;
626+
627+
-- Non-deterministic query with UUID
628+
WITH RECURSIVE randoms(val) AS (
629+
SELECT UUID(82374)
630+
UNION ALL
631+
SELECT UUID(237685)
632+
FROM randoms
633+
)
634+
SELECT val FROM randoms LIMIT 5;
635+
636+
-- Non-deterministic query with shuffle
637+
WITH RECURSIVE randoms(val) AS (
638+
SELECT ARRAY(1,2,3,4,5)
639+
UNION ALL
640+
SELECT SHUFFLE(ARRAY(1,2,3,4,5), 237685)
641+
FROM randoms
642+
)
643+
SELECT val FROM randoms LIMIT 5;

sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1475,3 +1475,111 @@ struct<n:int>
14751475
3
14761476
4
14771477
5
1478+
1479+
1480+
-- !query
1481+
WITH RECURSIVE randoms(val) AS (
1482+
SELECT CAST(floor(rand(82374) * 5 + 1) AS INT)
1483+
UNION ALL
1484+
SELECT CAST(floor(rand(237685) * 5 + 1) AS INT)
1485+
FROM randoms
1486+
)
1487+
SELECT val FROM randoms LIMIT 5
1488+
-- !query schema
1489+
struct<val:int>
1490+
-- !query output
1491+
1
1492+
3
1493+
4
1494+
4
1495+
5
1496+
1497+
1498+
-- !query
1499+
WITH RECURSIVE randoms(val) AS (
1500+
SELECT CAST(UNIFORM(1, 6, 82374) AS INT)
1501+
UNION ALL
1502+
SELECT CAST(UNIFORM(1, 6, 237685) AS INT)
1503+
FROM randoms
1504+
)
1505+
SELECT val FROM randoms LIMIT 5
1506+
-- !query schema
1507+
struct<val:int>
1508+
-- !query output
1509+
1
1510+
3
1511+
4
1512+
4
1513+
5
1514+
1515+
1516+
-- !query
1517+
WITH RECURSIVE randoms(val) AS (
1518+
SELECT CAST(floor(randn(82374) * 5 + 1) AS INT)
1519+
UNION ALL
1520+
SELECT CAST(floor(randn(237685) * 5 + 1) AS INT)
1521+
FROM randoms
1522+
)
1523+
SELECT val FROM randoms LIMIT 5
1524+
-- !query schema
1525+
struct<val:int>
1526+
-- !query output
1527+
-2
1528+
2
1529+
2
1530+
5
1531+
6
1532+
1533+
1534+
-- !query
1535+
WITH RECURSIVE randoms(val) AS (
1536+
SELECT randstr(10, 82374)
1537+
UNION ALL
1538+
SELECT randstr(10, 237685)
1539+
FROM randoms
1540+
)
1541+
SELECT val FROM randoms LIMIT 5
1542+
-- !query schema
1543+
struct<val:string>
1544+
-- !query output
1545+
IpXzdTW03I
1546+
Zj7uI2Ex6e
1547+
dBlWnfo7rO
1548+
fmfDBMf60f
1549+
kFeBV7dQWi
1550+
1551+
1552+
-- !query
1553+
WITH RECURSIVE randoms(val) AS (
1554+
SELECT UUID(82374)
1555+
UNION ALL
1556+
SELECT UUID(237685)
1557+
FROM randoms
1558+
)
1559+
SELECT val FROM randoms LIMIT 5
1560+
-- !query schema
1561+
struct<val:string>
1562+
-- !query output
1563+
19974dca-21f6-47ef-b58c-73908ab52aa0
1564+
4ea190e3-c088-4ddd-a545-fb431059ae3c
1565+
8b88900e-f862-468c-8d3b-828188116155
1566+
be4f5346-1c7f-4697-8a2c-1343347872c5
1567+
d0032efe-ae60-461b-8582-f6a7c649f238
1568+
1569+
1570+
-- !query
1571+
WITH RECURSIVE randoms(val) AS (
1572+
SELECT ARRAY(1,2,3,4,5)
1573+
UNION ALL
1574+
SELECT SHUFFLE(ARRAY(1,2,3,4,5), 237685)
1575+
FROM randoms
1576+
)
1577+
SELECT val FROM randoms LIMIT 5
1578+
-- !query schema
1579+
struct<val:array<int>>
1580+
-- !query output
1581+
[1,2,3,4,5]
1582+
[1,2,3,5,4]
1583+
[2,1,5,3,4]
1584+
[4,3,2,5,1]
1585+
[4,5,1,2,3]

0 commit comments

Comments
 (0)