From b37cc387a0ba0f4086509008a1362db4b9d40fb1 Mon Sep 17 00:00:00 2001 From: comphead Date: Wed, 2 Jul 2025 08:16:53 -0700 Subject: [PATCH 01/12] feat: support literal for ARRAY top level --- native/core/src/execution/planner.rs | 15 +++- native/proto/src/proto/expr.proto | 20 ++--- native/proto/src/proto/types.proto | 39 ++++++++++ .../apache/comet/expressions/CometCast.scala | 3 +- .../apache/comet/serde/QueryPlanSerde.scala | 75 ++++++++++++++++++- .../comet/CometArrayExpressionSuite.scala | 15 +++- 6 files changed, 149 insertions(+), 18 deletions(-) create mode 100644 native/proto/src/proto/types.proto diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 327e57bcfd..3d35afeada 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -85,6 +85,8 @@ use datafusion::physical_expr::window::WindowExpr; use datafusion::physical_expr::LexOrdering; use crate::parquet::parquet_exec::init_datasource_exec; +use arrow::array::Int32Array; +use datafusion::common::utils::SingleRowListArrayBuilder; use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion::physical_plan::filter::FilterExec as DataFusionFilterExec; use datafusion_comet_proto::spark_operator::SparkFilePartition; @@ -440,6 +442,18 @@ impl PhysicalPlanner { ))) } } + }, + Value::ListVal(values) => { + dbg!(values); + //dbg!(literal.datatype.as_ref().unwrap()); + //dbg!(data_type); + match data_type { + DataType::List(f) if f.data_type().equals_datatype(&DataType::Int32) => { + SingleRowListArrayBuilder::new(Arc::new(Int32Array::from(values.clone().int_values))) + .build_list_scalar() + } + _ => todo!() + } } } }; @@ -2273,7 +2287,6 @@ impl PhysicalPlanner { other => other, }; let func = self.session_ctx.udf(fun_name)?; - let coerced_types = func .coerce_types(&input_expr_types) .unwrap_or_else(|_| input_expr_types.clone()); diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 8f4c875eec..7e04c7aa55 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -21,6 +21,8 @@ syntax = "proto3"; package spark.spark_expression; +import "types.proto"; + option java_package = "org.apache.comet.serde"; // The basic message representing a Spark expression. @@ -110,13 +112,13 @@ enum StatisticsType { } message Count { - repeated Expr children = 1; + repeated Expr children = 1; } message Sum { - Expr child = 1; - DataType datatype = 2; - bool fail_on_error = 3; + Expr child = 1; + DataType datatype = 2; + bool fail_on_error = 3; } message Min { @@ -213,10 +215,11 @@ message Literal { string string_val = 8; bytes bytes_val = 9; bytes decimal_val = 10; - } + ListLiteral list_val = 11; + } - DataType datatype = 11; - bool is_null = 12; + DataType datatype = 12; + bool is_null = 13; } message MathExpr { @@ -469,5 +472,4 @@ message DataType { } DataTypeInfo type_info = 2; -} - +} \ No newline at end of file diff --git a/native/proto/src/proto/types.proto b/native/proto/src/proto/types.proto new file mode 100644 index 0000000000..1c277c2e3d --- /dev/null +++ b/native/proto/src/proto/types.proto @@ -0,0 +1,39 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + + + +syntax = "proto3"; + +package spark.spark_expression; + +option java_package = "org.apache.comet.serde"; + +message ListLiteral { + // Only one of these fields should be populated based on the array type + repeated bool boolean_values = 1; + repeated int32 byte_values = 2; + repeated int32 short_values = 3; + repeated int32 int_values = 4; + repeated int64 long_values = 5; + repeated float float_values = 6; + repeated double double_values = 7; + repeated string string_values = 8; + repeated bytes bytes_values = 9; + repeated bytes decimal_values = 10; + repeated ListLiteral list_values = 11; +} \ No newline at end of file diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala index e0e89b35fc..6d52824b56 100644 --- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala +++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala @@ -19,7 +19,7 @@ package org.apache.comet.expressions -import org.apache.spark.sql.types.{DataType, DataTypes, DecimalType, StructType} +import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, DecimalType, NullType, StructType} sealed trait SupportLevel @@ -62,6 +62,7 @@ object CometCast { } (fromType, toType) match { + case (dt: ArrayType, _: ArrayType) if dt.elementType == NullType => Compatible() case (dt: DataType, _) if dt.typeName == "timestamp_ntz" => // https://github.com/apache/datafusion-comet/issues/378 toType match { diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 970329b28f..4c574ef5a0 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -31,7 +31,8 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, NormalizeNaNAndZero} import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.util.{CharVarcharCodegenUtils, GenericArrayData} import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.getExistenceDefaultValues import org.apache.spark.sql.comet._ import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec @@ -56,6 +57,7 @@ import org.apache.comet.serde.ExprOuterClass.{AggExpr, DataType => ProtoDataType import org.apache.comet.serde.ExprOuterClass.DataType._ import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, BuildSide, JoinType, Operator} import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto} +import org.apache.comet.serde.Types.ListLiteral import org.apache.comet.shims.CometExprShim /** @@ -142,6 +144,52 @@ object QueryPlanSerde extends Logging with CometExprShim { false } +// def convertArrayToProtoLiteral(array: Seq[Any], arrayType: ArrayType): Literal = { +// val elementType = arrayType.elementType +// val listLiteralBuilder = ListLiteral.newBuilder() +// +// elementType match { +// case BooleanType => +// listLiteralBuilder.addAllBooleanValues(array.map(_.asInstanceOf[Boolean]).asJava) +// +// case ByteType => +// listLiteralBuilder.addAllByteValues(array.map(_.asInstanceOf[Byte].toInt).asJava) +// +// case ShortType => +// listLiteralBuilder.addAllShortValues(array.map(_.asInstanceOf[Short].toInt).asJava) +// +// case IntegerType => +// listLiteralBuilder.addAllIntValues(array.map(_.asInstanceOf[Int]).asJava) +// +// case LongType => +// listLiteralBuilder.addAllLongValues(array.map(_.asInstanceOf[Long]).asJava) +// +// case FloatType => +// listLiteralBuilder.addAllFloatValues(array.map(_.asInstanceOf[Float]).asJava) +// +// case DoubleType => +// listLiteralBuilder.addAllDoubleValues(array.map(_.asInstanceOf[Double]).asJava) +// +// case StringType => +// listLiteralBuilder.addAllStringValues(array.map(_.asInstanceOf[String]).asJava) +// +// case BinaryType => +// listLiteralBuilder.addAllBytesValues + // (array.map(x => com.google.protobuf + // .ByteString.copyFrom(x.asInstanceOf[Array[Byte]])).asJava) +// +// case nested: ArrayType => +// val nestedListLiterals = array.map { +// case null => ListLiteral.newBuilder().build() // or handle nulls appropriately +// case seq: Seq[_] => convertArrayToProtoLiteral(seq, nested).getListVal +// } +// listLiteralBuilder.addAllListValues(nestedListLiterals.asJava) +// +// case _ => +// throw new UnsupportedOperationException(s"Unsupported element type: $elementType") +// } +// } + /** * Serializes Spark datatype to protobuf. Note that, a datatype can be serialized by this method * doesn't mean it is supported by Comet native execution, i.e., `supportedDataType` may return @@ -837,8 +885,7 @@ object QueryPlanSerde extends Logging with CometExprShim { binding, (builder, binaryExpr) => builder.setLtEq(binaryExpr)) - case Literal(value, dataType) - if supportedDataType(dataType, allowComplex = value == null) => + case Literal(value, dataType) if supportedDataType(dataType, allowComplex = true) => val exprBuilder = ExprOuterClass.Literal.newBuilder() if (value == null) { @@ -867,6 +914,28 @@ object QueryPlanSerde extends Logging with CometExprShim { com.google.protobuf.ByteString.copyFrom(value.asInstanceOf[Array[Byte]]) exprBuilder.setBytesVal(byteStr) case _: DateType => exprBuilder.setIntVal(value.asInstanceOf[Int]) + case a: ArrayType => + val listLiteralBuilder = ListLiteral.newBuilder() + val array = value.asInstanceOf[GenericArrayData].array + a.elementType match { + case BooleanType => + listLiteralBuilder.addAllBooleanValues( + array.map(_.asInstanceOf[java.lang.Boolean]).toIterable.asJava) + case ByteType => + listLiteralBuilder.addAllByteValues( + array.map(_.asInstanceOf[java.lang.Integer]).toIterable.asJava) + case ShortType => + listLiteralBuilder.addAllShortValues( + array.map(_.asInstanceOf[java.lang.Integer]).toIterable.asJava) + case IntegerType => + listLiteralBuilder.addAllIntValues( + array.map(_.asInstanceOf[java.lang.Integer]).toIterable.asJava) + case LongType => + listLiteralBuilder.addAllLongValues( + array.map(_.asInstanceOf[java.lang.Long]).toIterable.asJava) + } + exprBuilder.setListVal(listLiteralBuilder.build()) + exprBuilder.setDatatype(serializeDataType(dataType).get) case dt => logWarning(s"Unexpected datatype '$dt' for literal value '$value'") } diff --git a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala index 0be89c5124..fef44fc06a 100644 --- a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala @@ -219,15 +219,22 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp } test("array_contains") { - withSQLConf(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") { + withSQLConf( + CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true", + CometConf.COMET_EXPLAIN_FALLBACK_ENABLED.key -> "true" + // "spark.sql.optimizer.excludedRules" -> "org.apache.spark.sql.catalyst.optimizer.ConstantFolding" + ) { withTempDir { dir => val path = new Path(dir.toURI.toString, "test.parquet") makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = false, n = 10000) spark.read.parquet(path.toString).createOrReplaceTempView("t1"); +// checkSparkAnswerAndOperator( +// spark.sql("SELECT array_contains(array(_2, _3, _4), _2) FROM t1")) +// checkSparkAnswerAndOperator( +// spark.sql("SELECT array_contains((CASE WHEN _2 =_3 THEN array(_4) END), _4) FROM t1")); checkSparkAnswerAndOperator( - spark.sql("SELECT array_contains(array(_2, _3, _4), _2) FROM t1")) - checkSparkAnswerAndOperator( - spark.sql("SELECT array_contains((CASE WHEN _2 =_3 THEN array(_4) END), _4) FROM t1")); + spark.sql( + "SELECT array_contains((CASE WHEN _2 =_3 THEN array(1, 2, 3) END), _4) FROM t1")); } } } From 09d43c8d3c34dcbd93b12e2a33ca778c3963eb06 Mon Sep 17 00:00:00 2001 From: comphead Date: Wed, 2 Jul 2025 08:38:18 -0700 Subject: [PATCH 02/12] feat: support literal for ARRAY top level --- .../org/apache/comet/exec/CometNativeReaderSuite.scala | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometNativeReaderSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometNativeReaderSuite.scala index f33da3ba71..a42e8a6c4c 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometNativeReaderSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometNativeReaderSuite.scala @@ -436,4 +436,12 @@ class CometNativeReaderSuite extends CometTestBase with AdaptiveSparkPlanHelper |""".stripMargin, "select c0['key1'].b from tbl") } + + test("native reader - support ARRAY literal INT fields") { + testSingleLineQuery( + """ + |select 1 a + |""".stripMargin, + "select array(1, 2, 3) from tbl") + } } From 3ce906d970b44c91755f5e3eb2abfc1285cc8265 Mon Sep 17 00:00:00 2001 From: comphead Date: Wed, 2 Jul 2025 08:38:49 -0700 Subject: [PATCH 03/12] feat: support literal for ARRAY top level --- .../org/apache/comet/CometArrayExpressionSuite.scala | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala index fef44fc06a..d94004c729 100644 --- a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala @@ -222,19 +222,15 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp withSQLConf( CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true", CometConf.COMET_EXPLAIN_FALLBACK_ENABLED.key -> "true" - // "spark.sql.optimizer.excludedRules" -> "org.apache.spark.sql.catalyst.optimizer.ConstantFolding" ) { withTempDir { dir => val path = new Path(dir.toURI.toString, "test.parquet") makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = false, n = 10000) spark.read.parquet(path.toString).createOrReplaceTempView("t1"); -// checkSparkAnswerAndOperator( -// spark.sql("SELECT array_contains(array(_2, _3, _4), _2) FROM t1")) -// checkSparkAnswerAndOperator( -// spark.sql("SELECT array_contains((CASE WHEN _2 =_3 THEN array(_4) END), _4) FROM t1")); checkSparkAnswerAndOperator( - spark.sql( - "SELECT array_contains((CASE WHEN _2 =_3 THEN array(1, 2, 3) END), _4) FROM t1")); + spark.sql("SELECT array_contains(array(_2, _3, _4), _2) FROM t1")) + checkSparkAnswerAndOperator( + spark.sql("SELECT array_contains((CASE WHEN _2 =_3 THEN array(_4) END), _4) FROM t1")); } } } From 8aa045b81fcec22513f3a79a97984aca46ba38bd Mon Sep 17 00:00:00 2001 From: comphead Date: Wed, 2 Jul 2025 08:39:46 -0700 Subject: [PATCH 04/12] fixed size list --- native/core/src/execution/planner.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 3d35afeada..93cf6c4bac 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -444,13 +444,15 @@ impl PhysicalPlanner { } }, Value::ListVal(values) => { - dbg!(values); + //dbg!(values); //dbg!(literal.datatype.as_ref().unwrap()); //dbg!(data_type); match data_type { DataType::List(f) if f.data_type().equals_datatype(&DataType::Int32) => { - SingleRowListArrayBuilder::new(Arc::new(Int32Array::from(values.clone().int_values))) - .build_list_scalar() + let vals = values.clone().int_values; + let len = &vals.len(); + SingleRowListArrayBuilder::new(Arc::new(Int32Array::from(vals))) + .build_fixed_size_list_scalar(*len) } _ => todo!() } From 3308c72a963f0f2bc087f77666831c1cf0655809 Mon Sep 17 00:00:00 2001 From: comphead Date: Wed, 2 Jul 2025 08:42:02 -0700 Subject: [PATCH 05/12] feat: support literal for ARRAY top level --- .../scala/org/apache/comet/CometArrayExpressionSuite.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala index d94004c729..084434ca13 100644 --- a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala @@ -221,8 +221,7 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp test("array_contains") { withSQLConf( CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true", - CometConf.COMET_EXPLAIN_FALLBACK_ENABLED.key -> "true" - ) { + CometConf.COMET_EXPLAIN_FALLBACK_ENABLED.key -> "true") { withTempDir { dir => val path = new Path(dir.toURI.toString, "test.parquet") makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = false, n = 10000) From 6b70695317ed91d808ee8728a4e2ecc72a0b3fb6 Mon Sep 17 00:00:00 2001 From: comphead Date: Wed, 9 Jul 2025 17:27:17 -0700 Subject: [PATCH 06/12] feat: support literal for ARRAY top level --- native/core/src/execution/planner.rs | 246 +++++++++++++++++- native/proto/src/lib.rs | 1 + native/proto/src/proto/types.proto | 2 + .../apache/comet/serde/QueryPlanSerde.scala | 93 +++++-- .../comet/exec/CometNativeReaderSuite.scala | 114 +++++++- 5 files changed, 427 insertions(+), 29 deletions(-) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 93cf6c4bac..8a692b5452 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -85,7 +85,11 @@ use datafusion::physical_expr::window::WindowExpr; use datafusion::physical_expr::LexOrdering; use crate::parquet::parquet_exec::init_datasource_exec; -use arrow::array::Int32Array; +use arrow::array::{ + BinaryArray, BinaryBuilder, BooleanBuilder, Date32Builder, Decimal128Array, Decimal128Builder, + Float32Builder, Float64Builder, Int16Builder, Int32Builder, Int64Builder, Int8Builder, + NullArray, StringBuilder, TimestampMicrosecondBuilder, +}; use datafusion::common::utils::SingleRowListArrayBuilder; use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion::physical_plan::filter::FilterExec as DataFusionFilterExec; @@ -444,17 +448,237 @@ impl PhysicalPlanner { } }, Value::ListVal(values) => { - //dbg!(values); - //dbg!(literal.datatype.as_ref().unwrap()); - //dbg!(data_type); - match data_type { - DataType::List(f) if f.data_type().equals_datatype(&DataType::Int32) => { - let vals = values.clone().int_values; - let len = &vals.len(); - SingleRowListArrayBuilder::new(Arc::new(Int32Array::from(vals))) - .build_fixed_size_list_scalar(*len) + if let DataType::List(f) = data_type { + match f.data_type() { + DataType::Null => { + SingleRowListArrayBuilder::new(Arc::new(NullArray::new(values.clone().null_mask.len()))) + .build_list_scalar() + } + DataType::Boolean => { + let vals = values.clone(); + let len = vals.boolean_values.len(); + let mut arr = BooleanBuilder::with_capacity(len); + + for i in 0 .. len { + if !vals.null_mask[i] { + arr.append_value(vals.boolean_values[i]); + } else { + arr.append_null(); + } + } + + SingleRowListArrayBuilder::new(Arc::new(arr.finish())) + .build_list_scalar() + } + DataType::Int8 => { + let vals = values.clone(); + let len = vals.byte_values.len(); + let mut arr = Int8Builder::with_capacity(len); + + for i in 0 .. len { + if !vals.null_mask[i] { + arr.append_value(vals.byte_values[i] as i8); + } else { + arr.append_null(); + } + } + + SingleRowListArrayBuilder::new(Arc::new(arr.finish())) + .build_list_scalar() + } + DataType::Int16 => { + let vals = values.clone(); + let len = vals.short_values.len(); + let mut arr = Int16Builder::with_capacity(len); + + for i in 0 .. len { + if !vals.null_mask[i] { + arr.append_value(vals.short_values[i] as i16); + } else { + arr.append_null(); + } + } + + SingleRowListArrayBuilder::new(Arc::new(arr.finish())) + .build_list_scalar() + } + DataType::Int32 => { + let vals = values.clone(); + let len = vals.int_values.len(); + let mut arr = Int32Builder::with_capacity(len); + + for i in 0 .. len { + if !vals.null_mask[i] { + arr.append_value(vals.int_values[i]); + } else { + arr.append_null(); + } + } + + SingleRowListArrayBuilder::new(Arc::new(arr.finish())) + .build_list_scalar() + } + DataType::Int64 => { + let vals = values.clone(); + let len = vals.long_values.len(); + let mut arr = Int64Builder::with_capacity(len); + + for i in 0 .. len { + if !vals.null_mask[i] { + arr.append_value(vals.long_values[i]); + } else { + arr.append_null(); + } + } + + SingleRowListArrayBuilder::new(Arc::new(arr.finish())) + .build_list_scalar() + } + DataType::Float32 => { + let vals = values.clone(); + let len = vals.float_values.len(); + let mut arr = Float32Builder::with_capacity(len); + + for i in 0 .. len { + if !vals.null_mask[i] { + arr.append_value(vals.float_values[i]); + } else { + arr.append_null(); + } + } + + SingleRowListArrayBuilder::new(Arc::new(arr.finish())) + .build_list_scalar() + } + DataType::Float64 => { + let vals = values.clone(); + let len = vals.double_values.len(); + let mut arr = Float64Builder::with_capacity(len); + + for i in 0 .. len { + if !vals.null_mask[i] { + arr.append_value(vals.double_values[i]); + } else { + arr.append_null(); + } + } + + SingleRowListArrayBuilder::new(Arc::new(arr.finish())) + .build_list_scalar() + } + DataType::Timestamp(TimeUnit::Microsecond, None) => { + let vals = values.clone(); + let len = vals.long_values.len(); + let mut arr = TimestampMicrosecondBuilder::with_capacity(len); + + for i in 0 .. len { + if !vals.null_mask[i] { + arr.append_value(vals.long_values[i]); + } else { + arr.append_null(); + } + } + + SingleRowListArrayBuilder::new(Arc::new(arr.finish())) + .build_list_scalar() + } + DataType::Timestamp(TimeUnit::Microsecond, Some(tz)) => { + let vals = values.clone(); + let len = vals.long_values.len(); + let mut arr = TimestampMicrosecondBuilder::with_capacity(len); + + for i in 0 .. len { + if !vals.null_mask[i] { + arr.append_value(vals.long_values[i]); + } else { + arr.append_null(); + } + } + + SingleRowListArrayBuilder::new(Arc::new(arr.finish().with_timezone(Arc::clone(tz)))) + .build_list_scalar() + } + DataType::Date32 => { + let vals = values.clone(); + let len = vals.int_values.len(); + let mut arr = Date32Builder::with_capacity(len); + + for i in 0 .. len { + if !vals.null_mask[i] { + arr.append_value(vals.int_values[i]); + } else { + arr.append_null(); + } + } + + SingleRowListArrayBuilder::new(Arc::new(arr.finish())) + .build_list_scalar() + } + DataType::Binary => { + let vals = values.clone(); + let mut arr = BinaryBuilder::new(); + + for (i, v) in vals.bytes_values.into_iter().enumerate() { + if !vals.null_mask[i] { + arr.append_value(v); + } else { + arr.append_null(); + } + } + + let binary_array: BinaryArray = arr.finish(); + SingleRowListArrayBuilder::new(Arc::new(binary_array)) + .build_list_scalar() + } + DataType::Utf8 => { + let vals = values.clone(); + let len = vals.string_values.len(); + let mut arr = StringBuilder::with_capacity(len, len); + + for (i, v) in vals.string_values.into_iter().enumerate() { + if !vals.null_mask[i] { + arr.append_value(v); + } else { + arr.append_null(); + } + } + + SingleRowListArrayBuilder::new(Arc::new(arr.finish())) + .build_list_scalar() + } + DataType::Decimal128(p, s) => { + let vals = values.clone(); + let mut arr = Decimal128Builder::new().with_precision_and_scale(*p, *s)?; + + for (i, v) in vals.decimal_values.into_iter().enumerate() { + if !vals.null_mask[i] { + let big_integer = BigInt::from_signed_bytes_be(&v); + let integer = big_integer.to_i128().ok_or_else(|| { + GeneralError(format!( + "Cannot parse {big_integer:?} as i128 for Decimal literal" + )) + })?; + arr.append_value(integer); + } else { + arr.append_null(); + } + } + + let decimal_array: Decimal128Array = arr.finish(); + SingleRowListArrayBuilder::new(Arc::new(decimal_array)) + .build_list_scalar() + } + dt => { + return Err(GeneralError(format!( + "DataType::List literal does not support {dt:?} type" + ))) + } } - _ => todo!() + + } else { + return Err(GeneralError(format!( + "Expected DataType::List but got {data_type:?}" + ))) } } } diff --git a/native/proto/src/lib.rs b/native/proto/src/lib.rs index ed24440360..2c213c2514 100644 --- a/native/proto/src/lib.rs +++ b/native/proto/src/lib.rs @@ -21,6 +21,7 @@ // Include generated modules from .proto files. #[allow(missing_docs)] +#[allow(clippy::large_enum_variant)] pub mod spark_expression { include!(concat!("generated", "/spark.spark_expression.rs")); } diff --git a/native/proto/src/proto/types.proto b/native/proto/src/proto/types.proto index 1c277c2e3d..cc163522b4 100644 --- a/native/proto/src/proto/types.proto +++ b/native/proto/src/proto/types.proto @@ -36,4 +36,6 @@ message ListLiteral { repeated bytes bytes_values = 9; repeated bytes decimal_values = 10; repeated ListLiteral list_values = 11; + + repeated bool null_mask = 12; } \ No newline at end of file diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 4c574ef5a0..a6ff5b1067 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -49,6 +49,8 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String +import com.google.protobuf.ByteString + import org.apache.comet.CometConf import org.apache.comet.CometSparkSessionExtensions.{isCometScan, withInfo} import org.apache.comet.expressions._ @@ -896,14 +898,13 @@ object QueryPlanSerde extends Logging with CometExprShim { case _: BooleanType => exprBuilder.setBoolVal(value.asInstanceOf[Boolean]) case _: ByteType => exprBuilder.setByteVal(value.asInstanceOf[Byte]) case _: ShortType => exprBuilder.setShortVal(value.asInstanceOf[Short]) - case _: IntegerType => exprBuilder.setIntVal(value.asInstanceOf[Int]) - case _: LongType => exprBuilder.setLongVal(value.asInstanceOf[Long]) + case _: IntegerType | _: DateType => exprBuilder.setIntVal(value.asInstanceOf[Int]) + case _: LongType | _: TimestampType | _: TimestampNTZType => + exprBuilder.setLongVal(value.asInstanceOf[Long]) case _: FloatType => exprBuilder.setFloatVal(value.asInstanceOf[Float]) case _: DoubleType => exprBuilder.setDoubleVal(value.asInstanceOf[Double]) case _: StringType => exprBuilder.setStringVal(value.asInstanceOf[UTF8String].toString) - case _: TimestampType => exprBuilder.setLongVal(value.asInstanceOf[Long]) - case _: TimestampNTZType => exprBuilder.setLongVal(value.asInstanceOf[Long]) case _: DecimalType => // Pass decimal literal as bytes. val unscaled = value.asInstanceOf[Decimal].toBigDecimal.underlying.unscaledValue @@ -913,26 +914,84 @@ object QueryPlanSerde extends Logging with CometExprShim { val byteStr = com.google.protobuf.ByteString.copyFrom(value.asInstanceOf[Array[Byte]]) exprBuilder.setBytesVal(byteStr) - case _: DateType => exprBuilder.setIntVal(value.asInstanceOf[Int]) case a: ArrayType => val listLiteralBuilder = ListLiteral.newBuilder() val array = value.asInstanceOf[GenericArrayData].array a.elementType match { + case NullType => + array.foreach(_ => listLiteralBuilder.addNullMask(true)) case BooleanType => - listLiteralBuilder.addAllBooleanValues( - array.map(_.asInstanceOf[java.lang.Boolean]).toIterable.asJava) + array.foreach(v => { + val casted = v.asInstanceOf[java.lang.Boolean] + listLiteralBuilder.addBooleanValues(casted) + listLiteralBuilder.addNullMask(casted == null) + }) case ByteType => - listLiteralBuilder.addAllByteValues( - array.map(_.asInstanceOf[java.lang.Integer]).toIterable.asJava) + array.foreach(v => { + val casted = v.asInstanceOf[java.lang.Integer] + listLiteralBuilder.addByteValues(casted) + listLiteralBuilder.addNullMask(casted == null) + }) case ShortType => - listLiteralBuilder.addAllShortValues( - array.map(_.asInstanceOf[java.lang.Integer]).toIterable.asJava) - case IntegerType => - listLiteralBuilder.addAllIntValues( - array.map(_.asInstanceOf[java.lang.Integer]).toIterable.asJava) - case LongType => - listLiteralBuilder.addAllLongValues( - array.map(_.asInstanceOf[java.lang.Long]).toIterable.asJava) + array.foreach(v => { + val casted = v.asInstanceOf[java.lang.Short] + listLiteralBuilder.addShortValues( + if (casted != null) casted.intValue() + else null.asInstanceOf[java.lang.Integer]) + listLiteralBuilder.addNullMask(casted == null) + }) + case IntegerType | DateType => + array.foreach(v => { + val casted = v.asInstanceOf[java.lang.Integer] + listLiteralBuilder.addIntValues(casted) + listLiteralBuilder.addNullMask(casted == null) + }) + case LongType | TimestampType | TimestampNTZType => + array.foreach(v => { + val casted = v.asInstanceOf[java.lang.Long] + listLiteralBuilder.addLongValues(casted) + listLiteralBuilder.addNullMask(casted == null) + }) + case FloatType => + array.foreach(v => { + val casted = v.asInstanceOf[java.lang.Float] + listLiteralBuilder.addFloatValues(casted) + listLiteralBuilder.addNullMask(casted == null) + }) + case DoubleType => + array.foreach(v => { + val casted = v.asInstanceOf[java.lang.Double] + listLiteralBuilder.addDoubleValues(casted) + listLiteralBuilder.addNullMask(casted == null) + }) + case StringType => + array.foreach(v => { + val casted = v.asInstanceOf[org.apache.spark.unsafe.types.UTF8String] + listLiteralBuilder.addStringValues( + if (casted != null) casted.toString else "") + listLiteralBuilder.addNullMask(casted == null) + }) + case _: DecimalType => + array + .foreach(v => { + val casted = + v.asInstanceOf[Decimal] + listLiteralBuilder.addDecimalValues(if (casted != null) { + com.google.protobuf.ByteString + .copyFrom(casted.toBigDecimal.underlying.unscaledValue.toByteArray) + } else ByteString.EMPTY) + listLiteralBuilder.addNullMask(casted == null) + }) + case _: BinaryType => + array + .foreach(v => { + val casted = + v.asInstanceOf[Array[Byte]] + listLiteralBuilder.addBytesValues(if (casted != null) { + com.google.protobuf.ByteString.copyFrom(casted) + } else ByteString.EMPTY) + listLiteralBuilder.addNullMask(casted == null) + }) } exprBuilder.setListVal(listLiteralBuilder.build()) exprBuilder.setDatatype(serializeDataType(dataType).get) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometNativeReaderSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometNativeReaderSuite.scala index a42e8a6c4c..0750593c96 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometNativeReaderSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometNativeReaderSuite.scala @@ -442,6 +442,118 @@ class CometNativeReaderSuite extends CometTestBase with AdaptiveSparkPlanHelper """ |select 1 a |""".stripMargin, - "select array(1, 2, 3) from tbl") + "select array(1, 2, 3, null) from tbl") + } + + test("native reader - support ARRAY literal BOOL fields") { + testSingleLineQuery( + """ + |select 1 a + |""".stripMargin, + "select array(true, false, null) from tbl") + } + + test("native reader - support ARRAY literal NULL fields") { + testSingleLineQuery( + """ + |select 1 a + |""".stripMargin, + "select array(null) from tbl") + } + + test("native reader - support empty ARRAY literal") { + testSingleLineQuery( + """ + |select 1 a + |""".stripMargin, + "select array() from tbl") + } + + test("native reader - support ARRAY literal BYTE fields") { + testSingleLineQuery( + """ + |select 1 a + |""".stripMargin, + "select array(1, 2, 3, null) from tbl") + } + + test("native reader - support ARRAY literal SHORT fields") { + testSingleLineQuery( + """ + |select 1 a + |""".stripMargin, + "select array(cast(1 as short), cast(2 as short), cast(3 as short), null) from tbl") + } + + test("native reader - support ARRAY literal DATE fields") { + testSingleLineQuery( + """ + |select 1 a + |""".stripMargin, + "select array(CAST('2024-01-01' AS DATE), CAST('2024-02-01' AS DATE), CAST('2024-03-01' AS DATE), null) from tbl") + } + + test("native reader - support ARRAY literal LONG fields") { + testSingleLineQuery( + """ + |select 1 a + |""".stripMargin, + "select array(cast(1 as bigint), cast(2 as bigint), cast(3 as bigint), null) from tbl") + } + + test("native reader - support ARRAY literal TIMESTAMP fields") { + testSingleLineQuery( + """ + |select 1 a + |""".stripMargin, + "select array(CAST('2024-01-01 10:00:00' AS TIMESTAMP), CAST('2024-01-02 10:00:00' AS TIMESTAMP), CAST('2024-01-03 10:00:00' AS TIMESTAMP), null) from tbl") + } + + test("native reader - support ARRAY literal TIMESTAMP TZ fields") { + testSingleLineQuery( + """ + |select 1 a + |""".stripMargin, + "select array(CAST('2024-01-01 10:00:00' AS TIMESTAMP_NTZ), CAST('2024-01-02 10:00:00' AS TIMESTAMP_NTZ), CAST('2024-01-03 10:00:00' AS TIMESTAMP_NTZ), null) from tbl") + } + + test("native reader - support ARRAY literal FLOAT fields") { + testSingleLineQuery( + """ + |select 1 a + |""".stripMargin, + "select array(cast(1 as float), cast(2 as float), cast(3 as float), null) from tbl") + } + + test("native reader - support ARRAY literal DOUBLE fields") { + testSingleLineQuery( + """ + |select 1 a + |""".stripMargin, + "select array(cast(1 as double), cast(2 as double), cast(3 as double), null) from tbl") + } + + test("native reader - support ARRAY literal STRING fields") { + testSingleLineQuery( + """ + |select 1 a + |""".stripMargin, + "select array('a', 'bc', 'def', null) from tbl") + } + + test("native reader - support ARRAY literal DECIMAL fields") { + testSingleLineQuery( + """ + |select 1 a + |""".stripMargin, + "select array(cast(1 as decimal(10, 2)), cast(2.5 as decimal(10, 2)),cast(3.75 as decimal(10, 2)), null) from tbl") + } + + test("native reader - support ARRAY literal BINARY fields") { + testSingleLineQuery( + """ + |select 1 a + |""".stripMargin, + "select array(cast('a' as binary), cast('bc' as binary), cast('def' as binary), null) from tbl") } } From eb5f2bf2f2bb9c59b7ae4f05b6066ba51f7981dd Mon Sep 17 00:00:00 2001 From: comphead Date: Wed, 9 Jul 2025 17:40:55 -0700 Subject: [PATCH 07/12] feat: support literal for ARRAY top level --- .../apache/comet/serde/QueryPlanSerde.scala | 54 +++---------------- .../comet/CometArrayExpressionSuite.scala | 4 +- 2 files changed, 8 insertions(+), 50 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index a6ff5b1067..7d1da4e437 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -146,52 +146,6 @@ object QueryPlanSerde extends Logging with CometExprShim { false } -// def convertArrayToProtoLiteral(array: Seq[Any], arrayType: ArrayType): Literal = { -// val elementType = arrayType.elementType -// val listLiteralBuilder = ListLiteral.newBuilder() -// -// elementType match { -// case BooleanType => -// listLiteralBuilder.addAllBooleanValues(array.map(_.asInstanceOf[Boolean]).asJava) -// -// case ByteType => -// listLiteralBuilder.addAllByteValues(array.map(_.asInstanceOf[Byte].toInt).asJava) -// -// case ShortType => -// listLiteralBuilder.addAllShortValues(array.map(_.asInstanceOf[Short].toInt).asJava) -// -// case IntegerType => -// listLiteralBuilder.addAllIntValues(array.map(_.asInstanceOf[Int]).asJava) -// -// case LongType => -// listLiteralBuilder.addAllLongValues(array.map(_.asInstanceOf[Long]).asJava) -// -// case FloatType => -// listLiteralBuilder.addAllFloatValues(array.map(_.asInstanceOf[Float]).asJava) -// -// case DoubleType => -// listLiteralBuilder.addAllDoubleValues(array.map(_.asInstanceOf[Double]).asJava) -// -// case StringType => -// listLiteralBuilder.addAllStringValues(array.map(_.asInstanceOf[String]).asJava) -// -// case BinaryType => -// listLiteralBuilder.addAllBytesValues - // (array.map(x => com.google.protobuf - // .ByteString.copyFrom(x.asInstanceOf[Array[Byte]])).asJava) -// -// case nested: ArrayType => -// val nestedListLiterals = array.map { -// case null => ListLiteral.newBuilder().build() // or handle nulls appropriately -// case seq: Seq[_] => convertArrayToProtoLiteral(seq, nested).getListVal -// } -// listLiteralBuilder.addAllListValues(nestedListLiterals.asJava) -// -// case _ => -// throw new UnsupportedOperationException(s"Unsupported element type: $elementType") -// } -// } - /** * Serializes Spark datatype to protobuf. Note that, a datatype can be serialized by this method * doesn't mean it is supported by Comet native execution, i.e., `supportedDataType` may return @@ -887,7 +841,13 @@ object QueryPlanSerde extends Logging with CometExprShim { binding, (builder, binaryExpr) => builder.setLtEq(binaryExpr)) - case Literal(value, dataType) if supportedDataType(dataType, allowComplex = true) => + case Literal(value, dataType) + if supportedDataType( + dataType, + allowComplex = value == null || Seq( + CometConf.SCAN_NATIVE_ICEBERG_COMPAT, + CometConf.SCAN_NATIVE_DATAFUSION).contains( + CometConf.COMET_NATIVE_SCAN_IMPL.get())) => val exprBuilder = ExprOuterClass.Literal.newBuilder() if (value == null) { diff --git a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala index 084434ca13..0be89c5124 100644 --- a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala @@ -219,9 +219,7 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp } test("array_contains") { - withSQLConf( - CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true", - CometConf.COMET_EXPLAIN_FALLBACK_ENABLED.key -> "true") { + withSQLConf(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") { withTempDir { dir => val path = new Path(dir.toURI.toString, "test.parquet") makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = false, n = 10000) From d8029cb5c8edacbdd65af43d6947ab0fc1e34d1e Mon Sep 17 00:00:00 2001 From: comphead Date: Thu, 10 Jul 2025 14:28:09 -0700 Subject: [PATCH 08/12] feat: support literal for ARRAY top level --- native/spark-expr/src/conversion_funcs/cast.rs | 4 ++++ .../main/scala/org/apache/comet/expressions/CometCast.scala | 2 ++ 2 files changed, 6 insertions(+) diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index e0bc5f39fc..fef6308306 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -20,6 +20,7 @@ use crate::utils::array_with_timezone; use crate::{EvalMode, SparkError, SparkResult}; use arrow::array::builder::StringBuilder; use arrow::array::{DictionaryArray, StringArray, StructArray}; +use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, Schema}; use arrow::{ array::{ @@ -967,6 +968,9 @@ fn cast_array( to_type, cast_options, )?), + (List(_), List(_)) if can_cast_types(from_type, to_type) => { + Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?) + } (UInt8 | UInt16 | UInt32 | UInt64, Int8 | Int16 | Int32 | Int64) if cast_options.allow_cast_unsigned_ints => { diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala index 6d52824b56..337eae11db 100644 --- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala +++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala @@ -63,6 +63,8 @@ object CometCast { (fromType, toType) match { case (dt: ArrayType, _: ArrayType) if dt.elementType == NullType => Compatible() + case (dt: ArrayType, dt1: ArrayType) => + isSupported(dt.elementType, dt1.elementType, timeZoneId, evalMode) case (dt: DataType, _) if dt.typeName == "timestamp_ntz" => // https://github.com/apache/datafusion-comet/issues/378 toType match { From 5c5522b9afe13360ef5ec2e312717f6eb930e8e0 Mon Sep 17 00:00:00 2001 From: comphead Date: Thu, 10 Jul 2025 14:29:11 -0700 Subject: [PATCH 09/12] feat: support literal for ARRAY top level --- .../org/apache/comet/exec/CometNativeReaderSuite.scala | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometNativeReaderSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometNativeReaderSuite.scala index 0750593c96..4c7ea3a202 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometNativeReaderSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometNativeReaderSuite.scala @@ -556,4 +556,12 @@ class CometNativeReaderSuite extends CometTestBase with AdaptiveSparkPlanHelper |""".stripMargin, "select array(cast('a' as binary), cast('bc' as binary), cast('def' as binary), null) from tbl") } + + test("native reader - array equality") { + testSingleLineQuery( + """ + | select array(1) a union all select array(2) + |""".stripMargin, + "select * from tbl where a = array(1L)") + } } From 2148f5d4ae3c708342fae820cb735d1544f7ef88 Mon Sep 17 00:00:00 2001 From: comphead Date: Fri, 11 Jul 2025 17:57:43 -0700 Subject: [PATCH 10/12] feat: support literal for ARRAY top level --- spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 7d1da4e437..8748ab975b 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -31,7 +31,6 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, NormalizeNaNAndZero} import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.util.{CharVarcharCodegenUtils, GenericArrayData} import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.getExistenceDefaultValues import org.apache.spark.sql.comet._ From 6aab4addd36c960a805cab45b2eca33899044076 Mon Sep 17 00:00:00 2001 From: comphead Date: Thu, 17 Jul 2025 08:30:53 -0700 Subject: [PATCH 11/12] feat: support literal for ARRAY top level --- .../spark-expr/src/conversion_funcs/cast.rs | 2 +- .../apache/comet/serde/QueryPlanSerde.scala | 4 ++- .../comet/exec/CometNativeReaderSuite.scala | 25 +++++++------------ 3 files changed, 13 insertions(+), 18 deletions(-) diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index fef6308306..09659b6ac3 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -1021,7 +1021,7 @@ fn is_datafusion_spark_compatible( DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { // note that the cast from Int32/Int64 -> Decimal128 here is actually // not compatible with Spark (no overflow checks) but we have tests that - // rely on this cast working so we have to leave it here for now + // rely on this cast working, so we have to leave it here for now matches!( to_type, DataType::Boolean diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 8748ab975b..c8f89310c9 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2180,7 +2180,9 @@ object QueryPlanSerde extends Logging with CometExprShim { op match { // Fully native scan for V1 - case scan: CometScanExec if scan.scanImpl == CometConf.SCAN_NATIVE_DATAFUSION => + case scan: CometScanExec + if scan.scanImpl == CometConf.SCAN_NATIVE_DATAFUSION + || scan.scanImpl == CometConf.SCAN_NATIVE_ICEBERG_COMPAT => val nativeScanBuilder = OperatorOuterClass.NativeScan.newBuilder() nativeScanBuilder.setSource(op.simpleStringWithNodeId()) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometNativeReaderSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometNativeReaderSuite.scala index 4c7ea3a202..d1b5cab262 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometNativeReaderSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometNativeReaderSuite.scala @@ -24,6 +24,7 @@ import org.scalatest.Tag import org.apache.spark.sql.CometTestBase import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.functions.{array, col} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{IntegerType, StringType, StructType} @@ -253,18 +254,11 @@ class CometNativeReaderSuite extends CometTestBase with AdaptiveSparkPlanHelper } test("native reader - read a STRUCT subfield - field from second") { - withSQLConf( - CometConf.COMET_EXEC_ENABLED.key -> "true", - SQLConf.USE_V1_SOURCE_LIST.key -> "parquet", - CometConf.COMET_ENABLED.key -> "true", - CometConf.COMET_EXPLAIN_FALLBACK_ENABLED.key -> "false", - CometConf.COMET_NATIVE_SCAN_IMPL.key -> "native_datafusion") { - testSingleLineQuery( - """ + testSingleLineQuery( + """ |select 1 a, named_struct('a', 1, 'b', 'n') c0 |""".stripMargin, - "select c0.b from tbl") - } + "select c0.b from tbl") } test("native reader - read a STRUCT subfield from ARRAY of STRUCTS - field from first") { @@ -557,11 +551,10 @@ class CometNativeReaderSuite extends CometTestBase with AdaptiveSparkPlanHelper "select array(cast('a' as binary), cast('bc' as binary), cast('def' as binary), null) from tbl") } - test("native reader - array equality") { - testSingleLineQuery( - """ - | select array(1) a union all select array(2) - |""".stripMargin, - "select * from tbl where a = array(1L)") + test("SPARK-18053: ARRAY equality is broken") { + withTable("array_tbl") { + spark.range(10).select(array(col("id")).as("arr")).write.saveAsTable("array_tbl") + assert(sql("SELECT * FROM array_tbl where arr = ARRAY(1L)").count == 1) + } } } From 4373909496d0fdfba04f151f6c50aac35575d2e4 Mon Sep 17 00:00:00 2001 From: comphead Date: Thu, 17 Jul 2025 09:32:33 -0700 Subject: [PATCH 12/12] feat: support literal for ARRAY top level --- .../main/scala/org/apache/comet/serde/QueryPlanSerde.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index c8f89310c9..8748ab975b 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2180,9 +2180,7 @@ object QueryPlanSerde extends Logging with CometExprShim { op match { // Fully native scan for V1 - case scan: CometScanExec - if scan.scanImpl == CometConf.SCAN_NATIVE_DATAFUSION - || scan.scanImpl == CometConf.SCAN_NATIVE_ICEBERG_COMPAT => + case scan: CometScanExec if scan.scanImpl == CometConf.SCAN_NATIVE_DATAFUSION => val nativeScanBuilder = OperatorOuterClass.NativeScan.newBuilder() nativeScanBuilder.setSource(op.simpleStringWithNodeId())