Skip to content

Commit dba2f06

Browse files
committed
feat: support literal for ARRAY top level
1 parent d9b4792 commit dba2f06

File tree

6 files changed

+148
-18
lines changed

6 files changed

+148
-18
lines changed

native/core/src/execution/planner.rs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ use datafusion::physical_expr::window::WindowExpr;
8888
use datafusion::physical_expr::LexOrdering;
8989

9090
use crate::parquet::parquet_exec::init_datasource_exec;
91+
use arrow::array::Int32Array;
92+
use datafusion::common::utils::SingleRowListArrayBuilder;
9193
use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec;
9294
use datafusion::physical_plan::filter::FilterExec as DataFusionFilterExec;
9395
use datafusion_comet_proto::spark_operator::SparkFilePartition;
@@ -436,6 +438,18 @@ impl PhysicalPlanner {
436438
)))
437439
}
438440
}
441+
},
442+
Value::ListVal(values) => {
443+
dbg!(values);
444+
//dbg!(literal.datatype.as_ref().unwrap());
445+
//dbg!(data_type);
446+
match data_type {
447+
DataType::List(f) if f.data_type().equals_datatype(&DataType::Int32) => {
448+
SingleRowListArrayBuilder::new(Arc::new(Int32Array::from(values.clone().int_values)))
449+
.build_list_scalar()
450+
}
451+
_ => todo!()
452+
}
439453
}
440454
}
441455
};
@@ -2283,7 +2297,6 @@ impl PhysicalPlanner {
22832297
other => other,
22842298
};
22852299
let func = self.session_ctx.udf(fun_name)?;
2286-
22872300
let coerced_types = func
22882301
.coerce_types(&input_expr_types)
22892302
.unwrap_or_else(|_| input_expr_types.clone());

native/proto/src/proto/expr.proto

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ syntax = "proto3";
2121

2222
package spark.spark_expression;
2323

24+
import "types.proto";
25+
2426
option java_package = "org.apache.comet.serde";
2527

2628
// The basic message representing a Spark expression.
@@ -112,13 +114,13 @@ enum StatisticsType {
112114
}
113115

114116
message Count {
115-
repeated Expr children = 1;
117+
repeated Expr children = 1;
116118
}
117119

118120
message Sum {
119-
Expr child = 1;
120-
DataType datatype = 2;
121-
bool fail_on_error = 3;
121+
Expr child = 1;
122+
DataType datatype = 2;
123+
bool fail_on_error = 3;
122124
}
123125

124126
message Min {
@@ -215,10 +217,11 @@ message Literal {
215217
string string_val = 8;
216218
bytes bytes_val = 9;
217219
bytes decimal_val = 10;
218-
}
220+
ListLiteral list_val = 11;
221+
}
219222

220-
DataType datatype = 11;
221-
bool is_null = 12;
223+
DataType datatype = 12;
224+
bool is_null = 13;
222225
}
223226

224227
message MathExpr {
@@ -471,5 +474,4 @@ message DataType {
471474
}
472475

473476
DataTypeInfo type_info = 2;
474-
}
475-
477+
}

native/proto/src/proto/types.proto

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
19+
20+
syntax = "proto3";
21+
22+
package spark.spark_expression;
23+
24+
option java_package = "org.apache.comet.serde";
25+
26+
message ListLiteral {
27+
// Only one of these fields should be populated based on the array type
28+
repeated bool boolean_values = 1;
29+
repeated int32 byte_values = 2;
30+
repeated int32 short_values = 3;
31+
repeated int32 int_values = 4;
32+
repeated int64 long_values = 5;
33+
repeated float float_values = 6;
34+
repeated double double_values = 7;
35+
repeated string string_values = 8;
36+
repeated bytes bytes_values = 9;
37+
repeated bytes decimal_values = 10;
38+
repeated ListLiteral list_values = 11;
39+
}

spark/src/main/scala/org/apache/comet/expressions/CometCast.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
package org.apache.comet.expressions
2121

