diff --git a/Sources/SparkConnect/DataFrame.swift b/Sources/SparkConnect/DataFrame.swift index 81b74b1..81c92e4 100644 --- a/Sources/SparkConnect/DataFrame.swift +++ b/Sources/SparkConnect/DataFrame.swift @@ -245,4 +245,43 @@ public actor DataFrame: Sendable { public func isEmpty() async throws -> Bool { return try await select().limit(1).count() == 0 } + + public func cache() async throws -> DataFrame { + return try await persist() + } + + public func persist( + useDisk: Bool = true, useMemory: Bool = true, useOffHeap: Bool = false, + deserialized: Bool = true, replication: Int32 = 1 + ) + async throws -> DataFrame + { + try await withGRPCClient( + transport: .http2NIOPosix( + target: .dns(host: spark.client.host, port: spark.client.port), + transportSecurity: .plaintext + ) + ) { client in + let service = Spark_Connect_SparkConnectService.Client(wrapping: client) + _ = try await service.analyzePlan( + spark.client.getPersist( + spark.sessionID, plan, useDisk, useMemory, useOffHeap, deserialized, replication)) + } + + return self + } + + public func unpersist(blocking: Bool = false) async throws -> DataFrame { + try await withGRPCClient( + transport: .http2NIOPosix( + target: .dns(host: spark.client.host, port: spark.client.port), + transportSecurity: .plaintext + ) + ) { client in + let service = Spark_Connect_SparkConnectService.Client(wrapping: client) + _ = try await service.analyzePlan(spark.client.getUnpersist(spark.sessionID, plan, blocking)) + } + + return self + } } diff --git a/Sources/SparkConnect/SparkConnectClient.swift b/Sources/SparkConnect/SparkConnectClient.swift index c0c0828..aefd844 100644 --- a/Sources/SparkConnect/SparkConnectClient.swift +++ b/Sources/SparkConnect/SparkConnectClient.swift @@ -256,6 +256,41 @@ public actor SparkConnectClient { return request } + func getPersist( + _ sessionID: String, _ plan: Plan, _ useDisk: Bool = true, _ useMemory: Bool = true, + _ useOffHeap: Bool = false, _ deserialized: Bool = true, _ replication: Int32 = 1 + ) async + -> AnalyzePlanRequest + { + return analyze( + sessionID, + { + var persist = AnalyzePlanRequest.Persist() + var level = StorageLevel() + level.useDisk = useDisk + level.useMemory = useMemory + level.useOffHeap = useOffHeap + level.deserialized = deserialized + level.replication = replication + persist.storageLevel = level + persist.relation = plan.root + return OneOf_Analyze.persist(persist) + }) + } + + func getUnpersist(_ sessionID: String, _ plan: Plan, _ blocking: Bool = false) async + -> AnalyzePlanRequest + { + return analyze( + sessionID, + { + var unpersist = AnalyzePlanRequest.Unpersist() + unpersist.relation = plan.root + unpersist.blocking = blocking + return OneOf_Analyze.unpersist(unpersist) + }) + } + static func getProject(_ child: Relation, _ cols: [String]) -> Plan { var project = Project() project.input = child diff --git a/Sources/SparkConnect/TypeAliases.swift b/Sources/SparkConnect/TypeAliases.swift index 8154662..92aa78e 100644 --- a/Sources/SparkConnect/TypeAliases.swift +++ b/Sources/SparkConnect/TypeAliases.swift @@ -30,5 +30,6 @@ typealias Range = Spark_Connect_Range typealias Relation = Spark_Connect_Relation typealias SparkConnectService = Spark_Connect_SparkConnectService typealias Sort = Spark_Connect_Sort +typealias StorageLevel = Spark_Connect_StorageLevel typealias UserContext = Spark_Connect_UserContext typealias UnresolvedAttribute = Spark_Connect_Expression.UnresolvedAttribute diff --git a/Tests/SparkConnectTests/DataFrameTests.swift b/Tests/SparkConnectTests/DataFrameTests.swift index a1c7e7e..552374d 100644 --- a/Tests/SparkConnectTests/DataFrameTests.swift +++ b/Tests/SparkConnectTests/DataFrameTests.swift @@ -193,5 +193,38 @@ struct DataFrameTests { try await spark.sql("DROP TABLE IF EXISTS t").show() await spark.stop() } + + @Test + func cache() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.range(10).cache().count() == 10) + await spark.stop() + } + + @Test + func persist() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.range(20).persist().count() == 20) + #expect(try await spark.range(21).persist(useDisk: false).count() == 21) + await spark.stop() + } + + @Test + func persistInvalidStorageLevel() async throws { + let spark = try await SparkSession.builder.getOrCreate() + try await #require(throws: Error.self) { + let _ = try await spark.range(9999).persist(replication: 0).count() + } + await spark.stop() + } + + @Test + func unpersist() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let df = try await spark.range(30) + #expect(try await df.persist().count() == 30) + #expect(try await df.unpersist().count() == 30) + await spark.stop() + } #endif }