Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions python/pyspark/sql/pandas/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,6 +1143,49 @@ def __repr__(self):
return "GroupArrowUDFSerializer"


class AggArrowUDFSerializer(ArrowStreamArrowUDFSerializer):
def __init__(
self,
timezone,
safecheck,
assign_cols_by_name,
arrow_cast,
):
super().__init__(
timezone=timezone,
safecheck=safecheck,
assign_cols_by_name=False,
arrow_cast=True,
)
self._timezone = timezone
self._safecheck = safecheck
self._assign_cols_by_name = assign_cols_by_name
self._arrow_cast = arrow_cast

def load_stream(self, stream):
"""
Flatten the struct into Arrow's record batches.
"""
import pyarrow as pa

dataframes_in_group = None

while dataframes_in_group is None or dataframes_in_group > 0:
dataframes_in_group = read_int(stream)

if dataframes_in_group == 1:
yield pa.concat_batches(ArrowStreamSerializer.load_stream(self, stream))

elif dataframes_in_group != 0:
raise PySparkValueError(
errorClass="INVALID_NUMBER_OF_DATAFRAMES_IN_GROUP",
messageParameters={"dataframes_in_group": str(dataframes_in_group)},
)

def __repr__(self):
return "AggArrowUDFSerializer"


class GroupPandasUDFSerializer(ArrowStreamPandasUDFSerializer):
def __init__(
self,
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ def test_arrow_batch_slicing(self):
df = df.withColumns(cols)

def min_max_v(table):
assert len(table) == 10000000 / 2, len(table)
return pa.Table.from_pydict(
{
"key": [table.column("key")[0].as_py()],
Expand All @@ -372,8 +373,7 @@ def min_max_v(table):
df.groupby("key").agg(sf.min("v").alias("min"), sf.max("v").alias("max")).sort("key")
).collect()

int_max = 2147483647
for maxRecords, maxBytes in [(1000, int_max), (0, 1048576), (1000, 1048576)]:
for maxRecords, maxBytes in [(1000, 2**31 - 1), (0, 1048576), (1000, 1048576)]:
with self.subTest(maxRecords=maxRecords, maxBytes=maxBytes):
with self.sql_conf(
{
Expand Down
28 changes: 28 additions & 0 deletions python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,6 +986,34 @@ def arrow_lit_1() -> int:
)
self.assertEqual(expected2.collect(), result2.collect())

def test_arrow_batch_slicing(self):
import pyarrow as pa

df = self.spark.range(10000000).select(
(sf.col("id") % 2).alias("key"), sf.col("id").alias("v")
)

@arrow_udf("long", ArrowUDFType.GROUPED_AGG)
def arrow_max(v):
assert len(v) == 10000000 / 2, len(v)
return pa.compute.max(v)

expected = (df.groupby("key").agg(sf.max("v").alias("res")).sort("key")).collect()

for maxRecords, maxBytes in [(1000, 2**31 - 1), (0, 1048576), (1000, 1048576)]:
with self.subTest(maxRecords=maxRecords, maxBytes=maxBytes):
with self.sql_conf(
{
"spark.sql.execution.arrow.maxRecordsPerBatch": maxRecords,
"spark.sql.execution.arrow.maxBytesPerBatch": maxBytes,
}
):
result = (
df.groupBy("key").agg(arrow_max("v").alias("res")).sort("key")
).collect()

self.assertEqual(expected, result)


class GroupedAggArrowUDFTests(GroupedAggArrowUDFTestsMixin, ReusedSQLTestCase):
pass
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,6 +954,7 @@ def test_arrow_batch_slicing(self):
df = df.withColumns(cols)

def min_max_v(pdf):
assert len(pdf) == 10000000 / 2, len(pdf)
return pd.DataFrame(
{
"key": [pdf.key.iloc[0]],
Expand All @@ -966,8 +967,7 @@ def min_max_v(pdf):
df.groupby("key").agg(sf.min("v").alias("min"), sf.max("v").alias("max")).sort("key")
).collect()

int_max = 2147483647
for maxRecords, maxBytes in [(1000, int_max), (0, 1048576), (1000, 1048576)]:
for maxRecords, maxBytes in [(1000, 2**31 - 1), (0, 1048576), (1000, 1048576)]:
with self.subTest(maxRecords=maxRecords, maxBytes=maxBytes):
with self.sql_conf(
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -801,6 +801,7 @@ def test_arrow_batch_slicing(self):

@pandas_udf("long", PandasUDFType.GROUPED_AGG)
def pandas_max(v):
assert len(v) == 10000000 / 2, len(v)
return v.max()

expected = (df.groupby("key").agg(sf.max("v").alias("res")).sort("key")).collect()
Expand Down
4 changes: 3 additions & 1 deletion python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from pyspark.sql.conversion import LocalDataToArrowConversion, ArrowTableToRowsConversion
from pyspark.sql.functions import SkipRestOfInputTableException
from pyspark.sql.pandas.serializers import (
AggArrowUDFSerializer,
ArrowStreamPandasUDFSerializer,
ArrowStreamPandasUDTFSerializer,
GroupPandasUDFSerializer,
Expand Down Expand Up @@ -2611,6 +2612,8 @@ def read_udfs(pickleSer, infile, eval_type):
or eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF
):
ser = GroupArrowUDFSerializer(_assign_cols_by_name)
elif eval_type == PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF:
ser = AggArrowUDFSerializer(timezone, True, _assign_cols_by_name, True)
elif eval_type in (
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
Expand Down Expand Up @@ -2700,7 +2703,6 @@ def read_udfs(pickleSer, infile, eval_type):
elif eval_type in (
PythonEvalType.SQL_SCALAR_ARROW_UDF,
PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF,
PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF,
PythonEvalType.SQL_WINDOW_AGG_ARROW_UDF,
):
# Arrow cast and safe check are always enabled
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,31 +180,17 @@ case class ArrowAggregatePythonExec(
rows
}

val runner = if (evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF) {
new ArrowPythonWithNamedArgumentRunner(
pyFuncs,
evalType,
argMetas,
aggInputSchema,
sessionLocalTimeZone,
largeVarTypes,
pythonRunnerConf,
pythonMetrics,
jobArtifactUUID,
conf.pythonUDFProfiler) with GroupedPythonArrowInput
} else {
new ArrowPythonWithNamedArgumentRunner(
pyFuncs,
evalType,
argMetas,
aggInputSchema,
sessionLocalTimeZone,
largeVarTypes,
pythonRunnerConf,
pythonMetrics,
jobArtifactUUID,
conf.pythonUDFProfiler)
}
val runner = new ArrowPythonWithNamedArgumentRunner(
pyFuncs,
evalType,
argMetas,
aggInputSchema,
sessionLocalTimeZone,
largeVarTypes,
pythonRunnerConf,
pythonMetrics,
jobArtifactUUID,
conf.pythonUDFProfiler) with GroupedPythonArrowInput

val columnarBatchIter = runner.compute(projectedRowIter, context.partitionId(), context)

Expand Down