diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index 4dfc590455..15d86bea18 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -136,7 +136,7 @@ object CometArrayAppend extends CometExpressionSerde with IncompatExpr { } } -object CometArrayContains extends CometExpressionSerde with IncompatExpr { +object CometArrayContains extends CometExpressionSerde { override def convert( expr: Expression, inputs: Seq[Attribute], diff --git a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala index 0be89c5124..60d495abd0 100644 --- a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala @@ -25,7 +25,7 @@ import scala.util.Random import org.apache.hadoop.fs.Path import org.apache.spark.sql.CometTestBase import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper -import org.apache.spark.sql.functions.{array, col, expr, lit, udf} +import org.apache.spark.sql.functions._ import org.apache.comet.CometSparkSessionExtensions.{isSpark35Plus, isSpark40Plus} import org.apache.comet.serde.CometArrayExcept @@ -218,7 +218,7 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp } } - test("array_contains") { + test("array_contains - int values") { withSQLConf(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") { withTempDir { dir => val path = new Path(dir.toURI.toString, "test.parquet") @@ -232,6 +232,78 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp } } + test("array_contains - test all types (native Parquet reader)") { + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + val filename = path.toString + val random = new Random(42) + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + ParquetGenerator.makeParquetFile( + random, + spark, + filename, + 100, + DataGenOptions( + allowNull = true, + generateNegativeZero = true, + generateArray = false, + generateStruct = false, + generateMap = false)) + } + val table = spark.read.parquet(filename) + table.createOrReplaceTempView("t1") + for (field <- table.schema.fields) { + val fieldName = field.name + val typeName = field.dataType.typeName + sql(s"SELECT array($fieldName, $fieldName) as a, $fieldName as b FROM t1") + .createOrReplaceTempView("t2") + checkSparkAnswerAndOperator(sql("SELECT array_contains(a, b) FROM t2")) + checkSparkAnswerAndOperator( + sql(s"SELECT array_contains(a, cast(null as $typeName)) FROM t2")) + checkSparkAnswerAndOperator( + sql(s"SELECT array_contains(cast(null as array<$typeName>), b) FROM t2")) + checkSparkAnswerAndOperator(sql( + s"SELECT array_contains(cast(array() as array<$typeName>), cast(null as $typeName)) FROM t2")) + checkSparkAnswerAndOperator(sql(s"SELECT array_contains(array(), 1) FROM t2")) + } + } + } + + test("array_contains - test all types (convert from Parquet)") { + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + val filename = path.toString + val random = new Random(42) + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + ParquetGenerator.makeParquetFile( + random, + spark, + filename, + 100, + DataGenOptions( + allowNull = true, + generateNegativeZero = true, + generateArray = true, + generateStruct = true, + generateMap = false)) + } + withSQLConf( + CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "false", + CometConf.COMET_SPARK_TO_ARROW_ENABLED.key -> "true", + CometConf.COMET_CONVERT_FROM_PARQUET_ENABLED.key -> "true") { + val table = spark.read.parquet(filename) + table.createOrReplaceTempView("t1") + for (field <- table.schema.fields) { + val fieldName = field.name + val typeName = field.dataType.typeName + sql(s"SELECT array($fieldName, $fieldName) as a, $fieldName as b FROM t1") + .createOrReplaceTempView("t2") + checkSparkAnswer(sql("SELECT array_contains(a, b) FROM t2")) + } + } + } + } + test("array_distinct") { withSQLConf(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") { Seq(true, false).foreach { dictionaryEnabled =>