Skip to content

Commit b37cc38

Browse files
committed
feat: support literal for ARRAY top level
1 parent b2ed0ed commit b37cc38

File tree

6 files changed

+149
-18
lines changed

6 files changed

+149
-18
lines changed

native/core/src/execution/planner.rs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ use datafusion::physical_expr::window::WindowExpr;
8585
use datafusion::physical_expr::LexOrdering;
8686

8787
use crate::parquet::parquet_exec::init_datasource_exec;
88+
use arrow::array::Int32Array;
89+
use datafusion::common::utils::SingleRowListArrayBuilder;
8890
use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec;
8991
use datafusion::physical_plan::filter::FilterExec as DataFusionFilterExec;
9092
use datafusion_comet_proto::spark_operator::SparkFilePartition;
@@ -440,6 +442,18 @@ impl PhysicalPlanner {
440442
)))
441443
}
442444
}
445+
},
446+
Value::ListVal(values) => {
447+
dbg!(values);
448+
//dbg!(literal.datatype.as_ref().unwrap());
449+
//dbg!(data_type);
450+
match data_type {
451+
DataType::List(f) if f.data_type().equals_datatype(&DataType::Int32) => {
452+
SingleRowListArrayBuilder::new(Arc::new(Int32Array::from(values.clone().int_values)))
453+
.build_list_scalar()
454+
}
455+
_ => todo!()
456+
}
443457
}
444458
}
445459
};
@@ -2273,7 +2287,6 @@ impl PhysicalPlanner {
22732287
other => other,
22742288
};
22752289
let func = self.session_ctx.udf(fun_name)?;
2276-
22772290
let coerced_types = func
22782291
.coerce_types(&input_expr_types)
22792292
.unwrap_or_else(|_| input_expr_types.clone());

native/proto/src/proto/expr.proto

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ syntax = "proto3";
2121

2222
package spark.spark_expression;
2323

24+
import "types.proto";
25+
2426
option java_package = "org.apache.comet.serde";
2527

2628
// The basic message representing a Spark expression.
@@ -110,13 +112,13 @@ enum StatisticsType {
110112
}
111113

112114
message Count {
113-
repeated Expr children = 1;
115+
repeated Expr children = 1;
114116
}
115117

116118
message Sum {
117-
Expr child = 1;
118-
DataType datatype = 2;
119-
bool fail_on_error = 3;
119+
Expr child = 1;
120+
DataType datatype = 2;
121+
bool fail_on_error = 3;
120122
}
121123

122124
message Min {
@@ -213,10 +215,11 @@ message Literal {
213215
string string_val = 8;
214216
bytes bytes_val = 9;
215217
bytes decimal_val = 10;
216-
}
218+
ListLiteral list_val = 11;
219+
}
217220

218-
DataType datatype = 11;
219-
bool is_null = 12;
221+
DataType datatype = 12;
222+
bool is_null = 13;
220223
}
221224

222225
message MathExpr {
@@ -469,5 +472,4 @@ message DataType {
469472
}
470473

471474
DataTypeInfo type_info = 2;
472-
}
473-
475+
}

native/proto/src/proto/types.proto

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
19+
20+
syntax = "proto3";
21+
22+
package spark.spark_expression;
23+
24+
option java_package = "org.apache.comet.serde";
25+
26+
message ListLiteral {
27+
// Only one of these fields should be populated based on the array type
28+
repeated bool boolean_values = 1;
29+
repeated int32 byte_values = 2;
30+
repeated int32 short_values = 3;
31+
repeated int32 int_values = 4;
32+
repeated int64 long_values = 5;
33+
repeated float float_values = 6;
34+
repeated double double_values = 7;
35+
repeated string string_values = 8;
36+
repeated bytes bytes_values = 9;
37+
repeated bytes decimal_values = 10;
38+
repeated ListLiteral list_values = 11;
39+
}

spark/src/main/scala/org/apache/comet/expressions/CometCast.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
package org.apache.comet.expressions
2121

