Skip to content

[SPARK-51560] Support cache/persist/unpersist for DataFrame #22

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions Sources/SparkConnect/DataFrame.swift
Original file line number Diff line number Diff line change
Expand Up @@ -245,4 +245,43 @@ public actor DataFrame: Sendable {
public func isEmpty() async throws -> Bool {
return try await select().limit(1).count() == 0
}

public func cache() async throws -> DataFrame {
return try await persist()
}

public func persist(
useDisk: Bool = true, useMemory: Bool = true, useOffHeap: Bool = false,
deserialized: Bool = true, replication: Int32 = 1
)
async throws -> DataFrame
{
try await withGRPCClient(
transport: .http2NIOPosix(
target: .dns(host: spark.client.host, port: spark.client.port),
transportSecurity: .plaintext
)
) { client in
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
_ = try await service.analyzePlan(
spark.client.getPersist(
spark.sessionID, plan, useDisk, useMemory, useOffHeap, deserialized, replication))
}

return self
}

public func unpersist(blocking: Bool = false) async throws -> DataFrame {
try await withGRPCClient(
transport: .http2NIOPosix(
target: .dns(host: spark.client.host, port: spark.client.port),
transportSecurity: .plaintext
)
) { client in
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
_ = try await service.analyzePlan(spark.client.getUnpersist(spark.sessionID, plan, blocking))
}

return self
}
}
35 changes: 35 additions & 0 deletions Sources/SparkConnect/SparkConnectClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,41 @@ public actor SparkConnectClient {
return request
}

func getPersist(
_ sessionID: String, _ plan: Plan, _ useDisk: Bool = true, _ useMemory: Bool = true,
_ useOffHeap: Bool = false, _ deserialized: Bool = true, _ replication: Int32 = 1
) async
-> AnalyzePlanRequest
{
return analyze(
sessionID,
{
var persist = AnalyzePlanRequest.Persist()
var level = StorageLevel()
level.useDisk = useDisk
level.useMemory = useMemory
level.useOffHeap = useOffHeap
level.deserialized = deserialized
level.replication = replication
persist.storageLevel = level
persist.relation = plan.root
return OneOf_Analyze.persist(persist)
})
}

func getUnpersist(_ sessionID: String, _ plan: Plan, _ blocking: Bool = false) async
-> AnalyzePlanRequest
{
return analyze(
sessionID,
{
var unpersist = AnalyzePlanRequest.Unpersist()
unpersist.relation = plan.root
unpersist.blocking = blocking
return OneOf_Analyze.unpersist(unpersist)
})
}

static func getProject(_ child: Relation, _ cols: [String]) -> Plan {
var project = Project()
project.input = child
Expand Down
1 change: 1 addition & 0 deletions Sources/SparkConnect/TypeAliases.swift
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,6 @@ typealias Range = Spark_Connect_Range
typealias Relation = Spark_Connect_Relation
typealias SparkConnectService = Spark_Connect_SparkConnectService
typealias Sort = Spark_Connect_Sort
typealias StorageLevel = Spark_Connect_StorageLevel
typealias UserContext = Spark_Connect_UserContext
typealias UnresolvedAttribute = Spark_Connect_Expression.UnresolvedAttribute
33 changes: 33 additions & 0 deletions Tests/SparkConnectTests/DataFrameTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -193,5 +193,38 @@ struct DataFrameTests {
try await spark.sql("DROP TABLE IF EXISTS t").show()
await spark.stop()
}

@Test
func cache() async throws {
let spark = try await SparkSession.builder.getOrCreate()
#expect(try await spark.range(10).cache().count() == 10)
await spark.stop()
}

@Test
func persist() async throws {
let spark = try await SparkSession.builder.getOrCreate()
#expect(try await spark.range(20).persist().count() == 20)
#expect(try await spark.range(21).persist(useDisk: false).count() == 21)
await spark.stop()
}
Comment on lines +204 to +210
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This checks the result correctness. If it is not actually cached, we still can get the same result. I'm not sure if we have a way to check if it is actually cached or not?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, correct! This is only API test, @viirya .

I checked the cached and uncached result manually via Connect Server UI because we didn't catalog part yet~


@Test
func persistInvalidStorageLevel() async throws {
let spark = try await SparkSession.builder.getOrCreate()
try await #require(throws: Error.self) {
let _ = try await spark.range(9999).persist(replication: 0).count()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, this is a valid error result check~

}
await spark.stop()
}

@Test
func unpersist() async throws {
let spark = try await SparkSession.builder.getOrCreate()
let df = try await spark.range(30)
#expect(try await df.persist().count() == 30)
#expect(try await df.unpersist().count() == 30)
await spark.stop()
}
#endif
}
Loading