Skip to content

Commit 355a082

Browse files
authored
fix(query): add find_leveled_eq_filters function (#18264)
* fix(query): add find_leveled_eq_filters function * fix(query): add find_leveled_eq_filters function * fix * fix * fix * fix
1 parent 14b24bc commit 355a082

File tree

8 files changed

+462
-231
lines changed

8 files changed

+462
-231
lines changed

src/query/expression/src/expression.rs

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@ use crate::values::Scalar;
3232

3333
pub trait ColumnIndex: Debug + Clone + Serialize + Hash + Eq + 'static {
3434
fn unique_name<W: Write>(&self, f: &mut W) -> std::fmt::Result;
35+
fn name(&self) -> String {
36+
let mut buf = String::new();
37+
let _ = self.unique_name(&mut buf);
38+
buf
39+
}
3540
}
3641

3742
impl ColumnIndex for usize {
@@ -759,6 +764,129 @@ impl<Index: ColumnIndex> Expr<Index> {
759764
}
760765
}
761766

767+
pub fn visit_func(&self, func_name: &str, visitor: &mut impl FnMut(&FunctionCall<Index>)) {
768+
struct Visitor<'a, Index: ColumnIndex, F: FnMut(&FunctionCall<Index>)> {
769+
name: String,
770+
visitor: &'a mut F,
771+
_marker: std::marker::PhantomData<Index>,
772+
}
773+
774+
impl<'a, Index: ColumnIndex, F> ExprVisitor<Index> for Visitor<'a, Index, F>
775+
where F: FnMut(&FunctionCall<Index>)
776+
{
777+
fn enter_function_call(
778+
&mut self,
779+
call: &FunctionCall<Index>,
780+
) -> Result<Option<Expr<Index>>, Self::Error> {
781+
if call.function.signature.name == self.name {
782+
(self.visitor)(call);
783+
}
784+
Self::visit_function_call(call, self)
785+
}
786+
}
787+
788+
visit_expr(self, &mut Visitor {
789+
name: func_name.to_string(),
790+
visitor,
791+
_marker: std::marker::PhantomData,
792+
})
793+
.unwrap();
794+
}
795+
796+
pub fn find_function_literals(
797+
&self,
798+
func_name: &str,
799+
// column name, constant scalar, column is left
800+
visitor: &mut impl FnMut(&Index, &Scalar, bool),
801+
) {
802+
struct Visitor<'a, Index: ColumnIndex, F: FnMut(&Index, &Scalar, bool)> {
803+
name: String,
804+
visitor: &'a mut F,
805+
_marker: std::marker::PhantomData<Index>,
806+
}
807+
808+
impl<'a, Index: ColumnIndex, F> ExprVisitor<Index> for Visitor<'a, Index, F>
809+
where F: FnMut(&Index, &Scalar, bool)
810+
{
811+
fn enter_function_call(
812+
&mut self,
813+
call: &FunctionCall<Index>,
814+
) -> Result<Option<Expr<Index>>, Self::Error> {
815+
if call.function.signature.name == self.name {
816+
match call.args.as_slice() {
817+
[Expr::ColumnRef(ColumnRef { id, .. }), Expr::Constant(Constant { scalar, .. })] =>
818+
{
819+
(self.visitor)(id, scalar, true);
820+
}
821+
[Expr::Constant(Constant { scalar, .. }), Expr::ColumnRef(ColumnRef { id, .. })] =>
822+
{
823+
(self.visitor)(id, scalar, false);
824+
}
825+
_ => {}
826+
}
827+
}
828+
Self::visit_function_call(call, self)
829+
}
830+
}
831+
832+
visit_expr(self, &mut Visitor {
833+
name: func_name.to_string(),
834+
visitor,
835+
_marker: std::marker::PhantomData,
836+
})
837+
.unwrap();
838+
}
839+
840+
// replace function call with constant scalar
841+
pub fn replace_function_literals(
842+
&self,
843+
func_name: &str,
844+
visitor: &mut impl FnMut(&Index, &Scalar, &FunctionCall<Index>) -> Option<Expr<Index>>,
845+
) -> Expr<Index> {
846+
struct Visitor<
847+
'a,
848+
Index: ColumnIndex,
849+
F: FnMut(&Index, &Scalar, &FunctionCall<Index>) -> Option<Expr<Index>>,
850+
> {
851+
name: String,
852+
visitor: &'a mut F,
853+
_marker: std::marker::PhantomData<Index>,
854+
}
855+
856+
impl<'a, Index: ColumnIndex, F> ExprVisitor<Index> for Visitor<'a, Index, F>
857+
where F: FnMut(&Index, &Scalar, &FunctionCall<Index>) -> Option<Expr<Index>>
858+
{
859+
fn enter_function_call(
860+
&mut self,
861+
call: &FunctionCall<Index>,
862+
) -> Result<Option<Expr<Index>>, Self::Error> {
863+
if call.function.signature.name == self.name {
864+
match call.args.as_slice() {
865+
[Expr::ColumnRef(ColumnRef { id, .. }), Expr::Constant(Constant { scalar, .. })] =>
866+
{
867+
return Ok((self.visitor)(id, scalar, call));
868+
}
869+
[Expr::Constant(Constant { scalar, .. }), Expr::ColumnRef(ColumnRef { id, .. })] =>
870+
{
871+
return Ok((self.visitor)(id, scalar, call));
872+
}
873+
_ => {}
874+
}
875+
}
876+
Self::visit_function_call(call, self)
877+
}
878+
}
879+
880+
let res = visit_expr(self, &mut Visitor {
881+
name: func_name.to_string(),
882+
visitor,
883+
_marker: std::marker::PhantomData,
884+
})
885+
.unwrap();
886+
887+
res.unwrap_or_else(|| self.clone())
888+
}
889+
762890
pub fn as_remote_expr(&self) -> RemoteExpr<Index> {
763891
match self {
764892
Expr::Constant(Constant {

src/query/expression/src/type_check.rs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,40 @@ pub fn wrap_nullable_for_try_cast(span: Span, ty: &DataType) -> Result<DataType>
213213
}
214214
}
215215

216+
pub fn check_string<Index: ColumnIndex>(
217+
span: Span,
218+
func_ctx: &FunctionContext,
219+
expr: &Expr<Index>,
220+
fn_registry: &FunctionRegistry,
221+
) -> Result<String> {
222+
let origin_ty = expr.data_type();
223+
let (expr, _) = if origin_ty != &DataType::String {
224+
ConstantFolder::fold(
225+
&Expr::Cast(Cast {
226+
span,
227+
is_try: false,
228+
expr: Box::new(expr.clone()),
229+
dest_type: DataType::String,
230+
}),
231+
func_ctx,
232+
fn_registry,
233+
)
234+
} else {
235+
ConstantFolder::fold(expr, func_ctx, fn_registry)
236+
};
237+
238+
match expr {
239+
Expr::Constant(Constant {
240+
scalar: Scalar::String(string),
241+
..
242+
}) => Ok(string.clone()),
243+
_ => Err(
244+
ErrorCode::from_string_no_backtrace("expected string literal".to_string())
245+
.set_span(span),
246+
),
247+
}
248+
}
249+
216250
pub fn check_number<T: Number, Index: ColumnIndex>(
217251
span: Span,
218252
func_ctx: &FunctionContext,

src/query/expression/src/utils/filter_helper.rs

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,19 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
use std::collections::HashSet;
16+
1517
use databend_common_column::bitmap::MutableBitmap;
1618

1719
use crate::arrow::bitmap_into_mut;
1820
use crate::types::BooleanType;
21+
use crate::ColumnIndex;
22+
use crate::Constant;
23+
use crate::ConstantFolder;
24+
use crate::Expr;
25+
use crate::FunctionContext;
26+
use crate::FunctionRegistry;
27+
use crate::Scalar;
1928
use crate::Value;
2029

2130
pub struct FilterHelpers;
@@ -36,4 +45,87 @@ impl FilterHelpers {
3645
Value::Column(bitmap) => bitmap_into_mut(bitmap),
3746
}
3847
}
48+
49+
pub fn find_leveled_eq_filters<I: ColumnIndex>(
50+
expr: &Expr<I>,
51+
level_names: &[&str],
52+
func_ctx: &FunctionContext,
53+
fn_registry: &FunctionRegistry,
54+
) -> databend_common_exception::Result<Vec<Vec<Scalar>>> {
55+
let mut scalars = vec![];
56+
for name in level_names {
57+
let mut values = Vec::new();
58+
expr.find_function_literals("eq", &mut |col_name, scalar, _| {
59+
if col_name.name() == *name {
60+
values.push(scalar.clone());
61+
}
62+
});
63+
values.dedup();
64+
65+
let mut results = Vec::with_capacity(values.len());
66+
67+
if !values.is_empty() {
68+
for value in values.iter() {
69+
// replace eq with false
70+
let expr =
71+
expr.replace_function_literals("eq", &mut |col_name, scalar, func| {
72+
if col_name.name() == *name {
73+
if scalar == value {
74+
let data_type = func.function.signature.return_type.clone();
75+
Some(Expr::Constant(Constant {
76+
span: None,
77+
scalar: Scalar::Boolean(false),
78+
data_type,
79+
}))
80+
} else {
81+
// for other values, we just ignore it
82+
None
83+
}
84+
} else {
85+
// for other columns, we just ignore it
86+
None
87+
}
88+
});
89+
90+
let (folded_expr, _) = ConstantFolder::fold(&expr, func_ctx, fn_registry);
91+
92+
if let Expr::Constant(Constant {
93+
scalar: Scalar::Boolean(false),
94+
..
95+
}) = folded_expr
96+
{
97+
results.push(value.clone());
98+
}
99+
}
100+
101+
// let's check if it contains or for other columns
102+
if results.is_empty() && !values.is_empty() {
103+
let mut results_all_used = true;
104+
// let's check or function that
105+
// for the equality columns set,let's call `ecs`
106+
// if any side of `or` of the name columns is not in `ecs`, it's valid
107+
// otherwise, it's invalid
108+
expr.visit_func("or", &mut |call| {
109+
for arg in call.args.iter() {
110+
let mut ecs = HashSet::new();
111+
arg.find_function_literals("eq", &mut |col_name, _scalar, _| {
112+
ecs.insert(col_name.name());
113+
});
114+
115+
if !ecs.contains(*name) {
116+
results_all_used = false;
117+
}
118+
}
119+
});
120+
if results_all_used {
121+
results = values;
122+
}
123+
}
124+
scalars.push(results);
125+
} else {
126+
scalars.push(vec![]);
127+
}
128+
}
129+
Ok(scalars)
130+
}
39131
}

0 commit comments

Comments
 (0)