Skip to content

Commit ebf39c4

Browse files
committed
nit
1 parent 87b5d29 commit ebf39c4

File tree

11 files changed

+280
-34
lines changed

11 files changed

+280
-34
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ private[spark] object PythonEvalType {
6969

7070
// Arrow UDFs
7171
val SQL_SCALAR_ARROW_UDF = 250
72+
val SQL_SCALAR_ARROW_ITER_UDF = 251
7273

7374
val SQL_TABLE_UDF = 300
7475
val SQL_ARROW_TABLE_UDF = 301
@@ -96,7 +97,10 @@ private[spark] object PythonEvalType {
9697
case SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF => "SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF"
9798
case SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF =>
9899
"SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF"
100+
101+
// Arrow UDFs
99102
case SQL_SCALAR_ARROW_UDF => "SQL_SCALAR_ARROW_UDF"
103+
case SQL_SCALAR_ARROW_ITER_UDF => "SQL_SCALAR_ARROW_ITER_UDF"
100104
}
101105
}
102106

python/pyspark/sql/connect/udf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,7 @@ def register(
278278
PythonEvalType.SQL_SCALAR_PANDAS_UDF,
279279
PythonEvalType.SQL_SCALAR_ARROW_UDF,
280280
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
281+
PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF,
281282
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
282283
]:
283284
raise PySparkTypeError(

python/pyspark/sql/pandas/_typing/__init__.pyi

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ GroupedMapUDFTransformWithStateInitStateType = Literal[214]
6262

6363
# Arrow UDFs
6464
ArrowScalarUDFType = Literal[250]
65+
ArrowScalarIterUDFType = Literal[251]
6566

6667
class ArrowVariadicScalarToScalarFunction(Protocol):
6768
def __call__(self, *_: pyarrow.Array) -> pyarrow.Array: ...
@@ -135,6 +136,11 @@ ArrowScalarToScalarFunction = Union[
135136
],
136137
]
137138

139+
ArrowScalarIterFunction = Union[
140+
Callable[[Iterable[pyarrow.Array]], Iterable[pyarrow.Array]],
141+
Callable[[Tuple[pyarrow.Array, ...]], Iterable[pyarrow.Array]],
142+
]
143+
138144
class PandasVariadicScalarToScalarFunction(Protocol):
139145
def __call__(self, *_: DataFrameOrSeriesLike_) -> DataFrameOrSeriesLike_: ...
140146

python/pyspark/sql/pandas/functions.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ class ArrowUDFType:
4646

4747
SCALAR = PythonEvalType.SQL_SCALAR_ARROW_UDF
4848

49+
SCALAR_ITER = PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF
50+
4951

