Skip to content

Commit 43863f4

Browse files
authored
fix(functions): improve array_prepend and array_append (#15437)
* fix(functions): improve array_prepend and array_append * fix * fix * fix * fix
1 parent b232ff1 commit 43863f4

File tree

5 files changed

+137
-206
lines changed

5 files changed

+137
-206
lines changed

src/common/tracing/src/structlog.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,8 @@ fn build_trees(spans: &[&SpanRecord]) -> Vec<TreeNode> {
164164

165165
let roots = raw.keys().filter(|id| !span_ids.contains(id)).cloned();
166166
roots
167-
.flat_map(|root| build_sub_tree(root, &raw).pop())
168-
.collect_vec()
167+
.filter_map(|root| build_sub_tree(root, &raw).pop())
168+
.collect()
169169
}
170170

171171
fn build_sub_tree(parent_id: SpanId, raw: &HashMap<SpanId, Vec<&SpanRecord>>) -> Vec<TreeNode> {

src/query/functions/src/scalars/array.rs

Lines changed: 56 additions & 183 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@ use std::hash::Hash;
1616
use std::ops::Range;
1717
use std::sync::Arc;
1818

19-
use databend_common_expression::type_check::common_super_type;
20-
use databend_common_expression::types::array::ArrayColumn;
2119
use databend_common_expression::types::array::ArrayColumnBuilder;
2220
use databend_common_expression::types::boolean::BooleanDomain;
2321
use databend_common_expression::types::nullable::NullableDomain;
@@ -72,7 +70,6 @@ use siphasher::sip128::SipHasher24;
7270

7371
use crate::aggregates::eval_aggr;
7472
use crate::AggregateFunctionFactory;
75-
use crate::BUILTIN_FUNCTIONS;
7673

7774
const ARRAY_AGGREGATE_FUNCTIONS: &[(&str, &str); 14] = &[
7875
("array_avg", "avg"),
@@ -243,10 +240,10 @@ pub fn register(registry: &mut FunctionRegistry) {
243240
),
244241
);
245242

246-
registry.register_2_arg_core::<NullableType<EmptyArrayType>, NullableType<EmptyArrayType>, EmptyArrayType, _, _>(
243+
registry.register_2_arg::<EmptyArrayType, EmptyArrayType, EmptyArrayType, _, _>(
247244
"array_concat",
248245
|_, _, _| FunctionDomain::Full,
249-
|_, _, _| Value::Scalar(()),
246+
|_, _, _| (),
250247
);
251248

252249
registry.register_passthrough_nullable_2_arg::<ArrayType<GenericType<0>>, ArrayType<GenericType<0>>, ArrayType<GenericType<0>>, _, _>(
@@ -431,183 +428,57 @@ pub fn register(registry: &mut FunctionRegistry) {
431428
),
432429
);
433430

434-
registry.register_function_factory("array_prepend", |_, args_type| {
435-
if args_type.len() != 2 {
436-
return None;
437-
}
438-
let (common_type, return_type) = match args_type[1].remove_nullable() {
439-
DataType::EmptyArray => (
440-
args_type[0].clone(),
441-
DataType::Array(Box::new(args_type[0].clone())),
442-
),
443-
DataType::Array(box inner_type) => {
444-
let common_type = common_super_type(
445-
inner_type.clone(),
446-
args_type[0].clone(),
447-
&BUILTIN_FUNCTIONS.default_cast_rules,
448-
)?;
449-
(common_type.clone(), DataType::Array(Box::new(common_type)))
450-
}
451-
_ => {
452-
return None;
453-
}
454-
};
455-
let args_type = vec![
456-
common_type,
457-
if args_type[1].is_nullable() {
458-
return_type.wrap_nullable()
459-
} else {
460-
return_type.clone()
461-
},
462-
];
463-
Some(Arc::new(Function {
464-
signature: FunctionSignature {
465-
name: "array_prepend".to_string(),
466-
args_type,
467-
return_type: return_type.clone(),
468-
},
469-
eval: FunctionEval::Scalar {
470-
calc_domain: Box::new(|_, args_domain| {
471-
let array_domain = match &args_domain[1] {
472-
Domain::Nullable(nullable_domain) => nullable_domain.value.clone(),
473-
other => Some(Box::new(other.clone())),
474-
};
475-
let inner_domain = match array_domain {
476-
Some(box Domain::Array(Some(box inner_domain))) => {
477-
inner_domain.merge(&args_domain[0])
478-
}
479-
_ => args_domain[0].clone(),
480-
};
481-
FunctionDomain::Domain(Domain::Array(Some(Box::new(inner_domain))))
482-
}),
483-
eval: Box::new(move |args, _| {
484-
let len = args.iter().find_map(|arg| match arg {
485-
ValueRef::Column(col) => Some(col.len()),
486-
_ => None,
487-
});
488-
489-
let mut offsets = Vec::with_capacity(len.unwrap_or(1) + 1);
490-
offsets.push(0);
491-
let inner_type = return_type.as_array().unwrap();
492-
let mut builder = ColumnBuilder::with_capacity(inner_type, len.unwrap_or(1));
493-
494-
for idx in 0..(len.unwrap_or(1)) {
495-
let val = match &args[0] {
496-
ValueRef::Scalar(scalar) => scalar.clone(),
497-
ValueRef::Column(col) => unsafe { col.index_unchecked(idx) },
498-
};
499-
builder.push(val.clone());
500-
let array_col = match &args[1] {
501-
ValueRef::Scalar(scalar) => scalar.clone(),
502-
ValueRef::Column(col) => unsafe { col.index_unchecked(idx).clone() },
503-
};
504-
if let ScalarRef::Array(col) = array_col {
505-
for val in col.iter() {
506-
builder.push(val.clone());
507-
}
508-
}
509-
offsets.push(builder.len() as u64);
510-
}
511-
match len {
512-
Some(_) => Value::Column(Column::Array(Box::new(ArrayColumn {
513-
values: builder.build(),
514-
offsets: offsets.into(),
515-
}))),
516-
None => Value::Scalar(Scalar::Array(builder.build())),
431+
registry.register_2_arg_core::<GenericType<0>, NullableType<ArrayType<GenericType<0>>>, ArrayType<GenericType<0>>, _, _>(
432+
"array_prepend",
433+
|_, item_domain, array_domain| {
434+
let domain = array_domain
435+
.value
436+
.as_ref()
437+
.map(|box inner_domain| {
438+
inner_domain
439+
.as_ref()
440+
.map(|inner_domain| inner_domain.merge(item_domain))
441+
.unwrap_or(item_domain.clone())
442+
});
443+
FunctionDomain::Domain(domain)
444+
},
445+
vectorize_with_builder_2_arg::<GenericType<0>, NullableType<ArrayType<GenericType<0>>>, ArrayType<GenericType<0>>>(
446+
|val, arr, output, _| {
447+
output.put_item(val);
448+
if let Some(arr) = arr {
449+
for item in arr.iter() {
450+
output.put_item(item);
517451
}
518-
}),
519-
},
520-
}))
521-
});
522-
523-
registry.register_function_factory("array_append", |_, args_type| {
524-
if args_type.len() != 2 {
525-
return None;
526-
}
527-
let (common_type, return_type) = match args_type[0].remove_nullable() {
528-
DataType::EmptyArray => (
529-
args_type[1].clone(),
530-
DataType::Array(Box::new(args_type[1].clone())),
531-
),
532-
DataType::Array(box inner_type) => {
533-
let common_type = common_super_type(
534-
inner_type.clone(),
535-
args_type[1].clone(),
536-
&BUILTIN_FUNCTIONS.default_cast_rules,
537-
)?;
538-
(common_type.clone(), DataType::Array(Box::new(common_type)))
539-
}
540-
_ => {
541-
return None;
542-
}
543-
};
544-
let args_type = vec![
545-
if args_type[0].is_nullable() {
546-
return_type.wrap_nullable()
547-
} else {
548-
return_type.clone()
549-
},
550-
common_type,
551-
];
552-
Some(Arc::new(Function {
553-
signature: FunctionSignature {
554-
name: "array_append".to_string(),
555-
args_type,
556-
return_type: return_type.clone(),
557-
},
558-
eval: FunctionEval::Scalar {
559-
calc_domain: Box::new(|_, args_domain| {
560-
let array_domain = match &args_domain[0] {
561-
Domain::Nullable(nullable_domain) => nullable_domain.value.clone(),
562-
other => Some(Box::new(other.clone())),
563-
};
564-
let inner_domain = match array_domain {
565-
Some(box Domain::Array(Some(box inner_domain))) => {
566-
inner_domain.merge(&args_domain[1])
567-
}
568-
_ => args_domain[1].clone(),
569-
};
570-
FunctionDomain::Domain(Domain::Array(Some(Box::new(inner_domain))))
571-
}),
572-
eval: Box::new(move |args, _| {
573-
let len = args.iter().find_map(|arg| match arg {
574-
ValueRef::Column(col) => Some(col.len()),
575-
_ => None,
576-
});
577-
578-
let mut offsets = Vec::with_capacity(len.unwrap_or(1) + 1);
579-
offsets.push(0);
580-
let inner_type = return_type.as_array().unwrap();
581-
let mut builder = ColumnBuilder::with_capacity(inner_type, len.unwrap_or(1));
452+
}
453+
output.commit_row()
454+
})
455+
);
582456

583-
for idx in 0..(len.unwrap_or(1)) {
584-
let array_col = match &args[0] {
585-
ValueRef::Scalar(scalar) => scalar.clone(),
586-
ValueRef::Column(col) => unsafe { col.index_unchecked(idx).clone() },
587-
};
588-
if let ScalarRef::Array(col) = array_col {
589-
for val in col.iter() {
590-
builder.push(val.clone());
591-
}
592-
}
593-
let val = match &args[1] {
594-
ValueRef::Scalar(scalar) => scalar.clone(),
595-
ValueRef::Column(col) => unsafe { col.index_unchecked(idx) },
596-
};
597-
builder.push(val.clone());
598-
offsets.push(builder.len() as u64);
599-
}
600-
match len {
601-
Some(_) => Value::Column(Column::Array(Box::new(ArrayColumn {
602-
values: builder.build(),
603-
offsets: offsets.into(),
604-
}))),
605-
None => Value::Scalar(Scalar::Array(builder.build())),
457+
registry.register_2_arg_core::<NullableType<ArrayType<GenericType<0>>>, GenericType<0>, ArrayType<GenericType<0>>, _, _>(
458+
"array_append",
459+
|_, array_domain, item_domain| {
460+
let domain = array_domain
461+
.value
462+
.as_ref()
463+
.map(|box inner_domain| {
464+
inner_domain
465+
.as_ref()
466+
.map(|inner_domain| inner_domain.merge(item_domain))
467+
.unwrap_or(item_domain.clone())
468+
});
469+
FunctionDomain::Domain(domain)
470+
},
471+
vectorize_with_builder_2_arg::<NullableType<ArrayType<GenericType<0>>>, GenericType<0>, ArrayType<GenericType<0>>>(
472+
|arr, val, output, _| {
473+
if let Some(arr) = arr {
474+
for item in arr.iter() {
475+
output.put_item(item);
606476
}
607-
}),
608-
},
609-
}))
610-
});
477+
}
478+
output.put_item(val);
479+
output.commit_row()
480+
})
481+
);
611482

612483
fn eval_contains<T: ArgType>(
613484
lhs: ValueRef<ArrayType<T>>,
@@ -791,12 +662,14 @@ pub fn register(registry: &mut FunctionRegistry) {
791662
}
792663
);
793664

794-
registry.register_2_arg_core::<ArrayType<GenericType<0>>, GenericType<0>, BooleanType, _, _>(
665+
registry.register_2_arg_core::<NullableType<ArrayType<GenericType<0>>>, GenericType<0>, BooleanType, _, _>(
795666
"contains",
796667
|_, _, _| FunctionDomain::Full,
797-
vectorize_2_arg::<ArrayType<GenericType<0>>, GenericType<0>, BooleanType>(|lhs, rhs, _| {
798-
lhs.iter().contains(&rhs)
799-
}),
668+
vectorize_2_arg::<NullableType<ArrayType<GenericType<0>>>, GenericType<0>, BooleanType>(
669+
|lhs, rhs, _| {
670+
lhs.map(|col| col.iter().contains(&rhs)).unwrap_or(false)
671+
}
672+
)
800673
);
801674

802675
registry.register_passthrough_nullable_1_arg::<EmptyArrayType, UInt64Type, _, _>(

src/query/functions/tests/it/scalars/array.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,16 @@ fn test_array_prepend(file: &mut impl Write) {
226226
run_ast(file, "array_prepend(1, [])", &[]);
227227
run_ast(file, "array_prepend(1, [2, 3, NULL, 4])", &[]);
228228
run_ast(file, "array_prepend('a', ['b', NULL, NULL, 'c', 'd'])", &[]);
229+
run_ast(
230+
file,
231+
"array_prepend(NULL, CAST([2, 3] AS Array(INT8 NULL) NULL))",
232+
&[],
233+
);
234+
run_ast(
235+
file,
236+
"array_prepend(1, CAST([2, 3] AS Array(INT8 NULL) NULL))",
237+
&[],
238+
);
229239
run_ast(file, "array_prepend(a, [b, c])", &[
230240
("a", Int16Type::from_data(vec![0i16, 1, 2])),
231241
("b", Int16Type::from_data(vec![3i16, 4, 5])),
@@ -237,6 +247,16 @@ fn test_array_append(file: &mut impl Write) {
237247
run_ast(file, "array_append([], 1)", &[]);
238248
run_ast(file, "array_append([2, 3, NULL, 4], 5)", &[]);
239249
run_ast(file, "array_append(['b', NULL, NULL, 'c', 'd'], 'e')", &[]);
250+
run_ast(
251+
file,
252+
"array_append(CAST([1, 2] AS Array(INT8 NULL) NULL), NULL)",
253+
&[],
254+
);
255+
run_ast(
256+
file,
257+
"array_append(CAST([1, 2] AS Array(INT8 NULL) NULL), 3)",
258+
&[],
259+
);
240260
run_ast(file, "array_append([b, c], a)", &[
241261
("a", Int16Type::from_data(vec![0i16, 1, 2])),
242262
("b", Int16Type::from_data(vec![3i16, 4, 5])),

0 commit comments

Comments
 (0)