Skip to content

Commit d0ca414

Browse files
authored
minor: Refactor to move some shuffle-related logic from QueryPlanSerde to CometExecRule (#2015)
1 parent 3654973 commit d0ca414

File tree

2 files changed

+152
-155
lines changed

2 files changed

+152
-155
lines changed

spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala

Lines changed: 152 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import org.apache.spark.sql.SparkSession
2525
import org.apache.spark.sql.catalyst.expressions.{Divide, DoubleLiteral, EqualNullSafe, EqualTo, Expression, FloatLiteral, GreaterThan, GreaterThanOrEqual, KnownFloatingPointNormalized, LessThan, LessThanOrEqual, NamedExpression, Remainder}
2626
import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial}
2727
import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero
28+
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition}
2829
import org.apache.spark.sql.catalyst.rules.Rule
2930
import org.apache.spark.sql.comet._
3031
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle, CometShuffleExchangeExec}
@@ -34,13 +35,15 @@ import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregat
3435
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
3536
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec}
3637
import org.apache.spark.sql.execution.window.WindowExec
37-
import org.apache.spark.sql.types.{DoubleType, FloatType}
38+
import org.apache.spark.sql.internal.SQLConf
39+
import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, ShortType, StringType, StructType, TimestampNTZType, TimestampType}
3840

3941
import org.apache.comet.{CometConf, ExtendedExplainInfo}
4042
import org.apache.comet.CometConf.COMET_ANSI_MODE_ENABLED
4143
import org.apache.comet.CometSparkSessionExtensions._
4244
import org.apache.comet.serde.OperatorOuterClass.Operator
4345
import org.apache.comet.serde.QueryPlanSerde
46+
import org.apache.comet.serde.QueryPlanSerde.emitWarning
4447