5052
def arrow_udf(f=None, returnType=None, functionType=None):
5153
return vectorized_udf(f, returnType, functionType, "arrow")
@@ -451,6 +453,7 @@ def calculate(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
451453
)
452454
if kind == "arrow" and eval_type not in [
453455
PythonEvalType.SQL_SCALAR_ARROW_UDF,
456+
PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF,
454457
None,
455458
]: # None means it should infer the type from type hints.
456459
raise PySparkTypeError(

python/pyspark/sql/pandas/functions.pyi

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ from pyspark.sql.pandas._typing import (
3737
PandasScalarUDFType,
3838
ArrowScalarToScalarFunction,
3939
ArrowScalarUDFType,
40+
ArrowScalarIterFunction,
41+
ArrowScalarIterUDFType,
4042
)
4143

4244
from pyspark import since as since # noqa: F401
@@ -51,6 +53,7 @@ class PandasUDFType:
5153

5254
class ArrowUDFType:
5355
SCALAR: ArrowScalarUDFType
56+
SCALAR_ITER: ArrowScalarIterUDFType
5457

5558
@overload
5659
def arrow_udf(
@@ -71,6 +74,24 @@ def arrow_udf(
7174
*, returnType: DataTypeOrString, functionType: ArrowScalarUDFType
7275
) -> Callable[[ArrowScalarToScalarFunction], UserDefinedFunctionLike]: ...
7376
@overload
77+
def arrow_udf(
78+
f: ArrowScalarIterFunction,
79+
returnType: Union[AtomicDataTypeOrString, ArrayType],
80+
functionType: ArrowScalarIterUDFType,
81+
) -> UserDefinedFunctionLike: ...
82+
@overload
83+
def arrow_udf(
84+
f: Union[AtomicDataTypeOrString, ArrayType], returnType: ArrowScalarIterUDFType
85+
) -> Callable[[ArrowScalarIterFunction], UserDefinedFunctionLike]: ...
86+
@overload
87+
def arrow_udf(
88+
*, returnType: Union[AtomicDataTypeOrString, ArrayType], functionType: ArrowScalarIterUDFType
89+
) -> Callable[[ArrowScalarIterFunction], UserDefinedFunctionLike]: ...
90+
@overload
91+
def arrow_udf(
92+
f: Union[AtomicDataTypeOrString, ArrayType], *, functionType: ArrowScalarIterUDFType
93+
) -> Callable[[ArrowScalarIterFunction], UserDefinedFunctionLike]: ...
94+
@overload
7495
def pandas_udf(
7596
f: PandasScalarToScalarFunction,
7697
returnType: Union[AtomicDataTypeOrString, ArrayType],

python/pyspark/sql/pandas/typehints.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
PandasScalarIterUDFType,
2727
PandasGroupedAggUDFType,
2828
ArrowScalarUDFType,
29+
ArrowScalarIterUDFType,
2930
)
3031

3132

@@ -36,6 +37,7 @@ def infer_eval_type(
3637
"PandasScalarIterUDFType",
3738
"PandasGroupedAggUDFType",
3839
"ArrowScalarUDFType",
40+
"ArrowScalarIterUDFType",
3941
]:
4042
"""
4143
Infers the evaluation type in :class:`pyspark.util.PythonEvalType` from
@@ -110,6 +112,21 @@ def infer_eval_type(
110112
)
111113
)
112114

115+
# Iterator[Tuple[pa.Array, ...] -> Iterator[pa.Array]
116+
is_iterator_tuple_array = (
117+
len(parameters_sig) == 1
118+
and check_iterator_annotation( # Iterator
119+
parameters_sig[0],
120+
parameter_check_func=lambda a: check_tuple_annotation( # Tuple
121+
a,
122+
parameter_check_func=lambda ta: (ta == Ellipsis or ta == pa.Array),
123+
),
124+
)
125+
and check_iterator_annotation(
126+
return_annotation, parameter_check_func=lambda a: a == pa.Array
127+
)
128+
)
129+
113130
# Iterator[Series, Frame or Union[DataFrame, Series]] -> Iterator[Series or Frame]
114131
is_iterator_series_or_frame = (
115132
len(parameters_sig) == 1
@@ -128,6 +145,18 @@ def infer_eval_type(
128145
)
129146
)
130147

148+
# Iterator[pa.Array] -> Iterator[pa.Array]
149+
is_iterator_array = (
150+
len(parameters_sig) == 1
151+
and check_iterator_annotation(
152+
parameters_sig[0],
153+
parameter_check_func=lambda a: (a == pd.Series or a == pa.Array),
154+
)
155+
and check_iterator_annotation(
156+
return_annotation, parameter_check_func=lambda a: a == pa.Array
157+
)
158+
)
159+
131160
# Series, Frame or Union[DataFrame, Series], ... -> Any
132161
is_series_or_frame_agg = all(
133162
a == pd.Series
@@ -152,6 +181,8 @@ def infer_eval_type(
152181
return ArrowUDFType.SCALAR
153182
elif is_iterator_tuple_series_or_frame or is_iterator_series_or_frame:
154183
return PandasUDFType.SCALAR_ITER
184+
elif is_iterator_tuple_array or is_iterator_array:
185+
return ArrowUDFType.SCALAR_ITER
155186
elif is_series_or_frame_agg:
156187
return PandasUDFType.GROUPED_AGG
157188
else:

0 commit comments

Comments
 (0)