17
17
18
18
package org .apache .spark .sql .catalyst .optimizer
19
19
20
- import java .util .HashSet
21
-
22
20
import scala .collection .mutable
23
21
24
22
import org .apache .spark .SparkException
@@ -897,24 +895,6 @@ object LimitPushDown extends Rule[LogicalPlan] {
897
895
*/
898
896
object PushProjectionThroughUnion extends Rule [LogicalPlan ] {
899
897
900
- /**
901
- * When pushing a [[Project ]] through [[Union ]] we need to maintain the invariant that [[Union ]]
902
- * children must have unique [[ExprId ]]s per branch. We can safely deduplicate [[ExprId ]]s
903
- * without updating any references because those [[ExprId ]]s will simply remain unused.
904
- * For example, in a `Project(col1#1, col#1)` we will alias the second `col1` and get
905
- * `Project(col1#1, col1 as col1#2)`. We don't need to update any references to `col1#1` we
906
- * aliased because `col1#1` still exists in [[Project ]] output.
907
- */
908
- private def deduplicateProjectList (projectList : Seq [NamedExpression ]) = {
909
- val existingExprIds = new HashSet [ExprId ]
910
- projectList.map(attr => if (existingExprIds.contains(attr.exprId)) {
911
- Alias (attr, attr.name)()
912
- } else {
913
- existingExprIds.add(attr.exprId)
914
- attr
915
- })
916
- }
917
-
918
898
/**
919
899
* Maps Attributes from the left side to the corresponding Attribute on the right side.
920
900
*/
@@ -942,16 +922,20 @@ object PushProjectionThroughUnion extends Rule[LogicalPlan] {
942
922
result.asInstanceOf [A ]
943
923
}
944
924
925
+ /**
926
+ * If [[SQLConf.UNION_IS_RESOLVED_WHEN_DUPLICATES_PER_CHILD_RESOLVED ]] is true, [[Project ]] can
927
+ * only be pushed down if there are no duplicate [[ExprId ]]s in the project list.
928
+ */
929
+ def canPushProjectionThroughUnion (project : Project ): Boolean = {
930
+ ! conf.unionIsResolvedWhenDuplicatesPerChildResolved ||
931
+ project.outputSet.size == project.projectList.size
932
+ }
933
+
945
934
def pushProjectionThroughUnion (projectList : Seq [NamedExpression ], u : Union ): Seq [LogicalPlan ] = {
946
- val deduplicatedProjectList = if (conf.unionIsResolvedWhenDuplicatesPerChildResolved) {
947
- deduplicateProjectList(projectList)
948
- } else {
949
- projectList
950
- }
951
- val newFirstChild = Project (deduplicatedProjectList, u.children.head)
935
+ val newFirstChild = Project (projectList, u.children.head)
952
936
val newOtherChildren = u.children.tail.map { child =>
953
937
val rewrites = buildRewrites(u.children.head, child)
954
- Project (deduplicatedProjectList .map(pushToRight(_, rewrites)), child)
938
+ Project (projectList .map(pushToRight(_, rewrites)), child)
955
939
}
956
940
newFirstChild +: newOtherChildren
957
941
}
@@ -960,8 +944,9 @@ object PushProjectionThroughUnion extends Rule[LogicalPlan] {
960
944
_.containsAllPatterns(UNION , PROJECT )) {
961
945
962
946
// Push down deterministic projection through UNION ALL
963
- case Project (projectList, u : Union )
964
- if projectList.forall(_.deterministic) && u.children.nonEmpty =>
947
+ case project @ Project (projectList, u : Union )
948
+ if projectList.forall(_.deterministic) && u.children.nonEmpty &&
949
+ canPushProjectionThroughUnion(project) =>
965
950
u.copy(children = pushProjectionThroughUnion(projectList, u))
966
951
}
967
952
}
@@ -1586,7 +1571,7 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan]
1586
1571
*/
1587
1572
object CombineUnions extends Rule [LogicalPlan ] {
1588
1573
import CollapseProject .{buildCleanedProjectList , canCollapseExpressions }
1589
- import PushProjectionThroughUnion .pushProjectionThroughUnion
1574
+ import PushProjectionThroughUnion .{ canPushProjectionThroughUnion , pushProjectionThroughUnion }
1590
1575
1591
1576
def apply (plan : LogicalPlan ): LogicalPlan = plan.transformDownWithPruning(
1592
1577
_.containsAnyPattern(UNION , DISTINCT_LIKE ), ruleId) {
@@ -1631,17 +1616,19 @@ object CombineUnions extends Rule[LogicalPlan] {
1631
1616
stack.pushAll(children.reverse)
1632
1617
// Push down projection through Union and then push pushed plan to Stack if
1633
1618
// there is a Project.
1634
- case Project (projectList, Distinct (u @ Union (children, byName, allowMissingCol)))
1619
+ case project @ Project (projectList, Distinct (u @ Union (children, byName, allowMissingCol)))
1635
1620
if projectList.forall(_.deterministic) && children.nonEmpty &&
1636
- flattenDistinct && byName == topByName && allowMissingCol == topAllowMissingCol =>
1621
+ flattenDistinct && byName == topByName && allowMissingCol == topAllowMissingCol &&
1622
+ canPushProjectionThroughUnion(project) =>
1637
1623
stack.pushAll(pushProjectionThroughUnion(projectList, u).reverse)
1638
- case Project (projectList, Deduplicate (keys : Seq [Attribute ], u : Union ))
1624
+ case project @ Project (projectList, Deduplicate (keys : Seq [Attribute ], u : Union ))
1639
1625
if projectList.forall(_.deterministic) && flattenDistinct && u.byName == topByName &&
1640
- u.allowMissingCol == topAllowMissingCol && AttributeSet (keys) == u.outputSet =>
1626
+ u.allowMissingCol == topAllowMissingCol && AttributeSet (keys) == u.outputSet &&
1627
+ canPushProjectionThroughUnion(project) =>
1641
1628
stack.pushAll(pushProjectionThroughUnion(projectList, u).reverse)
1642
- case Project (projectList, u @ Union (children, byName, allowMissingCol))
1643
- if projectList.forall(_.deterministic) && children.nonEmpty &&
1644
- byName == topByName && allowMissingCol == topAllowMissingCol =>
1629
+ case project @ Project (projectList, u @ Union (children, byName, allowMissingCol))
1630
+ if projectList.forall(_.deterministic) && children.nonEmpty && byName == topByName &&
1631
+ allowMissingCol == topAllowMissingCol && canPushProjectionThroughUnion(project) =>
1645
1632
stack.pushAll(pushProjectionThroughUnion(projectList, u).reverse)
1646
1633
case child =>
1647
1634
flattened += child
0 commit comments