Skip to content

Commit 4c4a839

Browse files
authored
Merge pull request #7759 from sundy-li/agg-distinct
feat(query): fix distinct aggregate function
2 parents 42c0883 + 2c8010a commit 4c4a839

File tree

8 files changed

+229
-104
lines changed

8 files changed

+229
-104
lines changed

src/common/hashtable/src/hash_table_entity.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ where
4141
Key: HashTableKeyable,
4242
Value: Sized + Clone,
4343
{
44+
#[inline(always)]
45+
pub fn set_key(self: *mut Self, key: Key) {
46+
unsafe { std::ptr::write(&mut (*self).key as *mut Key, key) }
47+
}
48+
4449
#[inline(always)]
4550
pub fn set_value(self: *mut Self, value: Value) {
4651
unsafe { std::ptr::write(&mut (*self).value as *mut Value, value) }

src/common/hashtable/src/keys_ref.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@ impl KeysRef {
2828
pub fn create(address: usize, length: usize) -> KeysRef {
2929
KeysRef { length, address }
3030
}
31+
32+
#[allow(clippy::missing_safety_doc)]
33+
#[inline]
34+
pub unsafe fn as_slice(&self) -> &[u8] {
35+
std::slice::from_raw_parts(self.address as *const u8, self.length)
36+
}
3137
}
3238

3339
impl Eq for KeysRef {}
@@ -38,11 +44,7 @@ impl PartialEq for KeysRef {
3844
return false;
3945
}
4046

41-
unsafe {
42-
let self_value = std::slice::from_raw_parts(self.address as *const u8, self.length);
43-
let other_value = std::slice::from_raw_parts(other.address as *const u8, other.length);
44-
self_value == other_value
45-
}
47+
unsafe { self.as_slice() == other.as_slice() }
4648
}
4749
}
4850

src/query/datavalues/src/columns/string/mutable.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,12 @@ use std::sync::Arc;
1717
use common_arrow::arrow::bitmap::MutableBitmap;
1818
use common_exception::ErrorCode;
1919
use common_exception::Result;
20+
use serde::Deserialize;
21+
use serde::Serialize;
2022

2123
use crate::prelude::*;
2224

25+
#[derive(Serialize, Deserialize)]
2326
pub struct MutableStringColumn {
2427
last_size: usize,
2528
offsets: Vec<i64>,
@@ -61,6 +64,11 @@ impl MutableStringColumn {
6164
}
6265
}
6366

67+
#[inline]
68+
pub fn may_resize(&self, add_size: usize) -> bool {
69+
self.values.len() + add_size > self.values.capacity()
70+
}
71+
6472
pub fn values_mut(&mut self) -> &mut Vec<u8> {
6573
&mut self.values
6674
}

src/query/expression/src/types/string.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,11 @@ impl StringColumnBuilder {
287287
self.data[(self.offsets[0] as usize)..(self.offsets[1] as usize)].to_vec()
288288
}
289289

290+
#[inline]
291+
pub fn may_resize(&self, add_size: usize) -> bool {
292+
self.data.len() + add_size > self.data.capacity()
293+
}
294+
290295
/// # Safety
291296
pub unsafe fn index_unchecked(&self, row: usize) -> &[u8] {
292297
let start = *self.offsets.get_unchecked(row) as usize;

src/query/functions-v2/src/aggregates/aggregate_distinct_state.rs

Lines changed: 91 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ use common_expression::types::ValueType;
3030
use common_expression::Column;
3131
use common_expression::ColumnBuilder;
3232
use common_expression::Scalar;
33+
use common_hashtable::HashSet as CommonHashSet;
3334
use common_hashtable::HashSetWithStackMemory;
3435
use common_hashtable::HashTableEntity;
3536
use common_hashtable::HashTableKeyable;
@@ -64,9 +65,13 @@ pub struct AggregateDistinctNumberState<T: Number + HashTableKeyable> {
6465
inserted: bool,
6566
}
6667

68+
const HOLDER_CAPACITY: usize = 256;
69+
const HOLDER_BYTES_CAPACITY: usize = HOLDER_CAPACITY * 8;
70+
6771
pub struct AggregateDistinctStringState {
68-
set: HashSet<KeysRef, RandomState>,
69-
holder: StringColumnBuilder,
72+
set: CommonHashSet<KeysRef>,
73+
inserted: bool,
74+
holders: Vec<StringColumnBuilder>,
7075
}
7176

7277
pub struct DataGroupValue;
@@ -148,26 +153,61 @@ impl DistinctStateFunc<DataGroupValue> for AggregateDistinctState {
148153
}
149154
}
150155

