Skip to content

Commit 170ffef

Browse files
committed
fix
1 parent ff93cbb commit 170ffef

File tree

9 files changed

+89
-9
lines changed

9 files changed

+89
-9
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
5959
case e @ Except(left, right, _) if !e.duplicateResolved && noMissingInput(right) =>
6060
e.copy(right = dedupRight(left, right))
6161
// Only after we finish by-name resolution for Union
62-
case u: Union if !u.byName && !u.duplicateResolved =>
62+
case u: Union if !u.byName && !u.duplicatesResolvedBetweenBranches =>
6363
val unionWithChildOutputsDeduplicated =
6464
DeduplicateUnionChildOutput.deduplicateOutputPerChild(u)
6565
// Use projection-based de-duplication for Union to avoid breaking the checkpoint sharing

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.catalyst.optimizer
1919

20+
import java.util.HashSet
21+
2022
import scala.collection.mutable
2123

2224
import org.apache.spark.SparkException
@@ -895,6 +897,20 @@ object LimitPushDown extends Rule[LogicalPlan] {
895897
*/
896898
object PushProjectionThroughUnion extends Rule[LogicalPlan] {
897899

900+
/**
901+
* When pushing a [[Project]] through [[Union]] we need to maintain the invariant that
902+
* [[Union]] children must have unique [[ExprId]]s per branch.
903+
*/
904+
private def deduplicateProjectList(projectList: Seq[NamedExpression]) = {
905+
val existingExprIds = new HashSet[ExprId]
906+
projectList.map(attr => if (existingExprIds.contains(attr.exprId)) {
907+
Alias(attr, attr.name)()
908+
} else {
909+
existingExprIds.add(attr.exprId)
910+
attr
911+
})
912+
}
913+
898914
/**
899915
* Maps Attributes from the left side to the corresponding Attribute on the right side.
900916
*/
@@ -923,10 +939,15 @@ object PushProjectionThroughUnion extends Rule[LogicalPlan] {
923939
}
924940

925941
def pushProjectionThroughUnion(projectList: Seq[NamedExpression], u: Union): Seq[LogicalPlan] = {
926-
val newFirstChild = Project(projectList, u.children.head)
942+
val deduplicatedProjectList = if (conf.unionIsResolvedWhenDuplicatesPerChildResolved) {
943+
deduplicateProjectList(projectList)
944+
} else {
945+
projectList
946+
}
947+
val newFirstChild = Project(deduplicatedProjectList, u.children.head)
927948
val newOtherChildren = u.children.tail.map { child =>
928949
val rewrites = buildRewrites(u.children.head, child)
929-
Project(projectList.map(pushToRight(_, rewrites)), child)
950+
Project(deduplicatedProjectList.map(pushToRight(_, rewrites)), child)
930951
}
931952
newFirstChild +: newOtherChildren
932953
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -614,13 +614,20 @@ case class Union(
614614
Some(sum.toLong)
615615
}
616616

617-
def duplicateResolved: Boolean = {
617+
private def duplicatesResolvedPerBranch: Boolean =
618+
children.forall(child => child.outputSet.size == child.output.size)
619+
620+
def duplicatesResolvedBetweenBranches: Boolean = {
618621
children.map(_.outputSet.size).sum ==
619622
AttributeSet.fromAttributeSets(children.map(_.outputSet)).size
620623
}
621624

622625
override lazy val resolved: Boolean = {
623-
children.length > 1 && !(byName || allowMissingCol) && childrenResolved && allChildrenCompatible
626+
children.length > 1 &&
627+
!(byName || allowMissingCol) &&
628+
childrenResolved &&
629+
allChildrenCompatible &&
630+
(!conf.unionIsResolvedWhenDuplicatesPerChildResolved || duplicatesResolvedPerBranch)
624631
}
625632

626633
override protected def withNewChildrenInternal(newChildren: IndexedSeq[LogicalPlan]): Union =

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,16 @@ object SQLConf {
241241
}
242242
}
243243

244+
val UNION_IS_RESOLVED_WHEN_DUPLICATES_PER_CHILD_RESOLVED =
245+
buildConf("spark.sql.analyzer.unionIsResolvedWhenDuplicatesPerChildResolved")
246+
.internal()
247+
.doc(
248+
"When true, union should only be resolved once there are no duplicate attributes in " +
249+
"each branch."
250+
)
251+
.booleanConf
252+
.createWithDefault(true)
253+
244254
val ONLY_NECESSARY_AND_UNIQUE_METADATA_COLUMNS =
245255
buildConf("spark.sql.analyzer.uniqueNecessaryMetadataColumns")
246256
.internal()
@@ -6852,6 +6862,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
68526862
def useNullsForMissingDefaultColumnValues: Boolean =
68536863
getConf(SQLConf.USE_NULLS_FOR_MISSING_DEFAULT_COLUMN_VALUES)
68546864

6865+
def unionIsResolvedWhenDuplicatesPerChildResolved: Boolean =
6866+
getConf(SQLConf.UNION_IS_RESOLVED_WHEN_DUPLICATES_PER_CHILD_RESOLVED)
6867+
68556868
override def enforceReservedKeywords: Boolean = ansiEnabled && getConf(ENFORCE_RESERVED_KEYWORDS)
68566869

68576870
override def doubleQuotedIdentifiers: Boolean = ansiEnabled && getConf(DOUBLE_QUOTED_IDENTIFIERS)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
Union false, false
1+
'Union false, false
22
:- LocalRelation <empty>, [id#0L, a#0, b#0]
33
+- LocalRelation <empty>, [id#0L, a#0, b#0]
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
Union false, false
1+
'Union false, false
22
:- LocalRelation <empty>, [id#0L, a#0, b#0]
33
+- LocalRelation <empty>, [id#0L, a#0, b#0]

sql/connect/common/src/test/resources/query-tests/explain-results/unionByName.explain

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
Union false, false
1+
'Union false, false
22
:- Project [id#0L, a#0]
33
: +- LocalRelation <empty>, [id#0L, a#0, b#0]
44
+- Project [id#0L, a#0]

sql/connect/common/src/test/resources/query-tests/explain-results/unionByName_allowMissingColumns.explain

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
Union false, false
1+
'Union false, false
22
:- Project [id#0L, a#0, b#0, null AS payload#0]
33
: +- LocalRelation <empty>, [id#0L, a#0, b#0]
44
+- Project [id#0L, a#0, null AS b#0, payload#0]

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4963,6 +4963,45 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
49634963
)
49644964
}
49654965

4966+
test("SPARK-52686: Union should be resolved only if there are no duplicates") {
4967+
withTable("t1", "t2", "t3") {
4968+
sql("CREATE TABLE t1 (col1 STRING, col2 STRING, col3 STRING)")
4969+
sql("CREATE TABLE t2 (col1 STRING, col2 DOUBLE, col3 STRING)")
4970+
sql("CREATE TABLE t3 (col1 STRING, col2 DOUBLE, a STRING, col3 STRING)")
4971+
4972+
for (confValue <- Seq(false, true)) {
4973+
withSQLConf(
4974+
SQLConf.UNION_IS_RESOLVED_WHEN_DUPLICATES_PER_CHILD_RESOLVED.key -> confValue.toString
4975+
) {
4976+
val analyzedPlan = sql(
4977+
"""SELECT
4978+
| *
4979+
|FROM (
4980+
| SELECT col1, col2, NULL AS a, col1 FROM t1
4981+
| UNION
4982+
| SELECT col1, col2, NULL AS a, col3 FROM t2
4983+
| UNION
4984+
| SELECT * FROM t3
4985+
|)""".stripMargin
4986+
).queryExecution.analyzed
4987+
4988+
val projectCount = analyzedPlan.collect {
4989+
case project: Project => project
4990+
}.size
4991+
4992+
// When UNION_IS_RESOLVED_WHEN_DUPLICATES_PER_CHILD_RESOLVED is disabled, we resolve
4993+
// outer Union before deduplicating ExprIds in inner union. Because of this we get an
4994+
// additional unnecessary Project (see SPARK-52686).
4995+
if (confValue) {
4996+
assert(projectCount == 7)
4997+
} else {
4998+
assert(projectCount == 8)
4999+
}
5000+
}
5001+
}
5002+
}
5003+
}
5004+
49665005
Seq(true, false).foreach { codegenEnabled =>
49675006
test(s"SPARK-52060: one row relation with codegen enabled - $codegenEnabled") {
49685007
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegenEnabled.toString) {

0 commit comments

Comments
 (0)