Skip to content

Commit bf7cde6

Browse files
committed
fix
1 parent c7c1021 commit bf7cde6

File tree

2 files changed

+46
-37
lines changed

2 files changed

+46
-37
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 24 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717

1818
package org.apache.spark.sql.catalyst.optimizer
1919

20-
import java.util.HashSet
21-
2220
import scala.collection.mutable
2321

2422
import org.apache.spark.SparkException
@@ -897,24 +895,6 @@ object LimitPushDown extends Rule[LogicalPlan] {
897895
*/
898896
object PushProjectionThroughUnion extends Rule[LogicalPlan] {
899897

900-
/**
901-
* When pushing a [[Project]] through [[Union]] we need to maintain the invariant that [[Union]]
902-
* children must have unique [[ExprId]]s per branch. We can safely deduplicate [[ExprId]]s
903-
* without updating any references because those [[ExprId]]s will simply remain unused.
904-
* For example, in a `Project(col1#1, col#1)` we will alias the second `col1` and get
905-
* `Project(col1#1, col1 as col1#2)`. We don't need to update any references to `col1#1` we
906-
* aliased because `col1#1` still exists in [[Project]] output.
907-
*/
908-
private def deduplicateProjectList(projectList: Seq[NamedExpression]) = {
909-
val existingExprIds = new HashSet[ExprId]
910-
projectList.map(attr => if (existingExprIds.contains(attr.exprId)) {
911-
Alias(attr, attr.name)()
912-
} else {
913-
existingExprIds.add(attr.exprId)
914-
attr
915-
})
916-
}
917-
918898
/**
919899
* Maps Attributes from the left side to the corresponding Attribute on the right side.
920900
*/
@@ -942,16 +922,20 @@ object PushProjectionThroughUnion extends Rule[LogicalPlan] {
942922
result.asInstanceOf[A]
943923
}
944924

925+
/**
926+
* If [[SQLConf.UNION_IS_RESOLVED_WHEN_DUPLICATES_PER_CHILD_RESOLVED]] is true, [[Project]] can
927+
* only be pushed down if there are no duplicate [[ExprId]]s in the project list.
928+
*/
929+
def canPushProjectionThroughUnion(project: Project): Boolean = {
930+
!conf.unionIsResolvedWhenDuplicatesPerChildResolved ||
931+
project.outputSet.size == project.projectList.size
932+
}
933+
945934
def pushProjectionThroughUnion(projectList: Seq[NamedExpression], u: Union): Seq[LogicalPlan] = {
946-
val deduplicatedProjectList = if (conf.unionIsResolvedWhenDuplicatesPerChildResolved) {
947-
deduplicateProjectList(projectList)
948-
} else {
949-
projectList
950-
}
951-
val newFirstChild = Project(deduplicatedProjectList, u.children.head)
935+
val newFirstChild = Project(projectList, u.children.head)
952936
val newOtherChildren = u.children.tail.map { child =>
953937
val rewrites = buildRewrites(u.children.head, child)
954-
Project(deduplicatedProjectList.map(pushToRight(_, rewrites)), child)
938+
Project(projectList.map(pushToRight(_, rewrites)), child)
955939
}
956940
newFirstChild +: newOtherChildren
957941
}
@@ -960,8 +944,9 @@ object PushProjectionThroughUnion extends Rule[LogicalPlan] {
960944
_.containsAllPatterns(UNION, PROJECT)) {
961945

962946
// Push down deterministic projection through UNION ALL
963-
case Project(projectList, u: Union)
964-
if projectList.forall(_.deterministic) && u.children.nonEmpty =>
947+
case project @ Project(projectList, u: Union)
948+
if projectList.forall(_.deterministic) && u.children.nonEmpty &&
949+
canPushProjectionThroughUnion(project) =>
965950
u.copy(children = pushProjectionThroughUnion(projectList, u))
966951
}
967952
}
@@ -1586,7 +1571,7 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan]
15861571
*/
15871572
object CombineUnions extends Rule[LogicalPlan] {
15881573
import CollapseProject.{buildCleanedProjectList, canCollapseExpressions}
1589-
import PushProjectionThroughUnion.pushProjectionThroughUnion
1574+
import PushProjectionThroughUnion.{canPushProjectionThroughUnion, pushProjectionThroughUnion}
15901575

15911576
def apply(plan: LogicalPlan): LogicalPlan = plan.transformDownWithPruning(
15921577
_.containsAnyPattern(UNION, DISTINCT_LIKE), ruleId) {
@@ -1631,17 +1616,19 @@ object CombineUnions extends Rule[LogicalPlan] {
16311616
stack.pushAll(children.reverse)
16321617
// Push down projection through Union and then push pushed plan to Stack if
16331618
// there is a Project.
1634-
case Project(projectList, Distinct(u @ Union(children, byName, allowMissingCol)))
1619+
case project @ Project(projectList, Distinct(u @ Union(children, byName, allowMissingCol)))
16351620
if projectList.forall(_.deterministic) && children.nonEmpty &&
1636-
flattenDistinct && byName == topByName && allowMissingCol == topAllowMissingCol =>
1621+
flattenDistinct && byName == topByName && allowMissingCol == topAllowMissingCol &&
1622+
canPushProjectionThroughUnion(project) =>
16371623
stack.pushAll(pushProjectionThroughUnion(projectList, u).reverse)
1638-
case Project(projectList, Deduplicate(keys: Seq[Attribute], u: Union))
1624+
case project @ Project(projectList, Deduplicate(keys: Seq[Attribute], u: Union))
16391625
if projectList.forall(_.deterministic) && flattenDistinct && u.byName == topByName &&
1640-
u.allowMissingCol == topAllowMissingCol && AttributeSet(keys) == u.outputSet =>
1626+
u.allowMissingCol == topAllowMissingCol && AttributeSet(keys) == u.outputSet &&
1627+
canPushProjectionThroughUnion(project) =>
16411628
stack.pushAll(pushProjectionThroughUnion(projectList, u).reverse)
1642-
case Project(projectList, u @ Union(children, byName, allowMissingCol))
1643-
if projectList.forall(_.deterministic) && children.nonEmpty &&
1644-
byName == topByName && allowMissingCol == topAllowMissingCol =>
1629+
case project @ Project(projectList, u @ Union(children, byName, allowMissingCol))
1630+
if projectList.forall(_.deterministic) && children.nonEmpty && byName == topByName &&
1631+
allowMissingCol == topAllowMissingCol && canPushProjectionThroughUnion(project) =>
16451632
stack.pushAll(pushProjectionThroughUnion(projectList, u).reverse)
16461633
case child =>
16471634
flattened += child

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.{And, GreaterThan, GreaterThanO
2424
import org.apache.spark.sql.catalyst.plans.PlanTest
2525
import org.apache.spark.sql.catalyst.plans.logical._
2626
import org.apache.spark.sql.catalyst.rules._
27+
import org.apache.spark.sql.internal.SQLConf
2728
import org.apache.spark.sql.types.{BooleanType, DecimalType}
2829

2930
class SetOperationSuite extends PlanTest {
@@ -313,6 +314,27 @@ class SetOperationSuite extends PlanTest {
313314
comparePlans(unionOptimized, unionCorrectAnswer)
314315
}
315316

317+
test("SPARK-52686: no pushdown if project has duplicate expression IDs") {
318+
val unionQuery = testUnion.select($"a", $"a")
319+
val unionCorrectAnswerWithConfOn = unionQuery.analyze
320+
val unionCorrectAnswerWithConfOff = Union(
321+
testRelation.select($"a", $"a").analyze ::
322+
testRelation2.select($"d", $"d").analyze ::
323+
testRelation3.select($"g", $"g").analyze ::
324+
Nil
325+
)
326+
327+
withSQLConf(SQLConf.UNION_IS_RESOLVED_WHEN_DUPLICATES_PER_CHILD_RESOLVED.key -> "true") {
328+
val unionOptimized = Optimize.execute(unionQuery.analyze)
329+
comparePlans(unionOptimized, unionCorrectAnswerWithConfOn)
330+
}
331+
332+
withSQLConf(SQLConf.UNION_IS_RESOLVED_WHEN_DUPLICATES_PER_CHILD_RESOLVED.key -> "false") {
333+
val unionOptimized = Optimize.execute(unionQuery.analyze)
334+
comparePlans(unionOptimized, unionCorrectAnswerWithConfOff)
335+
}
336+
}
337+
316338
test("CombineUnions only flatten the unions with same byName and allowMissingCol") {
317339
val union1 = Union(testRelation :: testRelation :: Nil, true, false)
318340
val union2 = Union(testRelation :: testRelation :: Nil, true, true)

0 commit comments

Comments
 (0)