Skip to content

[SPARK-52302] Improve stop to use ReleaseSessionRequest #175

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions Sources/SparkConnect/SparkConnectClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,17 @@ public actor SparkConnectClient {
self.userContext = userName.toUserContext
}

/// Stop the connection. Currently, this API is no-op because we don't reuse the connection yet.
func stop() {
/// Stop the connection.
func stop() async {
guard self.sessionID != nil else { return }
try? await withGPRC { client in
let service = SparkConnectService.Client(wrapping: client)
var request = Spark_Connect_ReleaseSessionRequest()
request.sessionID = self.sessionID!
request.userContext = self.userContext
request.clientType = self.clientType
_ = try await service.releaseSession(request)
}
}

/// Connect to the `Spark Connect` server with the given session ID string.
Expand Down
32 changes: 31 additions & 1 deletion Tests/SparkConnectTests/SparkSessionTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import Testing
struct SparkSessionTests {
@Test
func sparkContext() async throws {
await SparkSession.builder.clear()
let spark = try await SparkSession.builder.getOrCreate()
await #expect(throws: SparkConnectError.UnsupportedOperationException) {
try await spark.sparkContext
Expand All @@ -36,12 +37,14 @@ struct SparkSessionTests {

@Test
func stop() async throws {
await SparkSession.builder.clear()
let spark = try await SparkSession.builder.getOrCreate()
await spark.stop()
}

@Test
func newSession() async throws {
await SparkSession.builder.clear()
let spark = try await SparkSession.builder.getOrCreate()
await spark.stop()
let newSpark = try await spark.newSession()
Expand All @@ -52,16 +55,31 @@ struct SparkSessionTests {

@Test
func sessionID() async throws {
await SparkSession.builder.clear()
let spark1 = try await SparkSession.builder.getOrCreate()
await spark1.stop()
let remote = ProcessInfo.processInfo.environment["SPARK_REMOTE"] ?? "sc://localhost"
let spark2 = try await SparkSession.builder.remote("\(remote)/;session_id=\(spark1.sessionID)").getOrCreate()
await spark2.stop()
#expect(spark1.sessionID == spark2.sessionID)
#expect(spark1 == spark2)
}

@Test
func closedSessionID() async throws {
await SparkSession.builder.clear()
let spark1 = try await SparkSession.builder.getOrCreate()
if await spark1.version >= "4.0.0" {
let sessionID = spark1.sessionID
await spark1.stop()
let remote = ProcessInfo.processInfo.environment["SPARK_REMOTE"] ?? "sc://localhost"
try await #require(throws: Error.self) {
try await SparkSession.builder.remote("\(remote)/;session_id=\(sessionID)").getOrCreate()
}
}
}

@Test func userContext() async throws {
await SparkSession.builder.clear()
let spark = try await SparkSession.builder.getOrCreate()
#if os(macOS) || os(Linux)
let defaultUserContext = ProcessInfo.processInfo.userName.toUserContext
Expand All @@ -74,6 +92,7 @@ struct SparkSessionTests {

@Test
func version() async throws {
await SparkSession.builder.clear()
let spark = try await SparkSession.builder.getOrCreate()
let version = await spark.version
#expect(version.starts(with: "4.0.0") || version.starts(with: "3.5."))
Expand All @@ -82,6 +101,7 @@ struct SparkSessionTests {

@Test
func conf() async throws {
await SparkSession.builder.clear()
let spark = try await SparkSession.builder.getOrCreate()
try await spark.conf.set("spark.x", "y")
#expect(try await spark.conf.get("spark.x") == "y")
Expand All @@ -91,6 +111,7 @@ struct SparkSessionTests {

@Test
func emptyDataFrame() async throws {
await SparkSession.builder.clear()
let spark = try await SparkSession.builder.getOrCreate()
#expect(try await spark.emptyDataFrame.count() == 0)
#expect(try await spark.emptyDataFrame.dtypes.isEmpty)
Expand All @@ -100,6 +121,7 @@ struct SparkSessionTests {

@Test
func range() async throws {
await SparkSession.builder.clear()
let spark = try await SparkSession.builder.getOrCreate()
#expect(try await spark.range(10).count() == 10)
#expect(try await spark.range(0, 100).count() == 100)
Expand All @@ -110,6 +132,7 @@ struct SparkSessionTests {
#if !os(Linux)
@Test
func sql() async throws {
await SparkSession.builder.clear()
let spark = try await SparkSession.builder.getOrCreate()
let expected = [Row(true, 1, "a")]
if await spark.version.starts(with: "4.") {
Expand All @@ -122,6 +145,7 @@ struct SparkSessionTests {

@Test
func table() async throws {
await SparkSession.builder.clear()
let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "")
let spark = try await SparkSession.builder.getOrCreate()
try await SQLHelper.withTable(spark, tableName)({
Expand All @@ -133,6 +157,7 @@ struct SparkSessionTests {

@Test
func time() async throws {
await SparkSession.builder.clear()
let spark = try await SparkSession.builder.getOrCreate()
#expect(try await spark.time(spark.range(1000).count) == 1000)
#if !os(Linux)
Expand All @@ -144,6 +169,7 @@ struct SparkSessionTests {

@Test
func tag() async throws {
await SparkSession.builder.clear()
let spark = try await SparkSession.builder.getOrCreate()
try await spark.addTag("tag1")
#expect(await spark.getTags() == Set(["tag1"]))
Expand All @@ -158,6 +184,7 @@ struct SparkSessionTests {

@Test
func invalidTags() async throws {
await SparkSession.builder.clear()
let spark = try await SparkSession.builder.getOrCreate()
await #expect(throws: SparkConnectError.InvalidArgumentException) {
try await spark.addTag("")
Expand All @@ -170,20 +197,23 @@ struct SparkSessionTests {

@Test
func interruptAll() async throws {
await SparkSession.builder.clear()
let spark = try await SparkSession.builder.getOrCreate()
#expect(try await spark.interruptAll() == [])
await spark.stop()
}

@Test
func interruptTag() async throws {
await SparkSession.builder.clear()
let spark = try await SparkSession.builder.getOrCreate()
#expect(try await spark.interruptTag("etl") == [])
await spark.stop()
}

@Test
func interruptOperation() async throws {
await SparkSession.builder.clear()
let spark = try await SparkSession.builder.getOrCreate()
#expect(try await spark.interruptOperation("id") == [])
await spark.stop()
Expand Down
Loading