Skip to content

[SPARK-51508] Support collect(): [[String?]] for DataFrame #17

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions Sources/SparkConnect/DataFrame.swift
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ public actor DataFrame: Sendable {

/// Add `Apache Arrow`'s `RecordBatch`s to the internal array.
/// - Parameter batches: An array of ``RecordBatch``.
private func addBathes(_ batches: [RecordBatch]) {
private func addBatches(_ batches: [RecordBatch]) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a typo fix.

self.batches.append(contentsOf: batches)
}

Expand Down Expand Up @@ -153,16 +153,35 @@ public actor DataFrame: Sendable {
let arrowResult = ArrowReader.makeArrowReaderResult()
_ = reader.fromMessage(schema, dataBody: Data(), result: arrowResult)
_ = reader.fromMessage(dataHeader, dataBody: dataBody, result: arrowResult)
await self.addBathes(arrowResult.batches)
await self.addBatches(arrowResult.batches)
}
}
}
}
}

/// This is designed not to support this feature in order to simplify the Swift client.
public func collect() async throws {
throw SparkConnectError.UnsupportedOperationException
/// Execute the plan and return the result as ``[[String?]]``.
/// - Returns: ``[[String?]]``
public func collect() async throws -> [[String?]] {
try await execute()

var result: [[String?]] = []
for batch in self.batches {
for i in 0..<batch.length {
var values: [String?] = []
for column in batch.columns {
let str = column.array as! AsString
Comment on lines +172 to +173
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For DataFrame, I suppose that the return of collect is an array of Row. But this collect returns strings. Is it just for initial implementation and will be Row later?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, correct. Row implementation is on the way~

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For Scala client, we also support Array[Long] like the following.

$ bin/spark-shell --remote sc://localhost:15002
scala> spark.range(1).collect()
res0: Array[java.lang.Long] = Array(0L)

if column.data.isNull(i) {
values.append(nil)
} else {
values.append(str.asString(i))
}
}
result.append(values)
}
}

return result
}

/// Execute the plan and show the result.
Expand Down
2 changes: 2 additions & 0 deletions Sources/SparkConnect/SparkConnectClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,11 @@ public actor SparkConnectClient {
let expressions: [Spark_Connect_Expression.SortOrder] = cols.map {
var expression = Spark_Connect_Expression.SortOrder()
expression.child.exprType = .unresolvedAttribute($0.toUnresolvedAttribute)
expression.direction = .ascending
return expression
}
sort.order = expressions
sort.isGlobal = true
Copy link
Member Author

@dongjoon-hyun dongjoon-hyun Mar 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I piggy-back this fix while improving a test coverage by checking the result via collect() API.

var relation = Relation()
relation.sort = sort
var plan = Plan()
Expand Down
6 changes: 2 additions & 4 deletions Sources/SparkConnect/SparkSession.swift
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,10 @@ public actor SparkSession {
/// - userID: an optional user ID. If absent, `SPARK_USER` environment or ``ProcessInfo.processInfo.userName`` is used.
init(_ connection: String, _ userID: String? = nil) {
let processInfo = ProcessInfo.processInfo
#if os(iOS) || os(watchOS) || os(tvOS)
let userName = processInfo.environment["SPARK_USER"] ?? ""
#elseif os(macOS) || os(Linux)
#if os(macOS) || os(Linux)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I simplified the implementation in a new way to cover os(visionOS) too.

let userName = processInfo.environment["SPARK_USER"] ?? processInfo.userName
#else
assert(false, "Unsupported platform")
let userName = processInfo.environment["SPARK_USER"] ?? ""
#endif
self.client = SparkConnectClient(remote: connection, user: userID ?? userName)
self.conf = RuntimeConf(self.client)
Expand Down
19 changes: 17 additions & 2 deletions Tests/SparkConnectTests/DataFrameTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -125,19 +125,23 @@ struct DataFrameTests {
await spark.stop()
}

#if !os(Linux)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Like show(), collect() currently has a binary compatibility issue on os(Linux).
  • On MacOS, all tests pass.

@Test
func sort() async throws {
let spark = try await SparkSession.builder.getOrCreate()
#expect(try await spark.range(10).sort("id").count() == 10)
let expected = (1...10).map{ [String($0)] }
#expect(try await spark.range(10, 0, -1).sort("id").collect() == expected)
await spark.stop()
}

@Test
func orderBy() async throws {
let spark = try await SparkSession.builder.getOrCreate()
#expect(try await spark.range(10).orderBy("id").count() == 10)
let expected = (1...10).map{ [String($0)] }
#expect(try await spark.range(10, 0, -1).orderBy("id").collect() == expected)
await spark.stop()
}
#endif

@Test
func table() async throws {
Expand All @@ -153,6 +157,17 @@ struct DataFrameTests {
}

#if !os(Linux)
@Test
func collect() async throws {
let spark = try await SparkSession.builder.getOrCreate()
#expect(try await spark.range(0).collect().isEmpty)
#expect(
try await spark.sql(
"SELECT * FROM VALUES (1, true, 'abc'), (null, null, null), (3, false, 'def')"
).collect() == [["1", "true", "abc"], [nil, nil, nil], ["3", "false", "def"]])
await spark.stop()
}

@Test
func show() async throws {
let spark = try await SparkSession.builder.getOrCreate()
Expand Down
4 changes: 4 additions & 0 deletions Tests/SparkConnectTests/SparkSessionTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,11 @@ struct SparkSessionTests {

@Test func userContext() async throws {
let spark = try await SparkSession.builder.getOrCreate()
#if os(macOS) || os(Linux)
let defaultUserContext = ProcessInfo.processInfo.userName.toUserContext
#else
let defaultUserContext = "".toUserContext
#endif
#expect(await spark.client.userContext == defaultUserContext)
await spark.stop()
}
Expand Down
Loading