Skip to content

Commit 688eb79

Browse files
committed
fix
1 parent ff93cbb commit 688eb79

File tree

9 files changed

+92
-9
lines changed

9 files changed

+92
-9
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
5959
case e @ Except(left, right, _) if !e.duplicateResolved && noMissingInput(right) =>
6060
e.copy(right = dedupRight(left, right))
6161
// Only after we finish by-name resolution for Union
62-
case u: Union if !u.byName && !u.duplicateResolved =>
62+
case u: Union if !u.byName && !u.duplicatesResolvedBetweenBranches =>
6363
val unionWithChildOutputsDeduplicated =
6464
DeduplicateUnionChildOutput.deduplicateOutputPerChild(u)
6565
// Use projection-based de-duplication for Union to avoid breaking the checkpoint sharing

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.catalyst.optimizer
1919

20+
import java.util.HashSet
21+
2022
import scala.collection.mutable
2123

2224
import org.apache.spark.SparkException
@@ -895,6 +897,24 @@ object LimitPushDown extends Rule[LogicalPlan] {
895897
*/
896898
object PushProjectionThroughUnion extends Rule[LogicalPlan] {
897899

900+
/**
901+
* When pushing a [[Project]] through [[Union]] we need to maintain the invariant that [[Union]]
902+
* children must have unique [[ExprId]]s per branch. We can safely deduplicate [[ExprId]]s
903+
* without updating any references because those [[ExprId]]s will simply remain unused.
904+
* For example, in a `Project(col1#1, col#1)` we will alias the second `col1` and get
905+
* `Project(col1#1, col1 as col1#2)`. We don't need to update any references to `col1#1` we
906+
* aliased because `col1#1` still exists in [[Project]] output.
907+
*/
908+
private def deduplicateProjectList(projectList: Seq[NamedExpression]) = {
909+
val existingExprIds = new HashSet[ExprId]
910+
projectList.map(attr => if (existingExprIds.contains(attr.exprId)) {
911+
Alias(attr, attr.name)()
912+
} else {
913+
existingExprIds.add(attr.exprId)
914+
attr
915+
})
916+
}
917+
898918
/**
899919
* Maps Attributes from the left side to the corresponding Attribute on the right side.
900920
*/
@@ -923,10 +943,15 @@ object PushProjectionThroughUnion extends Rule[LogicalPlan] {
923943
}
924944

925945
def pushProjectionThroughUnion(projectList: Seq[NamedExpression], u: Union): Seq[LogicalPlan] = {
926-
val newFirstChild = Project(projectList, u.children.head)
946+
val deduplicatedProjectList = if (conf.unionIsResolvedWhenDuplicatesPerChildResolved) {
947+
deduplicateProjectList(projectList)
948+
} else {
949+
projectList
950+
}
951+
val newFirstChild = Project(deduplicatedProjectList, u.children.head)
927952
val newOtherChildren = u.children.tail.map { child =>
928953
val rewrites = buildRewrites(u.children.head, child)
929-
Project(projectList.map(pushToRight(_, rewrites)), child)
954+
Project(deduplicatedProjectList.map(pushToRight(_, rewrites)), child)
930955
}
931956
newFirstChild +: newOtherChildren
932957
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -614,13 +614,20 @@ case class Union(
614614
Some(sum.toLong)
615615
}
616616

617-
def duplicateResolved: Boolean = {
617+
private def duplicatesResolvedPerBranch: Boolean =
618+
children.forall(child => child.outputSet.size == child.output.size)
619+
620+
def duplicatesResolvedBetweenBranches: Boolean = {
618621
children.map(_.outputSet.size).sum ==
619622
AttributeSet.fromAttributeSets(children.map(_.outputSet)).size
620623
}
621624

622625
override lazy val resolved: Boolean = {
623-
children.length > 1 && !(byName || allowMissingCol) && childrenResolved && allChildrenCompatible
626+
children.length > 1 &&
627+
!(byName || allowMissingCol) &&
628+
childrenResolved &&
629+
allChildrenCompatible &&
630+
(!conf.unionIsResolvedWhenDuplicatesPerChildResolved || duplicatesResolvedPerBranch)
624631
}
625632

626633
override protected def withNewChildrenInternal(newChildren: IndexedSeq[LogicalPlan]): Union =

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,15 @@ object SQLConf {
241241
}
242242
}
243243

244+
val UNION_IS_RESOLVED_WHEN_DUPLICATES_PER_CHILD_RESOLVED =
245+
buildConf("spark.sql.analyzer.unionIsResolvedWhenDuplicatesPerChildResolved")
246+
.internal()
247+
.doc(
248+
"When true, union should only be resolved once there are no duplicate attributes in " +
249+
"each branch.")
250+
.booleanConf
251+
.createWithDefault(true)
252+
244253
val ONLY_NECESSARY_AND_UNIQUE_METADATA_COLUMNS =
245254
buildConf("spark.sql.analyzer.uniqueNecessaryMetadataColumns")
246255
.internal()
@@ -6852,6 +6861,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
68526861
def useNullsForMissingDefaultColumnValues: Boolean =
68536862
getConf(SQLConf.USE_NULLS_FOR_MISSING_DEFAULT_COLUMN_VALUES)
68546863

6864+
def unionIsResolvedWhenDuplicatesPerChildResolved: Boolean =
6865+
getConf(SQLConf.UNION_IS_RESOLVED_WHEN_DUPLICATES_PER_CHILD_RESOLVED)
6866+
68556867
override def enforceReservedKeywords: Boolean = ansiEnabled && getConf(ENFORCE_RESERVED_KEYWORDS)
68566868

68576869
override def doubleQuotedIdentifiers: Boolean = ansiEnabled && getConf(DOUBLE_QUOTED_IDENTIFIERS)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
Union false, false
1+
'Union false, false
22
:- LocalRelation <empty>, [id#0L, a#0, b#0]
33
+- LocalRelation <empty>, [id#0L, a#0, b#0]
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
Union false, false
1+
'Union false, false
22
:- LocalRelation <empty>, [id#0L, a#0, b#0]
33
+- LocalRelation <empty>, [id#0L, a#0, b#0]

sql/connect/common/src/test/resources/query-tests/explain-results/unionByName.explain

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
Union false, false
1+
'Union false, false
22
:- Project [id#0L, a#0]
33
: +- LocalRelation <empty>, [id#0L, a#0, b#0]
44
+- Project [id#0L, a#0]

sql/connect/common/src/test/resources/query-tests/explain-results/unionByName_allowMissingColumns.explain

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
Union false, false
1+
'Union false, false
22
:- Project [id#0L, a#0, b#0, null AS payload#0]
33
: +- LocalRelation <empty>, [id#0L, a#0, b#0]
44
+- Project [id#0L, a#0, null AS b#0, payload#0]

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4963,6 +4963,45 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
49634963
)
49644964
}
49654965

4966+
test("SPARK-52686: Union should be resolved only if there are no duplicates") {
4967+
withTable("t1", "t2", "t3") {
4968+
sql("CREATE TABLE t1 (col1 STRING, col2 STRING, col3 STRING)")
4969+
sql("CREATE TABLE t2 (col1 STRING, col2 DOUBLE, col3 STRING)")
4970+
sql("CREATE TABLE t3 (col1 STRING, col2 DOUBLE, a STRING, col3 STRING)")
4971+
4972+
for (confValue <- Seq(false, true)) {
4973+
withSQLConf(
4974+
SQLConf.UNION_IS_RESOLVED_WHEN_DUPLICATES_PER_CHILD_RESOLVED.key -> confValue.toString
4975+
) {
4976+
val analyzedPlan = sql(
4977+
"""SELECT
4978+
| *
4979+
|FROM (
4980+
| SELECT col1, col2, NULL AS a, col1 FROM t1
4981+
| UNION
4982+
| SELECT col1, col2, NULL AS a, col3 FROM t2
4983+
| UNION
4984+
| SELECT * FROM t3
4985+
|)""".stripMargin
4986+
).queryExecution.analyzed
4987+
4988+
val projectCount = analyzedPlan.collect {
4989+
case project: Project => project
4990+
}.size
4991+
4992+
// When UNION_IS_RESOLVED_WHEN_DUPLICATES_PER_CHILD_RESOLVED is disabled, we resolve
4993+
// outer Union before deduplicating ExprIds in inner union. Because of this we get an
4994+
// additional unnecessary Project (see SPARK-52686).
4995+
if (confValue) {
4996+
assert(projectCount == 7)
4997+
} else {
4998+
assert(projectCount == 8)
4999+
}
5000+
}
5001+
}
5002+
}
5003+
}
5004+
49665005
Seq(true, false).foreach { codegenEnabled =>
49675006
test(s"SPARK-52060: one row relation with codegen enabled - $codegenEnabled") {
49685007
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegenEnabled.toString) {

0 commit comments

Comments
 (0)