22-
import org.apache.spark.sql.types.{DataType, DataTypes, DecimalType, StructType}
22+
import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, DecimalType, NullType, StructType}
2323

2424
sealed trait SupportLevel
2525

@@ -62,6 +62,7 @@ object CometCast {
6262
}
6363

6464
(fromType, toType) match {
65+
case (dt: ArrayType, _: ArrayType) if dt.elementType == NullType => Compatible()
6566
case (dt: DataType, _) if dt.typeName == "timestamp_ntz" =>
6667
// https://github.com/apache/datafusion-comet/issues/378
6768
toType match {

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 71 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
3131
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, NormalizeNaNAndZero}
3232
import org.apache.spark.sql.catalyst.plans._
3333
import org.apache.spark.sql.catalyst.plans.physical._
34-
import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils
34+
import org.apache.spark.sql.catalyst.util.{CharVarcharCodegenUtils, GenericArrayData}
3535
import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.getExistenceDefaultValues
3636
import org.apache.spark.sql.comet._
3737
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
@@ -55,6 +55,7 @@ import org.apache.comet.objectstore.NativeConfig
5555
import org.apache.comet.serde.ExprOuterClass.{AggExpr, DataType => ProtoDataType, Expr, ScalarFunc}
5656
import org.apache.comet.serde.ExprOuterClass.DataType._
5757
import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, BuildSide, JoinType, Operator}
58+
import org.apache.comet.serde.Types.ListLiteral
5859
import org.apache.comet.shims.CometExprShim
5960

