From 70ef41c610dc71a4e7702fc9ada92774ab55f35b Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 13 Mar 2025 19:04:16 -0700 Subject: [PATCH 1/2] [SPARK-51508] Support `collect(): [[String?]]` for `DataFrame` --- Sources/SparkConnect/DataFrame.swift | 29 +++++++++++++++---- Sources/SparkConnect/SparkConnectClient.swift | 2 ++ Sources/SparkConnect/SparkSession.swift | 6 ++-- Tests/SparkConnectTests/DataFrameTests.swift | 21 ++++++++++++-- .../SparkConnectTests/SparkSessionTests.swift | 4 +++ 5 files changed, 51 insertions(+), 11 deletions(-) diff --git a/Sources/SparkConnect/DataFrame.swift b/Sources/SparkConnect/DataFrame.swift index 7ab4fc6..81b74b1 100644 --- a/Sources/SparkConnect/DataFrame.swift +++ b/Sources/SparkConnect/DataFrame.swift @@ -58,7 +58,7 @@ public actor DataFrame: Sendable { /// Add `Apache Arrow`'s `RecordBatch`s to the internal array. /// - Parameter batches: An array of ``RecordBatch``. - private func addBathes(_ batches: [RecordBatch]) { + private func addBatches(_ batches: [RecordBatch]) { self.batches.append(contentsOf: batches) } @@ -153,16 +153,35 @@ public actor DataFrame: Sendable { let arrowResult = ArrowReader.makeArrowReaderResult() _ = reader.fromMessage(schema, dataBody: Data(), result: arrowResult) _ = reader.fromMessage(dataHeader, dataBody: dataBody, result: arrowResult) - await self.addBathes(arrowResult.batches) + await self.addBatches(arrowResult.batches) } } } } } - /// This is designed not to support this feature in order to simplify the Swift client. - public func collect() async throws { - throw SparkConnectError.UnsupportedOperationException + /// Execute the plan and return the result as ``[[String?]]``. + /// - Returns: ``[[String?]]`` + public func collect() async throws -> [[String?]] { + try await execute() + + var result: [[String?]] = [] + for batch in self.batches { + for i in 0.. Date: Thu, 13 Mar 2025 19:45:31 -0700 Subject: [PATCH 2/2] clean up print statement --- Tests/SparkConnectTests/DataFrameTests.swift | 2 -- 1 file changed, 2 deletions(-) diff --git a/Tests/SparkConnectTests/DataFrameTests.swift b/Tests/SparkConnectTests/DataFrameTests.swift index c44254b..a1c7e7e 100644 --- a/Tests/SparkConnectTests/DataFrameTests.swift +++ b/Tests/SparkConnectTests/DataFrameTests.swift @@ -138,8 +138,6 @@ struct DataFrameTests { func orderBy() async throws { let spark = try await SparkSession.builder.getOrCreate() let expected = (1...10).map{ [String($0)] } - print(expected) - print(try await spark.range(10, 0, -1).orderBy("id").collect()) #expect(try await spark.range(10, 0, -1).orderBy("id").collect() == expected) await spark.stop() }