Skip to content

Commit f0876d6

Browse files
committed
fix
1 parent ff93cbb commit f0876d6

File tree

8 files changed

+54
-9
lines changed

8 files changed

+54
-9
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
5959
case e @ Except(left, right, _) if !e.duplicateResolved && noMissingInput(right) =>
6060
e.copy(right = dedupRight(left, right))
6161
// Only after we finish by-name resolution for Union
62-
case u: Union if !u.byName && !u.duplicateResolved =>
62+
case u: Union if !u.byName && !u.duplicatesResolvedBetweenBranches =>
6363
val unionWithChildOutputsDeduplicated =
6464
DeduplicateUnionChildOutput.deduplicateOutputPerChild(u)
6565
// Use projection-based de-duplication for Union to avoid breaking the checkpoint sharing

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.catalyst.optimizer
1919

20+
import java.util.HashSet
21+
2022
import scala.collection.mutable
2123

2224
import org.apache.spark.SparkException
@@ -895,6 +897,24 @@ object LimitPushDown extends Rule[LogicalPlan] {
895897
*/
896898
object PushProjectionThroughUnion extends Rule[LogicalPlan] {
897899

900+
/**
901+
* When pushing a [[Project]] through [[Union]] we need to maintain the invariant that
902+
* [[Union]] children must have unique [[ExprId]]s per branch.
903+
*/
904+
private def deduplicateProjectList(projectList: Seq[NamedExpression]) = {
905+
val existingExprIds = new HashSet[ExprId]
906+
projectList.map(attr => if (existingExprIds.contains(attr.exprId)) {
907+
val newMetadata = new MetadataBuilder()
908+
.withMetadata(attr.metadata)
909+
.putNull("__is_duplicate")
910+
.build()
911+
Alias(attr, attr.name)(explicitMetadata = Some(newMetadata))
912+
} else {
913+
existingExprIds.add(attr.exprId)
914+
attr
915+
})
916+
}
917+
898918
/**
899919
* Maps Attributes from the left side to the corresponding Attribute on the right side.
900920
*/
@@ -923,10 +943,11 @@ object PushProjectionThroughUnion extends Rule[LogicalPlan] {
923943
}
924944

925945
def pushProjectionThroughUnion(projectList: Seq[NamedExpression], u: Union): Seq[LogicalPlan] = {
926-
val newFirstChild = Project(projectList, u.children.head)
946+
val deduplicatedProjectList = deduplicateProjectList(projectList)
947+
val newFirstChild = Project(deduplicatedProjectList, u.children.head)
927948
val newOtherChildren = u.children.tail.map { child =>
928949
val rewrites = buildRewrites(u.children.head, child)
929-
Project(projectList.map(pushToRight(_, rewrites)), child)
950+
Project(deduplicatedProjectList.map(pushToRight(_, rewrites)), child)
930951
}
931952
newFirstChild +: newOtherChildren
932953
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -614,13 +614,20 @@ case class Union(
614614
Some(sum.toLong)
615615
}
616616

617-
def duplicateResolved: Boolean = {
617+
private def duplicatesResolvedPerBranch: Boolean =
618+
children.forall(child => child.outputSet.size == child.output.size)
619+
620+
def duplicatesResolvedBetweenBranches: Boolean = {
618621
children.map(_.outputSet.size).sum ==
619622
AttributeSet.fromAttributeSets(children.map(_.outputSet)).size
620623
}
621624

622625
override lazy val resolved: Boolean = {
623-
children.length > 1 && !(byName || allowMissingCol) && childrenResolved && allChildrenCompatible
626+
children.length > 1 &&
627+
!(byName || allowMissingCol) &&
628+
childrenResolved &&
629+
allChildrenCompatible &&
630+
duplicatesResolvedPerBranch
624631
}
625632

626633
override protected def withNewChildrenInternal(newChildren: IndexedSeq[LogicalPlan]): Union =
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
Union false, false
1+
'Union false, false
22
:- LocalRelation <empty>, [id#0L, a#0, b#0]
33
+- LocalRelation <empty>, [id#0L, a#0, b#0]
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
Union false, false
1+
'Union false, false
22
:- LocalRelation <empty>, [id#0L, a#0, b#0]
33
+- LocalRelation <empty>, [id#0L, a#0, b#0]

sql/connect/common/src/test/resources/query-tests/explain-results/unionByName.explain

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
Union false, false
1+
'Union false, false
22
:- Project [id#0L, a#0]
33
: +- LocalRelation <empty>, [id#0L, a#0, b#0]
44
+- Project [id#0L, a#0]

sql/connect/common/src/test/resources/query-tests/explain-results/unionByName_allowMissingColumns.explain

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
Union false, false
1+
'Union false, false
22
:- Project [id#0L, a#0, b#0, null AS payload#0]
33
: +- LocalRelation <empty>, [id#0L, a#0, b#0]
44
+- Project [id#0L, a#0, null AS b#0, payload#0]

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4963,6 +4963,23 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
49634963
)
49644964
}
49654965

4966+
test("SPARK-52686: Union should be resolved only if there are no duplicates") {
4967+
withTable("t1") {
4968+
withView("v1") {
4969+
sql("CREATE TABLE t1(col1 STRING)")
4970+
sql("CREATE VIEW v1 AS SELECT * FROM t1")
4971+
val analyzedPlan = sql(
4972+
"SELECT * FROM (SELECT col1, col1 FROM v1 UNION SELECT col1, col1 FROM v1)"
4973+
).queryExecution.analyzed
4974+
4975+
// Resolving * should wait for Union output to be deduplicated, otherwise we are
4976+
// left with duplicate ExprIds in the result.
4977+
val exprIds = analyzedPlan.asInstanceOf[Project].projectList.map(_.exprId)
4978+
assert(exprIds.size == exprIds.distinct.size)
4979+
}
4980+
}
4981+
}
4982+
49664983
Seq(true, false).foreach { codegenEnabled =>
49674984
test(s"SPARK-52060: one row relation with codegen enabled - $codegenEnabled") {
49684985
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegenEnabled.toString) {

0 commit comments

Comments
 (0)