Skip to content

fix: Support auto scan mode with Spark 4.0.0 #1975

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,10 @@ import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream}

import org.apache.comet.shims.CometTypeShim
import org.apache.comet.vector.CometVector

object Utils {
object Utils extends CometTypeShim {
def getConfPath(confFileName: String): String = {
sys.env
.get("COMET_CONF_DIR")
Expand Down Expand Up @@ -124,7 +125,10 @@ object Utils {
case LongType => new ArrowType.Int(8 * 8, true)
case FloatType => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)
case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)
case StringType => ArrowType.Utf8.INSTANCE
case _: StringType => ArrowType.Utf8.INSTANCE
case dt if isStringCollationType(dt) =>
// TODO collation information is lost with this transformation
ArrowType.Utf8.INSTANCE
case BinaryType => ArrowType.Binary.INSTANCE
case DecimalType.Fixed(precision, scale) => new ArrowType.Decimal(precision, scale, 128)
case DateType => new ArrowType.Date(DateUnit.DAY)
Expand All @@ -138,7 +142,8 @@ object Utils {
case TimestampNTZType =>
new ArrowType.Timestamp(TimeUnit.MICROSECOND, null)
case _ =>
throw new UnsupportedOperationException(s"Unsupported data type: ${dt.catalogString}")
throw new UnsupportedOperationException(
s"Unsupported data type: [${dt.getClass.getName}] ${dt.catalogString}")
}

/** Maps field from Spark to Arrow. NOTE: timeZoneId required for TimestampType */
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.comet.shims

import org.apache.spark.sql.types.DataType

trait CometTypeShim {
def isStringCollationType(dt: DataType): Boolean = false
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.comet.shims

import org.apache.spark.sql.internal.types.StringTypeWithCollation
import org.apache.spark.sql.types.DataType

trait CometTypeShim {
def isStringCollationType(dt: DataType): Boolean = dt.isInstanceOf[StringTypeWithCollation]
}
50 changes: 38 additions & 12 deletions dev/diffs/4.0.0.diff
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
diff --git a/pom.xml b/pom.xml
index 443d46a4302..3b8483173f1 100644
index 443d46a4302..63ec4784625 100644
--- a/pom.xml
+++ b/pom.xml
@@ -148,6 +148,8 @@
Expand Down Expand Up @@ -700,10 +700,10 @@ index 9c529d14221..069b7c5adeb 100644
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/IgnoreComet.scala b/sql/core/src/test/scala/org/apache/spark/sql/IgnoreComet.scala
new file mode 100644
index 00000000000..5eb3fa17ca8
index 00000000000..5691536c114
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/IgnoreComet.scala
@@ -0,0 +1,43 @@
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
Expand Down Expand Up @@ -732,6 +732,8 @@ index 00000000000..5eb3fa17ca8
+ * Tests with this tag will be ignored when Comet is enabled (e.g., via `ENABLE_COMET`).
+ */
+case class IgnoreComet(reason: String) extends Tag("DisableComet")
+case class IgnoreCometNativeIcebergCompat(reason: String) extends Tag("DisableComet")
+case class IgnoreCometNativeDataFusion(reason: String) extends Tag("DisableComet")
+case class IgnoreCometNativeScan(reason: String) extends Tag("DisableComet")
+
+/**
Expand Down Expand Up @@ -3309,40 +3311,64 @@ index 86c4e49f6f6..2e639e5f38d 100644
val tblTargetName = "tbl_target"
val tblSourceQualified = s"default.$tblSourceName"
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
index f0f3f94b811..486a436afb2 100644
index f0f3f94b811..d64e4e54e22 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
@@ -33,7 +33,7 @@ import org.scalatest.{BeforeAndAfterAll, Suite, Tag}
@@ -27,13 +27,14 @@ import scala.jdk.CollectionConverters._
import scala.language.implicitConversions
import scala.util.control.NonFatal

+import org.apache.comet.CometConf
import org.apache.hadoop.fs.Path
import org.scalactic.source.Position
import org.scalatest.{BeforeAndAfterAll, Suite, Tag}
import org.scalatest.concurrent.Eventually

import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.{AnalysisException, Row}
+import org.apache.spark.sql.{AnalysisException, IgnoreComet, Row}
+import org.apache.spark.sql.{AnalysisException, IgnoreComet, IgnoreCometNativeDataFusion, IgnoreCometNativeIcebergCompat, IgnoreCometNativeScan, Row}
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
import org.apache.spark.sql.catalyst.catalog.SessionCatalog.DEFAULT_DATABASE
@@ -42,6 +42,7 @@ import org.apache.spark.sql.catalyst.plans.PlanTestBase
@@ -42,6 +43,7 @@ import org.apache.spark.sql.catalyst.plans.PlanTestBase
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.classic.{ClassicConversions, ColumnConversions, ColumnNodeToExpressionConverter, DataFrame, Dataset, SparkSession, SQLImplicits}
+import org.apache.spark.sql.comet._
+import org.apache.spark.sql.comet.{CometFilterExec, CometProjectExec}
import org.apache.spark.sql.execution.FilterExec
import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecution
import org.apache.spark.sql.execution.datasources.DataSourceUtils
@@ -128,7 +129,11 @@ private[sql] trait SQLTestUtils extends SparkFunSuite with SQLTestUtilsBase with
@@ -128,7 +130,28 @@ private[sql] trait SQLTestUtils extends SparkFunSuite with SQLTestUtilsBase with
}
}
} else {
- super.test(testName, testTags: _*)(testFun)
+ if (isCometEnabled && testTags.exists(_.isInstanceOf[IgnoreComet])) {
+ ignore(testName + " (disabled when Comet is on)", testTags: _*)(testFun)
+ } else {
+ super.test(testName, testTags: _*)(testFun)
+ val cometScanImpl = CometConf.COMET_NATIVE_SCAN_IMPL.get(conf)
+ val isNativeIcebergCompat = cometScanImpl == CometConf.SCAN_NATIVE_ICEBERG_COMPAT ||
+ cometScanImpl == CometConf.SCAN_AUTO
+ val isNativeDataFusion = cometScanImpl == CometConf.SCAN_NATIVE_DATAFUSION ||
+ cometScanImpl == CometConf.SCAN_AUTO
+ if (isCometEnabled && isNativeIcebergCompat &&
+ testTags.exists(_.isInstanceOf[IgnoreCometNativeIcebergCompat])) {
+ ignore(testName + " (disabled for NATIVE_ICEBERG_COMPAT)", testTags: _*)(testFun)
+ } else if (isCometEnabled && isNativeDataFusion &&
+ testTags.exists(_.isInstanceOf[IgnoreCometNativeDataFusion])) {
+ ignore(testName + " (disabled for NATIVE_DATAFUSION)", testTags: _*)(testFun)
+ } else if (isCometEnabled && (isNativeDataFusion || isNativeIcebergCompat) &&
+ testTags.exists(_.isInstanceOf[IgnoreCometNativeScan])) {
+ ignore(testName + " (disabled for NATIVE_DATAFUSION and NATIVE_ICEBERG_COMPAT)",
+ testTags: _*)(testFun)
+ } else {
+ super.test(testName, testTags: _*)(testFun)
+ }
+ }
}
}

@@ -248,8 +253,33 @@ private[sql] trait SQLTestUtilsBase
@@ -248,8 +271,33 @@ private[sql] trait SQLTestUtilsBase
override protected def converter: ColumnNodeToExpressionConverter = self.spark.converter
}

Expand Down Expand Up @@ -3376,7 +3402,7 @@ index f0f3f94b811..486a436afb2 100644
super.withSQLConf(pairs: _*)(f)
}

@@ -451,6 +481,8 @@ private[sql] trait SQLTestUtilsBase
@@ -451,6 +499,8 @@ private[sql] trait SQLTestUtilsBase
val schema = df.schema
val withoutFilters = df.queryExecution.executedPlan.transform {
case FilterExec(_, child) => child
Expand Down
29 changes: 16 additions & 13 deletions spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,16 @@ import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

import org.apache.comet.{CometConf, CometSparkSessionExtensions, DataTypeSupport}
import org.apache.comet.{CometConf, DataTypeSupport}
import org.apache.comet.CometConf._
import org.apache.comet.CometSparkSessionExtensions.{isCometLoaded, isCometScanEnabled, withInfo, withInfos}
import org.apache.comet.parquet.{CometParquetScan, SupportsComet}
import org.apache.comet.shims.CometTypeShim

/**
* Spark physical optimizer rule for replacing Spark scans with Comet scans.
*/
case class CometScanRule(session: SparkSession) extends Rule[SparkPlan] {
case class CometScanRule(session: SparkSession) extends Rule[SparkPlan] with CometTypeShim {

private lazy val showTransformations = CometConf.COMET_EXPLAIN_TRANSFORMATIONS.get()

Expand Down Expand Up @@ -261,10 +262,6 @@ case class CometScanRule(session: SparkSession) extends Rule[SparkPlan] {

val fallbackReasons = new ListBuffer[String]()

if (CometSparkSessionExtensions.isSpark40Plus) {
fallbackReasons += s"$SCAN_NATIVE_ICEBERG_COMPAT is not implemented for Spark 4.0.0"
}

// native_iceberg_compat only supports local filesystem and S3
if (!scanExec.relation.inputFiles
.forall(path => path.startsWith("file://") || path.startsWith("s3a://"))) {
Expand All @@ -282,21 +279,27 @@ case class CometScanRule(session: SparkSession) extends Rule[SparkPlan] {
case _ => false
}

def hasMapsContainingStructs(dataType: DataType): Boolean = {
def hasUnsupportedType(dataType: DataType): Boolean = {
dataType match {
case s: StructType => s.exists(field => hasMapsContainingStructs(field.dataType))
case a: ArrayType => hasMapsContainingStructs(a.elementType)
case m: MapType => isComplexType(m.keyType) || isComplexType(m.valueType)
case s: StructType => s.exists(field => hasUnsupportedType(field.dataType))
case a: ArrayType => hasUnsupportedType(a.elementType)
case m: MapType =>
// maps containing complex types are not supported
isComplexType(m.keyType) || isComplexType(m.valueType)
case dt => isStringCollationType(dt)
case _: StringType =>
// we only support `case object StringType` and not other instances of `class StringType`
dataType != StringType
case _ => false
}
}

val knownIssues =
scanExec.requiredSchema.exists(field => hasMapsContainingStructs(field.dataType)) ||
partitionSchema.exists(field => hasMapsContainingStructs(field.dataType))
scanExec.requiredSchema.exists(field => hasUnsupportedType(field.dataType)) ||
partitionSchema.exists(field => hasUnsupportedType(field.dataType))

if (knownIssues) {
fallbackReasons += "There are known issues with maps containing structs when using " +
fallbackReasons += "Schema contains data types that are not supported by " +
s"$SCAN_NATIVE_ICEBERG_COMPAT"
}

Expand Down
Loading