Skip to content

Commit 8bc9e16

Browse files
committed
feat: support literal for ARRAY top level
1 parent 75a53f5 commit 8bc9e16

File tree

6 files changed

+148
-18
lines changed

6 files changed

+148
-18
lines changed

native/core/src/execution/planner.rs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ use datafusion::physical_expr::window::WindowExpr;
8888
use datafusion::physical_expr::LexOrdering;
8989

9090
use crate::parquet::parquet_exec::init_datasource_exec;
91+
use arrow::array::Int32Array;
92+
use datafusion::common::utils::SingleRowListArrayBuilder;
9193
use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec;
9294
use datafusion::physical_plan::filter::FilterExec as DataFusionFilterExec;
9395
use datafusion_comet_proto::spark_operator::SparkFilePartition;
@@ -434,6 +436,18 @@ impl PhysicalPlanner {
434436
)))
435437
}
436438
}
439+
},
440+
Value::ListVal(values) => {
441+
dbg!(values);
442+
//dbg!(literal.datatype.as_ref().unwrap());
443+
//dbg!(data_type);
444+
match data_type {
445+
DataType::List(f) if f.data_type().equals_datatype(&DataType::Int32) => {
446+
SingleRowListArrayBuilder::new(Arc::new(Int32Array::from(values.clone().int_values)))
447+
.build_list_scalar()
448+
}
449+
_ => todo!()
450+
}
437451
}
438452
}
439453
};
@@ -2281,7 +2295,6 @@ impl PhysicalPlanner {
22812295
other => other,
22822296
};
22832297
let func = self.session_ctx.udf(fun_name)?;
2284-
22852298
let coerced_types = func
22862299
.coerce_types(&input_expr_types)
22872300
.unwrap_or_else(|_| input_expr_types.clone());

native/proto/src/proto/expr.proto

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ syntax = "proto3";
2121

2222
package spark.spark_expression;
2323

24+
import "types.proto";
25+
2426
option java_package = "org.apache.comet.serde";
2527

2628
// The basic message representing a Spark expression.
@@ -113,13 +115,13 @@ enum StatisticsType {
113115
}
114116

115117
message Count {
116-
repeated Expr children = 1;
118+
repeated Expr children = 1;
117119
}
118120

119121
message Sum {
120-
Expr child = 1;
121-
DataType datatype = 2;
122-
bool fail_on_error = 3;
122+
Expr child = 1;
123+
DataType datatype = 2;
124+
bool fail_on_error = 3;
123125
}
124126

125127
message Min {
@@ -216,10 +218,11 @@ message Literal {
216218
string string_val = 8;
217219
bytes bytes_val = 9;
218220
bytes decimal_val = 10;
219-
}
221+
ListLiteral list_val = 11;
222+
}
220223

221-
DataType datatype = 11;
222-
bool is_null = 12;
224+
DataType datatype = 12;
225+
bool is_null = 13;
223226
}
224227

225228
message MathExpr {
@@ -472,5 +475,4 @@ message DataType {
472475
}
473476

474477
DataTypeInfo type_info = 2;
475-
}
476-
478+
}

native/proto/src/proto/types.proto

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
19+
20+
syntax = "proto3";
21+
22+
package spark.spark_expression;
23+
24+
option java_package = "org.apache.comet.serde";
25+
26+
message ListLiteral {
27+
// Only one of these fields should be populated based on the array type
28+
repeated bool boolean_values = 1;
29+
repeated int32 byte_values = 2;
30+
repeated int32 short_values = 3;
31+
repeated int32 int_values = 4;
32+
repeated int64 long_values = 5;
33+
repeated float float_values = 6;
34+
repeated double double_values = 7;
35+
repeated string string_values = 8;
36+
repeated bytes bytes_values = 9;
37+
repeated bytes decimal_values = 10;
38+
repeated ListLiteral list_values = 11;
39+
}

spark/src/main/scala/org/apache/comet/expressions/CometCast.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
package org.apache.comet.expressions
2121

