Skip to content

Commit ad2f709

Browse files
committed
chore(query): improve in
1 parent cf0b77c commit ad2f709

File tree

9 files changed

+152
-45
lines changed

9 files changed

+152
-45
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/common/hashtable/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ test = false
1515
common-base = { path = "../base" }
1616

1717
# Crates.io dependencies
18+
ahash = "0.7.6"
1819
ordered-float = "3.0.0"
1920
primitive-types = "0.11.1"
2021

src/common/hashtable/src/hash_set.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,8 @@ impl<Key: HashTableKeyable, Grower: HashTableGrower, Allocator: AllocatorTrait +
3030
self.insert_key(value.get_key(), &mut inserted);
3131
}
3232
}
33+
34+
pub fn contains(&self, key: &Key) -> bool {
35+
self.find_key(key).is_some()
36+
}
3337
}

src/common/hashtable/src/hash_table_key.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ primitive_hasher_impl!(u16);
6262
primitive_hasher_impl!(u32);
6363
primitive_hasher_impl!(u64);
6464

65+
6566
impl HashTableKeyable for u128 {
6667
const BEFORE_EQ_HASH: bool = false;
6768
#[inline(always)]

src/common/hashtable/src/keys_ref.rs

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
// Copyright 2021 Datafuse Labs.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
use std::hash::Hasher;
16+
17+
use ahash::AHasher;
18+
use super::HashTableKeyable;
19+
20+
#[derive(Clone, Copy)]
21+
pub struct KeysRef {
22+
pub length: usize,
23+
pub address: usize,
24+
}
25+
26+
impl KeysRef {
27+
pub fn create(address: usize, length: usize) -> KeysRef {
28+
KeysRef { length, address }
29+
}
30+
}
31+
32+
impl Eq for KeysRef {}
33+
34+
impl PartialEq for KeysRef {
35+
fn eq(&self, other: &Self) -> bool {
36+
if self.length != other.length {
37+
return false;
38+
}
39+
40+
unsafe {
41+
let self_value = std::slice::from_raw_parts(self.address as *const u8, self.length);
42+
let other_value = std::slice::from_raw_parts(other.address as *const u8, other.length);
43+
self_value == other_value
44+
}
45+
}
46+
}
47+
48+
impl HashTableKeyable for KeysRef {
49+
const BEFORE_EQ_HASH: bool = true;
50+
51+
fn is_zero(&self) -> bool {
52+
self.length == 0
53+
}
54+
55+
fn fast_hash(&self) -> u64 {
56+
unsafe {
57+
// TODO(Winter) We need more efficient hash algorithm
58+
let value = std::slice::from_raw_parts(self.address as *const u8, self.length);
59+
60+
let mut hasher = AHasher::default();
61+
hasher.write(value);
62+
hasher.finish()
63+
}
64+
}
65+
66+
fn set_key(&mut self, new_value: &Self) {
67+
self.length = new_value.length;
68+
self.address = new_value.address;
69+
}
70+
}
71+
72+
impl std::hash::Hash for KeysRef {
73+
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
74+
let self_value =
75+
unsafe { std::slice::from_raw_parts(self.address as *const u8, self.length) };
76+
self_value.hash(state);
77+
}
78+
}

src/common/hashtable/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ mod hash_table_grower;
3737
mod hash_table_iter;
3838
mod hash_table_key;
3939
mod two_level_hash_table;
40+
mod keys_ref;
41+
42+
pub use keys_ref::KeysRef;
4043

4144
#[cfg(not(target_os = "linux"))]
4245
type HashTableAllocator = common_base::mem_allocator::JEAllocator;

src/query/functions/src/aggregates/aggregate_combinator_distinct.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ use std::sync::Arc;
2020
use common_arrow::arrow::bitmap::Bitmap;
2121
use common_datavalues::prelude::*;
2222
use common_exception::Result;
23+
use common_hashtable::KeysRef;
2324
use common_io::prelude::*;
2425
use ordered_float::OrderedFloat;
2526

@@ -28,7 +29,6 @@ use super::aggregate_distinct_state::AggregateDistinctState;
2829
use super::aggregate_distinct_state::AggregateDistinctStringState;
2930
use super::aggregate_distinct_state::DataGroupValues;
3031
use super::aggregate_distinct_state::DistinctStateFunc;
31-
use super::aggregate_distinct_state::KeysRef;
3232
use super::aggregate_function::AggregateFunction;
3333
use super::aggregate_function_factory::AggregateFunctionCreator;
3434
use super::aggregate_function_factory::AggregateFunctionDescription;

