diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index bff7f337314b6..323aea3a59ce2 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -1143,6 +1143,49 @@ def __repr__(self): return "GroupArrowUDFSerializer" +class AggArrowUDFSerializer(ArrowStreamArrowUDFSerializer): + def __init__( + self, + timezone, + safecheck, + assign_cols_by_name, + arrow_cast, + ): + super().__init__( + timezone=timezone, + safecheck=safecheck, + assign_cols_by_name=False, + arrow_cast=True, + ) + self._timezone = timezone + self._safecheck = safecheck + self._assign_cols_by_name = assign_cols_by_name + self._arrow_cast = arrow_cast + + def load_stream(self, stream): + """ + Flatten the struct into Arrow's record batches. + """ + import pyarrow as pa + + dataframes_in_group = None + + while dataframes_in_group is None or dataframes_in_group > 0: + dataframes_in_group = read_int(stream) + + if dataframes_in_group == 1: + yield pa.concat_batches(ArrowStreamSerializer.load_stream(self, stream)) + + elif dataframes_in_group != 0: + raise PySparkValueError( + errorClass="INVALID_NUMBER_OF_DATAFRAMES_IN_GROUP", + messageParameters={"dataframes_in_group": str(dataframes_in_group)}, + ) + + def __repr__(self): + return "AggArrowUDFSerializer" + + class GroupPandasUDFSerializer(ArrowStreamPandasUDFSerializer): def __init__( self, diff --git a/python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py b/python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py index 765bc7ba6fe13..8d3d929096b18 100644 --- a/python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py +++ b/python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py @@ -360,6 +360,7 @@ def test_arrow_batch_slicing(self): df = df.withColumns(cols) def min_max_v(table): + assert len(table) == 10000000 / 2, len(table) return pa.Table.from_pydict( { "key": [table.column("key")[0].as_py()], @@ -372,8 +373,7 @@ def min_max_v(table): df.groupby("key").agg(sf.min("v").alias("min"), sf.max("v").alias("max")).sort("key") ).collect() - int_max = 2147483647 - for maxRecords, maxBytes in [(1000, int_max), (0, 1048576), (1000, 1048576)]: + for maxRecords, maxBytes in [(1000, 2**31 - 1), (0, 1048576), (1000, 1048576)]: with self.subTest(maxRecords=maxRecords, maxBytes=maxBytes): with self.sql_conf( { diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py b/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py index 3fe6d28c66a66..fae9650b2864a 100644 --- a/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py +++ b/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py @@ -986,6 +986,34 @@ def arrow_lit_1() -> int: ) self.assertEqual(expected2.collect(), result2.collect()) + def test_arrow_batch_slicing(self): + import pyarrow as pa + + df = self.spark.range(10000000).select( + (sf.col("id") % 2).alias("key"), sf.col("id").alias("v") + ) + + @arrow_udf("long", ArrowUDFType.GROUPED_AGG) + def arrow_max(v): + assert len(v) == 10000000 / 2, len(v) + return pa.compute.max(v) + + expected = (df.groupby("key").agg(sf.max("v").alias("res")).sort("key")).collect() + + for maxRecords, maxBytes in [(1000, 2**31 - 1), (0, 1048576), (1000, 1048576)]: + with self.subTest(maxRecords=maxRecords, maxBytes=maxBytes): + with self.sql_conf( + { + "spark.sql.execution.arrow.maxRecordsPerBatch": maxRecords, + "spark.sql.execution.arrow.maxBytesPerBatch": maxBytes, + } + ): + result = ( + df.groupBy("key").agg(arrow_max("v").alias("res")).sort("key") + ).collect() + + self.assertEqual(expected, result) + class GroupedAggArrowUDFTests(GroupedAggArrowUDFTestsMixin, ReusedSQLTestCase): pass diff --git a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py index f81c774c0e915..fb81cd7727773 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py @@ -954,6 +954,7 @@ def test_arrow_batch_slicing(self): df = df.withColumns(cols) def min_max_v(pdf): + assert len(pdf) == 10000000 / 2, len(pdf) return pd.DataFrame( { "key": [pdf.key.iloc[0]], @@ -966,8 +967,7 @@ def min_max_v(pdf): df.groupby("key").agg(sf.min("v").alias("min"), sf.max("v").alias("max")).sort("key") ).collect() - int_max = 2147483647 - for maxRecords, maxBytes in [(1000, int_max), (0, 1048576), (1000, 1048576)]: + for maxRecords, maxBytes in [(1000, 2**31 - 1), (0, 1048576), (1000, 1048576)]: with self.subTest(maxRecords=maxRecords, maxBytes=maxBytes): with self.sql_conf( { diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py b/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py index a08401b087da7..cfcbb96fcc36f 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py @@ -801,6 +801,7 @@ def test_arrow_batch_slicing(self): @pandas_udf("long", PandasUDFType.GROUPED_AGG) def pandas_max(v): + assert len(v) == 10000000 / 2, len(v) return v.max() expected = (df.groupby("key").agg(sf.max("v").alias("res")).sort("key")).collect() diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 56e03f5959ede..c3ba8bc7063cb 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -52,6 +52,7 @@ from pyspark.sql.conversion import LocalDataToArrowConversion, ArrowTableToRowsConversion from pyspark.sql.functions import SkipRestOfInputTableException from pyspark.sql.pandas.serializers import ( + AggArrowUDFSerializer, ArrowStreamPandasUDFSerializer, ArrowStreamPandasUDTFSerializer, GroupPandasUDFSerializer, @@ -2611,6 +2612,8 @@ def read_udfs(pickleSer, infile, eval_type): or eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF ): ser = GroupArrowUDFSerializer(_assign_cols_by_name) + elif eval_type == PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF: + ser = AggArrowUDFSerializer(timezone, True, _assign_cols_by_name, True) elif eval_type in ( PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, @@ -2700,7 +2703,6 @@ def read_udfs(pickleSer, infile, eval_type): elif eval_type in ( PythonEvalType.SQL_SCALAR_ARROW_UDF, PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF, - PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF, PythonEvalType.SQL_WINDOW_AGG_ARROW_UDF, ): # Arrow cast and safe check are always enabled diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowAggregatePythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowAggregatePythonExec.scala index 8b5f65a8a5aad..f4e8831f23b85 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowAggregatePythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowAggregatePythonExec.scala @@ -180,31 +180,17 @@ case class ArrowAggregatePythonExec( rows } - val runner = if (evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF) { - new ArrowPythonWithNamedArgumentRunner( - pyFuncs, - evalType, - argMetas, - aggInputSchema, - sessionLocalTimeZone, - largeVarTypes, - pythonRunnerConf, - pythonMetrics, - jobArtifactUUID, - conf.pythonUDFProfiler) with GroupedPythonArrowInput - } else { - new ArrowPythonWithNamedArgumentRunner( - pyFuncs, - evalType, - argMetas, - aggInputSchema, - sessionLocalTimeZone, - largeVarTypes, - pythonRunnerConf, - pythonMetrics, - jobArtifactUUID, - conf.pythonUDFProfiler) - } + val runner = new ArrowPythonWithNamedArgumentRunner( + pyFuncs, + evalType, + argMetas, + aggInputSchema, + sessionLocalTimeZone, + largeVarTypes, + pythonRunnerConf, + pythonMetrics, + jobArtifactUUID, + conf.pythonUDFProfiler) with GroupedPythonArrowInput val columnarBatchIter = runner.compute(projectedRowIter, context.partitionId(), context)