Skip to content

Commit 29514e7

Browse files
authored
fix(query): fix array_concat function domain panic (#15424)
1 parent d799038 commit 29514e7

File tree

4 files changed

+326
-105
lines changed

4 files changed

+326
-105
lines changed

src/query/functions/src/scalars/array.rs

Lines changed: 225 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ use std::hash::Hash;
1616
use std::ops::Range;
1717
use std::sync::Arc;
1818

19+
use databend_common_expression::type_check::common_super_type;
20+
use databend_common_expression::types::array::ArrayColumn;
1921
use databend_common_expression::types::array::ArrayColumnBuilder;
2022
use databend_common_expression::types::boolean::BooleanDomain;
2123
use databend_common_expression::types::nullable::NullableDomain;
@@ -70,6 +72,7 @@ use siphasher::sip128::SipHasher24;
7072

7173
use crate::aggregates::eval_aggr;
7274
use crate::AggregateFunctionFactory;
75+
use crate::BUILTIN_FUNCTIONS;
7376

7477
const ARRAY_AGGREGATE_FUNCTIONS: &[(&str, &str); 14] = &[
7578
("array_avg", "avg"),
@@ -248,9 +251,24 @@ pub fn register(registry: &mut FunctionRegistry) {
248251

249252
registry.register_passthrough_nullable_2_arg::<ArrayType<GenericType<0>>, ArrayType<GenericType<0>>, ArrayType<GenericType<0>>, _, _>(
250253
"array_concat",
251-
|_, _, _| FunctionDomain::Full,
254+
|_, domain1, domain2| {
255+
FunctionDomain::Domain(
256+
match (domain1, domain2) {
257+
(Some(domain1), Some(domain2)) => Some(domain1.merge(domain2)),
258+
(Some(domain1), None) => Some(domain1).cloned(),
259+
(None, Some(domain2)) => Some(domain2).cloned(),
260+
(None, None) => None,
261+
}
262+
)
263+
},
252264
vectorize_with_builder_2_arg::<ArrayType<GenericType<0>>, ArrayType<GenericType<0>>, ArrayType<GenericType<0>>>(
253-
|lhs, rhs, output, _| {
265+
|lhs, rhs, output, ctx| {
266+
if let Some(validity) = &ctx.validity {
267+
if !validity.get_bit(output.len()) {
268+
output.commit_row();
269+
return;
270+
}
271+
}
254272
output.builder.append_column(&lhs);
255273
output.builder.append_column(&rhs);
256274
output.commit_row()
@@ -261,15 +279,15 @@ pub fn register(registry: &mut FunctionRegistry) {
261279
registry
262280
.register_passthrough_nullable_1_arg::<ArrayType<ArrayType<GenericType<0>>>, ArrayType<GenericType<0>>, _, _>(
263281
"array_flatten",
264-
|_, _| FunctionDomain::Full,
282+
|_, domain| FunctionDomain::Domain(domain.clone().flatten()),
265283
vectorize_1_arg::<ArrayType<ArrayType<GenericType<0>>>, ArrayType<GenericType<0>>>(
266-
|a, b| {
267-
let mut builder = ColumnBuilder::with_capacity(&b.generics[0], a.len());
268-
for a in a.iter() {
269-
builder.append_column(&a);
284+
|arr, ctx| {
285+
let mut builder = ColumnBuilder::with_capacity(&ctx.generics[0], arr.len());
286+
for v in arr.iter() {
287+
builder.append_column(&v);
288+
}
289+
builder.build()
270290
}
271-
builder.build()
272-
}
273291
),
274292
);
275293

@@ -278,7 +296,13 @@ pub fn register(registry: &mut FunctionRegistry) {
278296
"array_to_string",
279297
|_, _, _| FunctionDomain::Full,
280298
vectorize_with_builder_2_arg::<ArrayType<StringType>, StringType, StringType>(
281-
|lhs, rhs, output, _| {
299+
|lhs, rhs, output, ctx| {
300+
if let Some(validity) = &ctx.validity {
301+
if !validity.get_bit(output.len()) {
302+
output.commit_row();
303+
return;
304+
}
305+
}
282306
for (i, d) in lhs.iter().enumerate() {
283307
if i != 0 {
284308
output.put_str(rhs);
@@ -407,29 +431,183 @@ pub fn register(registry: &mut FunctionRegistry) {
407431
),
408432
);
409433

410-
registry.register_2_arg_core::<GenericType<0>, ArrayType<GenericType<0>>, ArrayType<GenericType<0>>, _, _>(
411-
"array_prepend",
412-
|_, _, _| FunctionDomain::Full,
413-
vectorize_2_arg::<GenericType<0>, ArrayType<GenericType<0>>, ArrayType<GenericType<0>>>(|val, arr, _| {
414-
let data_type = arr.data_type();
415-
let mut builder = ColumnBuilder::with_capacity(&data_type, arr.len() + 1);
416-
builder.push(val);
417-
builder.append_column(&arr);
418-
builder.build()
419-
}),
420-
);
434+
registry.register_function_factory("array_prepend", |_, args_type| {
435+
if args_type.len() != 2 {
436+
return None;
437+
}
438+
let (common_type, return_type) = match args_type[1].remove_nullable() {
439+
DataType::EmptyArray => (
440+
args_type[0].clone(),
441+
DataType::Array(Box::new(args_type[0].clone())),
442+
),
443+
DataType::Array(box inner_type) => {
444+
let common_type = common_super_type(
445+
inner_type.clone(),
446+
args_type[0].clone(),
447+
&BUILTIN_FUNCTIONS.default_cast_rules,
448+
)?;
449+
(common_type.clone(), DataType::Array(Box::new(common_type)))
450+
}
451+
_ => {
452+
return None;
453+
}
454+
};
455+
let args_type = vec![
456+
common_type,
457+
if args_type[1].is_nullable() {
458+
return_type.wrap_nullable()
459+
} else {
460+
return_type.clone()
461+
},
462+
];
463+
Some(Arc::new(Function {
464+
signature: FunctionSignature {
465+
name: "array_prepend".to_string(),
466+
args_type,
467+
return_type: return_type.clone(),
468+
},
469+
eval: FunctionEval::Scalar {
470+
calc_domain: Box::new(|_, args_domain| {
471+
let array_domain = match &args_domain[1] {
472+
Domain::Nullable(nullable_domain) => nullable_domain.value.clone(),
473+
other => Some(Box::new(other.clone())),
474+
};
475+
let inner_domain = match array_domain {
476+
Some(box Domain::Array(Some(box inner_domain))) => {
477+
inner_domain.merge(&args_domain[0])
478+
}
479+
_ => args_domain[0].clone(),
480+
};
481+
FunctionDomain::Domain(Domain::Array(Some(Box::new(inner_domain))))
482+
}),
483+
eval: Box::new(move |args, _| {
484+
let len = args.iter().find_map(|arg| match arg {
485+
ValueRef::Column(col) => Some(col.len()),
486+
_ => None,
487+
});
421488

422-
registry.register_2_arg_core::<ArrayType<GenericType<0>>, GenericType<0>, ArrayType<GenericType<0>>, _, _>(
423-
"array_append",
424-
|_, _, _| FunctionDomain::Full,
425-
vectorize_2_arg::<ArrayType<GenericType<0>>, GenericType<0>, ArrayType<GenericType<0>>>(|arr, val, _| {
426-
let data_type = arr.data_type();
427-
let mut builder = ColumnBuilder::with_capacity(&data_type, arr.len() + 1);
428-
builder.append_column(&arr);
429-
builder.push(val);
430-
builder.build()
431-
}),
432-
);
489+
let mut offsets = Vec::with_capacity(len.unwrap_or(1) + 1);
490+
offsets.push(0);
491+
let inner_type = return_type.as_array().unwrap();
492+
let mut builder = ColumnBuilder::with_capacity(inner_type, len.unwrap_or(1));
493+
494+
for idx in 0..(len.unwrap_or(1)) {
495+
let val = match &args[0] {
496+
ValueRef::Scalar(scalar) => scalar.clone(),
497+
ValueRef::Column(col) => unsafe { col.index_unchecked(idx) },
498+
};
499+
builder.push(val.clone());
500+
let array_col = match &args[1] {
501+
ValueRef::Scalar(scalar) => scalar.clone(),
502+
ValueRef::Column(col) => unsafe { col.index_unchecked(idx).clone() },
503+
};
504+
if let ScalarRef::Array(col) = array_col {
505+
for val in col.iter() {
506+
builder.push(val.clone());
507+
}
508+
}
509+
offsets.push(builder.len() as u64);
510+
}
511+
match len {
512+
Some(_) => Value::Column(Column::Array(Box::new(ArrayColumn {
513+
values: builder.build(),
514+
offsets: offsets.into(),
515+
}))),
516+
None => Value::Scalar(Scalar::Array(builder.build())),
517+
}
518+
}),
519+
},
520+
}))
521+
});
522+
523+
registry.register_function_factory("array_append", |_, args_type| {
524+
if args_type.len() != 2 {
525+
return None;
526+
}
527+
let (common_type, return_type) = match args_type[0].remove_nullable() {
528+
DataType::EmptyArray => (
529+
args_type[1].clone(),
530+
DataType::Array(Box::new(args_type[1].clone())),
531+
),
532+
DataType::Array(box inner_type) => {
533+
let common_type = common_super_type(
534+
inner_type.clone(),
535+
args_type[1].clone(),
536+
&BUILTIN_FUNCTIONS.default_cast_rules,
537+
)?;
538+
(common_type.clone(), DataType::Array(Box::new(common_type)))
539+
}
540+
_ => {
541+
return None;
542+
}
543+
};
544+
let args_type = vec![
545+
if args_type[0].is_nullable() {
546+
return_type.wrap_nullable()
547+
} else {
548+
return_type.clone()
549+
},
550+
common_type,
551+
];
552+
Some(Arc::new(Function {
553+
signature: FunctionSignature {
554+
name: "array_append".to_string(),
555+
args_type,
556+
return_type: return_type.clone(),
557+
},
558+
eval: FunctionEval::Scalar {
559+
calc_domain: Box::new(|_, args_domain| {
560+
let array_domain = match &args_domain[0] {
561+
Domain::Nullable(nullable_domain) => nullable_domain.value.clone(),
562+
other => Some(Box::new(other.clone())),
563+
};
564+
let inner_domain = match array_domain {
565+
Some(box Domain::Array(Some(box inner_domain))) => {
566+
inner_domain.merge(&args_domain[1])
567+
}
568+
_ => args_domain[1].clone(),
569+
};
570+
FunctionDomain::Domain(Domain::Array(Some(Box::new(inner_domain))))
571+
}),
572+
eval: Box::new(move |args, _| {
573+
let len = args.iter().find_map(|arg| match arg {
574+
ValueRef::Column(col) => Some(col.len()),
575+
_ => None,
576+
});
577+
578+
let mut offsets = Vec::with_capacity(len.unwrap_or(1) + 1);
579+
offsets.push(0);
580+
let inner_type = return_type.as_array().unwrap();
581+
let mut builder = ColumnBuilder::with_capacity(inner_type, len.unwrap_or(1));
582+
583+
for idx in 0..(len.unwrap_or(1)) {
584+
let array_col = match &args[0] {
585+
ValueRef::Scalar(scalar) => scalar.clone(),
586+
ValueRef::Column(col) => unsafe { col.index_unchecked(idx).clone() },
587+
};
588+
if let ScalarRef::Array(col) = array_col {
589+
for val in col.iter() {
590+
builder.push(val.clone());
591+
}
592+
}
593+
let val = match &args[1] {
594+
ValueRef::Scalar(scalar) => scalar.clone(),
595+
ValueRef::Column(col) => unsafe { col.index_unchecked(idx) },
596+
};
597+
builder.push(val.clone());
598+
offsets.push(builder.len() as u64);
599+
}
600+
match len {
601+
Some(_) => Value::Column(Column::Array(Box::new(ArrayColumn {
602+
values: builder.build(),
603+
offsets: offsets.into(),
604+
}))),
605+
None => Value::Scalar(Scalar::Array(builder.build())),
606+
}
607+
}),
608+
},
609+
}))
610+
});
433611

434612
fn eval_contains<T: ArgType>(
435613
lhs: ValueRef<ArrayType<T>>,
@@ -494,11 +672,11 @@ pub fn register(registry: &mut FunctionRegistry) {
494672

495673
registry.register_passthrough_nullable_2_arg::<ArrayType<StringType>, StringType, BooleanType, _, _>(
496674
"contains",
497-
|_, lhs, rhs| {
498-
lhs.as_ref().map(|lhs| {
499-
lhs.domain_contains(rhs)
500-
}).unwrap_or(FunctionDomain::Full)
501-
},
675+
|_, lhs, rhs| {
676+
lhs.as_ref().map(|lhs| {
677+
lhs.domain_contains(rhs)
678+
}).unwrap_or(FunctionDomain::Full)
679+
},
502680
|lhs, rhs, _| {
503681
match lhs {
504682
ValueRef::Scalar(array) => {
@@ -551,15 +729,15 @@ pub fn register(registry: &mut FunctionRegistry) {
551729
);
552730

553731
registry.register_passthrough_nullable_2_arg::<ArrayType<TimestampType>, TimestampType, BooleanType, _, _>(
554-
"contains",
555-
|_, lhs, rhs| {
556-
let has_true = lhs.is_some_and(|lhs| !(lhs.min > rhs.max || lhs.max < rhs.min));
557-
FunctionDomain::Domain(BooleanDomain {
558-
has_false: true,
559-
has_true,
560-
})
561-
},
562-
|lhs, rhs, _| eval_contains::<TimestampType>(lhs, rhs)
732+
"contains",
733+
|_, lhs, rhs| {
734+
let has_true = lhs.is_some_and(|lhs| !(lhs.min > rhs.max || lhs.max < rhs.min));
735+
FunctionDomain::Domain(BooleanDomain {
736+
has_false: true,
737+
has_true,
738+
})
739+
},
740+
|lhs, rhs, _| eval_contains::<TimestampType>(lhs, rhs)
563741
);
564742

565743
registry.register_passthrough_nullable_2_arg::<ArrayType<BooleanType>, BooleanType, BooleanType, _, _>(
@@ -657,7 +835,7 @@ pub fn register(registry: &mut FunctionRegistry) {
657835

658836
registry.register_passthrough_nullable_1_arg::<ArrayType<GenericType<0>>, ArrayType<GenericType<0>>, _, _>(
659837
"array_distinct",
660-
|_, _| FunctionDomain::Full,
838+
|_, domain| FunctionDomain::Domain(domain.clone()),
661839
vectorize_1_arg::<ArrayType<GenericType<0>>, ArrayType<GenericType<0>>>(|arr, _| {
662840
if arr.len() > 0 {
663841
let data_type = arr.data_type();

0 commit comments

Comments
 (0)