22-
import org.apache.spark.sql.types.{DataType, DataTypes, DecimalType, StructType}
22+
import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, DecimalType, NullType, StructType}
2323

2424
sealed trait SupportLevel
2525

@@ -62,6 +62,7 @@ object CometCast {
6262
}
6363

6464
(fromType, toType) match {
65+
case (dt: ArrayType, _: ArrayType) if dt.elementType == NullType => Compatible()
6566
case (dt: DataType, _) if dt.typeName == "timestamp_ntz" =>
6667
// https://github.com/apache/datafusion-comet/issues/378
6768
toType match {

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 71 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
3232
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, NormalizeNaNAndZero}
3333
import org.apache.spark.sql.catalyst.plans._
3434
import org.apache.spark.sql.catalyst.plans.physical._
35-
import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils
35+
import org.apache.spark.sql.catalyst.util.{CharVarcharCodegenUtils, GenericArrayData}
3636
import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.getExistenceDefaultValues
3737
import org.apache.spark.sql.comet._
3838
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
@@ -56,6 +56,7 @@ import org.apache.comet.objectstore.NativeConfig
5656
import org.apache.comet.serde.ExprOuterClass.{AggExpr, DataType => ProtoDataType, Expr, ScalarFunc}
5757
import org.apache.comet.serde.ExprOuterClass.DataType._
5858
import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, BuildSide, JoinType, Operator}
59+
import org.apache.comet.serde.Types.ListLiteral
5960
import org.apache.comet.shims.CometExprShim
6061

