diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 327e57bcfd..8a692b5452 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -85,6 +85,12 @@ use datafusion::physical_expr::window::WindowExpr; use datafusion::physical_expr::LexOrdering; use crate::parquet::parquet_exec::init_datasource_exec; +use arrow::array::{ + BinaryArray, BinaryBuilder, BooleanBuilder, Date32Builder, Decimal128Array, Decimal128Builder, + Float32Builder, Float64Builder, Int16Builder, Int32Builder, Int64Builder, Int8Builder, + NullArray, StringBuilder, TimestampMicrosecondBuilder, +}; +use datafusion::common::utils::SingleRowListArrayBuilder; use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion::physical_plan::filter::FilterExec as DataFusionFilterExec; use datafusion_comet_proto::spark_operator::SparkFilePartition; @@ -440,6 +446,240 @@ impl PhysicalPlanner { ))) } } + }, + Value::ListVal(values) => { + if let DataType::List(f) = data_type { + match f.data_type() { + DataType::Null => { + SingleRowListArrayBuilder::new(Arc::new(NullArray::new(values.clone().null_mask.len()))) + .build_list_scalar() + } + DataType::Boolean => { + let vals = values.clone(); + let len = vals.boolean_values.len(); + let mut arr = BooleanBuilder::with_capacity(len); + + for i in 0 .. len { + if !vals.null_mask[i] { + arr.append_value(vals.boolean_values[i]); + } else { + arr.append_null(); + } + } + + SingleRowListArrayBuilder::new(Arc::new(arr.finish())) + .build_list_scalar() + } + DataType::Int8 => { + let vals = values.clone(); + let len = vals.byte_values.len(); + let mut arr = Int8Builder::with_capacity(len); + + for i in 0 .. len { + if !vals.null_mask[i] { + arr.append_value(vals.byte_values[i] as i8); + } else { + arr.append_null(); + } + } + + SingleRowListArrayBuilder::new(Arc::new(arr.finish())) + .build_list_scalar() + } + DataType::Int16 => { + let vals = values.clone(); + let len = vals.short_values.len(); + let mut arr = Int16Builder::with_capacity(len); + + for i in 0 .. len { + if !vals.null_mask[i] { + arr.append_value(vals.short_values[i] as i16); + } else { + arr.append_null(); + } + } + + SingleRowListArrayBuilder::new(Arc::new(arr.finish())) + .build_list_scalar() + } + DataType::Int32 => { + let vals = values.clone(); + let len = vals.int_values.len(); + let mut arr = Int32Builder::with_capacity(len); + + for i in 0 .. len { + if !vals.null_mask[i] { + arr.append_value(vals.int_values[i]); + } else { + arr.append_null(); + } + } + + SingleRowListArrayBuilder::new(Arc::new(arr.finish())) + .build_list_scalar() + } + DataType::Int64 => { + let vals = values.clone(); + let len = vals.long_values.len(); + let mut arr = Int64Builder::with_capacity(len); + + for i in 0 .. len { + if !vals.null_mask[i] { + arr.append_value(vals.long_values[i]); + } else { + arr.append_null(); + } + } + + SingleRowListArrayBuilder::new(Arc::new(arr.finish())) + .build_list_scalar() + } + DataType::Float32 => { + let vals = values.clone(); + let len = vals.float_values.len(); + let mut arr = Float32Builder::with_capacity(len); + + for i in 0 .. len { + if !vals.null_mask[i] { + arr.append_value(vals.float_values[i]); + } else { + arr.append_null(); + } + } + + SingleRowListArrayBuilder::new(Arc::new(arr.finish())) + .build_list_scalar() + } + DataType::Float64 => { + let vals = values.clone(); + let len = vals.double_values.len(); + let mut arr = Float64Builder::with_capacity(len); + + for i in 0 .. len { + if !vals.null_mask[i] { + arr.append_value(vals.double_values[i]); + } else { + arr.append_null(); + } + } + + SingleRowListArrayBuilder::new(Arc::new(arr.finish())) + .build_list_scalar() + } + DataType::Timestamp(TimeUnit::Microsecond, None) => { + let vals = values.clone(); + let len = vals.long_values.len(); + let mut arr = TimestampMicrosecondBuilder::with_capacity(len); + + for i in 0 .. len { + if !vals.null_mask[i] { + arr.append_value(vals.long_values[i]); + } else { + arr.append_null(); + } + } + + SingleRowListArrayBuilder::new(Arc::new(arr.finish())) + .build_list_scalar() + } + DataType::Timestamp(TimeUnit::Microsecond, Some(tz)) => { + let vals = values.clone(); + let len = vals.long_values.len(); + let mut arr = TimestampMicrosecondBuilder::with_capacity(len); + + for i in 0 .. len { + if !vals.null_mask[i] { + arr.append_value(vals.long_values[i]); + } else { + arr.append_null(); + } + } + + SingleRowListArrayBuilder::new(Arc::new(arr.finish().with_timezone(Arc::clone(tz)))) + .build_list_scalar() + } + DataType::Date32 => { + let vals = values.clone(); + let len = vals.int_values.len(); + let mut arr = Date32Builder::with_capacity(len); + + for i in 0 .. len { + if !vals.null_mask[i] { + arr.append_value(vals.int_values[i]); + } else { + arr.append_null(); + } + } + + SingleRowListArrayBuilder::new(Arc::new(arr.finish())) + .build_list_scalar() + } + DataType::Binary => { + let vals = values.clone(); + let mut arr = BinaryBuilder::new(); + + for (i, v) in vals.bytes_values.into_iter().enumerate() { + if !vals.null_mask[i] { + arr.append_value(v); + } else { + arr.append_null(); + } + } + + let binary_array: BinaryArray = arr.finish(); + SingleRowListArrayBuilder::new(Arc::new(binary_array)) + .build_list_scalar() + } + DataType::Utf8 => { + let vals = values.clone(); + let len = vals.string_values.len(); + let mut arr = StringBuilder::with_capacity(len, len); + + for (i, v) in vals.string_values.into_iter().enumerate() { + if !vals.null_mask[i] { + arr.append_value(v); + } else { + arr.append_null(); + } + } + + SingleRowListArrayBuilder::new(Arc::new(arr.finish())) + .build_list_scalar() + } + DataType::Decimal128(p, s) => { + let vals = values.clone(); + let mut arr = Decimal128Builder::new().with_precision_and_scale(*p, *s)?; + + for (i, v) in vals.decimal_values.into_iter().enumerate() { + if !vals.null_mask[i] { + let big_integer = BigInt::from_signed_bytes_be(&v); + let integer = big_integer.to_i128().ok_or_else(|| { + GeneralError(format!( + "Cannot parse {big_integer:?} as i128 for Decimal literal" + )) + })?; + arr.append_value(integer); + } else { + arr.append_null(); + } + } + + let decimal_array: Decimal128Array = arr.finish(); + SingleRowListArrayBuilder::new(Arc::new(decimal_array)) + .build_list_scalar() + } + dt => { + return Err(GeneralError(format!( + "DataType::List literal does not support {dt:?} type" + ))) + } + } + + } else { + return Err(GeneralError(format!( + "Expected DataType::List but got {data_type:?}" + ))) + } } } }; @@ -2273,7 +2513,6 @@ impl PhysicalPlanner { other => other, }; let func = self.session_ctx.udf(fun_name)?; - let coerced_types = func .coerce_types(&input_expr_types) .unwrap_or_else(|_| input_expr_types.clone()); diff --git a/native/proto/src/lib.rs b/native/proto/src/lib.rs index ed24440360..2c213c2514 100644 --- a/native/proto/src/lib.rs +++ b/native/proto/src/lib.rs @@ -21,6 +21,7 @@ // Include generated modules from .proto files. #[allow(missing_docs)] +#[allow(clippy::large_enum_variant)] pub mod spark_expression { include!(concat!("generated", "/spark.spark_expression.rs")); } diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 8f4c875eec..7e04c7aa55 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -21,6 +21,8 @@ syntax = "proto3"; package spark.spark_expression; +import "types.proto"; + option java_package = "org.apache.comet.serde"; // The basic message representing a Spark expression. @@ -110,13 +112,13 @@ enum StatisticsType { } message Count { - repeated Expr children = 1; + repeated Expr children = 1; } message Sum { - Expr child = 1; - DataType datatype = 2; - bool fail_on_error = 3; + Expr child = 1; + DataType datatype = 2; + bool fail_on_error = 3; } message Min { @@ -213,10 +215,11 @@ message Literal { string string_val = 8; bytes bytes_val = 9; bytes decimal_val = 10; - } + ListLiteral list_val = 11; + } - DataType datatype = 11; - bool is_null = 12; + DataType datatype = 12; + bool is_null = 13; } message MathExpr { @@ -469,5 +472,4 @@ message DataType { } DataTypeInfo type_info = 2; -} - +} \ No newline at end of file diff --git a/native/proto/src/proto/types.proto b/native/proto/src/proto/types.proto new file mode 100644 index 0000000000..cc163522b4 --- /dev/null +++ b/native/proto/src/proto/types.proto @@ -0,0 +1,41 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + + + +syntax = "proto3"; + +package spark.spark_expression; + +option java_package = "org.apache.comet.serde"; + +message ListLiteral { + // Only one of these fields should be populated based on the array type + repeated bool boolean_values = 1; + repeated int32 byte_values = 2; + repeated int32 short_values = 3; + repeated int32 int_values = 4; + repeated int64 long_values = 5; + repeated float float_values = 6; + repeated double double_values = 7; + repeated string string_values = 8; + repeated bytes bytes_values = 9; + repeated bytes decimal_values = 10; + repeated ListLiteral list_values = 11; + + repeated bool null_mask = 12; +} \ No newline at end of file diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index e0bc5f39fc..09659b6ac3 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -20,6 +20,7 @@ use crate::utils::array_with_timezone; use crate::{EvalMode, SparkError, SparkResult}; use arrow::array::builder::StringBuilder; use arrow::array::{DictionaryArray, StringArray, StructArray}; +use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, Schema}; use arrow::{ array::{ @@ -967,6 +968,9 @@ fn cast_array( to_type, cast_options, )?), + (List(_), List(_)) if can_cast_types(from_type, to_type) => { + Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?) + } (UInt8 | UInt16 | UInt32 | UInt64, Int8 | Int16 | Int32 | Int64) if cast_options.allow_cast_unsigned_ints => { @@ -1017,7 +1021,7 @@ fn is_datafusion_spark_compatible( DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { // note that the cast from Int32/Int64 -> Decimal128 here is actually // not compatible with Spark (no overflow checks) but we have tests that - // rely on this cast working so we have to leave it here for now + // rely on this cast working, so we have to leave it here for now matches!( to_type, DataType::Boolean diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala index e0e89b35fc..337eae11db 100644 --- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala +++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala @@ -19,7 +19,7 @@ package org.apache.comet.expressions -import org.apache.spark.sql.types.{DataType, DataTypes, DecimalType, StructType} +import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, DecimalType, NullType, StructType} sealed trait SupportLevel @@ -62,6 +62,9 @@ object CometCast { } (fromType, toType) match { + case (dt: ArrayType, _: ArrayType) if dt.elementType == NullType => Compatible() + case (dt: ArrayType, dt1: ArrayType) => + isSupported(dt.elementType, dt1.elementType, timeZoneId, evalMode) case (dt: DataType, _) if dt.typeName == "timestamp_ntz" => // https://github.com/apache/datafusion-comet/issues/378 toType match { diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 970329b28f..8748ab975b 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, NormalizeNaNAndZero} import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils +import org.apache.spark.sql.catalyst.util.{CharVarcharCodegenUtils, GenericArrayData} import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.getExistenceDefaultValues import org.apache.spark.sql.comet._ import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec @@ -48,6 +48,8 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String +import com.google.protobuf.ByteString + import org.apache.comet.CometConf import org.apache.comet.CometSparkSessionExtensions.{isCometScan, withInfo} import org.apache.comet.expressions._ @@ -56,6 +58,7 @@ import org.apache.comet.serde.ExprOuterClass.{AggExpr, DataType => ProtoDataType import org.apache.comet.serde.ExprOuterClass.DataType._ import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, BuildSide, JoinType, Operator} import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto} +import org.apache.comet.serde.Types.ListLiteral import org.apache.comet.shims.CometExprShim /** @@ -838,7 +841,12 @@ object QueryPlanSerde extends Logging with CometExprShim { (builder, binaryExpr) => builder.setLtEq(binaryExpr)) case Literal(value, dataType) - if supportedDataType(dataType, allowComplex = value == null) => + if supportedDataType( + dataType, + allowComplex = value == null || Seq( + CometConf.SCAN_NATIVE_ICEBERG_COMPAT, + CometConf.SCAN_NATIVE_DATAFUSION).contains( + CometConf.COMET_NATIVE_SCAN_IMPL.get())) => val exprBuilder = ExprOuterClass.Literal.newBuilder() if (value == null) { @@ -849,14 +857,13 @@ object QueryPlanSerde extends Logging with CometExprShim { case _: BooleanType => exprBuilder.setBoolVal(value.asInstanceOf[Boolean]) case _: ByteType => exprBuilder.setByteVal(value.asInstanceOf[Byte]) case _: ShortType => exprBuilder.setShortVal(value.asInstanceOf[Short]) - case _: IntegerType => exprBuilder.setIntVal(value.asInstanceOf[Int]) - case _: LongType => exprBuilder.setLongVal(value.asInstanceOf[Long]) + case _: IntegerType | _: DateType => exprBuilder.setIntVal(value.asInstanceOf[Int]) + case _: LongType | _: TimestampType | _: TimestampNTZType => + exprBuilder.setLongVal(value.asInstanceOf[Long]) case _: FloatType => exprBuilder.setFloatVal(value.asInstanceOf[Float]) case _: DoubleType => exprBuilder.setDoubleVal(value.asInstanceOf[Double]) case _: StringType => exprBuilder.setStringVal(value.asInstanceOf[UTF8String].toString) - case _: TimestampType => exprBuilder.setLongVal(value.asInstanceOf[Long]) - case _: TimestampNTZType => exprBuilder.setLongVal(value.asInstanceOf[Long]) case _: DecimalType => // Pass decimal literal as bytes. val unscaled = value.asInstanceOf[Decimal].toBigDecimal.underlying.unscaledValue @@ -866,7 +873,87 @@ object QueryPlanSerde extends Logging with CometExprShim { val byteStr = com.google.protobuf.ByteString.copyFrom(value.asInstanceOf[Array[Byte]]) exprBuilder.setBytesVal(byteStr) - case _: DateType => exprBuilder.setIntVal(value.asInstanceOf[Int]) + case a: ArrayType => + val listLiteralBuilder = ListLiteral.newBuilder() + val array = value.asInstanceOf[GenericArrayData].array + a.elementType match { + case NullType => + array.foreach(_ => listLiteralBuilder.addNullMask(true)) + case BooleanType => + array.foreach(v => { + val casted = v.asInstanceOf[java.lang.Boolean] + listLiteralBuilder.addBooleanValues(casted) + listLiteralBuilder.addNullMask(casted == null) + }) + case ByteType => + array.foreach(v => { + val casted = v.asInstanceOf[java.lang.Integer] + listLiteralBuilder.addByteValues(casted) + listLiteralBuilder.addNullMask(casted == null) + }) + case ShortType => + array.foreach(v => { + val casted = v.asInstanceOf[java.lang.Short] + listLiteralBuilder.addShortValues( + if (casted != null) casted.intValue() + else null.asInstanceOf[java.lang.Integer]) + listLiteralBuilder.addNullMask(casted == null) + }) + case IntegerType | DateType => + array.foreach(v => { + val casted = v.asInstanceOf[java.lang.Integer] + listLiteralBuilder.addIntValues(casted) + listLiteralBuilder.addNullMask(casted == null) + }) + case LongType | TimestampType | TimestampNTZType => + array.foreach(v => { + val casted = v.asInstanceOf[java.lang.Long] + listLiteralBuilder.addLongValues(casted) + listLiteralBuilder.addNullMask(casted == null) + }) + case FloatType => + array.foreach(v => { + val casted = v.asInstanceOf[java.lang.Float] + listLiteralBuilder.addFloatValues(casted) + listLiteralBuilder.addNullMask(casted == null) + }) + case DoubleType => + array.foreach(v => { + val casted = v.asInstanceOf[java.lang.Double] + listLiteralBuilder.addDoubleValues(casted) + listLiteralBuilder.addNullMask(casted == null) + }) + case StringType => + array.foreach(v => { + val casted = v.asInstanceOf[org.apache.spark.unsafe.types.UTF8String] + listLiteralBuilder.addStringValues( + if (casted != null) casted.toString else "") + listLiteralBuilder.addNullMask(casted == null) + }) + case _: DecimalType => + array + .foreach(v => { + val casted = + v.asInstanceOf[Decimal] + listLiteralBuilder.addDecimalValues(if (casted != null) { + com.google.protobuf.ByteString + .copyFrom(casted.toBigDecimal.underlying.unscaledValue.toByteArray) + } else ByteString.EMPTY) + listLiteralBuilder.addNullMask(casted == null) + }) + case _: BinaryType => + array + .foreach(v => { + val casted = + v.asInstanceOf[Array[Byte]] + listLiteralBuilder.addBytesValues(if (casted != null) { + com.google.protobuf.ByteString.copyFrom(casted) + } else ByteString.EMPTY) + listLiteralBuilder.addNullMask(casted == null) + }) + } + exprBuilder.setListVal(listLiteralBuilder.build()) + exprBuilder.setDatatype(serializeDataType(dataType).get) case dt => logWarning(s"Unexpected datatype '$dt' for literal value '$value'") } diff --git a/spark/src/test/scala/org/apache/comet/exec/CometNativeReaderSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometNativeReaderSuite.scala index f33da3ba71..d1b5cab262 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometNativeReaderSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometNativeReaderSuite.scala @@ -24,6 +24,7 @@ import org.scalatest.Tag import org.apache.spark.sql.CometTestBase import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.functions.{array, col} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{IntegerType, StringType, StructType} @@ -253,18 +254,11 @@ class CometNativeReaderSuite extends CometTestBase with AdaptiveSparkPlanHelper } test("native reader - read a STRUCT subfield - field from second") { - withSQLConf( - CometConf.COMET_EXEC_ENABLED.key -> "true", - SQLConf.USE_V1_SOURCE_LIST.key -> "parquet", - CometConf.COMET_ENABLED.key -> "true", - CometConf.COMET_EXPLAIN_FALLBACK_ENABLED.key -> "false", - CometConf.COMET_NATIVE_SCAN_IMPL.key -> "native_datafusion") { - testSingleLineQuery( - """ + testSingleLineQuery( + """ |select 1 a, named_struct('a', 1, 'b', 'n') c0 |""".stripMargin, - "select c0.b from tbl") - } + "select c0.b from tbl") } test("native reader - read a STRUCT subfield from ARRAY of STRUCTS - field from first") { @@ -436,4 +430,131 @@ class CometNativeReaderSuite extends CometTestBase with AdaptiveSparkPlanHelper |""".stripMargin, "select c0['key1'].b from tbl") } + + test("native reader - support ARRAY literal INT fields") { + testSingleLineQuery( + """ + |select 1 a + |""".stripMargin, + "select array(1, 2, 3, null) from tbl") + } + + test("native reader - support ARRAY literal BOOL fields") { + testSingleLineQuery( + """ + |select 1 a + |""".stripMargin, + "select array(true, false, null) from tbl") + } + + test("native reader - support ARRAY literal NULL fields") { + testSingleLineQuery( + """ + |select 1 a + |""".stripMargin, + "select array(null) from tbl") + } + + test("native reader - support empty ARRAY literal") { + testSingleLineQuery( + """ + |select 1 a + |""".stripMargin, + "select array() from tbl") + } + + test("native reader - support ARRAY literal BYTE fields") { + testSingleLineQuery( + """ + |select 1 a + |""".stripMargin, + "select array(1, 2, 3, null) from tbl") + } + + test("native reader - support ARRAY literal SHORT fields") { + testSingleLineQuery( + """ + |select 1 a + |""".stripMargin, + "select array(cast(1 as short), cast(2 as short), cast(3 as short), null) from tbl") + } + + test("native reader - support ARRAY literal DATE fields") { + testSingleLineQuery( + """ + |select 1 a + |""".stripMargin, + "select array(CAST('2024-01-01' AS DATE), CAST('2024-02-01' AS DATE), CAST('2024-03-01' AS DATE), null) from tbl") + } + + test("native reader - support ARRAY literal LONG fields") { + testSingleLineQuery( + """ + |select 1 a + |""".stripMargin, + "select array(cast(1 as bigint), cast(2 as bigint), cast(3 as bigint), null) from tbl") + } + + test("native reader - support ARRAY literal TIMESTAMP fields") { + testSingleLineQuery( + """ + |select 1 a + |""".stripMargin, + "select array(CAST('2024-01-01 10:00:00' AS TIMESTAMP), CAST('2024-01-02 10:00:00' AS TIMESTAMP), CAST('2024-01-03 10:00:00' AS TIMESTAMP), null) from tbl") + } + + test("native reader - support ARRAY literal TIMESTAMP TZ fields") { + testSingleLineQuery( + """ + |select 1 a + |""".stripMargin, + "select array(CAST('2024-01-01 10:00:00' AS TIMESTAMP_NTZ), CAST('2024-01-02 10:00:00' AS TIMESTAMP_NTZ), CAST('2024-01-03 10:00:00' AS TIMESTAMP_NTZ), null) from tbl") + } + + test("native reader - support ARRAY literal FLOAT fields") { + testSingleLineQuery( + """ + |select 1 a + |""".stripMargin, + "select array(cast(1 as float), cast(2 as float), cast(3 as float), null) from tbl") + } + + test("native reader - support ARRAY literal DOUBLE fields") { + testSingleLineQuery( + """ + |select 1 a + |""".stripMargin, + "select array(cast(1 as double), cast(2 as double), cast(3 as double), null) from tbl") + } + + test("native reader - support ARRAY literal STRING fields") { + testSingleLineQuery( + """ + |select 1 a + |""".stripMargin, + "select array('a', 'bc', 'def', null) from tbl") + } + + test("native reader - support ARRAY literal DECIMAL fields") { + testSingleLineQuery( + """ + |select 1 a + |""".stripMargin, + "select array(cast(1 as decimal(10, 2)), cast(2.5 as decimal(10, 2)),cast(3.75 as decimal(10, 2)), null) from tbl") + } + + test("native reader - support ARRAY literal BINARY fields") { + testSingleLineQuery( + """ + |select 1 a + |""".stripMargin, + "select array(cast('a' as binary), cast('bc' as binary), cast('def' as binary), null) from tbl") + } + + test("SPARK-18053: ARRAY equality is broken") { + withTable("array_tbl") { + spark.range(10).select(array(col("id")).as("arr")).write.saveAsTable("array_tbl") + assert(sql("SELECT * FROM array_tbl where arr = ARRAY(1L)").count == 1) + } + } }