diff --git a/Sources/SparkConnect/SparkConnectClient.swift b/Sources/SparkConnect/SparkConnectClient.swift index a5306dc..5c6b4cc 100644 --- a/Sources/SparkConnect/SparkConnectClient.swift +++ b/Sources/SparkConnect/SparkConnectClient.swift @@ -82,8 +82,17 @@ public actor SparkConnectClient { self.userContext = userName.toUserContext } - /// Stop the connection. Currently, this API is no-op because we don't reuse the connection yet. - func stop() { + /// Stop the connection. + func stop() async { + guard self.sessionID != nil else { return } + try? await withGPRC { client in + let service = SparkConnectService.Client(wrapping: client) + var request = Spark_Connect_ReleaseSessionRequest() + request.sessionID = self.sessionID! + request.userContext = self.userContext + request.clientType = self.clientType + _ = try await service.releaseSession(request) + } } /// Connect to the `Spark Connect` server with the given session ID string. diff --git a/Tests/SparkConnectTests/SparkSessionTests.swift b/Tests/SparkConnectTests/SparkSessionTests.swift index e8213cc..7404395 100644 --- a/Tests/SparkConnectTests/SparkSessionTests.swift +++ b/Tests/SparkConnectTests/SparkSessionTests.swift @@ -27,6 +27,7 @@ import Testing struct SparkSessionTests { @Test func sparkContext() async throws { + await SparkSession.builder.clear() let spark = try await SparkSession.builder.getOrCreate() await #expect(throws: SparkConnectError.UnsupportedOperationException) { try await spark.sparkContext @@ -36,12 +37,14 @@ struct SparkSessionTests { @Test func stop() async throws { + await SparkSession.builder.clear() let spark = try await SparkSession.builder.getOrCreate() await spark.stop() } @Test func newSession() async throws { + await SparkSession.builder.clear() let spark = try await SparkSession.builder.getOrCreate() await spark.stop() let newSpark = try await spark.newSession() @@ -52,8 +55,8 @@ struct SparkSessionTests { @Test func sessionID() async throws { + await SparkSession.builder.clear() let spark1 = try await SparkSession.builder.getOrCreate() - await spark1.stop() let remote = ProcessInfo.processInfo.environment["SPARK_REMOTE"] ?? "sc://localhost" let spark2 = try await SparkSession.builder.remote("\(remote)/;session_id=\(spark1.sessionID)").getOrCreate() await spark2.stop() @@ -61,7 +64,22 @@ struct SparkSessionTests { #expect(spark1 == spark2) } + @Test + func closedSessionID() async throws { + await SparkSession.builder.clear() + let spark1 = try await SparkSession.builder.getOrCreate() + if await spark1.version >= "4.0.0" { + let sessionID = spark1.sessionID + await spark1.stop() + let remote = ProcessInfo.processInfo.environment["SPARK_REMOTE"] ?? "sc://localhost" + try await #require(throws: Error.self) { + try await SparkSession.builder.remote("\(remote)/;session_id=\(sessionID)").getOrCreate() + } + } + } + @Test func userContext() async throws { + await SparkSession.builder.clear() let spark = try await SparkSession.builder.getOrCreate() #if os(macOS) || os(Linux) let defaultUserContext = ProcessInfo.processInfo.userName.toUserContext @@ -74,6 +92,7 @@ struct SparkSessionTests { @Test func version() async throws { + await SparkSession.builder.clear() let spark = try await SparkSession.builder.getOrCreate() let version = await spark.version #expect(version.starts(with: "4.0.0") || version.starts(with: "3.5.")) @@ -82,6 +101,7 @@ struct SparkSessionTests { @Test func conf() async throws { + await SparkSession.builder.clear() let spark = try await SparkSession.builder.getOrCreate() try await spark.conf.set("spark.x", "y") #expect(try await spark.conf.get("spark.x") == "y") @@ -91,6 +111,7 @@ struct SparkSessionTests { @Test func emptyDataFrame() async throws { + await SparkSession.builder.clear() let spark = try await SparkSession.builder.getOrCreate() #expect(try await spark.emptyDataFrame.count() == 0) #expect(try await spark.emptyDataFrame.dtypes.isEmpty) @@ -100,6 +121,7 @@ struct SparkSessionTests { @Test func range() async throws { + await SparkSession.builder.clear() let spark = try await SparkSession.builder.getOrCreate() #expect(try await spark.range(10).count() == 10) #expect(try await spark.range(0, 100).count() == 100) @@ -110,6 +132,7 @@ struct SparkSessionTests { #if !os(Linux) @Test func sql() async throws { + await SparkSession.builder.clear() let spark = try await SparkSession.builder.getOrCreate() let expected = [Row(true, 1, "a")] if await spark.version.starts(with: "4.") { @@ -122,6 +145,7 @@ struct SparkSessionTests { @Test func table() async throws { + await SparkSession.builder.clear() let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") let spark = try await SparkSession.builder.getOrCreate() try await SQLHelper.withTable(spark, tableName)({ @@ -133,6 +157,7 @@ struct SparkSessionTests { @Test func time() async throws { + await SparkSession.builder.clear() let spark = try await SparkSession.builder.getOrCreate() #expect(try await spark.time(spark.range(1000).count) == 1000) #if !os(Linux) @@ -144,6 +169,7 @@ struct SparkSessionTests { @Test func tag() async throws { + await SparkSession.builder.clear() let spark = try await SparkSession.builder.getOrCreate() try await spark.addTag("tag1") #expect(await spark.getTags() == Set(["tag1"])) @@ -158,6 +184,7 @@ struct SparkSessionTests { @Test func invalidTags() async throws { + await SparkSession.builder.clear() let spark = try await SparkSession.builder.getOrCreate() await #expect(throws: SparkConnectError.InvalidArgumentException) { try await spark.addTag("") @@ -170,6 +197,7 @@ struct SparkSessionTests { @Test func interruptAll() async throws { + await SparkSession.builder.clear() let spark = try await SparkSession.builder.getOrCreate() #expect(try await spark.interruptAll() == []) await spark.stop() @@ -177,6 +205,7 @@ struct SparkSessionTests { @Test func interruptTag() async throws { + await SparkSession.builder.clear() let spark = try await SparkSession.builder.getOrCreate() #expect(try await spark.interruptTag("etl") == []) await spark.stop() @@ -184,6 +213,7 @@ struct SparkSessionTests { @Test func interruptOperation() async throws { + await SparkSession.builder.clear() let spark = try await SparkSession.builder.getOrCreate() #expect(try await spark.interruptOperation("id") == []) await spark.stop()