Skip to content

Commit 4b86268

Browse files
benhurdelheyHyukjinKwon
authored andcommitted
[SPARK-52082][PYTHON][DOCS] Improve ExtractPythonUDF docs
### What changes were proposed in this pull request? - renames two methods in ExtractPythonUDFs, and adds docstrings explaining the parallel fusing and chaining concepts ### Why are the changes needed? - in my experience, new developers find the planning code hard to understand without sufficient explanations. The current method naming is confusing, as the `canChainUDF` is actually used select eligibility to fuse parallel udf invocations like `udf1(), udf2()`. ### Does this PR introduce _any_ user-facing change? - No ### How was this patch tested? - Existing tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #50867 from benrobby/SPARK-52082. Authored-by: Ben Hurdelhey <ben.hurdelhey@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
1 parent 5a4932b commit 4b86268

File tree

1 file changed

+56
-7
lines changed

1 file changed

+56
-7
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala

Lines changed: 56 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -169,16 +169,63 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] with Logging {
169169
e.exists(PythonUDF.isScalarPythonUDF)
170170
}
171171

172+
/**
173+
* Return true if we should extract the current expression, including all of its current
174+
* children (including UDF expression, and all others), to a logical node.
175+
* The children of the expression can be UDF expressions, this would be nested chaining.
176+
* If child UDF expressions were already extracted before, then this will just extract
177+
* the current UDF expression, so they will end up in separate logical nodes. The child
178+
* expressions will have been transformed to Attribute expressions referencing the child plan
179+
* node's output.
180+
*
181+
* Return false if there is no single continuous chain of UDFs that can be extracted:
182+
* - if there are other expression in-between, return false. In
183+
* below example, the caller will have to extract bar(baz()) separately first:
184+
* Query: foo(1 + bar(baz()))
185+
* Plan:
186+
* - PythonUDF (foo)
187+
* - Project
188+
* - PythonUDF (bar)
189+
* - PythonUDF (baz)
190+
* - if the eval types of the UDF expressions in the chain differ, return false.
191+
* - if a UDF has more than one child, e.g. foo(bar(), baz()), return false
192+
* If we return false here, the expectation is that the recursive calls of
193+
* collectEvaluableUDFsFromExpressions will then visit the children and extract them first to
194+
* separate nodes.
195+
*/
172196
@scala.annotation.tailrec
173-
private def canEvaluateInPython(e: PythonUDF): Boolean = {
197+
private def shouldExtractUDFExpressionTree(e: PythonUDF): Boolean = {
174198
e.children match {
175-
// single PythonUDF child could be chained and evaluated in Python
176-
case Seq(u: PythonUDF) => correctEvalType(e) == correctEvalType(u) && canEvaluateInPython(u)
199+
case Seq(child: PythonUDF) => correctEvalType(e) == correctEvalType(child) &&
200+
shouldExtractUDFExpressionTree(child)
177201
// Python UDF can't be evaluated directly in JVM
178202
case children => !children.exists(hasScalarPythonUDF)
179203
}
180204
}
181205

206+
/**
207+
* We use the following terminology:
208+
* - chaining is the act of combining multiple UDFs into a single logical node. This can be
209+
* accomplished in different cases, for example:
210+
* - parallel chaining: if the UDFs are siblings, e.g., foo(x), bar(x),
211+
* where multiple independent UDFs are evaluated together over the same input
212+
* - nested chaining: if the UDFs are nested, e.g., foo(bar(...)),
213+
* where the output of one UDF feeds into the next in a sequential pipeline
214+
*
215+
* collectEvaluableUDFsFromExpressions returns a list of UDF expressions that can be planned
216+
* together into one plan node. collectEvaluableUDFsFromExpressions will be called multiple times
217+
* by recursive calls of extract(plan), until no more evaluable UDFs are found.
218+
*
219+
* As an example, consider the following expression tree:
220+
* udf1(udf2(udf3(x)), udf4(x))), where all UDFs are PythonUDFs of the same evaltype.
221+
* We can only fuse UDFs of the same eval type, and never UDFs of SQL_SCALAR_PANDAS_ITER_UDF.
222+
* The following udf expressions will be returned:
223+
* - First, we will return Seq(udf3, udf4), as these two UDFs must be evaluated first.
224+
* We return both in one Seq, as it is possible to do parallel fusing for udf3 an udf4.
225+
* - As we can only chain UDFs with exactly one child, we will not fuse udf2 with its children.
226+
* But we can chain udf1 and udf2, so a later call to collectEvaluableUDFsFromExpressions will
227+
* return Seq(udf1, udf2).
228+
*/
182229
private def collectEvaluableUDFsFromExpressions(expressions: Seq[Expression]): Seq[PythonUDF] = {
183230
// If first UDF is SQL_SCALAR_PANDAS_ITER_UDF or SQL_SCALAR_ARROW_ITER_UDF,
184231
// then only return this UDF,
@@ -187,7 +234,7 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] with Logging {
187234

188235
var firstVisitedScalarUDFEvalType: Option[Int] = None
189236

190-
def canChainUDF(evalType: Int): Boolean = {
237+
def canChainWithParallelUDFs(evalType: Int): Boolean = {
191238
if (evalType == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF ||
192239
evalType == PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF) {
193240
false
@@ -197,12 +244,14 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] with Logging {
197244
}
198245

199246
def collectEvaluableUDFs(expr: Expression): Seq[PythonUDF] = expr match {
200-
case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf)
247+
case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf)
248+
&& shouldExtractUDFExpressionTree(udf)
201249
&& firstVisitedScalarUDFEvalType.isEmpty =>
202250
firstVisitedScalarUDFEvalType = Some(correctEvalType(udf))
203251
Seq(udf)
204-
case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf)
205-
&& canChainUDF(correctEvalType(udf)) =>
252+
case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf)
253+
&& shouldExtractUDFExpressionTree(udf)
254+
&& canChainWithParallelUDFs(correctEvalType(udf)) =>
206255
Seq(udf)
207256
case e => e.children.flatMap(collectEvaluableUDFs)
208257
}

0 commit comments

Comments
 (0)