@@ -25,6 +25,7 @@ import org.apache.spark.sql.SparkSession
25
25
import org .apache .spark .sql .catalyst .expressions .{Divide , DoubleLiteral , EqualNullSafe , EqualTo , Expression , FloatLiteral , GreaterThan , GreaterThanOrEqual , KnownFloatingPointNormalized , LessThan , LessThanOrEqual , NamedExpression , Remainder }
26
26
import org .apache .spark .sql .catalyst .expressions .aggregate .{Final , Partial }
27
27
import org .apache .spark .sql .catalyst .optimizer .NormalizeNaNAndZero
28
+ import org .apache .spark .sql .catalyst .plans .physical .{HashPartitioning , RangePartitioning , RoundRobinPartitioning , SinglePartition }
28
29
import org .apache .spark .sql .catalyst .rules .Rule
29
30
import org .apache .spark .sql .comet ._
30
31
import org .apache .spark .sql .comet .execution .shuffle .{CometColumnarShuffle , CometNativeShuffle , CometShuffleExchangeExec }
@@ -34,13 +35,15 @@ import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregat
34
35
import org .apache .spark .sql .execution .exchange .{BroadcastExchangeExec , ReusedExchangeExec , ShuffleExchangeExec }
35
36
import org .apache .spark .sql .execution .joins .{BroadcastHashJoinExec , ShuffledHashJoinExec , SortMergeJoinExec }
36
37
import org .apache .spark .sql .execution .window .WindowExec
37
- import org .apache .spark .sql .types .{DoubleType , FloatType }
38
+ import org .apache .spark .sql .internal .SQLConf
39
+ import org .apache .spark .sql .types .{ArrayType , BinaryType , BooleanType , ByteType , DataType , DateType , DecimalType , DoubleType , FloatType , IntegerType , LongType , MapType , ShortType , StringType , StructType , TimestampNTZType , TimestampType }
38
40
39
41
import org .apache .comet .{CometConf , ExtendedExplainInfo }
40
42
import org .apache .comet .CometConf .COMET_ANSI_MODE_ENABLED
41
43
import org .apache .comet .CometSparkSessionExtensions ._
42
44
import org .apache .comet .serde .OperatorOuterClass .Operator
43
45
import org .apache .comet .serde .QueryPlanSerde
46
+ import org .apache .comet .serde .QueryPlanSerde .emitWarning
44
47
45
48
/**
46
49
* Spark physical optimizer rule for replacing Spark operators with Comet operators.
@@ -53,7 +56,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
53
56
plan.transformUp {
54
57
case s : ShuffleExchangeExec
55
58
if isCometPlan(s.child) && isCometNativeShuffleMode(conf) &&
56
- QueryPlanSerde . nativeShuffleSupported(s)._1 =>
59
+ nativeShuffleSupported(s)._1 =>
57
60
logInfo(" Comet extension enabled for Native Shuffle" )
58
61
59
62
// Switch to use Decimal128 regardless of precision, since Arrow native execution
@@ -65,7 +68,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
65
68
// (if configured)
66
69
case s : ShuffleExchangeExec
67
70
if (! s.child.supportsColumnar || isCometPlan(s.child)) && isCometJVMShuffleMode(conf) &&
68
- QueryPlanSerde . columnarShuffleSupported(s)._1 &&
71
+ columnarShuffleSupported(s)._1 &&
69
72
! isShuffleOperator(s.child) =>
70
73
logInfo(" Comet extension enabled for JVM Columnar Shuffle" )
71
74
CometShuffleExchangeExec (s, shuffleType = CometColumnarShuffle )
@@ -490,7 +493,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
490
493
case s : ShuffleExchangeExec =>
491
494
val nativePrecondition = isCometShuffleEnabled(conf) &&
492
495
isCometNativeShuffleMode(conf) &&
493
- QueryPlanSerde . nativeShuffleSupported(s)._1
496
+ nativeShuffleSupported(s)._1
494
497
495
498
val nativeShuffle : Option [SparkPlan ] =
496
499
if (nativePrecondition) {
@@ -517,7 +520,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
517
520
// If the child of ShuffleExchangeExec is also a ShuffleExchangeExec, we should not
518
521
// convert it to CometColumnarShuffle,
519
522
if (isCometShuffleEnabled(conf) && isCometJVMShuffleMode(conf) &&
520
- QueryPlanSerde . columnarShuffleSupported(s)._1 &&
523
+ columnarShuffleSupported(s)._1 &&
521
524
! isShuffleOperator(s.child)) {
522
525
523
526
val newOp = QueryPlanSerde .operator2Proto(s)
@@ -547,18 +550,12 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
547
550
val msg1 = createMessage(! isShuffleEnabled, s " Comet shuffle is not enabled: $reason" )
548
551
val columnarShuffleEnabled = isCometJVMShuffleMode(conf)
549
552
val msg2 = createMessage(
550
- isShuffleEnabled && ! columnarShuffleEnabled && ! QueryPlanSerde
551
- .nativeShuffleSupported(s)
552
- ._1,
553
+ isShuffleEnabled && ! columnarShuffleEnabled && ! nativeShuffleSupported(s)._1,
553
554
" Native shuffle: " +
554
- s " ${QueryPlanSerde .nativeShuffleSupported(s)._2}" )
555
- val typeInfo = QueryPlanSerde
556
- .columnarShuffleSupported(s)
557
- ._2
555
+ s " ${nativeShuffleSupported(s)._2}" )
556
+ val typeInfo = columnarShuffleSupported(s)._2
558
557
val msg3 = createMessage(
559
- isShuffleEnabled && columnarShuffleEnabled && ! QueryPlanSerde
560
- .columnarShuffleSupported(s)
561
- ._1,
558
+ isShuffleEnabled && columnarShuffleEnabled && ! columnarShuffleSupported(s)._1,
562
559
" JVM shuffle: " +
563
560
s " $typeInfo" )
564
561
withInfo(s, Seq (msg1, msg2, msg3).flatten.mkString(" ," ))
@@ -578,7 +575,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
578
575
}
579
576
}
580
577
581
- def normalizePlan (plan : SparkPlan ): SparkPlan = {
578
+ private def normalizePlan (plan : SparkPlan ): SparkPlan = {
582
579
plan.transformUp {
583
580
case p : ProjectExec =>
584
581
val newProjectList = p.projectList.map(normalize(_).asInstanceOf [NamedExpression ])
@@ -595,7 +592,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
595
592
// because they are handled well in Spark (e.g., `SQLOrderingUtil.compareFloats`). But the
596
593
// comparison functions in arrow-rs do not normalize NaN and zero. So we need to normalize NaN
597
594
// and zero for comparison operators in Comet.
598
- def normalize (expr : Expression ): Expression = {
595
+ private def normalize (expr : Expression ): Expression = {
599
596
expr.transformUp {
600
597
case EqualTo (left, right) =>
601
598
EqualTo (normalizeNaNAndZero(left), normalizeNaNAndZero(right))
@@ -616,7 +613,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
616
613
}
617
614
}
618
615
619
- def normalizeNaNAndZero (expr : Expression ): Expression = {
616
+ private def normalizeNaNAndZero (expr : Expression ): Expression = {
620
617
expr match {
621
618
case _ : KnownFloatingPointNormalized => expr
622
619
case FloatLiteral (f) if ! f.equals(- 0.0f ) => expr
@@ -755,7 +752,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
755
752
* Find the first Comet partial aggregate in the plan. If it reaches a Spark HashAggregate with
756
753
* partial mode, it will return None.
757
754
*/
758
- def findCometPartialAgg (plan : SparkPlan ): Option [CometHashAggregateExec ] = {
755
+ private def findCometPartialAgg (plan : SparkPlan ): Option [CometHashAggregateExec ] = {
759
756
plan.collectFirst {
760
757
case agg : CometHashAggregateExec if agg.aggregateExpressions.forall(_.mode == Partial ) =>
761
758
Some (agg)
@@ -770,12 +767,147 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
770
767
/**
771
768
* Returns true if a given spark plan is Comet shuffle operator.
772
769
*/
773
- def isShuffleOperator (op : SparkPlan ): Boolean = {
770
+ private def isShuffleOperator (op : SparkPlan ): Boolean = {
774
771
op match {
775
772
case op : ShuffleQueryStageExec if op.plan.isInstanceOf [CometShuffleExchangeExec ] => true
776
773
case _ : CometShuffleExchangeExec => true
777
774
case op : CometSinkPlaceHolder => isShuffleOperator(op.child)
778
775
case _ => false
779
776
}
780
777
}
778
+
779
+ /**
780
+ * Whether the given Spark partitioning is supported by Comet native shuffle.
781
+ */
782
+ private def nativeShuffleSupported (s : ShuffleExchangeExec ): (Boolean , String ) = {
783
+
784
+ /**
785
+ * Determine which data types are supported as hash-partition keys in native shuffle.
786
+ *
787
+ * Hash Partition Key determines how data should be collocated for operations like
788
+ * `groupByKey`, `reduceByKey` or `join`.
789
+ */
790
+ def supportedHashPartitionKeyDataType (dt : DataType ): Boolean = dt match {
791
+ case _ : BooleanType | _ : ByteType | _ : ShortType | _ : IntegerType | _ : LongType |
792
+ _ : FloatType | _ : DoubleType | _ : StringType | _ : BinaryType | _ : TimestampType |
793
+ _ : TimestampNTZType | _ : DecimalType | _ : DateType =>
794
+ true
795
+ case _ =>
796
+ false
797
+ }
798
+
799
+ val inputs = s.child.output
800
+ val partitioning = s.outputPartitioning
801
+ val conf = SQLConf .get
802
+ var msg = " "
803
+ val supported = partitioning match {
804
+ case HashPartitioning (expressions, _) =>
805
+ // native shuffle currently does not support complex types as partition keys
806
+ // due to lack of hashing support for those types
807
+ val supported =
808
+ expressions.map(QueryPlanSerde .exprToProto(_, inputs)).forall(_.isDefined) &&
809
+ expressions.forall(e => supportedHashPartitionKeyDataType(e.dataType)) &&
810
+ inputs.forall(attr => supportedShuffleDataType(attr.dataType)) &&
811
+ CometConf .COMET_EXEC_SHUFFLE_WITH_HASH_PARTITIONING_ENABLED .get(conf)
812
+ if (! supported) {
813
+ msg = s " unsupported Spark partitioning: $expressions"
814
+ }
815
+ supported
816
+ case SinglePartition =>
817
+ inputs.forall(attr => supportedShuffleDataType(attr.dataType))
818
+ case RangePartitioning (ordering, _) =>
819
+ val supported = ordering.map(QueryPlanSerde .exprToProto(_, inputs)).forall(_.isDefined) &&
820
+ inputs.forall(attr => supportedShuffleDataType(attr.dataType)) &&
821
+ CometConf .COMET_EXEC_SHUFFLE_WITH_RANGE_PARTITIONING_ENABLED .get(conf)
822
+ if (! supported) {
823
+ msg = s " unsupported Spark partitioning: $ordering"
824
+ }
825
+ supported
826
+ case _ =>
827
+ msg = s " unsupported Spark partitioning: ${partitioning.getClass.getName}"
828
+ false
829
+ }
830
+
831
+ if (! supported) {
832
+ emitWarning(msg)
833
+ (false , msg)
834
+ } else {
835
+ (true , null )
836
+ }
837
+ }
838
+
839
+ /**
840
+ * Check if the datatypes of shuffle input are supported. This is used for Columnar shuffle
841
+ * which supports struct/array.
842
+ */
843
+ private def columnarShuffleSupported (s : ShuffleExchangeExec ): (Boolean , String ) = {
844
+ val inputs = s.child.output
845
+ val partitioning = s.outputPartitioning
846
+ var msg = " "
847
+ val supported = partitioning match {
848
+ case HashPartitioning (expressions, _) =>
849
+ // columnar shuffle supports the same data types (including complex types) both for
850
+ // partition keys and for other columns
851
+ val supported =
852
+ expressions.map(QueryPlanSerde .exprToProto(_, inputs)).forall(_.isDefined) &&
853
+ expressions.forall(e => supportedShuffleDataType(e.dataType)) &&
854
+ inputs.forall(attr => supportedShuffleDataType(attr.dataType))
855
+ if (! supported) {
856
+ msg = s " unsupported Spark partitioning expressions: $expressions"
857
+ }
858
+ supported
859
+ case SinglePartition =>
860
+ inputs.forall(attr => supportedShuffleDataType(attr.dataType))
861
+ case RoundRobinPartitioning (_) =>
862
+ inputs.forall(attr => supportedShuffleDataType(attr.dataType))
863
+ case RangePartitioning (orderings, _) =>
864
+ val supported =
865
+ orderings.map(QueryPlanSerde .exprToProto(_, inputs)).forall(_.isDefined) &&
866
+ orderings.forall(e => supportedShuffleDataType(e.dataType)) &&
867
+ inputs.forall(attr => supportedShuffleDataType(attr.dataType))
868
+ if (! supported) {
869
+ msg = s " unsupported Spark partitioning expressions: $orderings"
870
+ }
871
+ supported
872
+ case _ =>
873
+ msg = s " unsupported Spark partitioning: ${partitioning.getClass.getName}"
874
+ false
875
+ }
876
+
877
+ if (! supported) {
878
+ emitWarning(msg)
879
+ (false , msg)
880
+ } else {
881
+ (true , null )
882
+ }
883
+ }
884
+
885
+ /**
886
+ * Determine which data types are supported in a shuffle.
887
+ */
888
+ private def supportedShuffleDataType (dt : DataType ): Boolean = dt match {
889
+ case _ : BooleanType | _ : ByteType | _ : ShortType | _ : IntegerType | _ : LongType |
890
+ _ : FloatType | _ : DoubleType | _ : StringType | _ : BinaryType | _ : TimestampType |
891
+ _ : TimestampNTZType | _ : DecimalType | _ : DateType =>
892
+ true
893
+ case StructType (fields) =>
894
+ fields.forall(f => supportedShuffleDataType(f.dataType)) &&
895
+ // Java Arrow stream reader cannot work on duplicate field name
896
+ fields.map(f => f.name).distinct.length == fields.length
897
+ case ArrayType (ArrayType (_, _), _) => false // TODO: nested array is not supported
898
+ case ArrayType (MapType (_, _, _), _) => false // TODO: map array element is not supported
899
+ case ArrayType (elementType, _) =>
900
+ supportedShuffleDataType(elementType)
901
+ case MapType (MapType (_, _, _), _, _) => false // TODO: nested map is not supported
902
+ case MapType (_, MapType (_, _, _), _) => false
903
+ case MapType (StructType (_), _, _) => false // TODO: struct map key/value is not supported
904
+ case MapType (_, StructType (_), _) => false
905
+ case MapType (ArrayType (_, _), _, _) => false // TODO: array map key/value is not supported
906
+ case MapType (_, ArrayType (_, _), _) => false
907
+ case MapType (keyType, valueType, _) =>
908
+ supportedShuffleDataType(keyType) && supportedShuffleDataType(valueType)
909
+ case _ =>
910
+ false
911
+ }
912
+
781
913
}
0 commit comments