Skip to content

Commit f0075de

Browse files
authored
refactor: use scalarvisitor replace predicate recursive (#15449)
1 parent ef64350 commit f0075de

File tree

1 file changed

+20
-164
lines changed

1 file changed

+20
-164
lines changed

src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_filter_eval_scalar.rs

Lines changed: 20 additions & 164 deletions
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,13 @@ use crate::optimizer::rule::RuleID;
2222
use crate::optimizer::rule::TransformResult;
2323
use crate::optimizer::RelExpr;
2424
use crate::optimizer::SExpr;
25-
use crate::plans::AggregateFunction;
26-
use crate::plans::CastExpr;
25+
use crate::plans::walk_expr_mut;
2726
use crate::plans::EvalScalar;
2827
use crate::plans::Filter;
29-
use crate::plans::FunctionCall;
30-
use crate::plans::LagLeadFunction;
31-
use crate::plans::LambdaFunc;
32-
use crate::plans::NthValueFunction;
3328
use crate::plans::RelOp;
3429
use crate::plans::ScalarExpr;
3530
use crate::plans::ScalarItem;
36-
use crate::plans::UDFCall;
37-
use crate::plans::WindowFunc;
38-
use crate::plans::WindowFuncType;
39-
use crate::plans::WindowOrderBy;
31+
use crate::plans::VisitorMut;
4032

4133
pub struct RulePushDownFilterEvalScalar {
4234
id: RuleID,
@@ -64,164 +56,28 @@ impl RulePushDownFilterEvalScalar {
6456

6557
// Replace predicate with children scalar items
6658
fn replace_predicate(predicate: &ScalarExpr, items: &[ScalarItem]) -> Result<ScalarExpr> {
67-
match predicate {
68-
ScalarExpr::BoundColumnRef(column) => {
69-
for item in items {
70-
if item.index == column.column.index {
71-
return Ok(item.scalar.clone());
72-
}
73-
}
74-
Ok(predicate.clone())
75-
}
76-
ScalarExpr::WindowFunction(window) => {
77-
let func = match &window.func {
78-
WindowFuncType::Aggregate(agg) => {
79-
let args = agg
80-
.args
81-
.iter()
82-
.map(|arg| Self::replace_predicate(arg, items))
83-
.collect::<Result<Vec<ScalarExpr>>>()?;
84-
85-
WindowFuncType::Aggregate(AggregateFunction {
86-
func_name: agg.func_name.clone(),
87-
distinct: agg.distinct,
88-
params: agg.params.clone(),
89-
args,
90-
return_type: agg.return_type.clone(),
91-
display_name: agg.display_name.clone(),
92-
})
93-
}
94-
WindowFuncType::LagLead(ll) => {
95-
let new_arg = Self::replace_predicate(&ll.arg, items)?;
96-
let new_default = match ll
97-
.default
98-
.clone()
99-
.map(|d| Self::replace_predicate(&d, items))
100-
{
101-
None => None,
102-
Some(d) => Some(Box::new(d?)),
103-
};
104-
WindowFuncType::LagLead(LagLeadFunction {
105-
is_lag: ll.is_lag,
106-
arg: Box::new(new_arg),
107-
offset: ll.offset,
108-
default: new_default,
109-
return_type: ll.return_type.clone(),
110-
})
111-
}
112-
WindowFuncType::NthValue(func) => {
113-
let new_arg = Self::replace_predicate(&func.arg, items)?;
114-
WindowFuncType::NthValue(NthValueFunction {
115-
n: func.n,
116-
arg: Box::new(new_arg),
117-
return_type: func.return_type.clone(),
118-
})
59+
struct PredicateVisitor<'a> {
60+
items: &'a [ScalarItem],
61+
}
62+
impl<'a> VisitorMut<'a> for PredicateVisitor<'a> {
63+
fn visit(&mut self, expr: &'a mut ScalarExpr) -> Result<()> {
64+
if let ScalarExpr::BoundColumnRef(column) = expr {
65+
for item in self.items {
66+
if item.index == column.column.index {
67+
*expr = item.scalar.clone();
68+
return Ok(());
69+
}
11970
}
120-
func => func.clone(),
71+
return Ok(());
12172
};
122-
123-
let partition_by = window
124-
.partition_by
125-
.iter()
126-
.map(|arg| Self::replace_predicate(arg, items))
127-
.collect::<Result<Vec<ScalarExpr>>>()?;
128-
129-
let order_by = window
130-
.order_by
131-
.iter()
132-
.map(|arg| {
133-
Ok(WindowOrderBy {
134-
asc: arg.asc,
135-
nulls_first: arg.nulls_first,
136-
expr: Self::replace_predicate(&arg.expr, items)?,
137-
})
138-
})
139-
.collect::<Result<Vec<_>>>()?;
140-
141-
Ok(ScalarExpr::WindowFunction(WindowFunc {
142-
span: window.span,
143-
display_name: window.display_name.clone(),
144-
func,
145-
partition_by,
146-
order_by,
147-
frame: window.frame.clone(),
148-
}))
73+
walk_expr_mut(self, expr)
14974
}
150-
ScalarExpr::AggregateFunction(agg_func) => {
151-
let args = agg_func
152-
.args
153-
.iter()
154-
.map(|arg| Self::replace_predicate(arg, items))
155-
.collect::<Result<Vec<ScalarExpr>>>()?;
156-
157-
Ok(ScalarExpr::AggregateFunction(AggregateFunction {
158-
func_name: agg_func.func_name.clone(),
159-
distinct: agg_func.distinct,
160-
params: agg_func.params.clone(),
161-
args,
162-
return_type: agg_func.return_type.clone(),
163-
display_name: agg_func.display_name.clone(),
164-
}))
165-
}
166-
ScalarExpr::FunctionCall(func) => {
167-
let arguments = func
168-
.arguments
169-
.iter()
170-
.map(|arg| Self::replace_predicate(arg, items))
171-
.collect::<Result<Vec<ScalarExpr>>>()?;
172-
173-
Ok(ScalarExpr::FunctionCall(FunctionCall {
174-
span: func.span,
175-
params: func.params.clone(),
176-
arguments,
177-
func_name: func.func_name.clone(),
178-
}))
179-
}
180-
ScalarExpr::LambdaFunction(lambda_func) => {
181-
let args = lambda_func
182-
.args
183-
.iter()
184-
.map(|arg| Self::replace_predicate(arg, items))
185-
.collect::<Result<Vec<ScalarExpr>>>()?;
186-
187-
Ok(ScalarExpr::LambdaFunction(LambdaFunc {
188-
span: lambda_func.span,
189-
func_name: lambda_func.func_name.clone(),
190-
args,
191-
lambda_expr: lambda_func.lambda_expr.clone(),
192-
lambda_display: lambda_func.lambda_display.clone(),
193-
return_type: lambda_func.return_type.clone(),
194-
}))
195-
}
196-
ScalarExpr::CastExpr(cast) => {
197-
let arg = Self::replace_predicate(&cast.argument, items)?;
198-
Ok(ScalarExpr::CastExpr(CastExpr {
199-
span: cast.span,
200-
is_try: cast.is_try,
201-
argument: Box::new(arg),
202-
target_type: cast.target_type.clone(),
203-
}))
204-
}
205-
ScalarExpr::UDFCall(udf) => {
206-
let arguments = udf
207-
.arguments
208-
.iter()
209-
.map(|arg| Self::replace_predicate(arg, items))
210-
.collect::<Result<Vec<ScalarExpr>>>()?;
211-
212-
Ok(ScalarExpr::UDFCall(UDFCall {
213-
span: udf.span,
214-
name: udf.name.clone(),
215-
func_name: udf.func_name.clone(),
216-
display_name: udf.display_name.clone(),
217-
udf_type: udf.udf_type.clone(),
218-
arg_types: udf.arg_types.clone(),
219-
return_type: udf.return_type.clone(),
220-
arguments,
221-
}))
222-
}
223-
_ => Ok(predicate.clone()),
22475
}
76+
77+
let mut visitor = PredicateVisitor { items };
78+
let mut predicate = predicate.clone();
79+
visitor.visit(&mut predicate)?;
80+
Ok(predicate.clone())
22581
}
22682
}
22783

0 commit comments

Comments
 (0)