Skip to content

Commit b232ff1

Browse files
authored
refactor: ScalarVisitor refactor replace_predicate_column (#15439)
refactor: ScalarVisitor refactor replace_predicate_x
1 parent 3b1b239 commit b232ff1

File tree

1 file changed

+27
-233
lines changed

1 file changed

+27
-233
lines changed

src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_filter_scan.rs

Lines changed: 27 additions & 233 deletions
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,12 @@ use crate::optimizer::rule::Rule;
2222
use crate::optimizer::rule::TransformResult;
2323
use crate::optimizer::RuleID;
2424
use crate::optimizer::SExpr;
25-
use crate::plans::AggregateFunction;
2625
use crate::plans::BoundColumnRef;
27-
use crate::plans::CastExpr;
2826
use crate::plans::Filter;
29-
use crate::plans::FunctionCall;
30-
use crate::plans::LagLeadFunction;
31-
use crate::plans::LambdaFunc;
32-
use crate::plans::NthValueFunction;
3327
use crate::plans::RelOp;
3428
use crate::plans::Scan;
35-
use crate::plans::UDFCall;
36-
use crate::plans::WindowFunc;
37-
use crate::plans::WindowFuncType;
38-
use crate::plans::WindowOrderBy;
29+
use crate::plans::SubqueryExpr;
30+
use crate::plans::VisitorMut;
3931
use crate::ColumnEntry;
4032
use crate::MetadataRef;
4133
use crate::ScalarExpr;
@@ -75,10 +67,16 @@ impl RulePushDownFilterScan {
7567
column_entries: &[&ColumnEntry],
7668
replace_view: bool,
7769
) -> Result<ScalarExpr> {
78-
match predicate {
79-
ScalarExpr::BoundColumnRef(column) => {
70+
struct ReplacePredicateColumnVisitor<'a> {
71+
table_entries: &'a [TableEntry],
72+
column_entries: &'a [&'a ColumnEntry],
73+
replace_view: bool,
74+
}
75+
76+
impl<'a> VisitorMut<'a> for ReplacePredicateColumnVisitor<'a> {
77+
fn visit_bound_column_ref(&mut self, column: &mut BoundColumnRef) -> Result<()> {
8078
if let Some(base_column) =
81-
column_entries
79+
self.column_entries
8280
.iter()
8381
.find_map(|column_entry| match column_entry {
8482
ColumnEntry::BaseTableColumn(base_column)
@@ -89,7 +87,8 @@ impl RulePushDownFilterScan {
8987
_ => None,
9088
})
9189
{
92-
if let Some(table_entry) = table_entries
90+
if let Some(table_entry) = self
91+
.table_entries
9392
.iter()
9493
.find(|table_entry| table_entry.index() == base_column.table_index)
9594
{
@@ -103,236 +102,31 @@ impl RulePushDownFilterScan {
103102
.database_name(Some(table_entry.database().to_string()))
104103
.table_index(Some(table_entry.index()));
105104

106-
if replace_view {
105+
if self.replace_view {
107106
column_binding_builder = column_binding_builder
108107
.virtual_computed_expr(column.column.virtual_computed_expr.clone());
109108
}
110109

111-
let bound_column_ref = BoundColumnRef {
112-
span: column.span,
113-
column: column_binding_builder.build(),
114-
};
115-
return Ok(ScalarExpr::BoundColumnRef(bound_column_ref));
110+
column.column = column_binding_builder.build();
116111
}
117112
}
118-
Ok(predicate.clone())
113+
Ok(())
119114
}
120-
ScalarExpr::WindowFunction(window) => {
121-
let func = match &window.func {
122-
WindowFuncType::Aggregate(agg) => {
123-
let args = agg
124-
.args
125-
.iter()
126-
.map(|arg| {
127-
Self::replace_predicate_column(
128-
arg,
129-
table_entries,
130-
column_entries,
131-
replace_view,
132-
)
133-
})
134-
.collect::<Result<Vec<ScalarExpr>>>()?;
135115

136-
WindowFuncType::Aggregate(AggregateFunction {
137-
func_name: agg.func_name.clone(),
138-
distinct: agg.distinct,
139-
params: agg.params.clone(),
140-
args,
141-
return_type: agg.return_type.clone(),
142-
display_name: agg.display_name.clone(),
143-
})
144-
}
145-
WindowFuncType::LagLead(ll) => {
146-
let new_arg = Self::replace_predicate_column(
147-
&ll.arg,
148-
table_entries,
149-
column_entries,
150-
replace_view,
151-
)?;
152-
let new_default = match ll.default.clone().map(|d| {
153-
Self::replace_predicate_column(
154-
&d,
155-
table_entries,
156-
column_entries,
157-
replace_view,
158-
)
159-
}) {
160-
None => None,
161-
Some(d) => Some(Box::new(d?)),
162-
};
163-
WindowFuncType::LagLead(LagLeadFunction {
164-
is_lag: ll.is_lag,
165-
arg: Box::new(new_arg),
166-
offset: ll.offset,
167-
default: new_default,
168-
return_type: ll.return_type.clone(),
169-
})
170-
}
171-
WindowFuncType::NthValue(func) => {
172-
let new_arg = Self::replace_predicate_column(
173-
&func.arg,
174-
table_entries,
175-
column_entries,
176-
replace_view,
177-
)?;
178-
WindowFuncType::NthValue(NthValueFunction {
179-
n: func.n,
180-
arg: Box::new(new_arg),
181-
return_type: func.return_type.clone(),
182-
})
183-
}
184-
func => func.clone(),
185-
};
186-
187-
let partition_by = window
188-
.partition_by
189-
.iter()
190-
.map(|arg| {
191-
Self::replace_predicate_column(
192-
arg,
193-
table_entries,
194-
column_entries,
195-
replace_view,
196-
)
197-
})
198-
.collect::<Result<Vec<ScalarExpr>>>()?;
199-
200-
let order_by = window
201-
.order_by
202-
.iter()
203-
.map(|item| {
204-
let replaced_scalar = Self::replace_predicate_column(
205-
&item.expr,
206-
table_entries,
207-
column_entries,
208-
replace_view,
209-
)?;
210-
Ok(WindowOrderBy {
211-
expr: replaced_scalar,
212-
asc: item.asc,
213-
nulls_first: item.nulls_first,
214-
})
215-
})
216-
.collect::<Result<Vec<WindowOrderBy>>>()?;
217-
218-
Ok(ScalarExpr::WindowFunction(WindowFunc {
219-
span: window.span,
220-
display_name: window.display_name.clone(),
221-
func,
222-
partition_by,
223-
order_by,
224-
frame: window.frame.clone(),
225-
}))
116+
fn visit_subquery_expr(&mut self, _subquery: &'a mut SubqueryExpr) -> Result<()> {
117+
Ok(())
226118
}
227-
ScalarExpr::AggregateFunction(agg_func) => {
228-
let args = agg_func
229-
.args
230-
.iter()
231-
.map(|arg| {
232-
Self::replace_predicate_column(
233-
arg,
234-
table_entries,
235-
column_entries,
236-
replace_view,
237-
)
238-
})
239-
.collect::<Result<Vec<ScalarExpr>>>()?;
240-
241-
Ok(ScalarExpr::AggregateFunction(AggregateFunction {
242-
func_name: agg_func.func_name.clone(),
243-
distinct: agg_func.distinct,
244-
params: agg_func.params.clone(),
245-
args,
246-
return_type: agg_func.return_type.clone(),
247-
display_name: agg_func.display_name.clone(),
248-
}))
249-
}
250-
ScalarExpr::LambdaFunction(lambda_func) => {
251-
let args = lambda_func
252-
.args
253-
.iter()
254-
.map(|arg| {
255-
Self::replace_predicate_column(
256-
arg,
257-
table_entries,
258-
column_entries,
259-
replace_view,
260-
)
261-
})
262-
.collect::<Result<Vec<ScalarExpr>>>()?;
263-
264-
Ok(ScalarExpr::LambdaFunction(LambdaFunc {
265-
span: lambda_func.span,
266-
func_name: lambda_func.func_name.clone(),
267-
args,
268-
lambda_expr: lambda_func.lambda_expr.clone(),
269-
lambda_display: lambda_func.lambda_display.clone(),
270-
return_type: lambda_func.return_type.clone(),
271-
}))
272-
}
273-
ScalarExpr::FunctionCall(func) => {
274-
let arguments = func
275-
.arguments
276-
.iter()
277-
.map(|arg| {
278-
Self::replace_predicate_column(
279-
arg,
280-
table_entries,
281-
column_entries,
282-
replace_view,
283-
)
284-
})
285-
.collect::<Result<Vec<ScalarExpr>>>()?;
286-
287-
Ok(ScalarExpr::FunctionCall(FunctionCall {
288-
span: func.span,
289-
params: func.params.clone(),
290-
arguments,
291-
func_name: func.func_name.clone(),
292-
}))
293-
}
294-
ScalarExpr::CastExpr(cast) => {
295-
let arg = Self::replace_predicate_column(
296-
&cast.argument,
297-
table_entries,
298-
column_entries,
299-
replace_view,
300-
)?;
301-
Ok(ScalarExpr::CastExpr(CastExpr {
302-
span: cast.span,
303-
is_try: cast.is_try,
304-
argument: Box::new(arg),
305-
target_type: cast.target_type.clone(),
306-
}))
307-
}
308-
ScalarExpr::UDFCall(udf) => {
309-
let arguments = udf
310-
.arguments
311-
.iter()
312-
.map(|arg| {
313-
Self::replace_predicate_column(
314-
arg,
315-
table_entries,
316-
column_entries,
317-
replace_view,
318-
)
319-
})
320-
.collect::<Result<Vec<ScalarExpr>>>()?;
119+
}
321120

322-
Ok(ScalarExpr::UDFCall(UDFCall {
323-
span: udf.span,
324-
name: udf.name.clone(),
325-
func_name: udf.func_name.clone(),
326-
display_name: udf.display_name.clone(),
327-
udf_type: udf.udf_type.clone(),
328-
arg_types: udf.arg_types.clone(),
329-
return_type: udf.return_type.clone(),
330-
arguments,
331-
}))
332-
}
121+
let mut visitor = ReplacePredicateColumnVisitor {
122+
table_entries,
123+
column_entries,
124+
replace_view,
125+
};
126+
let mut predicate = predicate.clone();
127+
visitor.visit(&mut predicate)?;
333128

334-
_ => Ok(predicate.clone()),
335-
}
129+
Ok(predicate.clone())
336130
}
337131

338132
fn find_push_down_predicates(

0 commit comments

Comments
 (0)