@@ -22,21 +22,13 @@ use crate::optimizer::rule::RuleID;
22
22
use crate :: optimizer:: rule:: TransformResult ;
23
23
use crate :: optimizer:: RelExpr ;
24
24
use crate :: optimizer:: SExpr ;
25
- use crate :: plans:: AggregateFunction ;
26
- use crate :: plans:: CastExpr ;
25
+ use crate :: plans:: walk_expr_mut;
27
26
use crate :: plans:: EvalScalar ;
28
27
use crate :: plans:: Filter ;
29
- use crate :: plans:: FunctionCall ;
30
- use crate :: plans:: LagLeadFunction ;
31
- use crate :: plans:: LambdaFunc ;
32
- use crate :: plans:: NthValueFunction ;
33
28
use crate :: plans:: RelOp ;
34
29
use crate :: plans:: ScalarExpr ;
35
30
use crate :: plans:: ScalarItem ;
36
- use crate :: plans:: UDFCall ;
37
- use crate :: plans:: WindowFunc ;
38
- use crate :: plans:: WindowFuncType ;
39
- use crate :: plans:: WindowOrderBy ;
31
+ use crate :: plans:: VisitorMut ;
40
32
41
33
pub struct RulePushDownFilterEvalScalar {
42
34
id : RuleID ,
@@ -64,164 +56,28 @@ impl RulePushDownFilterEvalScalar {
64
56
65
57
// Replace predicate with children scalar items
66
58
fn replace_predicate ( predicate : & ScalarExpr , items : & [ ScalarItem ] ) -> Result < ScalarExpr > {
67
- match predicate {
68
- ScalarExpr :: BoundColumnRef ( column) => {
69
- for item in items {
70
- if item. index == column. column . index {
71
- return Ok ( item. scalar . clone ( ) ) ;
72
- }
73
- }
74
- Ok ( predicate. clone ( ) )
75
- }
76
- ScalarExpr :: WindowFunction ( window) => {
77
- let func = match & window. func {
78
- WindowFuncType :: Aggregate ( agg) => {
79
- let args = agg
80
- . args
81
- . iter ( )
82
- . map ( |arg| Self :: replace_predicate ( arg, items) )
83
- . collect :: < Result < Vec < ScalarExpr > > > ( ) ?;
84
-
85
- WindowFuncType :: Aggregate ( AggregateFunction {
86
- func_name : agg. func_name . clone ( ) ,
87
- distinct : agg. distinct ,
88
- params : agg. params . clone ( ) ,
89
- args,
90
- return_type : agg. return_type . clone ( ) ,
91
- display_name : agg. display_name . clone ( ) ,
92
- } )
93
- }
94
- WindowFuncType :: LagLead ( ll) => {
95
- let new_arg = Self :: replace_predicate ( & ll. arg , items) ?;
96
- let new_default = match ll
97
- . default
98
- . clone ( )
99
- . map ( |d| Self :: replace_predicate ( & d, items) )
100
- {
101
- None => None ,
102
- Some ( d) => Some ( Box :: new ( d?) ) ,
103
- } ;
104
- WindowFuncType :: LagLead ( LagLeadFunction {
105
- is_lag : ll. is_lag ,
106
- arg : Box :: new ( new_arg) ,
107
- offset : ll. offset ,
108
- default : new_default,
109
- return_type : ll. return_type . clone ( ) ,
110
- } )
111
- }
112
- WindowFuncType :: NthValue ( func) => {
113
- let new_arg = Self :: replace_predicate ( & func. arg , items) ?;
114
- WindowFuncType :: NthValue ( NthValueFunction {
115
- n : func. n ,
116
- arg : Box :: new ( new_arg) ,
117
- return_type : func. return_type . clone ( ) ,
118
- } )
59
+ struct PredicateVisitor < ' a > {
60
+ items : & ' a [ ScalarItem ] ,
61
+ }
62
+ impl < ' a > VisitorMut < ' a > for PredicateVisitor < ' a > {
63
+ fn visit ( & mut self , expr : & ' a mut ScalarExpr ) -> Result < ( ) > {
64
+ if let ScalarExpr :: BoundColumnRef ( column) = expr {
65
+ for item in self . items {
66
+ if item. index == column. column . index {
67
+ * expr = item. scalar . clone ( ) ;
68
+ return Ok ( ( ) ) ;
69
+ }
119
70
}
120
- func => func . clone ( ) ,
71
+ return Ok ( ( ) ) ;
121
72
} ;
122
-
123
- let partition_by = window
124
- . partition_by
125
- . iter ( )
126
- . map ( |arg| Self :: replace_predicate ( arg, items) )
127
- . collect :: < Result < Vec < ScalarExpr > > > ( ) ?;
128
-
129
- let order_by = window
130
- . order_by
131
- . iter ( )
132
- . map ( |arg| {
133
- Ok ( WindowOrderBy {
134
- asc : arg. asc ,
135
- nulls_first : arg. nulls_first ,
136
- expr : Self :: replace_predicate ( & arg. expr , items) ?,
137
- } )
138
- } )
139
- . collect :: < Result < Vec < _ > > > ( ) ?;
140
-
141
- Ok ( ScalarExpr :: WindowFunction ( WindowFunc {
142
- span : window. span ,
143
- display_name : window. display_name . clone ( ) ,
144
- func,
145
- partition_by,
146
- order_by,
147
- frame : window. frame . clone ( ) ,
148
- } ) )
73
+ walk_expr_mut ( self , expr)
149
74
}
150
- ScalarExpr :: AggregateFunction ( agg_func) => {
151
- let args = agg_func
152
- . args
153
- . iter ( )
154
- . map ( |arg| Self :: replace_predicate ( arg, items) )
155
- . collect :: < Result < Vec < ScalarExpr > > > ( ) ?;
156
-
157
- Ok ( ScalarExpr :: AggregateFunction ( AggregateFunction {
158
- func_name : agg_func. func_name . clone ( ) ,
159
- distinct : agg_func. distinct ,
160
- params : agg_func. params . clone ( ) ,
161
- args,
162
- return_type : agg_func. return_type . clone ( ) ,
163
- display_name : agg_func. display_name . clone ( ) ,
164
- } ) )
165
- }
166
- ScalarExpr :: FunctionCall ( func) => {
167
- let arguments = func
168
- . arguments
169
- . iter ( )
170
- . map ( |arg| Self :: replace_predicate ( arg, items) )
171
- . collect :: < Result < Vec < ScalarExpr > > > ( ) ?;
172
-
173
- Ok ( ScalarExpr :: FunctionCall ( FunctionCall {
174
- span : func. span ,
175
- params : func. params . clone ( ) ,
176
- arguments,
177
- func_name : func. func_name . clone ( ) ,
178
- } ) )
179
- }
180
- ScalarExpr :: LambdaFunction ( lambda_func) => {
181
- let args = lambda_func
182
- . args
183
- . iter ( )
184
- . map ( |arg| Self :: replace_predicate ( arg, items) )
185
- . collect :: < Result < Vec < ScalarExpr > > > ( ) ?;
186
-
187
- Ok ( ScalarExpr :: LambdaFunction ( LambdaFunc {
188
- span : lambda_func. span ,
189
- func_name : lambda_func. func_name . clone ( ) ,
190
- args,
191
- lambda_expr : lambda_func. lambda_expr . clone ( ) ,
192
- lambda_display : lambda_func. lambda_display . clone ( ) ,
193
- return_type : lambda_func. return_type . clone ( ) ,
194
- } ) )
195
- }
196
- ScalarExpr :: CastExpr ( cast) => {
197
- let arg = Self :: replace_predicate ( & cast. argument , items) ?;
198
- Ok ( ScalarExpr :: CastExpr ( CastExpr {
199
- span : cast. span ,
200
- is_try : cast. is_try ,
201
- argument : Box :: new ( arg) ,
202
- target_type : cast. target_type . clone ( ) ,
203
- } ) )
204
- }
205
- ScalarExpr :: UDFCall ( udf) => {
206
- let arguments = udf
207
- . arguments
208
- . iter ( )
209
- . map ( |arg| Self :: replace_predicate ( arg, items) )
210
- . collect :: < Result < Vec < ScalarExpr > > > ( ) ?;
211
-
212
- Ok ( ScalarExpr :: UDFCall ( UDFCall {
213
- span : udf. span ,
214
- name : udf. name . clone ( ) ,
215
- func_name : udf. func_name . clone ( ) ,
216
- display_name : udf. display_name . clone ( ) ,
217
- udf_type : udf. udf_type . clone ( ) ,
218
- arg_types : udf. arg_types . clone ( ) ,
219
- return_type : udf. return_type . clone ( ) ,
220
- arguments,
221
- } ) )
222
- }
223
- _ => Ok ( predicate. clone ( ) ) ,
224
75
}
76
+
77
+ let mut visitor = PredicateVisitor { items } ;
78
+ let mut predicate = predicate. clone ( ) ;
79
+ visitor. visit ( & mut predicate) ?;
80
+ Ok ( predicate. clone ( ) )
225
81
}
226
82
}
227
83
0 commit comments