6061
/**
@@ -115,6 +116,52 @@ object QueryPlanSerde extends Logging with CometExprShim {
115116
false
116117
}
117118

119+
// def convertArrayToProtoLiteral(array: Seq[Any], arrayType: ArrayType): Literal = {
120+
// val elementType = arrayType.elementType
121+
// val listLiteralBuilder = ListLiteral.newBuilder()
122+
//
123+
// elementType match {
124+
// case BooleanType =>
125+
// listLiteralBuilder.addAllBooleanValues(array.map(_.asInstanceOf[Boolean]).asJava)
126+
//
127+
// case ByteType =>
128+
// listLiteralBuilder.addAllByteValues(array.map(_.asInstanceOf[Byte].toInt).asJava)
129+
//
130+
// case ShortType =>
131+
// listLiteralBuilder.addAllShortValues(array.map(_.asInstanceOf[Short].toInt).asJava)
132+
//
133+
// case IntegerType =>
134+
// listLiteralBuilder.addAllIntValues(array.map(_.asInstanceOf[Int]).asJava)
135+
//
136+
// case LongType =>
137+
// listLiteralBuilder.addAllLongValues(array.map(_.asInstanceOf[Long]).asJava)
138+
//
139+
// case FloatType =>
140+
// listLiteralBuilder.addAllFloatValues(array.map(_.asInstanceOf[Float]).asJava)
141+
//
142+
// case DoubleType =>
143+
// listLiteralBuilder.addAllDoubleValues(array.map(_.asInstanceOf[Double]).asJava)
144+
//
145+
// case StringType =>
146+
// listLiteralBuilder.addAllStringValues(array.map(_.asInstanceOf[String]).asJava)
147+
//
148+
// case BinaryType =>
149+
// listLiteralBuilder.addAllBytesValues
150+
// (array.map(x => com.google.protobuf
151+
// .ByteString.copyFrom(x.asInstanceOf[Array[Byte]])).asJava)
152+
//
153+
// case nested: ArrayType =>
154+
// val nestedListLiterals = array.map {
155+
// case null => ListLiteral.newBuilder().build() // or handle nulls appropriately
156+
// case seq: Seq[_] => convertArrayToProtoLiteral(seq, nested).getListVal
157+
// }
158+
// listLiteralBuilder.addAllListValues(nestedListLiterals.asJava)
159+
//
160+
// case _ =>
161+
// throw new UnsupportedOperationException(s"Unsupported element type: $elementType")
162+
// }
163+
// }
164+
118165
/**
119166
* Serializes Spark datatype to protobuf. Note that, a datatype can be serialized by this method
120167
* doesn't mean it is supported by Comet native execution, i.e., `supportedDataType` may return
@@ -812,8 +859,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
812859
binding,
813860
(builder, binaryExpr) => builder.setLtEq(binaryExpr))
814861

815-
case Literal(value, dataType)
816-
if supportedDataType(dataType, allowComplex = value == null) =>
862+
case Literal(value, dataType) if supportedDataType(dataType, allowComplex = true) =>
817863
val exprBuilder = ExprOuterClass.Literal.newBuilder()
818864

819865
if (value == null) {
@@ -842,6 +888,28 @@ object QueryPlanSerde extends Logging with CometExprShim {
842888
com.google.protobuf.ByteString.copyFrom(value.asInstanceOf[Array[Byte]])
843889
exprBuilder.setBytesVal(byteStr)
844890
case _: DateType => exprBuilder.setIntVal(value.asInstanceOf[Int])
891+
case a: ArrayType =>
892+
val listLiteralBuilder = ListLiteral.newBuilder()
893+
val array = value.asInstanceOf[GenericArrayData].array
894+
a.elementType match {
895+
case BooleanType =>
896+
listLiteralBuilder.addAllBooleanValues(
897+
array.map(_.asInstanceOf[java.lang.Boolean]).toIterable.asJava)
898+
case ByteType =>
899+
listLiteralBuilder.addAllByteValues(
900+
array.map(_.asInstanceOf[java.lang.Integer]).toIterable.asJava)
901+
case ShortType =>
902+
listLiteralBuilder.addAllShortValues(
903+
array.map(_.asInstanceOf[java.lang.Integer]).toIterable.asJava)
904+
case IntegerType =>
905+
listLiteralBuilder.addAllIntValues(
906+
array.map(_.asInstanceOf[java.lang.Integer]).toIterable.asJava)
907+
case LongType =>
908+
listLiteralBuilder.addAllLongValues(
909+
array.map(_.asInstanceOf[java.lang.Long]).toIterable.asJava)
910+
}
911+
exprBuilder.setListVal(listLiteralBuilder.build())
912+
exprBuilder.setDatatype(serializeDataType(dataType).get)
845913
case dt =>
846914
logWarning(s"Unexpected datatype '$dt' for literal value '$value'")
847915
}

spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -219,15 +219,22 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp
219219
}
220220

221221
test("array_contains") {
222-
withSQLConf(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") {
222+
withSQLConf(
223+
CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true",
224+
CometConf.COMET_EXPLAIN_FALLBACK_ENABLED.key -> "true"
225+
// "spark.sql.optimizer.excludedRules" -> "org.apache.spark.sql.catalyst.optimizer.ConstantFolding"
226+
) {
223227
withTempDir { dir =>
224228
val path = new Path(dir.toURI.toString, "test.parquet")
225229
makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = false, n = 10000)
226230
spark.read.parquet(path.toString).createOrReplaceTempView("t1");
231+
// checkSparkAnswerAndOperator(
232+
// spark.sql("SELECT array_contains(array(_2, _3, _4), _2) FROM t1"))
233+
// checkSparkAnswerAndOperator(
234+
// spark.sql("SELECT array_contains((CASE WHEN _2 =_3 THEN array(_4) END), _4) FROM t1"));
227235
checkSparkAnswerAndOperator(
228-
spark.sql("SELECT array_contains(array(_2, _3, _4), _2) FROM t1"))
229-
checkSparkAnswerAndOperator(
230-
spark.sql("SELECT array_contains((CASE WHEN _2 =_3 THEN array(_4) END), _4) FROM t1"));
236+
spark.sql(
237+
"SELECT array_contains((CASE WHEN _2 =_3 THEN array(1, 2, 3) END), _4) FROM t1"));
231238
}
232239
}
233240
}

0 commit comments

Comments
 (0)