Skip to content

Commit 2d76884

Browse files
committed
ee
1 parent dfec718 commit 2d76884

File tree

7 files changed

+113
-36
lines changed

7 files changed

+113
-36
lines changed

src/query/expression/src/aggregate/aggregate_function.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ use crate::ProjectedBlock;
2929
use crate::Scalar;
3030
use crate::ScalarRef;
3131
use crate::StateSerdeItem;
32+
use crate::StateSerdeType;
3233

3334
pub type AggregateFunctionRef = Arc<dyn AggregateFunction>;
3435

@@ -71,6 +72,11 @@ pub trait AggregateFunction: fmt::Display + Sync + Send {
7172

7273
fn serialize_type(&self) -> Vec<StateSerdeItem>;
7374

75+
fn serialize_data_type(&self) -> DataType {
76+
let serde_type = StateSerdeType::new(self.serialize_type());
77+
serde_type.data_type()
78+
}
79+
7480
fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> {
7581
let binary_builder = builders[0].as_binary_mut().unwrap();
7682
self.serialize_binary(place, &mut binary_builder.data)?;

src/query/expression/src/aggregate/aggregate_function_state.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,10 @@ pub enum StateSerdeItem {
203203
pub struct StateSerdeType(Box<[StateSerdeItem]>);
204204

205205
impl StateSerdeType {
206+
pub fn new(items: impl Into<Box<[StateSerdeItem]>>) -> Self {
207+
StateSerdeType(items.into())
208+
}
209+
206210
pub fn data_type(&self) -> DataType {
207211
DataType::Tuple(
208212
self.0

src/query/functions/src/aggregates/aggregate_combinator_state.rs

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ use databend_common_expression::AggrStateRegistry;
2222
use databend_common_expression::ColumnBuilder;
2323
use databend_common_expression::ProjectedBlock;
2424
use databend_common_expression::Scalar;
25+
use databend_common_expression::ScalarRef;
2526
use databend_common_expression::StateSerdeItem;
2627

2728
use super::AggregateFunctionFactory;
@@ -71,7 +72,7 @@ impl AggregateFunction for AggregateStateCombinator {
7172
}
7273

7374
fn return_type(&self) -> Result<DataType> {
74-
Ok(DataType::Binary)
75+
Ok(self.nested.serialize_data_type())
7576
}
7677

7778
fn init_state(&self, place: AggrState) {
@@ -108,27 +109,32 @@ impl AggregateFunction for AggregateStateCombinator {
108109
}
109110

110111
fn serialize_type(&self) -> Vec<StateSerdeItem> {
111-
vec![StateSerdeItem::Binary(None)]
112+
self.nested.serialize_type()
112113
}
113114

114-
fn serialize_binary(&self, place: AggrState, writer: &mut Vec<u8>) -> Result<()> {
115-
self.nested.serialize_binary(place, writer)
115+
fn serialize(&self, place: AggrState, builders: &mut [ColumnBuilder]) -> Result<()> {
116+
self.nested.serialize(place, builders)
116117
}
117118

118-
#[inline]
119-
fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> {
120-
self.nested.merge_binary(place, reader)
119+
fn serialize_binary(&self, _: AggrState, _: &mut Vec<u8>) -> Result<()> {
120+
unreachable!()
121+
}
122+
123+
fn merge(&self, place: AggrState, data: &[ScalarRef]) -> Result<()> {
124+
self.nested.merge(place, data)
125+
}
126+
127+
fn merge_binary(&self, _: AggrState, _: &mut &[u8]) -> Result<()> {
128+
unreachable!()
121129
}
122130

123131
fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> {
124132
self.nested.merge_states(place, rhs)
125133
}
126134

127135
fn merge_result(&self, place: AggrState, builder: &mut ColumnBuilder) -> Result<()> {
128-
let builder = builder.as_binary_mut().unwrap();
129-
self.nested.serialize_binary(place, &mut builder.data)?;
130-
builder.commit_row();
131-
Ok(())
136+
let builders = builder.as_tuple_mut().unwrap().as_mut_slice();
137+
self.nested.serialize(place, builders)
132138
}
133139

134140
fn need_manual_drop_state(&self) -> bool {

src/query/sql/src/executor/physical_plans/physical_aggregate_partial.rs

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ use databend_common_expression::types::DataType;
1717
use databend_common_expression::DataField;
1818
use databend_common_expression::DataSchemaRef;
1919
use databend_common_expression::DataSchemaRefExt;
20-
use databend_common_expression::StateSerdeItem;
2120
use databend_common_functions::aggregates::AggregateFunctionFactory;
2221

2322
use super::SortDesc;
@@ -69,16 +68,7 @@ impl AggregatePartial {
6968
)
7069
.unwrap();
7170

72-
let tuple = func
73-
.serialize_type()
74-
.iter()
75-
.map(|serde_type| match serde_type {
76-
StateSerdeItem::DataType(data_type) => data_type.clone(),
77-
StateSerdeItem::Binary(_) => DataType::Binary,
78-
})
79-
.collect();
80-
81-
fields.push(DataField::new(&name, DataType::Tuple(tuple)))
71+
fields.push(DataField::new(&name, func.serialize_data_type()))
8272
}
8373

8474
for (idx, field) in self.group_by.iter().zip(

src/query/sql/src/planner/optimizer/optimizers/operator/aggregate/stats_aggregate.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,11 @@ impl RuleStatsAggregateOptimizer {
129129
for (need_rewrite_agg, agg) in
130130
need_rewrite_aggs.iter().zip(agg.aggregate_functions.iter())
131131
{
132+
if matches!(agg.scalar, ScalarExpr::UDAFCall(_)) {
133+
agg_results.push(agg.clone());
134+
continue;
135+
}
136+
132137
let agg_func = AggregateFunction::try_from(agg.scalar.clone())?;
133138

134139
if let Some((col_id, name)) = need_rewrite_agg {

src/query/sql/src/planner/optimizer/optimizers/rule/agg_rules/agg_index/query_rewrite.rs

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ use databend_common_expression::types::DataType;
2525
use databend_common_expression::Scalar;
2626
use databend_common_expression::TableField;
2727
use databend_common_expression::TableSchemaRefExt;
28+
use databend_common_functions::aggregates::AggregateFunctionFactory;
2829
use itertools::Itertools;
2930
use log::info;
3031

@@ -306,6 +307,12 @@ impl QueryInfo {
306307
}
307308
return Ok(None);
308309
}
310+
ScalarExpr::UDAFCall(udaf) => {
311+
for arg in &udaf.arguments {
312+
self.check_output_cols(arg, index_output_cols, new_selection_set)?;
313+
}
314+
return Ok(None);
315+
}
309316
ScalarExpr::UDFCall(udf) => {
310317
let mut valid = true;
311318
let mut new_args = Vec::with_capacity(udf.arguments.len());
@@ -361,23 +368,38 @@ impl ViewInfo {
361368
// query can use those columns to compute expressions.
362369
let mut index_fields = Vec::with_capacity(query_info.output_cols.len());
363370
let mut index_output_cols = HashMap::with_capacity(query_info.output_cols.len());
371+
let factory = AggregateFunctionFactory::instance();
364372
for (index, item) in query_info.output_cols.iter().enumerate() {
365373
let display_name = format_scalar(&item.scalar, &query_info.column_map);
366374

367-
let mut is_agg = false;
368-
if let Some(ref aggregate) = query_info.aggregate {
369-
for agg_func in &aggregate.aggregate_functions {
370-
if item.index == agg_func.index {
371-
is_agg = true;
372-
break;
373-
}
375+
let aggr_scalar_item = query_info.aggregate.as_ref().and_then(|aggregate| {
376+
aggregate
377+
.aggregate_functions
378+
.iter()
379+
.find(|agg_func| agg_func.index == item.index)
380+
});
381+
382+
let (data_type, is_agg) = match aggr_scalar_item {
383+
Some(item) => {
384+
let func = match &item.scalar {
385+
ScalarExpr::AggregateFunction(func) => func,
386+
_ => unreachable!(),
387+
};
388+
let func = factory.get(
389+
&func.func_name,
390+
func.params.clone(),
391+
func.args
392+
.iter()
393+
.map(|arg| arg.data_type())
394+
.collect::<Result<_>>()?,
395+
func.sort_descs
396+
.iter()
397+
.map(|desc| desc.try_into())
398+
.collect::<Result<_>>()?,
399+
)?;
400+
(func.serialize_data_type(), true)
374401
}
375-
}
376-
// we store the value of aggregate function as binary data.
377-
let data_type = if is_agg {
378-
DataType::Binary
379-
} else {
380-
item.scalar.data_type().unwrap()
402+
None => (item.scalar.data_type().unwrap(), false),
381403
};
382404

383405
let name = format!("{index}");
@@ -1299,6 +1321,15 @@ fn format_scalar(scalar: &ScalarExpr, column_map: &HashMap<IndexType, ScalarExpr
12991321
}
13001322
scalar
13011323
}
1324+
ScalarExpr::UDAFCall(udaf) => {
1325+
let args = udaf
1326+
.arguments
1327+
.iter()
1328+
.map(|arg| format_scalar(arg, column_map))
1329+
.collect::<Vec<_>>()
1330+
.join(", ");
1331+
format!("{}({})", &udaf.name, args)
1332+
}
13021333
ScalarExpr::UDFCall(udf) => format!(
13031334
"{}({})",
13041335
&udf.handler,

tests/sqllogictests/suites/ee/02_ee_aggregating_index/02_0001_agg_index_projected_scan.test

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,42 @@ SELECT b, SUM(a) from t WHERE c > 1 GROUP BY b ORDER BY b
7575
1 1
7676
2 3
7777

78-
query IIT
78+
query IIR
7979
SELECT MAX(a), MIN(b), AVG(c) from t
8080
----
8181
2 1 3.5
82+
83+
statement ok
84+
CREATE or REPLACE FUNCTION weighted_avg (INT, INT) STATE {sum INT, weight INT} RETURNS FLOAT
85+
LANGUAGE javascript AS $$
86+
export function create_state() {
87+
return {sum: 0, weight: 0};
88+
}
89+
export function accumulate(state, value, weight) {
90+
state.sum += value * weight;
91+
state.weight += weight;
92+
return state;
93+
}
94+
export function retract(state, value, weight) {
95+
state.sum -= value * weight;
96+
state.weight -= weight;
97+
return state;
98+
}
99+
export function merge(state1, state2) {
100+
state1.sum += state2.sum;
101+
state1.weight += state2.weight;
102+
return state1;
103+
}
104+
export function finish(state) {
105+
return state.sum / state.weight;
106+
}
107+
$$;
108+
109+
query IIR
110+
SELECT MAX(a), MIN(b), weighted_avg(a,b) from t group by b;
111+
----
112+
2 2 1.3333334
113+
1 1 1.0
114+
115+
# fix me
116+
# SELECT MAX(a), MIN(b), weighted_avg(a,b) from t;

0 commit comments

Comments
 (0)