22-
import org.apache.spark.sql.types.{DataType, DataTypes, DecimalType, StructType}
22+
import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, DecimalType, NullType, StructType}
2323

2424
sealed trait SupportLevel
2525

@@ -62,6 +62,7 @@ object CometCast {
6262
}
6363

6464
(fromType, toType) match {
65+
case (dt: ArrayType, _: ArrayType) if dt.elementType == NullType => Compatible()
6566
case (dt: DataType, _) if dt.typeName == "timestamp_ntz" =>
6667
// https://github.com/apache/datafusion-comet/issues/378
6768
toType match {

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 72 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
3131
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
3232
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, NormalizeNaNAndZero}
3333
import org.apache.spark.sql.catalyst.plans._
34-
import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils
34+
import org.apache.spark.sql.catalyst.plans.physical._
35+
import org.apache.spark.sql.catalyst.util.{CharVarcharCodegenUtils, GenericArrayData}
3536
import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.getExistenceDefaultValues
3637
import org.apache.spark.sql.comet._
3738
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
@@ -56,6 +57,7 @@ import org.apache.comet.serde.ExprOuterClass.{AggExpr, DataType => ProtoDataType
5657
import org.apache.comet.serde.ExprOuterClass.DataType._
5758
import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, BuildSide, JoinType, Operator}
5859
import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto}
60+
import org.apache.comet.serde.Types.ListLiteral
5961
import org.apache.comet.shims.CometExprShim
6062

