Skip to content

feat: add CopyExec and move CopyExec handling to Spark #2001

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 19 additions & 21 deletions native/core/src/execution/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1107,15 +1107,8 @@ impl PhysicalPlanner {
.collect();

let fetch = sort.fetch.map(|num| num as usize);

// SortExec caches batches so we need to make a copy of incoming batches. Also,
// SortExec fails in some cases if we do not unpack dictionary-encoded arrays, and
// it would be more efficient if we could avoid that.
// https://github.com/apache/datafusion-comet/issues/963
let child_copied = Self::wrap_in_copy_exec(Arc::clone(&child.native_plan));

let sort = Arc::new(
SortExec::new(LexOrdering::new(exprs?), Arc::clone(&child_copied))
SortExec::new(LexOrdering::new(exprs?), Arc::clone(&child.native_plan))
.with_fetch(fetch),
);

Expand Down Expand Up @@ -1285,7 +1278,7 @@ impl PhysicalPlanner {
}?;

let shuffle_writer = Arc::new(ShuffleWriterExec::try_new(
Self::wrap_in_copy_exec(Arc::clone(&child.native_plan)),
Arc::clone(&child.native_plan),
partitioning,
codec,
writer.output_data_file.clone(),
Expand Down Expand Up @@ -1344,6 +1337,7 @@ impl PhysicalPlanner {
// if the child operator is `ScanExec`, because other operators after `ScanExec`
// will create new arrays for the output batch.
let input = if can_reuse_input_batch(&child.native_plan) {
// FIXME: handle me in Spark Planner
Arc::new(CopyExec::new(
Arc::clone(&child.native_plan),
CopyMode::UnpackOrDeepCopy,
Expand Down Expand Up @@ -1446,8 +1440,8 @@ impl PhysicalPlanner {
// to copy the input batch to avoid the data corruption from reusing the input
// batch. We also need to unpack dictionary arrays, because the join operators
// do not support them.
let left = Self::wrap_in_copy_exec(Arc::clone(&join_params.left.native_plan));
let right = Self::wrap_in_copy_exec(Arc::clone(&join_params.right.native_plan));
let left = Arc::clone(&join_params.left.native_plan);
let right = Arc::clone(&join_params.right.native_plan);

let hash_join = Arc::new(HashJoinExec::try_new(
left,
Expand Down Expand Up @@ -1535,6 +1529,20 @@ impl PhysicalPlanner {
Arc::new(SparkPlan::new(spark_plan.plan_id, window_agg, vec![child])),
))
}
OpStruct::Copy(copy) => {
assert_eq!(children.len(), 1);
let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?;
let copy_mode = if copy.mode == 0 {
CopyMode::UnpackOrDeepCopy
} else {
CopyMode::UnpackOrClone
};
let copy = Arc::new(CopyExec::new(Arc::clone(&child.native_plan), copy_mode));
Ok((
scans,
Arc::new(SparkPlan::new(spark_plan.plan_id, copy, vec![child])),
))
}
}
}

Expand Down Expand Up @@ -1679,16 +1687,6 @@ impl PhysicalPlanner {
))
}

/// Wrap an ExecutionPlan in a CopyExec, which will unpack any dictionary-encoded arrays
/// and make a deep copy of other arrays if the plan re-uses batches.
fn wrap_in_copy_exec(plan: Arc<dyn ExecutionPlan>) -> Arc<dyn ExecutionPlan> {
if can_reuse_input_batch(&plan) {
Arc::new(CopyExec::new(plan, CopyMode::UnpackOrDeepCopy))
} else {
Arc::new(CopyExec::new(plan, CopyMode::UnpackOrClone))
}
}

/// Create a DataFusion physical aggregate expression from Spark physical aggregate expression
fn create_agg_expr(
&self,
Expand Down
11 changes: 11 additions & 0 deletions native/proto/src/proto/operator.proto
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ message Operator {
HashJoin hash_join = 109;
Window window = 110;
NativeScan native_scan = 111;
Copy copy = 112;
}
}

Expand Down Expand Up @@ -244,3 +245,13 @@ message Window {
repeated spark.spark_expression.Expr partition_by_list = 3;
Operator child = 4;
}


enum CopyMode {
UnpackOrDeepCopy = 0;
UnpackOrClone = 1;
}

message Copy {
CopyMode mode = 3;
}
44 changes: 43 additions & 1 deletion spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

package org.apache.comet.rules

import scala.annotation.tailrec
import scala.collection.mutable.ListBuffer

import org.apache.spark.sql.SparkSession
Expand Down Expand Up @@ -338,6 +339,11 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
op.right,
SerializedPlan(None)))

case op: CopyExec if op.children.forall(isCometNative) =>
newPlanWithProto(
op,
CometCopyExec(_, op, op.output, op.copyMode, op.child, SerializedPlan(None)))

case op: SortMergeJoinExec
if CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED.get(conf) &&
!op.children.forall(isCometNative) =>
Expand Down Expand Up @@ -671,7 +677,9 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
normalizePlan(plan)
}

var newPlan = transform(normalizedPlan)
// FIXME: Should we move to separate Rule
var newPlan = transformAndAddCopyExec(normalizedPlan)
newPlan = transform(normalizedPlan)

// if the plan cannot be run fully natively then explain why (when appropriate
// config is enabled)
Expand Down Expand Up @@ -751,6 +759,40 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
}
}

