Skip to content

[SPARK-52921][SQL] Specify outputPartitioning for UnionExec for partitioner aware case #51623

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1606,11 +1606,16 @@ class SparkContext(config: SparkConf) extends Logging {
new ReliableCheckpointRDD[T](this, path)
}

// Note that input rdds must be all non-empty, i.e., rdds.filter(_.partitions.isEmpty).isEmpty
protected[spark] def isPartitionerAwareUnion[T: ClassTag](rdds: Seq[RDD[T]]): Boolean = {
Copy link
Member

@dongjoon-hyun dongjoon-hyun Jul 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment about the assumption, rdds.filter(!_.partitions.isEmpty)? Otherwise, it may cause correctness issues later if we use this blindly.

Otherwise, we had better include the assumption inside this method.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added comment and a check.

assert(!rdds.exists(_.partitions.isEmpty), "Must not have empty RDDs")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

rdds.forall(_.partitioner.isDefined) && rdds.flatMap(_.partitioner).toSet.size == 1
}

/** Build the union of a list of RDDs. */
def union[T: ClassTag](rdds: Seq[RDD[T]]): RDD[T] = withScope {
val nonEmptyRdds = rdds.filter(!_.partitions.isEmpty)
val partitioners = nonEmptyRdds.flatMap(_.partitioner).toSet
if (nonEmptyRdds.forall(_.partitioner.isDefined) && partitioners.size == 1) {
if (isPartitionerAwareUnion(nonEmptyRdds)) {
new PartitionerAwareUnionRDD(this, nonEmptyRdds)
} else {
new UnionRDD(this, nonEmptyRdds)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5993,6 +5993,16 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val UNION_OUTPUT_PARTITIONING =
buildConf("spark.sql.unionOutputPartitioning")
.internal()
.doc("When set to true, the output partitioning of UnionExec will be the same as the " +
"input partitioning if its children have same partitioner. Otherwise, it will be a " +
"default partitioning.")
.version("4.1.0")
.booleanConf
.createWithDefault(true)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For safety, added an internal config for it.


val LEGACY_PARSE_QUERY_WITHOUT_EOF = buildConf("spark.sql.legacy.parseQueryWithoutEof")
.internal()
.doc(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import java.util.concurrent.TimeUnit._
import scala.collection.mutable
import scala.concurrent.ExecutionContext
import scala.concurrent.duration.Duration
import scala.util.control.NonFatal

import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, SparkException, TaskContext}
import org.apache.spark.rdd.{EmptyRDD, PartitionwiseSampledRDD, RDD}
Expand All @@ -31,6 +32,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.trees.TreePattern.COMMAND
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
import org.apache.spark.sql.types.{LongType, StructType}
Expand Down Expand Up @@ -699,8 +701,42 @@ case class UnionExec(children: Seq[SparkPlan]) extends SparkPlan {
}
}

private lazy val childrenRDDs = children.map(_.execute())

override def outputPartitioning: Partitioning = {
if (conf.getConf(SQLConf.UNION_OUTPUT_PARTITIONING)) {
// Commands like `AppendDataExec` have side effects when creating RDDs, so we
// cannot call `execute` on them to determine the partitioning.
if (children.exists(_.containsPattern(COMMAND))) {
return super.outputPartitioning
}

try {
val nonEmptyRdds = childrenRDDs.filter(!_.partitions.isEmpty)
if (sparkContext.isPartitionerAwareUnion(nonEmptyRdds)) {
// `isPartitionerAwareUnion` ensures that at least one child is non-empty.
children.head.outputPartitioning
} else {
super.outputPartitioning
}
} catch {
// If any child operator doesn't support `execute`, we cannot determine the
// partitioning. Even if it is other exception, we also simply fall back to
// the default partitioning. Note that for such cases, it means that these
// child operator will be replaced by Spark in query planning later, in other
// words, `execute` won't be actually called on them during the execution of
// this plan. So we can safely return the default partitioning. If it is a
// real exception, when `doExecute` is called to access `childrenRDDs`, the
// exception will be thrown again.
case e if NonFatal(e) => super.outputPartitioning
}
} else {
super.outputPartitioning
}
}

protected override def doExecute(): RDD[InternalRow] =
sparkContext.union(children.map(_.execute()))
sparkContext.union(childrenRDDs)

override def supportsColumnar: Boolean = children.forall(_.supportsColumnar)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{Command, ExecutableDuringAnalysis, LogicalPlan, SupervisingCommand}
import org.apache.spark.sql.catalyst.trees.{LeafLike, UnaryLike}
import org.apache.spark.sql.catalyst.trees.{LeafLike, TreePattern, UnaryLike}
import org.apache.spark.sql.connector.ExternalCommandRunner
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.{CommandExecutionMode, ExplainMode, LeafExecNode, SparkPlan, UnaryExecNode}
Expand Down Expand Up @@ -110,6 +110,8 @@ case class ExecutedCommandExec(cmd: RunnableCommand) extends LeafExecNode {
case class DataWritingCommandExec(cmd: DataWritingCommand, child: SparkPlan)
extends UnaryExecNode {

final override val nodePatterns: Seq[TreePattern.TreePattern] = Seq(TreePattern.COMMAND)

override lazy val metrics: Map[String, SQLMetric] = cmd.metrics

protected[sql] lazy val sideEffectResult: Seq[InternalRow] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{AttributeSet, GenericRowWithSchema}
import org.apache.spark.sql.catalyst.trees.LeafLike
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.execution.SparkPlan

Expand All @@ -30,6 +31,7 @@ import org.apache.spark.sql.execution.SparkPlan
* Any V2 commands that do not require triggering a spark job should extend this class.
*/
abstract class V2CommandExec extends SparkPlan {
final override val nodePatterns: Seq[TreePattern] = Seq(COMMAND)

/**
* Abstract method that each concrete command needs to implement to compute the result.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Union
import org.apache.spark.sql.execution.{SparkPlan, UnionExec}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSparkSession, SQLTestData}
Expand Down Expand Up @@ -1508,6 +1509,28 @@ class DataFrameSetOperationsSuite extends QueryTest
}
}
}

test("union partitioning") {
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
val df1 = Seq((1, 2, 4), (1, 3, 5), (2, 2, 3), (2, 4, 5)).toDF("a", "b", "c")
val df2 = Seq((1, 2, 4), (1, 3, 5), (2, 2, 3), (2, 4, 5)).toDF("a", "b", "c")

val union = df1.repartition($"a").union(df2.repartition($"a"))
val unionExec = union.queryExecution.executedPlan.collect {
case u: UnionExec => u
}
assert(unionExec.size == 1)

val shuffle = df1.repartition($"a").queryExecution.executedPlan.collect {
case s: ShuffleExchangeExec => s
}
assert(shuffle.size == 1)

val childPartitioning = shuffle.head.outputPartitioning
val partitioning = unionExec.head.outputPartitioning
assert(partitioning == childPartitioning)
}
}
}

case class UnionClass1a(a: Int, b: Long, nested: UnionClass2)
Expand Down