Skip to content

Commit d9d5247

Browse files
committed
chore(query): add combinators
1 parent f11b2b9 commit d9d5247

File tree

11 files changed

+225
-102
lines changed

11 files changed

+225
-102
lines changed

src/query/functions-v2/src/aggregates/adaptors/aggregate_null_unary_adaptor.rs

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -114,18 +114,15 @@ impl<const NULLABLE_RESULT: bool> AggregateFunction for AggregateNullUnaryAdapto
114114
let validity = column_merge_validity(col, validity.cloned());
115115
let not_null_column = col.remove_nullable();
116116

117-
self.nested.accumulate(
118-
place,
119-
&[not_null_column.clone()],
120-
validity.as_ref(),
121-
input_rows,
122-
)?;
123-
124-
match validity {
125-
Some(v) if v.unset_bits() != input_rows => {
126-
self.set_flag(place, 1);
127-
}
128-
_ => self.set_flag(place, 1),
117+
self.nested
118+
.accumulate(place, &[not_null_column], validity.as_ref(), input_rows)?;
119+
120+
if validity
121+
.as_ref()
122+
.map(|c| c.unset_bits() != input_rows)
123+
.unwrap_or(true)
124+
{
125+
self.set_flag(place, 1);
129126
}
130127
Ok(())
131128
}

src/query/functions-v2/src/aggregates/adaptors/aggregate_null_variadic_adaptor.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -123,11 +123,12 @@ impl<const NULLABLE_RESULT: bool, const STKIP_NULL: bool> AggregateFunction
123123
self.nested
124124
.accumulate(place, &not_null_columns, validity.as_ref(), input_rows)?;
125125

126-
match validity {
127-
Some(v) if v.unset_bits() != input_rows => {
128-
self.set_flag(place, 1);
129-
}
130-
_ => self.set_flag(place, 1),
126+
if validity
127+
.as_ref()
128+
.map(|c| c.unset_bits() != input_rows)
129+
.unwrap_or(true)
130+
{
131+
self.set_flag(place, 1);
131132
}
132133
Ok(())
133134
}

src/query/functions-v2/src/aggregates/adaptors/aggregate_ornull_adaptor.rs

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -106,19 +106,24 @@ impl AggregateFunction for AggregateFunctionOrNullAdaptor {
106106
}
107107

108108
let if_cond = self.inner.get_if_condition(columns);
109-
if let Some(bm) = if_cond {
110-
if bm.unset_bits() == input_rows {
111-
return Ok(());
112-
}
113-
}
114109

115-
if let Some(bm) = validity {
116-
if bm.unset_bits() == input_rows {
117-
return Ok(());
118-
}
110+
let validity = match (if_cond, validity) {
111+
(None, None) => None,
112+
(None, Some(b)) => Some(b.clone()),
113+
(Some(a), None) => Some(a),
114+
(Some(a), Some(b)) => Some(&a & b),
115+
};
116+
117+
if validity
118+
.as_ref()
119+
.map(|c| c.unset_bits() != input_rows)
120+
.unwrap_or(true)
121+
{
122+
self.set_flag(place, 1);
123+
self.inner
124+
.accumulate(place, columns, validity.as_ref(), input_rows)?;
119125
}
120-
self.set_flag(place, 1);
121-
self.inner.accumulate(place, &columns, validity, input_rows)
126+
Ok(())
122127
}
123128

