Skip to content

Commit 1ac6fd1

Browse files
committed
trait
1 parent d6ddbae commit 1ac6fd1

30 files changed

+132
-102
lines changed

src/query/expression/src/aggregate/aggregate_function.rs

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@ use super::AggrState;
2222
use super::AggrStateLoc;
2323
use super::AggrStateRegistry;
2424
use super::StateAddr;
25-
use crate::types::BinaryType;
2625
use crate::types::DataType;
26+
use crate::AggrStateSerdeType;
2727
use crate::BlockEntry;
2828
use crate::ColumnBuilder;
29-
use crate::ColumnView;
3029
use crate::ProjectedBlock;
3130
use crate::Scalar;
31+
use crate::ScalarRef;
3232

3333
pub type AggregateFunctionRef = Arc<dyn AggregateFunction>;
3434

@@ -69,32 +69,49 @@ pub trait AggregateFunction: fmt::Display + Sync + Send {
6969
// Used in aggregate_null_adaptor
7070
fn accumulate_row(&self, place: AggrState, columns: ProjectedBlock, row: usize) -> Result<()>;
7171

72-
fn serialize(&self, place: AggrState, writer: &mut Vec<u8>) -> Result<()>;
72+
fn serialize_type(&self) -> Vec<AggrStateSerdeType> {
73+
vec![AggrStateSerdeType::Binary(self.serialize_size_per_row())]
74+
}
75+
76+
fn serialize(&self, place: AggrState, builder: &mut ColumnBuilder) -> Result<()> {
77+
let binary_builder = builder.as_tuple_mut().unwrap()[0].as_binary_mut().unwrap();
78+
self.serialize_binary(place, &mut binary_builder.data)?;
79+
binary_builder.commit_row();
80+
Ok(())
81+
}
82+
83+
fn serialize_binary(&self, place: AggrState, writer: &mut Vec<u8>) -> Result<()>;
7384

7485
fn serialize_size_per_row(&self) -> Option<usize> {
7586
None
7687
}
7788

78-
fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()>;
89+
fn merge(&self, place: AggrState, data: ScalarRef) -> Result<()> {
90+
let mut binary = *data.as_tuple().unwrap()[0].as_binary().unwrap();
91+
self.merge_binary(place, &mut binary)
92+
}
93+
94+
fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()>;
7995

8096
/// Batch merge and deserialize the state from binary array
8197
fn batch_merge(
8298
&self,
8399
places: &[StateAddr],
84100
loc: &[AggrStateLoc],
85-
state: &ColumnView<BinaryType>,
101+
state: &BlockEntry,
86102
) -> Result<()> {
87-
for (place, mut data) in places.iter().zip(state.iter()) {
88-
self.merge(AggrState::new(*place, loc), &mut data)?;
103+
let column = state.to_column();
104+
for (place, data) in places.iter().zip(column.iter()) {
105+
self.merge(AggrState::new(*place, loc), data)?;
89106
}
90107

91108
Ok(())
92109
}
93110

94111
fn batch_merge_single(&self, place: AggrState, state: &BlockEntry) -> Result<()> {
95-
let view = state.downcast::<BinaryType>().unwrap();
96-
for mut data in view.iter() {
97-
self.merge(place, &mut data)?;
112+
let column = state.to_column();
113+
for data in column.iter() {
114+
self.merge(place, data)?;
98115
}
99116
Ok(())
100117
}

src/query/expression/src/aggregate/aggregate_function_state.rs

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ use enum_as_inner::EnumAsInner;
2020

2121
use super::AggregateFunctionRef;
2222
use crate::types::binary::BinaryColumnBuilder;
23+
use crate::types::DataType;
24+
use crate::ColumnBuilder;
2325

2426
#[derive(Clone, Copy, Debug)]
2527
pub struct StateAddr {
@@ -113,11 +115,11 @@ impl From<StateAddr> for usize {
113115

114116
pub fn get_states_layout(funcs: &[AggregateFunctionRef]) -> Result<StatesLayout> {
115117
let mut registry = AggrStateRegistry::default();
116-
let mut serialize_size = Vec::with_capacity(funcs.len());
118+
let mut serialize_type = Vec::with_capacity(funcs.len());
117119
for func in funcs {
118120
func.register_state(&mut registry);
119121
registry.commit();
120-
serialize_size.push(func.serialize_size_per_row());
122+
serialize_type.push(func.serialize_type().into_boxed_slice());
121123
}
122124

123125
let AggrStateRegistry { states, offsets } = registry;
@@ -132,7 +134,7 @@ pub fn get_states_layout(funcs: &[AggregateFunctionRef]) -> Result<StatesLayout>
132134
Ok(StatesLayout {
133135
layout,
134136
states_loc,
135-
serialize_size,
137+
serialize_type,
136138
})
137139
}
138140

@@ -195,14 +197,30 @@ impl AggrStateLoc {
195197
pub struct StatesLayout {
196198
pub layout: Layout,
197199
pub states_loc: Vec<Box<[AggrStateLoc]>>,
198-
serialize_size: Vec<Option<usize>>,
200+
serialize_type: Vec<Box<[AggrStateSerdeType]>>,
199201
}
200202

201203
impl StatesLayout {
202-
pub fn serialize_builders(&self, num_rows: usize) -> Vec<BinaryColumnBuilder> {
203-
self.serialize_size
204+
pub fn serialize_builders(&self, num_rows: usize) -> Vec<ColumnBuilder> {
205+
self.serialize_type
204206
.iter()
205-
.map(|size| BinaryColumnBuilder::with_capacity(num_rows, num_rows * size.unwrap_or(0)))
207+
.map(|item| {
208+
let builder = item
209+
.iter()
210+
.map(|serde_type| match serde_type {
211+
AggrStateSerdeType::Bool => {
212+
ColumnBuilder::with_capacity(&DataType::Boolean, num_rows)
213+
}
214+
AggrStateSerdeType::Binary(size) => {
215+
ColumnBuilder::Binary(BinaryColumnBuilder::with_capacity(
216+
num_rows,
217+
num_rows * size.unwrap_or(0),
218+
))
219+
}
220+
})
221+
.collect();
222+
ColumnBuilder::Tuple(builder)
223+
})
206224
.collect()
207225
}
208226
}
@@ -288,6 +306,12 @@ pub enum AggrStateType {
288306
Custom(Layout),
289307
}
290308

309+
#[derive(Debug, Clone, Copy)]
310+
pub enum AggrStateSerdeType {
311+
Bool,
312+
Binary(Option<usize>),
313+
}
314+
291315
#[cfg(test)]
292316
mod tests {
293317
use proptest::prelude::*;

src/query/expression/src/aggregate/aggregate_hashtable.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ use crate::aggregate::payload_row::row_match_columns;
2727
use crate::group_hash_columns;
2828
use crate::new_sel;
2929
use crate::read;
30-
use crate::types::BinaryType;
3130
use crate::types::DataType;
3231
use crate::AggregateFunctionRef;
3332
use crate::BlockEntry;
@@ -219,7 +218,7 @@ impl AggregateHashTable {
219218
.zip(agg_states.iter())
220219
.zip(states_layout.states_loc.iter())
221220
{
222-
func.batch_merge(state_places, loc, &state.downcast::<BinaryType>().unwrap())?;
221+
func.batch_merge(state_places, loc, state)?;
223222
}
224223
}
225224
}

src/query/expression/src/aggregate/payload_flush.rs

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -150,17 +150,12 @@ impl Payload {
150150
{
151151
{
152152
let builder = &mut builders[idx];
153-
func.serialize(AggrState::new(*place, loc), &mut builder.data)?;
154-
builder.commit_row();
153+
func.serialize(AggrState::new(*place, loc), builder)?;
155154
}
156155
}
157156
}
158157

159-
entries.extend(
160-
builders
161-
.into_iter()
162-
.map(|builder| Column::Binary(builder.build()).into()),
163-
);
158+
entries.extend(builders.into_iter().map(|builder| builder.build().into()));
164159
}
165160

166161
entries.extend_from_slice(&state.take_group_columns());

src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -183,11 +183,11 @@ impl<const NULLABLE_RESULT: bool> AggregateFunction for AggregateNullUnaryAdapto
183183
.accumulate_row(place, not_null_columns, validity, row)
184184
}
185185

186-
fn serialize(&self, place: AggrState, writer: &mut Vec<u8>) -> Result<()> {
186+
fn serialize_binary(&self, place: AggrState, writer: &mut Vec<u8>) -> Result<()> {
187187
self.0.serialize(place, writer)
188188
}
189189

190-
fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> {
190+
fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> {
191191
self.0.merge(place, reader)
192192
}
193193

@@ -308,11 +308,11 @@ impl<const NULLABLE_RESULT: bool> AggregateFunction
308308
.accumulate_row(place, not_null_columns, validity, row)
309309
}
310310

311-
fn serialize(&self, place: AggrState, writer: &mut Vec<u8>) -> Result<()> {
311+
fn serialize_binary(&self, place: AggrState, writer: &mut Vec<u8>) -> Result<()> {
312312
self.0.serialize(place, writer)
313313
}
314314

315-
fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> {
315+
fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> {
316316
self.0.merge(place, reader)
317317
}
318318

@@ -498,17 +498,17 @@ impl<const NULLABLE_RESULT: bool> CommonNullAdaptor<NULLABLE_RESULT> {
498498

499499
fn serialize(&self, place: AggrState, writer: &mut Vec<u8>) -> Result<()> {
500500
if !NULLABLE_RESULT {
501-
return self.nested.serialize(place, writer);
501+
return self.nested.serialize_binary(place, writer);
502502
}
503503

504-
self.nested.serialize(place.remove_last_loc(), writer)?;
504+
self.nested.serialize_binary(place.remove_last_loc(), writer)?;
505505
let flag = get_flag(place);
506506
writer.write_scalar(&flag)
507507
}
508508

509509
fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> {
510510
if !NULLABLE_RESULT {
511-
return self.nested.merge(place, reader);
511+
return self.nested.merge_binary(place, reader);
512512
}
513513

514514
let flag = reader[reader.len() - 1];
@@ -522,7 +522,7 @@ impl<const NULLABLE_RESULT: bool> CommonNullAdaptor<NULLABLE_RESULT> {
522522
}
523523
set_flag(place, true);
524524
self.nested
525-
.merge(place.remove_last_loc(), &mut &reader[..reader.len() - 1])
525+
.merge_binary(place.remove_last_loc(), &mut &reader[..reader.len() - 1])
526526
}
527527

528528
fn merge_states(&self, place: AggrState, rhs: AggrState) -> Result<()> {

src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -178,17 +178,17 @@ impl AggregateFunction for AggregateFunctionOrNullAdaptor {
178178
}
179179

180180
#[inline]
181-
fn serialize(&self, place: AggrState, writer: &mut Vec<u8>) -> Result<()> {
182-
self.inner.serialize(place.remove_last_loc(), writer)?;
181+
fn serialize_binary(&self, place: AggrState, writer: &mut Vec<u8>) -> Result<()> {
182+
self.inner.serialize_binary(place.remove_last_loc(), writer)?;
183183
let flag = get_flag(place) as u8;
184184
writer.write_scalar(&flag)
185185
}
186186

187187
#[inline]
188-
fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> {
188+
fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> {
189189
let flag = get_flag(place) || reader[reader.len() - 1] > 0;
190190
self.inner
191-
.merge(place.remove_last_loc(), &mut &reader[..reader.len() - 1])?;
191+
.merge_binary(place.remove_last_loc(), &mut &reader[..reader.len() - 1])?;
192192
set_flag(place, flag);
193193
Ok(())
194194
}

src/query/functions/src/aggregates/adaptors/aggregate_sort_adaptor.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,12 +121,12 @@ impl AggregateFunction for AggregateFunctionSortAdaptor {
121121
Ok(())
122122
}
123123

124-
fn serialize(&self, place: AggrState, writer: &mut Vec<u8>) -> Result<()> {
124+
fn serialize_binary(&self, place: AggrState, writer: &mut Vec<u8>) -> Result<()> {
125125
let state = Self::get_state(place);
126126
Ok(state.serialize(writer)?)
127127
}
128128

129-
fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> {
129+
fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> {
130130
let state = Self::get_state(place);
131131
let rhs = SortAggState::deserialize(reader)?;
132132

src/query/functions/src/aggregates/aggregate_arg_min_max.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,12 +270,12 @@ where
270270
Ok(())
271271
}
272272

273-
fn serialize(&self, place: AggrState, writer: &mut Vec<u8>) -> Result<()> {
273+
fn serialize_binary(&self, place: AggrState, writer: &mut Vec<u8>) -> Result<()> {
274274
let state = place.get::<State>();
275275
Ok(state.serialize(writer)?)
276276
}
277277

278-
fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> {
278+
fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> {
279279
let state = place.get::<State>();
280280
let rhs: State = borsh_partial_deserialize(reader)?;
281281
state.merge_from(rhs)

src/query/functions/src/aggregates/aggregate_array_agg.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -548,12 +548,12 @@ where
548548
Ok(())
549549
}
550550

551-
fn serialize(&self, place: AggrState, writer: &mut Vec<u8>) -> Result<()> {
551+
fn serialize_binary(&self, place: AggrState, writer: &mut Vec<u8>) -> Result<()> {
552552
let state = place.get::<State>();
553553
Ok(state.serialize(writer)?)
554554
}
555555

556-
fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> {
556+
fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> {
557557
let state = place.get::<State>();
558558
let rhs = State::deserialize_reader(reader)?;
559559

src/query/functions/src/aggregates/aggregate_array_moving.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -447,12 +447,12 @@ where State: SumState
447447
state.accumulate_row(&columns[0], row)
448448
}
449449

450-
fn serialize(&self, place: AggrState, writer: &mut Vec<u8>) -> Result<()> {
450+
fn serialize_binary(&self, place: AggrState, writer: &mut Vec<u8>) -> Result<()> {
451451
let state = place.get::<State>();
452452
Ok(state.serialize(writer)?)
453453
}
454454

455-
fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> {
455+
fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> {
456456
let state = place.get::<State>();
457457
let rhs = State::deserialize_reader(reader)?;
458458

@@ -616,12 +616,12 @@ where State: SumState
616616
state.accumulate_row(&columns[0], row)
617617
}
618618

619-
fn serialize(&self, place: AggrState, writer: &mut Vec<u8>) -> Result<()> {
619+
fn serialize_binary(&self, place: AggrState, writer: &mut Vec<u8>) -> Result<()> {
620620
let state = place.get::<State>();
621621
Ok(state.serialize(writer)?)
622622
}
623623

624-
fn merge(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> {
624+
fn merge_binary(&self, place: AggrState, reader: &mut &[u8]) -> Result<()> {
625625
let state = place.get::<State>();
626626
let rhs = State::deserialize_reader(reader)?;
627627

0 commit comments

Comments
 (0)