156+
impl AggregateDistinctStringState {
157+
#[inline]
158+
fn insert_and_materialize(&mut self, key: &KeysRef) {
159+
let entity = self.set.insert_key(key, &mut self.inserted);
160+
if self.inserted {
161+
let data = unsafe { key.as_slice() };
162+
163+
let holder = self.holders.last_mut().unwrap();
164+
// TODO(sundy): may cause memory fragmentation, refactor this using arena
165+
if holder.may_resize(data.len()) {
166+
let mut holder = StringColumnBuilder::with_capacity(
167+
HOLDER_CAPACITY,
168+
HOLDER_BYTES_CAPACITY.max(data.len()),
169+
);
170+
holder.put_slice(data);
171+
holder.commit_row();
172+
let value = unsafe { holder.index_unchecked(holder.len() - 1) };
173+
entity.set_key(KeysRef::create(value.as_ptr() as usize, value.len()));
174+
self.holders.push(holder);
175+
} else {
176+
holder.put_slice(data);
177+
holder.commit_row();
178+
let value = unsafe { holder.index_unchecked(holder.len() - 1) };
179+
entity.set_key(KeysRef::create(value.as_ptr() as usize, value.len()));
180+
}
181+
}
182+
}
183+
}
184+
151185
impl DistinctStateFunc<KeysRef> for AggregateDistinctStringState {
152186
fn new() -> Self {
153187
AggregateDistinctStringState {
154-
set: HashSet::new(),
155-
holder: StringColumnBuilder::with_capacity(0, 0),
188+
set: CommonHashSet::create(),
189+
inserted: false,
190+
holders: vec![StringColumnBuilder::with_capacity(
191+
HOLDER_CAPACITY,
192+
HOLDER_BYTES_CAPACITY,
193+
)],
156194
}
157195
}
158196

159197
fn serialize(&self, writer: &mut BytesMut) -> Result<()> {
160-
serialize_into_buf(writer, &self.holder)
198+
serialize_into_buf(writer, &self.holders)
161199
}
162200

163201
fn deserialize(&mut self, reader: &mut &[u8]) -> Result<()> {
164-
self.holder = deserialize_from_slice(reader)?;
165-
self.set = HashSet::with_capacity(self.holder.len());
166-
167-
for index in 0..self.holder.len() {
168-
let data = unsafe { self.holder.index_unchecked(index) };
169-
let key = KeysRef::create(data.as_ptr() as usize, data.len());
170-
self.set.insert(key);
202+
self.holders = deserialize_from_slice(reader)?;
203+
self.set = CommonHashSet::with_capacity(self.holders.iter().map(|h| h.len()).sum());
204+
205+
for holder in self.holders.iter() {
206+
for index in 0..holder.len() {
207+
let data = unsafe { holder.index_unchecked(index) };
208+
let key = KeysRef::create(data.as_ptr() as usize, data.len());
209+
self.set.insert_key(&key, &mut self.inserted);
210+
}
171211
}
172212
Ok(())
173213
}
@@ -183,16 +223,8 @@ impl DistinctStateFunc<KeysRef> for AggregateDistinctStringState {
183223
fn add(&mut self, columns: &[Column], row: usize) -> Result<()> {
184224
let column = StringType::try_downcast_column(&columns[0]).unwrap();
185225
let data = unsafe { column.index_unchecked(row) };
186-
187-
let mut key = KeysRef::create(data.as_ptr() as usize, data.len());
188-
189-
if !self.set.contains(&key) {
190-
self.holder.put_slice(data);
191-
self.holder.commit_row();
192-
let data = unsafe { self.holder.index_unchecked(self.holder.len() - 1) };
193-
key = KeysRef::create(data.as_ptr() as usize, data.len());
194-
self.set.insert(key);
195-
}
226+
let key = KeysRef::create(data.as_ptr() as usize, data.len());
227+
self.insert_and_materialize(&key);
196228
Ok(())
197229
}
198230

@@ -204,47 +236,59 @@ impl DistinctStateFunc<KeysRef> for AggregateDistinctStringState {
204236
) -> Result<()> {
205237
let column = StringType::try_downcast_column(&columns[0]).unwrap();
206238

207-
for row in 0..input_rows {
208-
match validity {
209-
Some(v) => {
239+
match validity {
240+
Some(v) => {
241+
for row in 0..input_rows {
210242
if v.get_bit(row) {
211243
let data = unsafe { column.index_unchecked(row) };
212-
let mut key = KeysRef::create(data.as_ptr() as usize, data.len());
213-
if !self.set.contains(&key) {
214-
self.holder.put_slice(data);
215-
self.holder.commit_row();
216-
217-
let data =
218-
unsafe { self.holder.index_unchecked(self.holder.len() - 1) };
219-
key = KeysRef::create(data.as_ptr() as usize, data.len());
220-
self.set.insert(key);
221-
}
244+
let key = KeysRef::create(data.as_ptr() as usize, data.len());
245+
self.insert_and_materialize(&key);
222246
}
223247
}
224-
None => {
248+
}
249+
None => {
250+
for row in 0..input_rows {
225251
let data = unsafe { column.index_unchecked(row) };
226-
let mut key = KeysRef::create(data.as_ptr() as usize, data.len());
227-
if !self.set.contains(&key) {
228-
self.holder.put_slice(data);
229-
self.holder.commit_row();
230-
231-
let data = unsafe { self.holder.index_unchecked(self.holder.len() - 1) };
232-
key = KeysRef::create(data.as_ptr() as usize, data.len());
233-
self.set.insert(key);
234-
}
252+
let key = KeysRef::create(data.as_ptr() as usize, data.len());
253+
self.insert_and_materialize(&key);
235254
}
236255
}
237256
}
238257
Ok(())
239258
}
240259

241260
fn merge(&mut self, rhs: &Self) -> Result<()> {
242-
self.set.extend(rhs.set.clone());
261+
for value in rhs.set.iter() {
262+
self.insert_and_materialize(value.get_key());
263+
}
243264
Ok(())
244265
}
245266

246267
fn build_columns(&mut self, _types: &[DataType]) -> Result<Vec<Column>> {
247-
let c = std::mem::replace(&mut self.holder, StringColumnBuilder::with_capacity(0, 0));
268+
if self.holders.len() == 1 {
269+
let c = std::mem::replace(
270+
&mut self.holders[0],
271+
StringColumnBuilder::with_capacity(0, 0),
272+
);
273+
return Ok(vec![Column::String(c.build())]);
274+
}
275+
276+
let mut values = Vec::with_capacity(self.holders.iter().map(|h| h.data.len()).sum());
277+
let mut offsets = Vec::with_capacity(self.holders.iter().map(|h| h.len()).sum());
278+
279+
let mut last_offset = 0;
280+
offsets.push(0);
281+
for holder in self.holders.iter_mut() {
282+
for offset in holder.offsets.iter() {
283+
last_offset += *offset;
284+
offsets.push(last_offset);
285+
}
286+
values.append(&mut holder.data);
287+
}
288+
let c = StringColumnBuilder {
289+
data: values,
290+
offsets,
291+
};
248292
Ok(vec![Column::String(c.build())])
249293
}
250294
}

0 commit comments

Comments
 (0)