private def transformAndAddCopyExec(plan: SparkPlan) = plan.transform {
case shj: ShuffledHashJoinExec =>
val newLeft = wrapInCopyExec(shj.left)
val newRight = wrapInCopyExec(shj.right)
shj.copy(left = newLeft, right = newRight)
case se: SortExec =>
val newChild = wrapInCopyExec(se.child)
se.copy(child = newChild)
case ee: ExpandExec =>
val newChild = wrapInCopyExec(ee.child)
ee.copy(child = newChild)
}

/// Returns true if given operator can return input array as output array without
/// modification. This is used to determine if we need to copy the input batch to avoid
/// data corruption from reusing the input batch.
@tailrec
private def canReuseInputBatch(plan: SparkPlan): Boolean = {
if (plan.isInstanceOf[ProjectExec] || plan.isInstanceOf[LocalLimitExec]) {
canReuseInputBatch(plan.children.head)
} else {
// FIXME
plan.isInstanceOf[CometScanExec]
}
}

private def wrapInCopyExec(plan: SparkPlan): SparkPlan = {
if (canReuseInputBatch(plan)) {
CopyExec(plan, UnpackOrDeepCopy)
} else {
CopyExec(plan, UnpackOrClone)
}
}

/**
* Find the first Comet partial aggregate in the plan. If it reaches a Spark HashAggregate with
* partial mode, it will return None.
Expand Down
17 changes: 17 additions & 0 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2306,6 +2306,23 @@ object QueryPlanSerde extends Logging with CometExprShim {
None
}

case CopyExec(child, copyMode) => {
if (childOp.nonEmpty) {
val copyModeBuilder = if (copyMode == UnpackOrDeepCopy) {
OperatorOuterClass.CopyMode.UnpackOrClone
} else {
OperatorOuterClass.CopyMode.UnpackOrDeepCopy
}
val copyBuilder = OperatorOuterClass.Copy
.newBuilder()
.setMode(copyModeBuilder)
Some(result.setCopy(copyBuilder).build())
} else {
withInfo(op, child)
None
}
}

case FilterExec(condition, child) if CometConf.COMET_EXEC_FILTER_ENABLED.get(conf) =>
val cond = exprToProto(condition, child.output)

Expand Down
42 changes: 42 additions & 0 deletions spark/src/main/scala/org/apache/spark/sql/comet/CopyExec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.spark.sql.comet

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}

case class CopyExec(override val child: SparkPlan, copyMode: CopyMode) extends UnaryExecNode {
override protected def doExecute(): RDD[InternalRow] = {
// This method should never be invoked as CopyExec is an internal operator used
// during native execution offload to handle data deep copying/cloning Record batches
// The actual execution happens in the native layer through CometExecNode.
throw new UnsupportedOperationException(
"This method should not be called directly - this operator is meant for internal purposes only")
}
override def output: Seq[Attribute] = child.output
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
this.copy(child = newChild)
}

sealed abstract class CopyMode {}
case object UnpackOrDeepCopy extends CopyMode
case object UnpackOrClone extends CopyMode
19 changes: 19 additions & 0 deletions spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1000,3 +1000,22 @@ case class CometSinkPlaceHolder(

override def stringArgs: Iterator[Any] = Iterator(originalPlan.output, child)
}

case class CometCopyExec(
override val nativeOp: Operator,
override val originalPlan: SparkPlan,
override val output: Seq[Attribute],
copyMode: CopyMode,
child: SparkPlan,
override val serializedPlanOpt: SerializedPlan)
extends CometUnaryExec {
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
this.copy(child = newChild)

override def verboseStringWithOperatorId(): String = {
s"""
|$formattedNodeName
|${ExplainUtils.generateFieldString("Input", child.output)}
|""".stripMargin
}
}
Loading