@@ -169,16 +169,63 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] with Logging {
169
169
e.exists(PythonUDF .isScalarPythonUDF)
170
170
}
171
171
172
+ /**
173
+ * Return true if we should extract the current expression, including all of its current
174
+ * children (including UDF expression, and all others), to a logical node.
175
+ * The children of the expression can be UDF expressions, this would be nested chaining.
176
+ * If child UDF expressions were already extracted before, then this will just extract
177
+ * the current UDF expression, so they will end up in separate logical nodes. The child
178
+ * expressions will have been transformed to Attribute expressions referencing the child plan
179
+ * node's output.
180
+ *
181
+ * Return false if there is no single continuous chain of UDFs that can be extracted:
182
+ * - if there are other expression in-between, return false. In
183
+ * below example, the caller will have to extract bar(baz()) separately first:
184
+ * Query: foo(1 + bar(baz()))
185
+ * Plan:
186
+ * - PythonUDF (foo)
187
+ * - Project
188
+ * - PythonUDF (bar)
189
+ * - PythonUDF (baz)
190
+ * - if the eval types of the UDF expressions in the chain differ, return false.
191
+ * - if a UDF has more than one child, e.g. foo(bar(), baz()), return false
192
+ * If we return false here, the expectation is that the recursive calls of
193
+ * collectEvaluableUDFsFromExpressions will then visit the children and extract them first to
194
+ * separate nodes.
195
+ */
172
196
@ scala.annotation.tailrec
173
- private def canEvaluateInPython (e : PythonUDF ): Boolean = {
197
+ private def shouldExtractUDFExpressionTree (e : PythonUDF ): Boolean = {
174
198
e.children match {
175
- // single PythonUDF child could be chained and evaluated in Python
176
- case Seq ( u : PythonUDF ) => correctEvalType(e) == correctEvalType(u) && canEvaluateInPython(u )
199
+ case Seq ( child : PythonUDF ) => correctEvalType(e) == correctEvalType(child) &&
200
+ shouldExtractUDFExpressionTree(child )
177
201
// Python UDF can't be evaluated directly in JVM
178
202
case children => ! children.exists(hasScalarPythonUDF)
179
203
}
180
204
}
181
205
206
+ /**
207
+ * We use the following terminology:
208
+ * - chaining is the act of combining multiple UDFs into a single logical node. This can be
209
+ * accomplished in different cases, for example:
210
+ * - parallel chaining: if the UDFs are siblings, e.g., foo(x), bar(x),
211
+ * where multiple independent UDFs are evaluated together over the same input
212
+ * - nested chaining: if the UDFs are nested, e.g., foo(bar(...)),
213
+ * where the output of one UDF feeds into the next in a sequential pipeline
214
+ *
215
+ * collectEvaluableUDFsFromExpressions returns a list of UDF expressions that can be planned
216
+ * together into one plan node. collectEvaluableUDFsFromExpressions will be called multiple times
217
+ * by recursive calls of extract(plan), until no more evaluable UDFs are found.
218
+ *
219
+ * As an example, consider the following expression tree:
220
+ * udf1(udf2(udf3(x)), udf4(x))), where all UDFs are PythonUDFs of the same evaltype.
221
+ * We can only fuse UDFs of the same eval type, and never UDFs of SQL_SCALAR_PANDAS_ITER_UDF.
222
+ * The following udf expressions will be returned:
223
+ * - First, we will return Seq(udf3, udf4), as these two UDFs must be evaluated first.
224
+ * We return both in one Seq, as it is possible to do parallel fusing for udf3 an udf4.
225
+ * - As we can only chain UDFs with exactly one child, we will not fuse udf2 with its children.
226
+ * But we can chain udf1 and udf2, so a later call to collectEvaluableUDFsFromExpressions will
227
+ * return Seq(udf1, udf2).
228
+ */
182
229
private def collectEvaluableUDFsFromExpressions (expressions : Seq [Expression ]): Seq [PythonUDF ] = {
183
230
// If first UDF is SQL_SCALAR_PANDAS_ITER_UDF or SQL_SCALAR_ARROW_ITER_UDF,
184
231
// then only return this UDF,
@@ -187,7 +234,7 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] with Logging {
187
234
188
235
var firstVisitedScalarUDFEvalType : Option [Int ] = None
189
236
190
- def canChainUDF (evalType : Int ): Boolean = {
237
+ def canChainWithParallelUDFs (evalType : Int ): Boolean = {
191
238
if (evalType == PythonEvalType .SQL_SCALAR_PANDAS_ITER_UDF ||
192
239
evalType == PythonEvalType .SQL_SCALAR_ARROW_ITER_UDF ) {
193
240
false
@@ -197,12 +244,14 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] with Logging {
197
244
}
198
245
199
246
def collectEvaluableUDFs (expr : Expression ): Seq [PythonUDF ] = expr match {
200
- case udf : PythonUDF if PythonUDF .isScalarPythonUDF(udf) && canEvaluateInPython(udf)
247
+ case udf : PythonUDF if PythonUDF .isScalarPythonUDF(udf)
248
+ && shouldExtractUDFExpressionTree(udf)
201
249
&& firstVisitedScalarUDFEvalType.isEmpty =>
202
250
firstVisitedScalarUDFEvalType = Some (correctEvalType(udf))
203
251
Seq (udf)
204
- case udf : PythonUDF if PythonUDF .isScalarPythonUDF(udf) && canEvaluateInPython(udf)
205
- && canChainUDF(correctEvalType(udf)) =>
252
+ case udf : PythonUDF if PythonUDF .isScalarPythonUDF(udf)
253
+ && shouldExtractUDFExpressionTree(udf)
254
+ && canChainWithParallelUDFs(correctEvalType(udf)) =>
206
255
Seq (udf)
207
256
case e => e.children.flatMap(collectEvaluableUDFs)
208
257
}
0 commit comments