Skip to content

Commit 6b70695

Browse files
committed
feat: support literal for ARRAY top level
1 parent 3308c72 commit 6b70695

File tree

5 files changed

+427
-29
lines changed

5 files changed

+427
-29
lines changed

native/core/src/execution/planner.rs

Lines changed: 235 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,11 @@ use datafusion::physical_expr::window::WindowExpr;
8585
use datafusion::physical_expr::LexOrdering;
8686

8787
use crate::parquet::parquet_exec::init_datasource_exec;
88-
use arrow::array::Int32Array;
88+
use arrow::array::{
89+
BinaryArray, BinaryBuilder, BooleanBuilder, Date32Builder, Decimal128Array, Decimal128Builder,
90+
Float32Builder, Float64Builder, Int16Builder, Int32Builder, Int64Builder, Int8Builder,
91+
NullArray, StringBuilder, TimestampMicrosecondBuilder,
92+
};
8993
use datafusion::common::utils::SingleRowListArrayBuilder;
9094
use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec;
9195
use datafusion::physical_plan::filter::FilterExec as DataFusionFilterExec;
@@ -444,17 +448,237 @@ impl PhysicalPlanner {
444448
}
445449
},
446450
Value::ListVal(values) => {
447-
//dbg!(values);
448-
//dbg!(literal.datatype.as_ref().unwrap());
449-
//dbg!(data_type);
450-
match data_type {
451-
DataType::List(f) if f.data_type().equals_datatype(&DataType::Int32) => {
452-
let vals = values.clone().int_values;
453-
let len = &vals.len();
454-
SingleRowListArrayBuilder::new(Arc::new(Int32Array::from(vals)))
455-
.build_fixed_size_list_scalar(*len)
451+
if let DataType::List(f) = data_type {
452+
match f.data_type() {
453+
DataType::Null => {
454+
SingleRowListArrayBuilder::new(Arc::new(NullArray::new(values.clone().null_mask.len())))
455+
.build_list_scalar()
456+
}
457+
DataType::Boolean => {
458+
let vals = values.clone();
459+
let len = vals.boolean_values.len();
460+
let mut arr = BooleanBuilder::with_capacity(len);
461+
462+
for i in 0 .. len {
463+
if !vals.null_mask[i] {
464+
arr.append_value(vals.boolean_values[i]);
465+
} else {
466+
arr.append_null();
467+
}
468+
}
469+
470+
SingleRowListArrayBuilder::new(Arc::new(arr.finish()))
471+
.build_list_scalar()
472+
}
473+
DataType::Int8 => {
474+
let vals = values.clone();
475+
let len = vals.byte_values.len();
476+
let mut arr = Int8Builder::with_capacity(len);
477+
478+
for i in 0 .. len {
479+
if !vals.null_mask[i] {
480+
arr.append_value(vals.byte_values[i] as i8);
481+
} else {
482+
arr.append_null();
483+
}
484+
}
485+
486+
SingleRowListArrayBuilder::new(Arc::new(arr.finish()))
487+
.build_list_scalar()
488+
}
489+
DataType::Int16 => {
490+
let vals = values.clone();
491+
let len = vals.short_values.len();
492+
let mut arr = Int16Builder::with_capacity(len);
493+
494+
for i in 0 .. len {
495+
if !vals.null_mask[i] {
496+
arr.append_value(vals.short_values[i] as i16);
497+
} else {
498+
arr.append_null();
499+
}
500+
}
501+
502+
SingleRowListArrayBuilder::new(Arc::new(arr.finish()))
503+
.build_list_scalar()
504+
}
505+
DataType::Int32 => {
506+
let vals = values.clone();
507+
let len = vals.int_values.len();
508+
let mut arr = Int32Builder::with_capacity(len);
509+
510+
for i in 0 .. len {
511+
if !vals.null_mask[i] {
512+
arr.append_value(vals.int_values[i]);
513+
} else {
514+
arr.append_null();
515+
}
516+
}
517+
518+
SingleRowListArrayBuilder::new(Arc::new(arr.finish()))
519+
.build_list_scalar()
520+
}
521+
DataType::Int64 => {
522+
let vals = values.clone();
523+
let len = vals.long_values.len();
524+
let mut arr = Int64Builder::with_capacity(len);
525+
526+
for i in 0 .. len {
527+
if !vals.null_mask[i] {
528+
arr.append_value(vals.long_values[i]);
529+
} else {
530+
arr.append_null();
531+
}
532+
}
533+
534+
SingleRowListArrayBuilder::new(Arc::new(arr.finish()))
535+
.build_list_scalar()
536+
}
537+
DataType::Float32 => {
538+
let vals = values.clone();
539+
let len = vals.float_values.len();
540+
let mut arr = Float32Builder::with_capacity(len);
541+
542+
for i in 0 .. len {
543+
if !vals.null_mask[i] {
544+
arr.append_value(vals.float_values[i]);
545+
} else {
546+
arr.append_null();
547+
}
548+
}
549+
550+
SingleRowListArrayBuilder::new(Arc::new(arr.finish()))
551+
.build_list_scalar()
552+
}
553+
DataType::Float64 => {
554+
let vals = values.clone();
555+
let len = vals.double_values.len();
556+
let mut arr = Float64Builder::with_capacity(len);
557+
558+
for i in 0 .. len {
559+
if !vals.null_mask[i] {
560+
arr.append_value(vals.double_values[i]);
561+
} else {
562+
arr.append_null();
563+
}
564+
}
565+
566+
SingleRowListArrayBuilder::new(Arc::new(arr.finish()))
567+
.build_list_scalar()
568+
}
569+
DataType::Timestamp(TimeUnit::Microsecond, None) => {
570+
let vals = values.clone();
571+
let len = vals.long_values.len();
572+
let mut arr = TimestampMicrosecondBuilder::with_capacity(len);
573+
574+
for i in 0 .. len {
575+
if !vals.null_mask[i] {
576+
arr.append_value(vals.long_values[i]);
577+
} else {
578+
arr.append_null();
579+
}
580+
}
581+
582+
SingleRowListArrayBuilder::new(Arc::new(arr.finish()))
583+
.build_list_scalar()
584+
}
585+
DataType::Timestamp(TimeUnit::Microsecond, Some(tz)) => {
586+
let vals = values.clone();
587+
let len = vals.long_values.len();
588+
let mut arr = TimestampMicrosecondBuilder::with_capacity(len);
589+
590+
for i in 0 .. len {
591+
if !vals.null_mask[i] {
592+
arr.append_value(vals.long_values[i]);
593+
} else {
594+
arr.append_null();
595+
}
596+
}
597+
598+
SingleRowListArrayBuilder::new(Arc::new(arr.finish().with_timezone(Arc::clone(tz))))
599+
.build_list_scalar()
600+
}
601+
DataType::Date32 => {
602+
let vals = values.clone();
603+
let len = vals.int_values.len();
604+
let mut arr = Date32Builder::with_capacity(len);
605+
606+
for i in 0 .. len {
607+
if !vals.null_mask[i] {
608+
arr.append_value(vals.int_values[i]);
609+
} else {
610+
arr.append_null();
611+
}
612+
}
613+
614+
SingleRowListArrayBuilder::new(Arc::new(arr.finish()))
615+
.build_list_scalar()
616+
}
617+
DataType::Binary => {
618+
let vals = values.clone();
619+
let mut arr = BinaryBuilder::new();
620+
621+
for (i, v) in vals.bytes_values.into_iter().enumerate() {
622+
if !vals.null_mask[i] {
623+
arr.append_value(v);
624+
} else {
625+
arr.append_null();
626+
}
627+
}
628+
629+
let binary_array: BinaryArray = arr.finish();
630+
SingleRowListArrayBuilder::new(Arc::new(binary_array))
631+
.build_list_scalar()
632+
}
633+
DataType::Utf8 => {
634+
let vals = values.clone();
635+
let len = vals.string_values.len();
636+
let mut arr = StringBuilder::with_capacity(len, len);
637+
638+
for (i, v) in vals.string_values.into_iter().enumerate() {
639+
if !vals.null_mask[i] {
640+
arr.append_value(v);
641+
} else {
642+
arr.append_null();
643+
}
644+
}
645+
646+
SingleRowListArrayBuilder::new(Arc::new(arr.finish()))
647+
.build_list_scalar()
648+
}
649+
DataType::Decimal128(p, s) => {
650+
let vals = values.clone();
651+
let mut arr = Decimal128Builder::new().with_precision_and_scale(*p, *s)?;
652+
653+
for (i, v) in vals.decimal_values.into_iter().enumerate() {
654+
if !vals.null_mask[i] {
655+
let big_integer = BigInt::from_signed_bytes_be(&v);
656+
let integer = big_integer.to_i128().ok_or_else(|| {
657+
GeneralError(format!(
658+
"Cannot parse {big_integer:?} as i128 for Decimal literal"
659+
))
660+
})?;
661+
arr.append_value(integer);
662+
} else {
663+
arr.append_null();
664+
}
665+
}
666+
667+
let decimal_array: Decimal128Array = arr.finish();
668+
SingleRowListArrayBuilder::new(Arc::new(decimal_array))
669+
.build_list_scalar()
670+
}
671+
dt => {
672+
return Err(GeneralError(format!(
673+
"DataType::List literal does not support {dt:?} type"
674+
)))
675+
}
456676
}
457-
_ => todo!()
677+
678+
} else {
679+
return Err(GeneralError(format!(
680+
"Expected DataType::List but got {data_type:?}"
681+
)))
458682
}
459683
}
460684
}

native/proto/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
// Include generated modules from .proto files.
2323
#[allow(missing_docs)]
24+
#[allow(clippy::large_enum_variant)]
2425
pub mod spark_expression {
2526
include!(concat!("generated", "/spark.spark_expression.rs"));
2627
}

native/proto/src/proto/types.proto

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,6 @@ message ListLiteral {
3636
repeated bytes bytes_values = 9;
3737
repeated bytes decimal_values = 10;
3838
repeated ListLiteral list_values = 11;
39+
40+
repeated bool null_mask = 12;
3941
}

0 commit comments

Comments
 (0)