Skip to content

Commit 625f49c

Browse files
committed
feat: support literal for ARRAY top level
1 parent 92b3d0e commit 625f49c

File tree

5 files changed

+427
-29
lines changed

5 files changed

+427
-29
lines changed

native/core/src/execution/planner.rs

Lines changed: 235 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,11 @@ use datafusion::physical_expr::window::WindowExpr;
8888
use datafusion::physical_expr::LexOrdering;
8989

9090
use crate::parquet::parquet_exec::init_datasource_exec;
91-
use arrow::array::Int32Array;
91+
use arrow::array::{
92+
BinaryArray, BinaryBuilder, BooleanBuilder, Date32Builder, Decimal128Array, Decimal128Builder,
93+
Float32Builder, Float64Builder, Int16Builder, Int32Builder, Int64Builder, Int8Builder,
94+
NullArray, StringBuilder, TimestampMicrosecondBuilder,
95+
};
9296
use datafusion::common::utils::SingleRowListArrayBuilder;
9397
use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec;
9498
use datafusion::physical_plan::filter::FilterExec as DataFusionFilterExec;
@@ -438,17 +442,237 @@ impl PhysicalPlanner {
438442
}
439443
},
440444
Value::ListVal(values) => {
441-
//dbg!(values);
442-
//dbg!(literal.datatype.as_ref().unwrap());
443-
//dbg!(data_type);
444-
match data_type {
445-
DataType::List(f) if f.data_type().equals_datatype(&DataType::Int32) => {
446-
let vals = values.clone().int_values;
447-
let len = &vals.len();
448-
SingleRowListArrayBuilder::new(Arc::new(Int32Array::from(vals)))
449-
.build_fixed_size_list_scalar(*len)
445+
if let DataType::List(f) = data_type {
446+
match f.data_type() {
447+
DataType::Null => {
448+
SingleRowListArrayBuilder::new(Arc::new(NullArray::new(values.clone().null_mask.len())))
449+
.build_list_scalar()
450+
}
451+
DataType::Boolean => {
452+
let vals = values.clone();
453+
let len = vals.boolean_values.len();
454+
let mut arr = BooleanBuilder::with_capacity(len);
455+
456+
for i in 0 .. len {
457+
if !vals.null_mask[i] {
458+
arr.append_value(vals.boolean_values[i]);
459+
} else {
460+
arr.append_null();
461+
}
462+
}
463+
464+
SingleRowListArrayBuilder::new(Arc::new(arr.finish()))
465+
.build_list_scalar()
466+
}
467+
DataType::Int8 => {
468+
let vals = values.clone();
469+
let len = vals.byte_values.len();
470+
let mut arr = Int8Builder::with_capacity(len);
471+
472+
for i in 0 .. len {
473+
if !vals.null_mask[i] {
474+
arr.append_value(vals.byte_values[i] as i8);
475+
} else {
476+
arr.append_null();
477+
}
478+
}
479+
480+
SingleRowListArrayBuilder::new(Arc::new(arr.finish()))
481+
.build_list_scalar()
482+
}
483+
DataType::Int16 => {
484+
let vals = values.clone();
485+
let len = vals.short_values.len();
486+
let mut arr = Int16Builder::with_capacity(len);
487+
488+
for i in 0 .. len {
489+
if !vals.null_mask[i] {
490+
arr.append_value(vals.short_values[i] as i16);
491+
} else {
492+
arr.append_null();
493+
}
494+
}
495+
496+
SingleRowListArrayBuilder::new(Arc::new(arr.finish()))
497+
.build_list_scalar()
498+
}
499+
DataType::Int32 => {
500+
let vals = values.clone();
501+
let len = vals.int_values.len();
502+
let mut arr = Int32Builder::with_capacity(len);
503+
504+
for i in 0 .. len {
505+
if !vals.null_mask[i] {
506+
arr.append_value(vals.int_values[i]);
507+
} else {
508+
arr.append_null();
509+
}
510+
}
511+
512+
SingleRowListArrayBuilder::new(Arc::new(arr.finish()))
513+
.build_list_scalar()
514+
}
515+
DataType::Int64 => {
516+
let vals = values.clone();
517+
let len = vals.long_values.len();
518+
let mut arr = Int64Builder::with_capacity(len);
519+
520+
for i in 0 .. len {
521+
if !vals.null_mask[i] {
522+
arr.append_value(vals.long_values[i]);
523+
} else {
524+
arr.append_null();
525+
}
526+
}
527+
528+
SingleRowListArrayBuilder::new(Arc::new(arr.finish()))
529+
.build_list_scalar()
530+
}
531+
DataType::Float32 => {
532+
let vals = values.clone();
533+
let len = vals.float_values.len();
534+
let mut arr = Float32Builder::with_capacity(len);
535+
536+
for i in 0 .. len {
537+
if !vals.null_mask[i] {
538+
arr.append_value(vals.float_values[i]);
539+
} else {
540+
arr.append_null();
541+
}
542+
}
543+
544+
SingleRowListArrayBuilder::new(Arc::new(arr.finish()))
545+
.build_list_scalar()
546+
}
547+
DataType::Float64 => {
548+
let vals = values.clone();
549+
let len = vals.double_values.len();
550+
let mut arr = Float64Builder::with_capacity(len);
551+
552+
for i in 0 .. len {
553+
if !vals.null_mask[i] {
554+
arr.append_value(vals.double_values[i]);
555+
} else {
556+
arr.append_null();
557+
}
558+
}
559+
560+
SingleRowListArrayBuilder::new(Arc::new(arr.finish()))
561+
.build_list_scalar()
562+
}
563+
DataType::Timestamp(TimeUnit::Microsecond, None) => {
564+
let vals = values.clone();
565+
let len = vals.long_values.len();
566+
let mut arr = TimestampMicrosecondBuilder::with_capacity(len);
567+
568+
for i in 0 .. len {
569+
if !vals.null_mask[i] {
570+
arr.append_value(vals.long_values[i]);
571+
} else {
572+
arr.append_null();
573+
}
574+
}
575+
576+
SingleRowListArrayBuilder::new(Arc::new(arr.finish()))
577+
.build_list_scalar()
578+
}
579+
DataType::Timestamp(TimeUnit::Microsecond, Some(tz)) => {
580+
let vals = values.clone();
581+
let len = vals.long_values.len();
582+
let mut arr = TimestampMicrosecondBuilder::with_capacity(len);
583+
584+
for i in 0 .. len {
585+
if !vals.null_mask[i] {
586+
arr.append_value(vals.long_values[i]);
587+
} else {
588+
arr.append_null();
589+
}
590+
}
591+
592+
SingleRowListArrayBuilder::new(Arc::new(arr.finish().with_timezone(Arc::clone(tz))))
593+
.build_list_scalar()
594+
}
595+
DataType::Date32 => {
596+
let vals = values.clone();
597+
let len = vals.int_values.len();
598+
let mut arr = Date32Builder::with_capacity(len);
599+
600+
for i in 0 .. len {
601+
if !vals.null_mask[i] {
602+
arr.append_value(vals.int_values[i]);
603+
} else {
604+
arr.append_null();
605+
}
606+
}
607+
608+
SingleRowListArrayBuilder::new(Arc::new(arr.finish()))
609+
.build_list_scalar()
610+
}
611+
DataType::Binary => {
612+
let vals = values.clone();
613+
let mut arr = BinaryBuilder::new();
614+
615+
for (i, v) in vals.bytes_values.into_iter().enumerate() {
616+
if !vals.null_mask[i] {
617+
arr.append_value(v);
618+
} else {
619+
arr.append_null();
620+
}
621+
}
622+
623+
let binary_array: BinaryArray = arr.finish();
624+
SingleRowListArrayBuilder::new(Arc::new(binary_array))
625+
.build_list_scalar()
626+
}
627+
DataType::Utf8 => {
628+
let vals = values.clone();
629+
let len = vals.string_values.len();
630+
let mut arr = StringBuilder::with_capacity(len, len);
631+
632+
for (i, v) in vals.string_values.into_iter().enumerate() {
633+
if !vals.null_mask[i] {
634+
arr.append_value(v);
635+
} else {
636+
arr.append_null();
637+
}
638+
}
639+
640+
SingleRowListArrayBuilder::new(Arc::new(arr.finish()))
641+
.build_list_scalar()
642+
}
643+
DataType::Decimal128(p, s) => {
644+
let vals = values.clone();
645+
let mut arr = Decimal128Builder::new().with_precision_and_scale(*p, *s)?;
646+
647+
for (i, v) in vals.decimal_values.into_iter().enumerate() {
648+
if !vals.null_mask[i] {
649+
let big_integer = BigInt::from_signed_bytes_be(&v);
650+
let integer = big_integer.to_i128().ok_or_else(|| {
651+
GeneralError(format!(
652+
"Cannot parse {big_integer:?} as i128 for Decimal literal"
653+
))
654+
})?;
655+
arr.append_value(integer);
656+
} else {
657+
arr.append_null();
658+
}
659+
}
660+
661+
let decimal_array: Decimal128Array = arr.finish();
662+
SingleRowListArrayBuilder::new(Arc::new(decimal_array))
663+
.build_list_scalar()
664+
}
665+
dt => {
666+
return Err(GeneralError(format!(
667+
"DataType::List literal does not support {dt:?} type"
668+
)))
669+
}
450670
}
451-
_ => todo!()
671+
672+
} else {
673+
return Err(GeneralError(format!(
674+
"Expected DataType::List but got {data_type:?}"
675+
)))
452676
}
453677
}
454678
}

native/proto/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
// Include generated modules from .proto files.
2323
#[allow(missing_docs)]
24+
#[allow(clippy::large_enum_variant)]
2425
pub mod spark_expression {
2526
include!(concat!("generated", "/spark.spark_expression.rs"));
2627
}

native/proto/src/proto/types.proto

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,6 @@ message ListLiteral {
3636
repeated bytes bytes_values = 9;
3737
repeated bytes decimal_values = 10;
3838
repeated ListLiteral list_values = 11;
39+
40+
repeated bool null_mask = 12;
3941
}

0 commit comments

Comments
 (0)