-
Notifications
You must be signed in to change notification settings - Fork 5
[SPARK-51508] Support collect(): [[String?]]
for DataFrame
#17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -58,7 +58,7 @@ public actor DataFrame: Sendable { | |
|
||
/// Add `Apache Arrow`'s `RecordBatch`s to the internal array. | ||
/// - Parameter batches: An array of ``RecordBatch``. | ||
private func addBathes(_ batches: [RecordBatch]) { | ||
private func addBatches(_ batches: [RecordBatch]) { | ||
self.batches.append(contentsOf: batches) | ||
} | ||
|
||
|
@@ -153,16 +153,35 @@ public actor DataFrame: Sendable { | |
let arrowResult = ArrowReader.makeArrowReaderResult() | ||
_ = reader.fromMessage(schema, dataBody: Data(), result: arrowResult) | ||
_ = reader.fromMessage(dataHeader, dataBody: dataBody, result: arrowResult) | ||
await self.addBathes(arrowResult.batches) | ||
await self.addBatches(arrowResult.batches) | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
/// This is designed not to support this feature in order to simplify the Swift client. | ||
public func collect() async throws { | ||
throw SparkConnectError.UnsupportedOperationException | ||
/// Execute the plan and return the result as ``[[String?]]``. | ||
/// - Returns: ``[[String?]]`` | ||
public func collect() async throws -> [[String?]] { | ||
try await execute() | ||
|
||
var result: [[String?]] = [] | ||
for batch in self.batches { | ||
for i in 0..<batch.length { | ||
var values: [String?] = [] | ||
for column in batch.columns { | ||
let str = column.array as! AsString | ||
Comment on lines
+172
to
+173
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For DataFrame, I suppose that the return of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, correct. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For
|
||
if column.data.isNull(i) { | ||
values.append(nil) | ||
} else { | ||
values.append(str.asString(i)) | ||
} | ||
} | ||
result.append(values) | ||
} | ||
} | ||
|
||
return result | ||
} | ||
|
||
/// Execute the plan and show the result. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -275,9 +275,11 @@ public actor SparkConnectClient { | |
let expressions: [Spark_Connect_Expression.SortOrder] = cols.map { | ||
var expression = Spark_Connect_Expression.SortOrder() | ||
expression.child.exprType = .unresolvedAttribute($0.toUnresolvedAttribute) | ||
expression.direction = .ascending | ||
return expression | ||
} | ||
sort.order = expressions | ||
sort.isGlobal = true | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I piggy-back this fix while improving a test coverage by checking the result via |
||
var relation = Relation() | ||
relation.sort = sort | ||
var plan = Plan() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -45,12 +45,10 @@ public actor SparkSession { | |
/// - userID: an optional user ID. If absent, `SPARK_USER` environment or ``ProcessInfo.processInfo.userName`` is used. | ||
init(_ connection: String, _ userID: String? = nil) { | ||
let processInfo = ProcessInfo.processInfo | ||
#if os(iOS) || os(watchOS) || os(tvOS) | ||
let userName = processInfo.environment["SPARK_USER"] ?? "" | ||
#elseif os(macOS) || os(Linux) | ||
#if os(macOS) || os(Linux) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I simplified the implementation in a new way to cover |
||
let userName = processInfo.environment["SPARK_USER"] ?? processInfo.userName | ||
#else | ||
assert(false, "Unsupported platform") | ||
let userName = processInfo.environment["SPARK_USER"] ?? "" | ||
#endif | ||
self.client = SparkConnectClient(remote: connection, user: userID ?? userName) | ||
self.conf = RuntimeConf(self.client) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -125,19 +125,23 @@ struct DataFrameTests { | |
await spark.stop() | ||
} | ||
|
||
#if !os(Linux) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
@Test | ||
func sort() async throws { | ||
let spark = try await SparkSession.builder.getOrCreate() | ||
#expect(try await spark.range(10).sort("id").count() == 10) | ||
let expected = (1...10).map{ [String($0)] } | ||
#expect(try await spark.range(10, 0, -1).sort("id").collect() == expected) | ||
await spark.stop() | ||
} | ||
|
||
@Test | ||
func orderBy() async throws { | ||
let spark = try await SparkSession.builder.getOrCreate() | ||
#expect(try await spark.range(10).orderBy("id").count() == 10) | ||
let expected = (1...10).map{ [String($0)] } | ||
#expect(try await spark.range(10, 0, -1).orderBy("id").collect() == expected) | ||
await spark.stop() | ||
} | ||
#endif | ||
|
||
@Test | ||
func table() async throws { | ||
|
@@ -153,6 +157,17 @@ struct DataFrameTests { | |
} | ||
|
||
#if !os(Linux) | ||
@Test | ||
func collect() async throws { | ||
let spark = try await SparkSession.builder.getOrCreate() | ||
#expect(try await spark.range(0).collect().isEmpty) | ||
#expect( | ||
try await spark.sql( | ||
"SELECT * FROM VALUES (1, true, 'abc'), (null, null, null), (3, false, 'def')" | ||
).collect() == [["1", "true", "abc"], [nil, nil, nil], ["3", "false", "def"]]) | ||
await spark.stop() | ||
} | ||
|
||
@Test | ||
func show() async throws { | ||
let spark = try await SparkSession.builder.getOrCreate() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a typo fix.