Skip to content

Commit 692f1b6

Browse files
committed
[SPARK-52278][PYTHON] Scalar Arrow UDF support named arguments
### What changes were proposed in this pull request? Scalar Arrow UDF support named arguments ### Why are the changes needed? for feature parity with pandas UDF ### Does this PR introduce _any_ user-facing change? no, Arrow UDF is not public now ### How was this patch tested? added tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #50996 from zhengruifeng/py_arrow_udf_test_named_args. Authored-by: Ruifeng Zheng <ruifengz@apache.org> Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
1 parent b8e6dc2 commit 692f1b6

File tree

1 file changed

+108
-2
lines changed

1 file changed

+108
-2
lines changed

python/pyspark/sql/tests/arrow/test_arrow_udf_scalar.py

Lines changed: 108 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
MapType,
4545
BinaryType,
4646
)
47-
from pyspark.errors import AnalysisException
47+
from pyspark.errors import AnalysisException, PythonException
4848
from pyspark.testing.sqlutils import (
4949
ReusedSQLTestCase,
5050
have_pyarrow,
@@ -726,7 +726,113 @@ def scalar_f(id):
726726
res = df.select(scalar_g(scalar_f(F.col("id"))).alias("res"))
727727
self.assertEqual(expected, res.collect())
728728

729-
# TODO: add tests for named arguments
729+
def test_arrow_udf_named_arguments(self):
730+
import pyarrow as pa
731+
732+
@arrow_udf("int")
733+
def test_udf(a, b):
734+
return pa.compute.add(a, pa.compute.multiply(b, 10)).cast(pa.int32())
735+
736+
self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS test_udf")
737+
self.spark.udf.register("test_udf", test_udf)
738+
739+
expected = [Row(0), Row(101)]
740+
for i, df in enumerate(
741+
[
742+
self.spark.range(2).select(test_udf(F.col("id"), b=F.col("id") * 10)),
743+
self.spark.range(2).select(test_udf(a=F.col("id"), b=F.col("id") * 10)),
744+
self.spark.range(2).select(test_udf(b=F.col("id") * 10, a=F.col("id"))),
745+
self.spark.sql("SELECT test_udf(id, b => id * 10) FROM range(2)"),
746+
self.spark.sql("SELECT test_udf(a => id, b => id * 10) FROM range(2)"),
747+
self.spark.sql("SELECT test_udf(b => id * 10, a => id) FROM range(2)"),
748+
]
749+
):
750+
with self.subTest(query_no=i):
751+
self.assertEqual(expected, df.collect())
752+
753+
def test_arrow_udf_named_arguments_negative(self):
754+
import pyarrow as pa
755+
756+
@arrow_udf("int")
757+
def test_udf(a, b):
758+
return pa.compute.add(a, b).cast(pa.int32())
759+
760+
self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS test_udf")
761+
self.spark.udf.register("test_udf", test_udf)
762+
763+
with self.assertRaisesRegex(
764+
AnalysisException,
765+
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
766+
):
767+
self.spark.sql("SELECT test_udf(a => id, a => id * 10) FROM range(2)").show()
768+
769+
with self.assertRaisesRegex(AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT"):
770+
self.spark.sql("SELECT test_udf(a => id, id * 10) FROM range(2)").show()
771+
772+
with self.assertRaisesRegex(
773+
PythonException, r"test_udf\(\) got an unexpected keyword argument 'c'"
774+
):
775+
self.spark.sql("SELECT test_udf(c => 'x') FROM range(2)").show()
776+
777+
def test_arrow_udf_named_arguments_and_defaults(self):
778+
import pyarrow as pa
779+
780+
@arrow_udf("int")
781+
def test_udf(a, b=0):
782+
return pa.compute.add(a, pa.compute.multiply(b, 10)).cast(pa.int32())
783+
784+
self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS test_udf")
785+
self.spark.udf.register("test_udf", test_udf)
786+
787+
# without "b"
788+
expected = [Row(0), Row(1)]
789+
for i, df in enumerate(
790+
[
791+
self.spark.range(2).select(test_udf(F.col("id"))),
792+
self.spark.range(2).select(test_udf(a=F.col("id"))),
793+
self.spark.sql("SELECT test_udf(id) FROM range(2)"),
794+
self.spark.sql("SELECT test_udf(a => id) FROM range(2)"),
795+
]
796+
):
797+
with self.subTest(with_b=False, query_no=i):
798+
self.assertEqual(expected, df.collect())
799+
800+
# with "b"
801+
expected = [Row(0), Row(101)]
802+
for i, df in enumerate(
803+
[
804+
self.spark.range(2).select(test_udf(F.col("id"), b=F.col("id") * 10)),
805+
self.spark.range(2).select(test_udf(a=F.col("id"), b=F.col("id") * 10)),
806+
self.spark.range(2).select(test_udf(b=F.col("id") * 10, a=F.col("id"))),
807+
self.spark.sql("SELECT test_udf(id, b => id * 10) FROM range(2)"),
808+
self.spark.sql("SELECT test_udf(a => id, b => id * 10) FROM range(2)"),
809+
self.spark.sql("SELECT test_udf(b => id * 10, a => id) FROM range(2)"),
810+
]
811+
):
812+
with self.subTest(with_b=True, query_no=i):
813+
self.assertEqual(expected, df.collect())
814+
815+
def test_arrow_udf_kwargs(self):
816+
import pyarrow as pa
817+
818+
@arrow_udf("int")
819+
def test_udf(a, **kwargs):
820+
return pa.compute.add(a, pa.compute.multiply(kwargs["b"], 10)).cast(pa.int32())
821+
822+
self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS test_udf")
823+
self.spark.udf.register("test_udf", test_udf)
824+
825+
expected = [Row(0), Row(101)]
826+
for i, df in enumerate(
827+
[
828+
self.spark.range(2).select(test_udf(a=F.col("id"), b=F.col("id") * 10)),
829+
self.spark.range(2).select(test_udf(b=F.col("id") * 10, a=F.col("id"))),
830+
self.spark.sql("SELECT test_udf(a => id, b => id * 10) FROM range(2)"),
831+
self.spark.sql("SELECT test_udf(b => id * 10, a => id) FROM range(2)"),
832+
]
833+
):
834+
with self.subTest(query_no=i):
835+
self.assertEqual(expected, df.collect())
730836

731837

732838
class ScalarArrowUDFTests(ScalarArrowUDFTestsMixin, ReusedSQLTestCase):

0 commit comments

Comments
 (0)