6163
/**
@@ -142,6 +144,52 @@ object QueryPlanSerde extends Logging with CometExprShim {
142144
false
143145
}
144146

147+
// def convertArrayToProtoLiteral(array: Seq[Any], arrayType: ArrayType): Literal = {
148+
// val elementType = arrayType.elementType
149+
// val listLiteralBuilder = ListLiteral.newBuilder()
150+
//
151+
// elementType match {
152+
// case BooleanType =>
153+
// listLiteralBuilder.addAllBooleanValues(array.map(_.asInstanceOf[Boolean]).asJava)
154+
//
155+
// case ByteType =>
156+
// listLiteralBuilder.addAllByteValues(array.map(_.asInstanceOf[Byte].toInt).asJava)
157+
//
158+
// case ShortType =>
159+
// listLiteralBuilder.addAllShortValues(array.map(_.asInstanceOf[Short].toInt).asJava)
160+
//
161+
// case IntegerType =>
162+
// listLiteralBuilder.addAllIntValues(array.map(_.asInstanceOf[Int]).asJava)
163+
//
164+
// case LongType =>
165+
// listLiteralBuilder.addAllLongValues(array.map(_.asInstanceOf[Long]).asJava)
166+
//
167+
// case FloatType =>
168+
// listLiteralBuilder.addAllFloatValues(array.map(_.asInstanceOf[Float]).asJava)
169+
//
170+
// case DoubleType =>
171+
// listLiteralBuilder.addAllDoubleValues(array.map(_.asInstanceOf[Double]).asJava)
172+
//
173+
// case StringType =>
174+
// listLiteralBuilder.addAllStringValues(array.map(_.asInstanceOf[String]).asJava)
175+
//
176+
// case BinaryType =>
177+
// listLiteralBuilder.addAllBytesValues
178+
// (array.map(x => com.google.protobuf
179+
// .ByteString.copyFrom(x.asInstanceOf[Array[Byte]])).asJava)
180+
//
181+
// case nested: ArrayType =>
182+
// val nestedListLiterals = array.map {
183+
// case null => ListLiteral.newBuilder().build() // or handle nulls appropriately
184+
// case seq: Seq[_] => convertArrayToProtoLiteral(seq, nested).getListVal
185+
// }
186+
// listLiteralBuilder.addAllListValues(nestedListLiterals.asJava)
187+
//
188+
// case _ =>
189+
// throw new UnsupportedOperationException(s"Unsupported element type: $elementType")
190+
// }
191+
// }
192+
145193
/**
146194
* Serializes Spark datatype to protobuf. Note that, a datatype can be serialized by this method
147195
* doesn't mean it is supported by Comet native execution, i.e., `supportedDataType` may return
@@ -837,8 +885,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
837885
binding,
838886
(builder, binaryExpr) => builder.setLtEq(binaryExpr))
839887

840-
case Literal(value, dataType)
841-
if supportedDataType(dataType, allowComplex = value == null) =>
888+
case Literal(value, dataType) if supportedDataType(dataType, allowComplex = true) =>
842889
val exprBuilder = ExprOuterClass.Literal.newBuilder()
843890

844891
if (value == null) {
@@ -867,6 +914,28 @@ object QueryPlanSerde extends Logging with CometExprShim {
867914
com.google.protobuf.ByteString.copyFrom(value.asInstanceOf[Array[Byte]])
868915
exprBuilder.setBytesVal(byteStr)
869916
case _: DateType => exprBuilder.setIntVal(value.asInstanceOf[Int])
917+
case a: ArrayType =>
918+
val listLiteralBuilder = ListLiteral.newBuilder()
919+
val array = value.asInstanceOf[GenericArrayData].array
920+
a.elementType match {
921+
case BooleanType =>
922+
listLiteralBuilder.addAllBooleanValues(
923+
array.map(_.asInstanceOf[java.lang.Boolean]).toIterable.asJava)
924+
case ByteType =>
925+
listLiteralBuilder.addAllByteValues(
926+
array.map(_.asInstanceOf[java.lang.Integer]).toIterable.asJava)
927+
case ShortType =>
928+
listLiteralBuilder.addAllShortValues(
929+
array.map(_.asInstanceOf[java.lang.Integer]).toIterable.asJava)
930+
case IntegerType =>
931+
listLiteralBuilder.addAllIntValues(
932+
array.map(_.asInstanceOf[java.lang.Integer]).toIterable.asJava)
933+
case LongType =>
934+
listLiteralBuilder.addAllLongValues(
935+
array.map(_.asInstanceOf[java.lang.Long]).toIterable.asJava)
936+
}
937+
exprBuilder.setListVal(listLiteralBuilder.build())
938+
exprBuilder.setDatatype(serializeDataType(dataType).get)
870939
case dt =>
871940
logWarning(s"Unexpected datatype '$dt' for literal value '$value'")
872941
}

spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -219,15 +219,22 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp
219219
}
220220

221221
test("array_contains") {
222-
withSQLConf(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") {
222+
withSQLConf(
223+
CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true",
224+
CometConf.COMET_EXPLAIN_FALLBACK_ENABLED.key -> "true"
225+
// "spark.sql.optimizer.excludedRules" -> "org.apache.spark.sql.catalyst.optimizer.ConstantFolding"
226+
) {
223227
withTempDir { dir =>
224228
val path = new Path(dir.toURI.toString, "test.parquet")
225229
makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = false, n = 10000)
226230
spark.read.parquet(path.toString).createOrReplaceTempView("t1");
231+
// checkSparkAnswerAndOperator(
232+
// spark.sql("SELECT array_contains(array(_2, _3, _4), _2) FROM t1"))
233+
// checkSparkAnswerAndOperator(
234+
// spark.sql("SELECT array_contains((CASE WHEN _2 =_3 THEN array(_4) END), _4) FROM t1"));
227235
checkSparkAnswerAndOperator(
228-
spark.sql("SELECT array_contains(array(_2, _3, _4), _2) FROM t1"))
229-
checkSparkAnswerAndOperator(
230-
spark.sql("SELECT array_contains((CASE WHEN _2 =_3 THEN array(_4) END), _4) FROM t1"));
236+
spark.sql(
237+
"SELECT array_contains((CASE WHEN _2 =_3 THEN array(1, 2, 3) END), _4) FROM t1"));
231238
}
232239
}
233240
}

0 commit comments

Comments
 (0)