6162
/**
@@ -118,6 +119,52 @@ object QueryPlanSerde extends Logging with CometExprShim {
118119
false
119120
}
120121

122+
// def convertArrayToProtoLiteral(array: Seq[Any], arrayType: ArrayType): Literal = {
123+
// val elementType = arrayType.elementType
124+
// val listLiteralBuilder = ListLiteral.newBuilder()
125+
//
126+
// elementType match {
127+
// case BooleanType =>
128+
// listLiteralBuilder.addAllBooleanValues(array.map(_.asInstanceOf[Boolean]).asJava)
129+
//
130+
// case ByteType =>
131+
// listLiteralBuilder.addAllByteValues(array.map(_.asInstanceOf[Byte].toInt).asJava)
132+
//
133+
// case ShortType =>
134+
// listLiteralBuilder.addAllShortValues(array.map(_.asInstanceOf[Short].toInt).asJava)
135+
//
136+
// case IntegerType =>
137+
// listLiteralBuilder.addAllIntValues(array.map(_.asInstanceOf[Int]).asJava)
138+
//
139+
// case LongType =>
140+
// listLiteralBuilder.addAllLongValues(array.map(_.asInstanceOf[Long]).asJava)
141+
//
142+
// case FloatType =>
143+
// listLiteralBuilder.addAllFloatValues(array.map(_.asInstanceOf[Float]).asJava)
144+
//
145+
// case DoubleType =>
146+
// listLiteralBuilder.addAllDoubleValues(array.map(_.asInstanceOf[Double]).asJava)
147+
//
148+
// case StringType =>
149+
// listLiteralBuilder.addAllStringValues(array.map(_.asInstanceOf[String]).asJava)
150+
//
151+
// case BinaryType =>
152+
// listLiteralBuilder.addAllBytesValues
153+
// (array.map(x => com.google.protobuf
154+
// .ByteString.copyFrom(x.asInstanceOf[Array[Byte]])).asJava)
155+
//
156+
// case nested: ArrayType =>
157+
// val nestedListLiterals = array.map {
158+
// case null => ListLiteral.newBuilder().build() // or handle nulls appropriately
159+
// case seq: Seq[_] => convertArrayToProtoLiteral(seq, nested).getListVal
160+
// }
161+
// listLiteralBuilder.addAllListValues(nestedListLiterals.asJava)
162+
//
163+
// case _ =>
164+
// throw new UnsupportedOperationException(s"Unsupported element type: $elementType")
165+
// }
166+
// }
167+
121168
/**
122169
* Serializes Spark datatype to protobuf. Note that, a datatype can be serialized by this method
123170
* doesn't mean it is supported by Comet native execution, i.e., `supportedDataType` may return
@@ -815,8 +862,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
815862
binding,
816863
(builder, binaryExpr) => builder.setLtEq(binaryExpr))
817864

818-
case Literal(value, dataType)
819-
if supportedDataType(dataType, allowComplex = value == null) =>
865+
case Literal(value, dataType) if supportedDataType(dataType, allowComplex = true) =>
820866
val exprBuilder = ExprOuterClass.Literal.newBuilder()
821867

822868
if (value == null) {
@@ -845,6 +891,28 @@ object QueryPlanSerde extends Logging with CometExprShim {
845891
com.google.protobuf.ByteString.copyFrom(value.asInstanceOf[Array[Byte]])
846892
exprBuilder.setBytesVal(byteStr)
847893
case _: DateType => exprBuilder.setIntVal(value.asInstanceOf[Int])
894+
case a: ArrayType =>
895+
val listLiteralBuilder = ListLiteral.newBuilder()
896+
val array = value.asInstanceOf[GenericArrayData].array
897+
a.elementType match {
898+
case BooleanType =>
899+
listLiteralBuilder.addAllBooleanValues(
900+
array.map(_.asInstanceOf[java.lang.Boolean]).toIterable.asJava)
901+
case ByteType =>
902+
listLiteralBuilder.addAllByteValues(
903+
array.map(_.asInstanceOf[java.lang.Integer]).toIterable.asJava)
904+
case ShortType =>
905+
listLiteralBuilder.addAllShortValues(
906+
array.map(_.asInstanceOf[java.lang.Integer]).toIterable.asJava)
907+
case IntegerType =>
908+
listLiteralBuilder.addAllIntValues(
909+
array.map(_.asInstanceOf[java.lang.Integer]).toIterable.asJava)
910+
case LongType =>
911+
listLiteralBuilder.addAllLongValues(
912+
array.map(_.asInstanceOf[java.lang.Long]).toIterable.asJava)
913+
}
914+
exprBuilder.setListVal(listLiteralBuilder.build())
915+
exprBuilder.setDatatype(serializeDataType(dataType).get)
848916
case dt =>
849917
logWarning(s"Unexpected datatype '$dt' for literal value '$value'")
850918
}

spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -219,15 +219,22 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp
219219
}
220220

221221
test("array_contains") {
222-
withSQLConf(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") {
222+
withSQLConf(
223+
CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true",
224+
CometConf.COMET_EXPLAIN_FALLBACK_ENABLED.key -> "true"
225+
// "spark.sql.optimizer.excludedRules" -> "org.apache.spark.sql.catalyst.optimizer.ConstantFolding"
226+
) {
223227
withTempDir { dir =>
224228
val path = new Path(dir.toURI.toString, "test.parquet")
225229
makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = false, n = 10000)
226230
spark.read.parquet(path.toString).createOrReplaceTempView("t1");
231+
// checkSparkAnswerAndOperator(
232+
// spark.sql("SELECT array_contains(array(_2, _3, _4), _2) FROM t1"))
233+
// checkSparkAnswerAndOperator(
234+
// spark.sql("SELECT array_contains((CASE WHEN _2 =_3 THEN array(_4) END), _4) FROM t1"));
227235
checkSparkAnswerAndOperator(
228-
spark.sql("SELECT array_contains(array(_2, _3, _4), _2) FROM t1"))
229-
checkSparkAnswerAndOperator(
230-
spark.sql("SELECT array_contains((CASE WHEN _2 =_3 THEN array(_4) END), _4) FROM t1"));
236+
spark.sql(
237+
"SELECT array_contains((CASE WHEN _2 =_3 THEN array(1, 2, 3) END), _4) FROM t1"));
231238
}
232239
}
233240
}

0 commit comments

Comments
 (0)