From 688eb79e77c95a5a63ae09e92299a9fa9e34c578 Mon Sep 17 00:00:00 2001 From: Mihailo Timotic Date: Fri, 4 Jul 2025 14:44:36 +0200 Subject: [PATCH] fix --- .../analysis/DeduplicateRelations.scala | 2 +- .../sql/catalyst/optimizer/Optimizer.scala | 29 +++++++++++++- .../plans/logical/basicLogicalOperators.scala | 11 +++++- .../apache/spark/sql/internal/SQLConf.scala | 12 ++++++ .../query-tests/explain-results/union.explain | 2 +- .../explain-results/unionAll.explain | 2 +- .../explain-results/unionByName.explain | 2 +- .../unionByName_allowMissingColumns.explain | 2 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 39 +++++++++++++++++++ 9 files changed, 92 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala index da940c9b8ead3..b6181a2d54faf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala @@ -59,7 +59,7 @@ object DeduplicateRelations extends Rule[LogicalPlan] { case e @ Except(left, right, _) if !e.duplicateResolved && noMissingInput(right) => e.copy(right = dedupRight(left, right)) // Only after we finish by-name resolution for Union - case u: Union if !u.byName && !u.duplicateResolved => + case u: Union if !u.byName && !u.duplicatesResolvedBetweenBranches => val unionWithChildOutputsDeduplicated = DeduplicateUnionChildOutput.deduplicateOutputPerChild(u) // Use projection-based de-duplication for Union to avoid breaking the checkpoint sharing diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 2701a79077ce6..2ae507c831f1a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.optimizer +import java.util.HashSet + import scala.collection.mutable import org.apache.spark.SparkException @@ -895,6 +897,24 @@ object LimitPushDown extends Rule[LogicalPlan] { */ object PushProjectionThroughUnion extends Rule[LogicalPlan] { + /** + * When pushing a [[Project]] through [[Union]] we need to maintain the invariant that [[Union]] + * children must have unique [[ExprId]]s per branch. We can safely deduplicate [[ExprId]]s + * without updating any references because those [[ExprId]]s will simply remain unused. + * For example, in a `Project(col1#1, col#1)` we will alias the second `col1` and get + * `Project(col1#1, col1 as col1#2)`. We don't need to update any references to `col1#1` we + * aliased because `col1#1` still exists in [[Project]] output. + */ + private def deduplicateProjectList(projectList: Seq[NamedExpression]) = { + val existingExprIds = new HashSet[ExprId] + projectList.map(attr => if (existingExprIds.contains(attr.exprId)) { + Alias(attr, attr.name)() + } else { + existingExprIds.add(attr.exprId) + attr + }) + } + /** * Maps Attributes from the left side to the corresponding Attribute on the right side. */ @@ -923,10 +943,15 @@ object PushProjectionThroughUnion extends Rule[LogicalPlan] { } def pushProjectionThroughUnion(projectList: Seq[NamedExpression], u: Union): Seq[LogicalPlan] = { - val newFirstChild = Project(projectList, u.children.head) + val deduplicatedProjectList = if (conf.unionIsResolvedWhenDuplicatesPerChildResolved) { + deduplicateProjectList(projectList) + } else { + projectList + } + val newFirstChild = Project(deduplicatedProjectList, u.children.head) val newOtherChildren = u.children.tail.map { child => val rewrites = buildRewrites(u.children.head, child) - Project(projectList.map(pushToRight(_, rewrites)), child) + Project(deduplicatedProjectList.map(pushToRight(_, rewrites)), child) } newFirstChild +: newOtherChildren } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 5215e8a9568bc..01e49a3eda81a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -614,13 +614,20 @@ case class Union( Some(sum.toLong) } - def duplicateResolved: Boolean = { + private def duplicatesResolvedPerBranch: Boolean = + children.forall(child => child.outputSet.size == child.output.size) + + def duplicatesResolvedBetweenBranches: Boolean = { children.map(_.outputSet.size).sum == AttributeSet.fromAttributeSets(children.map(_.outputSet)).size } override lazy val resolved: Boolean = { - children.length > 1 && !(byName || allowMissingCol) && childrenResolved && allChildrenCompatible + children.length > 1 && + !(byName || allowMissingCol) && + childrenResolved && + allChildrenCompatible && + (!conf.unionIsResolvedWhenDuplicatesPerChildResolved || duplicatesResolvedPerBranch) } override protected def withNewChildrenInternal(newChildren: IndexedSeq[LogicalPlan]): Union = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 34d7c64d200f2..0fef83fcdcc3c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -241,6 +241,15 @@ object SQLConf { } } + val UNION_IS_RESOLVED_WHEN_DUPLICATES_PER_CHILD_RESOLVED = + buildConf("spark.sql.analyzer.unionIsResolvedWhenDuplicatesPerChildResolved") + .internal() + .doc( + "When true, union should only be resolved once there are no duplicate attributes in " + + "each branch.") + .booleanConf + .createWithDefault(true) + val ONLY_NECESSARY_AND_UNIQUE_METADATA_COLUMNS = buildConf("spark.sql.analyzer.uniqueNecessaryMetadataColumns") .internal() @@ -6852,6 +6861,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def useNullsForMissingDefaultColumnValues: Boolean = getConf(SQLConf.USE_NULLS_FOR_MISSING_DEFAULT_COLUMN_VALUES) + def unionIsResolvedWhenDuplicatesPerChildResolved: Boolean = + getConf(SQLConf.UNION_IS_RESOLVED_WHEN_DUPLICATES_PER_CHILD_RESOLVED) + override def enforceReservedKeywords: Boolean = ansiEnabled && getConf(ENFORCE_RESERVED_KEYWORDS) override def doubleQuotedIdentifiers: Boolean = ansiEnabled && getConf(DOUBLE_QUOTED_IDENTIFIERS) diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/union.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/union.explain index 4d5d1f53b8412..252774510896c 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/union.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/union.explain @@ -1,3 +1,3 @@ -Union false, false +'Union false, false :- LocalRelation , [id#0L, a#0, b#0] +- LocalRelation , [id#0L, a#0, b#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/unionAll.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/unionAll.explain index 4d5d1f53b8412..252774510896c 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/unionAll.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/unionAll.explain @@ -1,3 +1,3 @@ -Union false, false +'Union false, false :- LocalRelation , [id#0L, a#0, b#0] +- LocalRelation , [id#0L, a#0, b#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/unionByName.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/unionByName.explain index 6ec8eb37f50ed..2877c7cef0fda 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/unionByName.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/unionByName.explain @@ -1,4 +1,4 @@ -Union false, false +'Union false, false :- Project [id#0L, a#0] : +- LocalRelation , [id#0L, a#0, b#0] +- Project [id#0L, a#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/unionByName_allowMissingColumns.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/unionByName_allowMissingColumns.explain index 96bd9f281c15e..dc0d1d94f85c1 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/unionByName_allowMissingColumns.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/unionByName_allowMissingColumns.explain @@ -1,4 +1,4 @@ -Union false, false +'Union false, false :- Project [id#0L, a#0, b#0, null AS payload#0] : +- LocalRelation , [id#0L, a#0, b#0] +- Project [id#0L, a#0, null AS b#0, payload#0] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 26aa4b6b5210f..f405989520e33 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -4963,6 +4963,45 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark ) } + test("SPARK-52686: Union should be resolved only if there are no duplicates") { + withTable("t1", "t2", "t3") { + sql("CREATE TABLE t1 (col1 STRING, col2 STRING, col3 STRING)") + sql("CREATE TABLE t2 (col1 STRING, col2 DOUBLE, col3 STRING)") + sql("CREATE TABLE t3 (col1 STRING, col2 DOUBLE, a STRING, col3 STRING)") + + for (confValue <- Seq(false, true)) { + withSQLConf( + SQLConf.UNION_IS_RESOLVED_WHEN_DUPLICATES_PER_CHILD_RESOLVED.key -> confValue.toString + ) { + val analyzedPlan = sql( + """SELECT + | * + |FROM ( + | SELECT col1, col2, NULL AS a, col1 FROM t1 + | UNION + | SELECT col1, col2, NULL AS a, col3 FROM t2 + | UNION + | SELECT * FROM t3 + |)""".stripMargin + ).queryExecution.analyzed + + val projectCount = analyzedPlan.collect { + case project: Project => project + }.size + + // When UNION_IS_RESOLVED_WHEN_DUPLICATES_PER_CHILD_RESOLVED is disabled, we resolve + // outer Union before deduplicating ExprIds in inner union. Because of this we get an + // additional unnecessary Project (see SPARK-52686). + if (confValue) { + assert(projectCount == 7) + } else { + assert(projectCount == 8) + } + } + } + } + } + Seq(true, false).foreach { codegenEnabled => test(s"SPARK-52060: one row relation with codegen enabled - $codegenEnabled") { withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegenEnabled.toString) {