src/query/functions/src/aggregates/aggregate_distinct_state.rs

Lines changed: 1 addition & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ use common_exception::Result;
2626
use common_hashtable::HashSetWithStackMemory;
2727
use common_hashtable::HashTableEntity;
2828
use common_hashtable::HashTableKeyable;
29+
use common_hashtable::KeysRef;
2930
use common_io::prelude::*;
3031
use serde::Deserialize;
3132
use serde::Serialize;
@@ -358,39 +359,3 @@ where
358359
Ok(vec![result.arc()])
359360
}
360361
}
361-
362-
#[derive(Clone, Copy)]
363-
pub struct KeysRef {
364-
pub length: usize,
365-
pub address: usize,
366-
}
367-
368-
impl KeysRef {
369-
pub fn create(address: usize, length: usize) -> KeysRef {
370-
KeysRef { length, address }
371-
}
372-
}
373-
374-
impl Eq for KeysRef {}
375-
376-
impl PartialEq for KeysRef {
377-
fn eq(&self, other: &Self) -> bool {
378-
if self.length != other.length {
379-
return false;
380-
}
381-
382-
unsafe {
383-
let self_value = std::slice::from_raw_parts(self.address as *const u8, self.length);
384-
let other_value = std::slice::from_raw_parts(other.address as *const u8, other.length);
385-
self_value == other_value
386-
}
387-
}
388-
}
389-
390-
impl Hash for KeysRef {
391-
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
392-
let self_value =
393-
unsafe { std::slice::from_raw_parts(self.address as *const u8, self.length) };
394-
self_value.hash(state);
395-
}
396-
}

src/query/functions/src/scalars/conditionals/in_basic.rs

Lines changed: 62 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,13 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
use std::collections::HashSet;
1615
use std::fmt;
1716

1817
use common_datavalues::prelude::*;
1918
use common_datavalues::type_coercion::numerical_coercion;
2019
use common_exception::ErrorCode;
2120
use common_exception::Result;
21+
use common_hashtable::KeysRef;
2222
use ordered_float::OrderedFloat;
2323

2424
use crate::scalars::cast_column_field;
@@ -27,6 +27,8 @@ use crate::scalars::FunctionContext;
2727
use crate::scalars::FunctionDescription;
2828
use crate::scalars::FunctionFeatures;
2929

