Skip to content

Commit 6445a92

Browse files
committed
feat: support literal for ARRAY top level
1 parent 894b319 commit 6445a92

File tree

5 files changed

+427
-29
lines changed

5 files changed

+427
-29
lines changed

native/core/src/execution/planner.rs

Lines changed: 235 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,11 @@ use datafusion::physical_expr::window::WindowExpr;
8686
use datafusion::physical_expr::LexOrdering;
8787

8888
use crate::parquet::parquet_exec::init_datasource_exec;
89-
use arrow::array::Int32Array;
89+
use arrow::array::{
90+
BinaryArray, BinaryBuilder, BooleanBuilder, Date32Builder, Decimal128Array, Decimal128Builder,
91+
Float32Builder, Float64Builder, Int16Builder, Int32Builder, Int64Builder, Int8Builder,
92+
NullArray, StringBuilder, TimestampMicrosecondBuilder,
93+
};
9094
use datafusion::common::utils::SingleRowListArrayBuilder;
9195
use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec;
9296
use datafusion::physical_plan::filter::FilterExec as DataFusionFilterExec;
@@ -436,17 +440,237 @@ impl PhysicalPlanner {
436440
}
437441
},
438442
Value::ListVal(values) => {
439-
//dbg!(values);
440-
//dbg!(literal.datatype.as_ref().unwrap());
441-
//dbg!(data_type);
442-
match data_type {
443-
DataType::List(f) if f.data_type().equals_datatype(&DataType::Int32) => {
444-
let vals = values.clone().int_values;
445-
let len = &vals.len();
446-
SingleRowListArrayBuilder::new(Arc::new(Int32Array::from(vals)))
447-
.build_fixed_size_list_scalar(*len)
443+
if let DataType::List(f) = data_type {
444+
match f.data_type() {
445+
DataType::Null => {
446+
SingleRowListArrayBuilder::new(Arc::new(NullArray::new(values.clone().null_mask.len())))
447+
.build_list_scalar()
448+
}
449+
DataType::Boolean => {
450+
let vals = values.clone();
451+
let len = vals.boolean_values.len();
452+
let mut arr = BooleanBuilder::with_capacity(len);
453+
454+
for i in 0 .. len {
455+
if !vals.null_mask[i] {
456+
arr.append_value(vals.boolean_values[i]);
457+
} else {
458+
arr.append_null();
459+
}
460+
}
461+
462+
SingleRowListArrayBuilder::new(Arc::new(arr.finish()))
463+
.build_list_scalar()
464+
}
465+
DataType::Int8 => {
466+
let vals = values.clone();
467+
let len = vals.byte_values.len();
468+
let mut arr = Int8Builder::with_capacity(len);
469+
470+
for i in 0 .. len {
471+
if !vals.null_mask[i] {
472+
arr.append_value(vals.byte_values[i] as i8);
473+
} else {
474+
arr.append_null();
475+
}
476+
}
477+
478+
SingleRowListArrayBuilder::new(Arc::new(arr.finish()))
479+
.build_list_scalar()
480+
}
481+
DataType::Int16 => {
482+
let vals = values.clone();
483+
let len = vals.short_values.len();
484+
let mut arr = Int16Builder::with_capacity(len);
485+
486+
for i in 0 .. len {
487+
if !vals.null_mask[i] {
488+
arr.append_value(vals.short_values[i] as i16);
489+
} else {
490+
arr.append_null();
491+
}
492+
}
493+
494+
SingleRowListArrayBuilder::new(Arc::new(arr.finish()))
495+
.build_list_scalar()
496+
}
497+
DataType::Int32 => {
498+
let vals = values.clone();
499+
let len = vals.int_values.len();
500+
let mut arr = Int32Builder::with_capacity(len);
501+
502+
for i in 0 .. len {
503+
if !vals.null_mask[i] {
504+
arr.append_value(vals.int_values[i]);
505+
} else {
506+
arr.append_null();
507+
}
508+
}
509+
510+
SingleRowListArrayBuilder::new(Arc::new(arr.finish()))
511+
.build_list_scalar()
512+
}
513+
DataType::Int64 => {
514+
let vals = values.clone();
515+
let len = vals.long_values.len();
516+
let mut arr = Int64Builder::with_capacity(len);
517+
518+
for i in 0 .. len {
519+
if !vals.null_mask[i] {
520+
arr.append_value(vals.long_values[i]);
521+
} else {
522+
arr.append_null();
523+
}
524+
}
525+
526+
SingleRowListArrayBuilder::new(Arc::new(arr.finish()))
527+
.build_list_scalar()
528+
}
529+
DataType::Float32 => {
530+
let vals = values.clone();
531+
let len = vals.float_values.len();
532+
let mut arr = Float32Builder::with_capacity(len);
533+
534+
for i in 0 .. len {
535+
if !vals.null_mask[i] {
536+
arr.append_value(vals.float_values[i]);
537+
} else {
538+
arr.append_null();
539+
}
540+
}
541+
542+
SingleRowListArrayBuilder::new(Arc::new(arr.finish()))
543+
.build_list_scalar()
544+
}
545+
DataType::Float64 => {
546+
let vals = values.clone();
547+
let len = vals.double_values.len();
548+
let mut arr = Float64Builder::with_capacity(len);
549+
550+
for i in 0 .. len {
551+
if !vals.null_mask[i] {
552+
arr.append_value(vals.double_values[i]);
553+
} else {
554+
arr.append_null();
555+
}
556+
}
557+
558+
SingleRowListArrayBuilder::new(Arc::new(arr.finish()))
559+
.build_list_scalar()
560+
}
561+
DataType::Timestamp(TimeUnit::Microsecond, None) => {
562+
let vals = values.clone();
563+
let len = vals.long_values.len();
564+
let mut arr = TimestampMicrosecondBuilder::with_capacity(len);
565+
566+
for i in 0 .. len {
567+
if !vals.null_mask[i] {
568+
arr.append_value(vals.long_values[i]);
569+
} else {
570+
arr.append_null();
571+
}
572+
}
573+
574+
SingleRowListArrayBuilder::new(Arc::new(arr.finish()))
575+
.build_list_scalar()
576+
}
577+
DataType::Timestamp(TimeUnit::Microsecond, Some(tz)) => {
578+
let vals = values.clone();
579+
let len = vals.long_values.len();
580+
let mut arr = TimestampMicrosecondBuilder::with_capacity(len);
581+
582+
for i in 0 .. len {
583+
if !vals.null_mask[i] {
584+
arr.append_value(vals.long_values[i]);
585+
} else {
586+
arr.append_null();
587+
}
588+
}
589+
590+
SingleRowListArrayBuilder::new(Arc::new(arr.finish().with_timezone(Arc::clone(tz))))
591+
.build_list_scalar()
592+
}
593+
DataType::Date32 => {
594+
let vals = values.clone();
595+
let len = vals.int_values.len();
596+
let mut arr = Date32Builder::with_capacity(len);
597+
598+
for i in 0 .. len {
599+
if !vals.null_mask[i] {
600+
arr.append_value(vals.int_values[i]);
601+
} else {
602+
arr.append_null();
603+
}
604+
}
605+
606+
SingleRowListArrayBuilder::new(Arc::new(arr.finish()))
607+
.build_list_scalar()
608+
}
609+
DataType::Binary => {
610+
let vals = values.clone();
611+
let mut arr = BinaryBuilder::new();
612+
613+
for (i, v) in vals.bytes_values.into_iter().enumerate() {
614+
if !vals.null_mask[i] {
615+
arr.append_value(v);
616+
} else {
617+
arr.append_null();
618+
}
619+
}
620+
621+
let binary_array: BinaryArray = arr.finish();
622+
SingleRowListArrayBuilder::new(Arc::new(binary_array))
623+
.build_list_scalar()
624+
}
625+
DataType::Utf8 => {
626+
let vals = values.clone();
627+
let len = vals.string_values.len();
628+
let mut arr = StringBuilder::with_capacity(len, len);
629+
630+
for (i, v) in vals.string_values.into_iter().enumerate() {
631+
if !vals.null_mask[i] {
632+
arr.append_value(v);
633+
} else {
634+
arr.append_null();
635+
}
636+
}
637+
638+
SingleRowListArrayBuilder::new(Arc::new(arr.finish()))
639+
.build_list_scalar()
640+
}
641+
DataType::Decimal128(p, s) => {
642+
let vals = values.clone();
643+
let mut arr = Decimal128Builder::new().with_precision_and_scale(*p, *s)?;
644+
645+
for (i, v) in vals.decimal_values.into_iter().enumerate() {
646+
if !vals.null_mask[i] {
647+
let big_integer = BigInt::from_signed_bytes_be(&v);
648+
let integer = big_integer.to_i128().ok_or_else(|| {
649+
GeneralError(format!(
650+
"Cannot parse {big_integer:?} as i128 for Decimal literal"
651+
))
652+
})?;
653+
arr.append_value(integer);
654+
} else {
655+
arr.append_null();
656+
}
657+
}
658+
659+
let decimal_array: Decimal128Array = arr.finish();
660+
SingleRowListArrayBuilder::new(Arc::new(decimal_array))
661+
.build_list_scalar()
662+
}
663+
dt => {
664+
return Err(GeneralError(format!(
665+
"DataType::List literal does not support {dt:?} type"
666+
)))
667+
}
448668
}
449-
_ => todo!()
669+
670+
} else {
671+
return Err(GeneralError(format!(
672+
"Expected DataType::List but got {data_type:?}"
673+
)))
450674
}
451675
}
452676
}

native/proto/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
// Include generated modules from .proto files.
2323
#[allow(missing_docs)]
24+
#[allow(clippy::large_enum_variant)]
2425
pub mod spark_expression {
2526
include!(concat!("generated", "/spark.spark_expression.rs"));
2627
}

native/proto/src/proto/types.proto

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,6 @@ message ListLiteral {
3636
repeated bytes bytes_values = 9;
3737
repeated bytes decimal_values = 10;
3838
repeated ListLiteral list_values = 11;
39+
40+
repeated bool null_mask = 12;
3941
}

0 commit comments

Comments
 (0)