4548
/**
4649
* Spark physical optimizer rule for replacing Spark operators with Comet operators.
@@ -53,7 +56,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
5356
plan.transformUp {
5457
case s: ShuffleExchangeExec
5558
if isCometPlan(s.child) && isCometNativeShuffleMode(conf) &&
56-
QueryPlanSerde.nativeShuffleSupported(s)._1 =>
59+
nativeShuffleSupported(s)._1 =>
5760
logInfo("Comet extension enabled for Native Shuffle")
5861

5962
// Switch to use Decimal128 regardless of precision, since Arrow native execution
@@ -65,7 +68,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
6568
// (if configured)
6669
case s: ShuffleExchangeExec
6770
if (!s.child.supportsColumnar || isCometPlan(s.child)) && isCometJVMShuffleMode(conf) &&
68-
QueryPlanSerde.columnarShuffleSupported(s)._1 &&
71+
columnarShuffleSupported(s)._1 &&
6972
!isShuffleOperator(s.child) =>
7073
logInfo("Comet extension enabled for JVM Columnar Shuffle")
7174
CometShuffleExchangeExec(s, shuffleType = CometColumnarShuffle)
@@ -490,7 +493,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
490493
case s: ShuffleExchangeExec =>
491494
val nativePrecondition = isCometShuffleEnabled(conf) &&
492495
isCometNativeShuffleMode(conf) &&
493-
QueryPlanSerde.nativeShuffleSupported(s)._1
496+
nativeShuffleSupported(s)._1
494497

495498
val nativeShuffle: Option[SparkPlan] =
496499
if (nativePrecondition) {
@@ -517,7 +520,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
517520
// If the child of ShuffleExchangeExec is also a ShuffleExchangeExec, we should not
518521
// convert it to CometColumnarShuffle,
519522
if (isCometShuffleEnabled(conf) && isCometJVMShuffleMode(conf) &&
520-
QueryPlanSerde.columnarShuffleSupported(s)._1 &&
523+
columnarShuffleSupported(s)._1 &&
521524
!isShuffleOperator(s.child)) {
522525

523526
val newOp = QueryPlanSerde.operator2Proto(s)
@@ -547,18 +550,12 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
547550
val msg1 = createMessage(!isShuffleEnabled, s"Comet shuffle is not enabled: $reason")
548551
val columnarShuffleEnabled = isCometJVMShuffleMode(conf)
549552
val msg2 = createMessage(
550-
isShuffleEnabled && !columnarShuffleEnabled && !QueryPlanSerde
551-
.nativeShuffleSupported(s)
552-
._1,
553+
isShuffleEnabled && !columnarShuffleEnabled && !nativeShuffleSupported(s)._1,
553554
"Native shuffle: " +
554-
s"${QueryPlanSerde.nativeShuffleSupported(s)._2}")
555-
val typeInfo = QueryPlanSerde
556-
.columnarShuffleSupported(s)
557-
._2
555+
s"${nativeShuffleSupported(s)._2}")
556+
val typeInfo = columnarShuffleSupported(s)._2
558557
val msg3 = createMessage(
559-
isShuffleEnabled && columnarShuffleEnabled && !QueryPlanSerde
560-
.columnarShuffleSupported(s)
561-
._1,
558+
isShuffleEnabled && columnarShuffleEnabled && !columnarShuffleSupported(s)._1,
562559
"JVM shuffle: " +
563560
s"$typeInfo")
564561
withInfo(s, Seq(msg1, msg2, msg3).flatten.mkString(","))
@@ -578,7 +575,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
578575
}
579576
}
580577

581-
def normalizePlan(plan: SparkPlan): SparkPlan = {
578+
private def normalizePlan(plan: SparkPlan): SparkPlan = {
582579
plan.transformUp {
583580
case p: ProjectExec =>
584581
val newProjectList = p.projectList.map(normalize(_).asInstanceOf[NamedExpression])
@@ -595,7 +592,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
595592
// because they are handled well in Spark (e.g., `SQLOrderingUtil.compareFloats`). But the
596593
// comparison functions in arrow-rs do not normalize NaN and zero. So we need to normalize NaN
597594
// and zero for comparison operators in Comet.
598-
def normalize(expr: Expression): Expression = {
595+
private def normalize(expr: Expression): Expression = {
599596
expr.transformUp {
600597
case EqualTo(left, right) =>
601598
EqualTo(normalizeNaNAndZero(left), normalizeNaNAndZero(right))
@@ -616,7 +613,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
616613
}
617614
}
618615

619-
def normalizeNaNAndZero(expr: Expression): Expression = {
616+
private def normalizeNaNAndZero(expr: Expression): Expression = {
620617
expr match {
621618
case _: KnownFloatingPointNormalized => expr
622619
case FloatLiteral(f) if !f.equals(-0.0f) => expr
@@ -755,7 +752,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
755752
* Find the first Comet partial aggregate in the plan. If it reaches a Spark HashAggregate with
756753
* partial mode, it will return None.
757754
*/
758-
def findCometPartialAgg(plan: SparkPlan): Option[CometHashAggregateExec] = {
755+
private def findCometPartialAgg(plan: SparkPlan): Option[CometHashAggregateExec] = {
759756
plan.collectFirst {
760757
case agg: CometHashAggregateExec if agg.aggregateExpressions.forall(_.mode == Partial) =>
761758
Some(agg)
@@ -770,12 +767,147 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
770767
/**
771768
* Returns true if a given spark plan is Comet shuffle operator.
772769
*/
773-
def isShuffleOperator(op: SparkPlan): Boolean = {
770+
private def isShuffleOperator(op: SparkPlan): Boolean = {
774771
op match {
775772
case op: ShuffleQueryStageExec if op.plan.isInstanceOf[CometShuffleExchangeExec] => true
776773
case _: CometShuffleExchangeExec => true
777774
case op: CometSinkPlaceHolder => isShuffleOperator(op.child)
778775
case _ => false
779776
}
780777
}
778+
779+
/**
780+
* Whether the given Spark partitioning is supported by Comet native shuffle.
781+
*/
782+
private def nativeShuffleSupported(s: ShuffleExchangeExec): (Boolean, String) = {
783+
784+
/**
785+
* Determine which data types are supported as hash-partition keys in native shuffle.
786+
*
787+
* Hash Partition Key determines how data should be collocated for operations like
788+
* `groupByKey`, `reduceByKey` or `join`.
789+
*/
790+
def supportedHashPartitionKeyDataType(dt: DataType): Boolean = dt match {
791+
case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | _: LongType |
792+
_: FloatType | _: DoubleType | _: StringType | _: BinaryType | _: TimestampType |
793+
_: TimestampNTZType | _: DecimalType | _: DateType =>
794+
true
795+
case _ =>
796+
false
797+
}
798+
799+
val inputs = s.child.output
800+
val partitioning = s.outputPartitioning
801+
val conf = SQLConf.get
802+
var msg = ""
803+
val supported = partitioning match {
804+
case HashPartitioning(expressions, _) =>
805+
// native shuffle currently does not support complex types as partition keys
806+
// due to lack of hashing support for those types
807+
val supported =
808+
expressions.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined) &&
809+
expressions.forall(e => supportedHashPartitionKeyDataType(e.dataType)) &&
810+
inputs.forall(attr => supportedShuffleDataType(attr.dataType)) &&
811+
CometConf.COMET_EXEC_SHUFFLE_WITH_HASH_PARTITIONING_ENABLED.get(conf)
812+
if (!supported) {
813+
msg = s"unsupported Spark partitioning: $expressions"
814+
}
815+
supported
816+
case SinglePartition =>
817+
inputs.forall(attr => supportedShuffleDataType(attr.dataType))
818+
case RangePartitioning(ordering, _) =>
819+
val supported = ordering.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined) &&
820+
inputs.forall(attr => supportedShuffleDataType(attr.dataType)) &&
821+
CometConf.COMET_EXEC_SHUFFLE_WITH_RANGE_PARTITIONING_ENABLED.get(conf)
822+
if (!supported) {
823+
msg = s"unsupported Spark partitioning: $ordering"
824+
}
825+
supported
826+
case _ =>
827+
msg = s"unsupported Spark partitioning: ${partitioning.getClass.getName}"
828+
false
829+
}
830+
831+
if (!supported) {
832+
emitWarning(msg)
833+
(false, msg)
834+
} else {
835+
(true, null)
836+
}
837+
}
838+
839+
/**
840+
* Check if the datatypes of shuffle input are supported. This is used for Columnar shuffle
841+
* which supports struct/array.
842+
*/
843+
private def columnarShuffleSupported(s: ShuffleExchangeExec): (Boolean, String) = {
844+
val inputs = s.child.output
845+
val partitioning = s.outputPartitioning
846+
var msg = ""
847+
val supported = partitioning match {
848+
case HashPartitioning(expressions, _) =>
849+
// columnar shuffle supports the same data types (including complex types) both for
850+
// partition keys and for other columns
851+
val supported =
852+
expressions.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined) &&
853+
expressions.forall(e => supportedShuffleDataType(e.dataType)) &&
854+
inputs.forall(attr => supportedShuffleDataType(attr.dataType))
855+
if (!supported) {
856+
msg = s"unsupported Spark partitioning expressions: $expressions"
857+
}
858+
supported
859+
case SinglePartition =>
860+
inputs.forall(attr => supportedShuffleDataType(attr.dataType))
861+
case RoundRobinPartitioning(_) =>
862+
inputs.forall(attr => supportedShuffleDataType(attr.dataType))
863+
case RangePartitioning(orderings, _) =>
864+
val supported =
865+
orderings.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined) &&
866+
orderings.forall(e => supportedShuffleDataType(e.dataType)) &&
867+
inputs.forall(attr => supportedShuffleDataType(attr.dataType))
868+
if (!supported) {
869+
msg = s"unsupported Spark partitioning expressions: $orderings"
870+
}
871+
supported
872+
case _ =>
873+
msg = s"unsupported Spark partitioning: ${partitioning.getClass.getName}"
874+
false
875+
}
876+
877+
if (!supported) {
878+
emitWarning(msg)
879+
(false, msg)
880+
} else {
881+
(true, null)
882+
}
883+
}
884+
885+
/**
886+
* Determine which data types are supported in a shuffle.
887+
*/
888+
private def supportedShuffleDataType(dt: DataType): Boolean = dt match {
889+
case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | _: LongType |
890+
_: FloatType | _: DoubleType | _: StringType | _: BinaryType | _: TimestampType |
891+
_: TimestampNTZType | _: DecimalType | _: DateType =>
892+
true
893+
case StructType(fields) =>
894+
fields.forall(f => supportedShuffleDataType(f.dataType)) &&
895+
// Java Arrow stream reader cannot work on duplicate field name
896+
fields.map(f => f.name).distinct.length == fields.length
897+
case ArrayType(ArrayType(_, _), _) => false // TODO: nested array is not supported
898+
case ArrayType(MapType(_, _, _), _) => false // TODO: map array element is not supported
899+
case ArrayType(elementType, _) =>
900+
supportedShuffleDataType(elementType)
901+
case MapType(MapType(_, _, _), _, _) => false // TODO: nested map is not supported
902+
case MapType(_, MapType(_, _, _), _) => false
903+
case MapType(StructType(_), _, _) => false // TODO: struct map key/value is not supported
904+
case MapType(_, StructType(_), _) => false
905+
case MapType(ArrayType(_, _), _, _) => false // TODO: array map key/value is not supported
906+
case MapType(_, ArrayType(_, _), _) => false
907+
case MapType(keyType, valueType, _) =>
908+
supportedShuffleDataType(keyType) && supportedShuffleDataType(valueType)
909+
case _ =>
910+
false
911+
}
912+
781913
}

0 commit comments

Comments
 (0)