30+
use common_hashtable::HashSetWithStackMemory;
31+
3032
#[derive(Clone)]
3133
pub struct InFunction<const NEGATED: bool> {
3234
is_null: bool,
@@ -60,13 +62,14 @@ impl<const NEGATED: bool> InFunction<NEGATED> {
6062
macro_rules! scalar_contains {
6163
($T: ident, $INPUT_COL: expr, $ROWS: expr, $COLUMNS: expr, $CAST_TYPE: ident, $FUNC_CTX: expr) => {{
6264
let mut builder: ColumnBuilder<bool> = ColumnBuilder::with_capacity($ROWS);
63-
let mut vals_set = HashSet::with_capacity($ROWS);
65+
let mut vals_set: HashSetWithStackMemory<64, $T> = HashSetWithStackMemory::create();
66+
let mut inserted = false;
6467
for col in &$COLUMNS[1..] {
6568
let col = cast_column_field(col, col.data_type(), &$CAST_TYPE, &$FUNC_CTX)?;
6669
let col_viewer = $T::try_create_viewer(&col)?;
6770
if col_viewer.valid_at(0) {
68-
let val = col_viewer.value_at(0).to_owned_scalar();
69-
vals_set.insert(val);
71+
let val = col_viewer.value_at(0);
72+
vals_set.insert_key(&val, &mut inserted);
7073
}
7174
}
7275
let input_viewer = $T::try_create_viewer(&$INPUT_COL)?;
@@ -79,16 +82,67 @@ macro_rules! scalar_contains {
7982
}};
8083
}
8184

85+
86+
macro_rules! bool_contains {
87+
($T: ident, $INPUT_COL: expr, $ROWS: expr, $COLUMNS: expr, $CAST_TYPE: ident, $FUNC_CTX: expr) => {{
88+
let mut builder: ColumnBuilder<bool> = ColumnBuilder::with_capacity($ROWS);
89+
let mut vals = 0;
90+
for col in &$COLUMNS[1..] {
91+
let col = cast_column_field(col, col.data_type(), &$CAST_TYPE, &$FUNC_CTX)?;
92+
let col_viewer = $T::try_create_viewer(&col)?;
93+
if col_viewer.valid_at(0) {
94+
let val = col_viewer.value_at(0);
95+
vals |= 1 << (val as u8 + 1);
96+
}
97+
}
98+
let input_viewer = $T::try_create_viewer(&$INPUT_COL)?;
99+
for (row, val) in input_viewer.iter().enumerate() {
100+
let contains = ((vals >> (val as u8 + 1)) & 1) > 0;
101+
let valid = input_viewer.valid_at(row);
102+
builder.append(valid && ((contains && !NEGATED) || (!contains && NEGATED)));
103+
}
104+
return Ok(builder.build($ROWS));
105+
}};
106+
}
107+
108+
109+
macro_rules! string_contains {
110+
($T: ident, $INPUT_COL: expr, $ROWS: expr, $COLUMNS: expr, $CAST_TYPE: ident, $FUNC_CTX: expr) => {{
111+
let mut builder: ColumnBuilder<bool> = ColumnBuilder::with_capacity($ROWS);
112+
let mut vals_set: HashSetWithStackMemory<64, KeysRef> = HashSetWithStackMemory::create();
113+
let mut inserted = false;
114+
for col in &$COLUMNS[1..] {
115+
let col = cast_column_field(col, col.data_type(), &$CAST_TYPE, &$FUNC_CTX)?;
116+
let col_viewer = $T::try_create_viewer(&col)?;
117+
if col_viewer.valid_at(0) {
118+
let val = col_viewer.value_at(0);
119+
let key = KeysRef::create(val.as_ptr() as usize, val.len());
120+
vals_set.insert_key(&key, &mut inserted);
121+
}
122+
}
123+
let input_viewer = $T::try_create_viewer(&$INPUT_COL)?;
124+
for (row, val) in input_viewer.iter().enumerate() {
125+
let key = KeysRef::create(val.as_ptr() as usize, val.len());
126+
let contains = vals_set.contains(&key);
127+
let valid = input_viewer.valid_at(row);
128+
builder.append(valid && ((contains && !NEGATED) || (!contains && NEGATED)));
129+
}
130+
return Ok(builder.build($ROWS));
131+
}};
132+
}
133+
82134
macro_rules! float_contains {
83135
($T: ident, $INPUT_COL: expr, $ROWS: expr, $COLUMNS: expr, $CAST_TYPE: ident, $FUNC_CTX: expr) => {{
84136
let mut builder: ColumnBuilder<bool> = ColumnBuilder::with_capacity($ROWS);
85-
let mut vals_set = HashSet::with_capacity($ROWS);
137+
let mut vals_set: HashSetWithStackMemory<64, OrderedFloat<$T>> = HashSetWithStackMemory::create();
138+
let mut inserted = false;
139+
86140
for col in &$COLUMNS[1..] {
87141
let col = cast_column_field(col, col.data_type(), &$CAST_TYPE, &$FUNC_CTX)?;
88142
let col_viewer = $T::try_create_viewer(&col)?;
89143
if col_viewer.valid_at(0) {
90144
let val = col_viewer.value_at(0);
91-
vals_set.insert(OrderedFloat::from(val));
145+
vals_set.insert_key(&OrderedFloat::from(val), &mut inserted);
92146
}
93147
}
94148
let input_viewer = $T::try_create_viewer(&$INPUT_COL)?;
@@ -158,7 +212,7 @@ impl<const NEGATED: bool> Function for InFunction<NEGATED> {
158212

159213
match least_super_type_id {
160214
TypeID::Boolean => {
161-
scalar_contains!(
215+
bool_contains!(
162216
bool,
163217
input_col,
164218
input_rows,
@@ -234,7 +288,7 @@ impl<const NEGATED: bool> Function for InFunction<NEGATED> {
234288
)
235289
}
236290
TypeID::String => {
237-
scalar_contains!(
291+
string_contains!(
238292
Vu8,
239293
input_col,
240294
input_rows,

0 commit comments

Comments
 (0)