124129
fn accumulate_keys(

src/query/functions-v2/src/aggregates/aggregate_combinator_if.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,11 +109,11 @@ impl AggregateFunction for AggregateIfCombinator {
109109
) -> Result<()> {
110110
let predicate: Bitmap =
111111
BooleanType::try_downcast_column(&columns[self.argument_len - 1]).unwrap();
112+
112113
let bitmap = match validity {
113114
Some(validity) => validity & (&predicate),
114115
None => predicate,
115116
};
116-
117117
self.nested.accumulate(
118118
place,
119119
&columns[0..self.argument_len - 1],

src/query/functions-v2/src/aggregates/aggregate_count.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,13 @@ impl AggregateFunction for AggregateCountFunction {
8080
Layout::new::<AggregateCountState>()
8181
}
8282

83-
// we have own adaptor, so the validity is not needed
83+
// columns may be nullable
84+
// if not we use validity as the null signs
8485
fn accumulate(
8586
&self,
8687
place: StateAddr,
8788
columns: &[Column],
88-
_validity: Option<&Bitmap>,
89+
validity: Option<&Bitmap>,
8990
input_rows: usize,
9091
) -> Result<()> {
9192
let state = place.get::<AggregateCountState>();
@@ -94,7 +95,7 @@ impl AggregateFunction for AggregateCountFunction {
9495
} else {
9596
match &columns[0] {
9697
Column::Nullable(v) => v.validity.unset_bits(),
97-
_ => 0,
98+
_ => validity.map(|v| v.unset_bits()).unwrap_or(0),
9899
}
99100
};
100101
state.count += (input_rows - nulls) as u64;

src/query/functions-v2/src/aggregates/aggregate_distinct_state.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,9 @@ impl DistinctStateFunc<DataGroupValue> for AggregateDistinctState {
133133
fn build_columns(&mut self, types: &[DataType]) -> Result<Vec<Column>> {
134134
let mut builders: Vec<ColumnBuilder> = types
135135
.iter()
136-
.map(|ty| ColumnBuilder::with_capacity(ty, 0))
136+
.map(|ty| ColumnBuilder::with_capacity(ty, self.set.len()))
137137
.collect();
138+
138139
for data in self.set.iter() {
139140
let mut slice = data.as_slice();
140141
let scalars: Vec<Scalar> = deserialize_from_slice(&mut slice)?;
@@ -261,8 +262,7 @@ where T: Number + Serialize + DeserializeOwned + HashTableKeyable
261262
fn serialize(&self, writer: &mut BytesMut) -> Result<()> {
262263
writer.write_uvarint(self.set.len() as u64)?;
263264
for value in self.set.iter() {
264-
let t: T = value.get_key().clone().into();
265-
serialize_into_buf(writer, &t)?
265+
serialize_into_buf(writer, value.get_key())?
266266
}
267267
Ok(())
268268
}
@@ -323,7 +323,7 @@ where T: Number + Serialize + DeserializeOwned + HashTableKeyable
323323
}
324324

325325
fn build_columns(&mut self, _types: &[DataType]) -> Result<Vec<Column>> {
326-
let values: Buffer<T> = self.set.iter().map(|e| e.get_key().clone()).collect();
326+
let values: Buffer<T> = self.set.iter().map(|e| *e.get_key()).collect();
327327
Ok(vec![NumberType::<T>::upcast_column(values)])
328328
}
329329
}

src/query/functions-v2/src/aggregates/aggregate_function_factory.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,6 @@ impl AggregateFunctionFactory {
170170
nested,
171171
features.clone(),
172172
)?;
173-
174173
if or_null {
175174
return AggregateFunctionOrNullAdaptor::create(agg, features);
176175
} else {
@@ -179,7 +178,6 @@ impl AggregateFunctionFactory {
179178
}
180179

181180
let agg = self.get_impl(name, params, arguments, &mut features)?;
182-
183181
if or_null {
184182
AggregateFunctionOrNullAdaptor::create(agg, features)
185183
} else {

src/query/functions-v2/tests/it/aggregates/agg.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ fn test_agg() {
3131
test_sum(file);
3232
test_avg(file);
3333
test_uniq(file);
34+
test_agg_if(file);
35+
test_agg_distinct(file);
3436
}
3537

3638
fn get_example() -> Vec<(&'static str, DataType, Column)> {
@@ -87,3 +89,19 @@ fn test_uniq(file: &mut impl Write) {
8789
run_agg_ast(file, "uniq(c)", get_example().as_slice());
8890
run_agg_ast(file, "uniq(x_null)", get_example().as_slice());
8991
}
92+
93+
fn test_agg_if(file: &mut impl Write) {
94+
run_agg_ast(
95+
file,
96+
"count_if(1, x_null is null)",
97+
get_example().as_slice(),
98+
);
99+
run_agg_ast(file, "sum_if(a, x_null is null)", get_example().as_slice());
100+
run_agg_ast(file, "sum_if(b, x_null is null)", get_example().as_slice());
101+
}
102+
103+
fn test_agg_distinct(file: &mut impl Write) {
104+
run_agg_ast(file, "sum_distinct(a)", get_example().as_slice());
105+
run_agg_ast(file, "sum_distinct(c)", get_example().as_slice());
106+
run_agg_ast(file, "sum_distinct(x_null)", get_example().as_slice());
107+
}

src/query/functions-v2/tests/it/aggregates/mod.rs

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ use common_expression::RawExpr;
2929
use common_expression::Scalar;
3030
use common_expression::Value;
3131
use common_functions_v2::aggregates::eval_aggr;
32-
use common_functions_v2::aggregates::AggregateFunctionFactory;
3332
use common_functions_v2::scalars::builtin_functions;
3433

3534
use super::scalars::parser;
@@ -53,6 +52,8 @@ pub fn run_agg_ast(file: &mut impl Write, text: &str, columns: &[(&str, DataType
5352
num_rows,
5453
);
5554

55+
let column_ids = collect_columns(&raw_expr);
56+
5657
// For test only, we just support agg function call here
5758
let result: common_exception::Result<(Column, DataType)> = try {
5859
match raw_expr {
@@ -82,27 +83,39 @@ pub fn run_agg_ast(file: &mut impl Write, text: &str, columns: &[(&str, DataType
8283
})
8384
.collect();
8485

85-
let result = eval_aggr(
86+
eval_aggr(
8687
name.as_str(),
8788
params,
8889
&arg_columns,
8990
&arg_types,
9091
chunk.num_rows(),
91-
)?;
92-
result
92+
)?
9393
}
9494
_ => unimplemented!(),
9595
}
9696
};
9797

9898
match result {
9999
Ok((column, _)) => {
100-
writeln!(file, "ast : {text}").unwrap();
100+
writeln!(file, "ast: {text}").unwrap();
101101
{
102102
let mut table = Table::new();
103103
table.load_preset("||--+-++| ++++++");
104104
table.set_header(&["Column", "Data"]);
105-
for (name, _, col) in columns.iter() {
105+
106+
let ids = match column_ids.is_empty() {
107+
true => {
108+
if columns.is_empty() {
109+
vec![]
110+
} else {
111+
vec![0]
112+
}
113+
}
114+
false => column_ids,
115+
};
116+
117+
for id in ids.iter() {
118+
let (name, _, col) = &columns[*id];
106119
table.add_row(&[name.to_string(), format!("{col:?}")]);
107120
}
108121
table.add_row(["Output".to_string(), format!("{column:?}")]);
@@ -122,7 +135,15 @@ pub fn run_scalar_expr(
122135
) -> common_expression::Result<(Value<AnyType>, DataType)> {
123136
let fn_registry = builtin_functions();
124137
let (expr, output_ty) = type_check::check(raw_expr, &fn_registry)?;
125-
let evaluator = Evaluator::new(&chunk, FunctionContext::default());
138+
let evaluator = Evaluator::new(chunk, FunctionContext::default());
126139
let result = evaluator.run(&expr)?;
127140
Ok((result, output_ty))
128141
}
142+
143+
fn collect_columns(raw_expr: &RawExpr) -> Vec<usize> {
144+
match raw_expr {
145+
RawExpr::ColumnRef { id, .. } => vec![*id],
146+
RawExpr::FunctionCall { args, .. } => args.iter().flat_map(collect_columns).collect(),
147+
_ => vec![],
148+
}
149+
}

0 commit comments

Comments
 (0)