diff --git a/Sources/SparkConnect/Extension.swift b/Sources/SparkConnect/Extension.swift new file mode 100644 index 0000000..4e3c627 --- /dev/null +++ b/Sources/SparkConnect/Extension.swift @@ -0,0 +1,67 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// + +import Foundation + +extension String { + /// Get a `Plan` instance from a string. + var toSparkConnectPlan: Plan { + var sql = Spark_Connect_SQL() + sql.query = self + var relation = Relation() + relation.sql = sql + var plan = Plan() + plan.opType = Plan.OneOf_OpType.root(relation) + return plan + } + + /// Get a `UserContext` instance from a string. + var toUserContext: UserContext { + var context = UserContext() + context.userID = self + context.userName = self + return context + } + + /// Get a `KeyValue` instance by using a string as the key. + var toKeyValue: KeyValue { + var keyValue = KeyValue() + keyValue.key = self + return keyValue + } +} + +extension [String: String] { + /// Get an array of `KeyValue` from `[String: String]`. + var toSparkConnectKeyValue: [KeyValue] { + var array = [KeyValue]() + for keyValue in self { + var kv = KeyValue() + kv.key = keyValue.key + kv.value = keyValue.value + array.append(kv) + } + return array + } +} + +extension Data { + /// Get an `Int32` value from unsafe 4 bytes. + var int32: Int32 { withUnsafeBytes({ $0.load(as: Int32.self) }) } +} diff --git a/Sources/SparkConnect/SparkConnectClient.swift b/Sources/SparkConnect/SparkConnectClient.swift new file mode 100644 index 0000000..0fa3d15 --- /dev/null +++ b/Sources/SparkConnect/SparkConnectClient.swift @@ -0,0 +1,226 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +import Foundation +import GRPCCore +import GRPCNIOTransportHTTP2 +import GRPCProtobuf +import Synchronization + +/// Conceptually the remote spark session that communicates with the server +public actor SparkConnectClient { + let clientType: String = "swift" + let url: URL + let host: String + let port: Int + let userContext: UserContext + var sessionID: String? = nil + + /// Create a client to use GRPCClient. + /// - Parameters: + /// - remote: A string to connect `Spark Connect` server. + /// - user: A string for the user ID of this connection. + init(remote: String, user: String) { + self.url = URL(string: remote)! + self.host = url.host() ?? "localhost" + self.port = self.url.port ?? 15002 + self.userContext = user.toUserContext + } + + /// Stop the connection. Currently, this API is no-op because we don't reuse the connection yet. + func stop() { + } + + /// Connect to the `Spark Connect` server with the given session ID string. + /// As a test connection, this sends the server `SparkVersion` request. + /// - Parameter sessionID: A string for the session ID. + /// - Returns: An `AnalyzePlanResponse` instance for `SparkVersion` + func connect(_ sessionID: String) async throws -> AnalyzePlanResponse { + try await withGRPCClient( + transport: .http2NIOPosix( + target: .dns(host: self.host, port: self.port), + transportSecurity: .plaintext + ) + ) { client in + // To prevent server-side `INVALID_HANDLE.FORMAT (SQLSTATE: HY000)` exception. + if UUID(uuidString: sessionID) == nil { + throw SparkConnectError.InvalidSessionIDException + } + + self.sessionID = sessionID + let service = SparkConnectService.Client(wrapping: client) + let version = AnalyzePlanRequest.SparkVersion() + var request = AnalyzePlanRequest() + request.clientType = clientType + request.userContext = userContext + request.sessionID = self.sessionID! + request.analyze = .sparkVersion(version) + let response = try await service.analyzePlan(request) + return response + } + } + + /// Create a ``ConfigRequest`` instance for `Set` operation. + /// - Parameter map: A map of key-value string pairs. + /// - Returns: A ``ConfigRequest`` instance. + func getConfigRequestSet(map: [String: String]) -> ConfigRequest { + var request = ConfigRequest() + request.operation = ConfigRequest.Operation() + var set = ConfigRequest.Set() + set.pairs = map.toSparkConnectKeyValue + request.operation.opType = .set(set) + return request + } + + /// Request the server to set a map of configurations for this session. + /// - Parameter map: A map of key-value pairs to set. + /// - Returns: Always return true. + func setConf(map: [String: String]) async throws -> Bool { + try await withGRPCClient( + transport: .http2NIOPosix( + target: .dns(host: self.host, port: self.port), + transportSecurity: .plaintext + ) + ) { client in + let service = SparkConnectService.Client(wrapping: client) + var request = getConfigRequestSet(map: map) + request.clientType = clientType + request.userContext = userContext + request.sessionID = self.sessionID! + let _ = try await service.config(request) + return true + } + } + + /// Create a ``ConfigRequest`` instance for `Get` operation. + /// - Parameter keys: An array of keys to get. + /// - Returns: A `ConfigRequest` instance. + func getConfigRequestGet(keys: [String]) -> ConfigRequest { + var request = ConfigRequest() + request.operation = ConfigRequest.Operation() + var get = ConfigRequest.Get() + get.keys = keys + request.operation.opType = .get(get) + return request + } + + /// Request the server to get a value of the given key. + /// - Parameter key: A string for key to look up. + /// - Returns: A string for the value of the key. + func getConf(_ key: String) async throws -> String { + try await withGRPCClient( + transport: .http2NIOPosix( + target: .dns(host: self.host, port: self.port), + transportSecurity: .plaintext + ) + ) { client in + let service = SparkConnectService.Client(wrapping: client) + var request = getConfigRequestGet(keys: [key]) + request.clientType = clientType + request.userContext = userContext + request.sessionID = self.sessionID! + let response = try await service.config(request) + return response.pairs[0].value + } + } + + /// Create a ``ConfigRequest`` for `GetAll` operation. + /// - Returns: A `ConfigRequest` instance. + func getConfigRequestGetAll() -> ConfigRequest { + var request = ConfigRequest() + request.operation = ConfigRequest.Operation() + let getAll = ConfigRequest.GetAll() + request.operation.opType = .getAll(getAll) + return request + } + + /// Request the server to get all configurations. + /// - Returns: A map of key-value pairs. + func getConfAll() async throws -> [String: String] { + try await withGRPCClient( + transport: .http2NIOPosix( + target: .dns(host: self.host, port: self.port), + transportSecurity: .plaintext + ) + ) { client in + let service = SparkConnectService.Client(wrapping: client) + var request = getConfigRequestGetAll() + request.clientType = clientType + request.userContext = userContext + request.sessionID = self.sessionID! + let response = try await service.config(request) + var map = [String: String]() + for pair in response.pairs { + map[pair.key] = pair.value + } + return map + } + } + + /// Create a `Plan` instance for `Range` relation. + /// - Parameters: + /// - start: A start of the range. + /// - end: A end (exclusive) of the range. + /// - step: A step value for the range from `start` to `end`. + /// - Returns: A `Plan` instance. + func getPlanRange(_ start: Int64, _ end: Int64, _ step: Int64) -> Plan { + var range = Range() + range.start = start + range.end = end + range.step = step + var relation = Relation() + relation.range = range + var plan = Plan() + plan.opType = .root(relation) + return plan + } + + /// Create a ``ExecutePlanRequest`` instance with the given plan. + /// The operation ID is created by UUID. + /// - Parameters: + /// - plan: A plan to execute. + /// - Returns: An ``ExecutePlanRequest`` instance. + func getExecutePlanRequest(_ sessionID: String, _ plan: Plan) async + -> ExecutePlanRequest + { + var request = ExecutePlanRequest() + request.clientType = clientType + request.userContext = userContext + request.sessionID = self.sessionID! + request.operationID = UUID().uuidString + request.plan = plan + return request + } + + /// Create a ``AnalyzePlanRequest`` instance with the given plan. + /// - Parameters: + /// - plan: A plan to analyze. + /// - Returns: An ``AnalyzePlanRequest`` instance + func getAnalyzePlanRequest(_ sessionID: String, _ plan: Plan) async + -> AnalyzePlanRequest + { + var request = AnalyzePlanRequest() + request.clientType = clientType + request.userContext = userContext + request.sessionID = self.sessionID! + var schema = AnalyzePlanRequest.Schema() + schema.plan = plan + request.analyze = .schema(schema) + return request + } +} diff --git a/Sources/SparkConnect/SparkConnectError.swift b/Sources/SparkConnect/SparkConnectError.swift index 5dda8cf..e88c061 100644 --- a/Sources/SparkConnect/SparkConnectError.swift +++ b/Sources/SparkConnect/SparkConnectError.swift @@ -20,4 +20,5 @@ /// A enum for ``SparkConnect`` package errors enum SparkConnectError: Error { case UnsupportedOperationException + case InvalidSessionIDException } diff --git a/Sources/SparkConnect/TypeAliases.swift b/Sources/SparkConnect/TypeAliases.swift new file mode 100644 index 0000000..ad0e898 --- /dev/null +++ b/Sources/SparkConnect/TypeAliases.swift @@ -0,0 +1,29 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +typealias AnalyzePlanRequest = Spark_Connect_AnalyzePlanRequest +typealias AnalyzePlanResponse = Spark_Connect_AnalyzePlanResponse +typealias ConfigRequest = Spark_Connect_ConfigRequest +typealias DataType = Spark_Connect_DataType +typealias ExecutePlanRequest = Spark_Connect_ExecutePlanRequest +typealias Plan = Spark_Connect_Plan +typealias KeyValue = Spark_Connect_KeyValue +typealias Range = Spark_Connect_Range +typealias Relation = Spark_Connect_Relation +typealias SparkConnectService = Spark_Connect_SparkConnectService +typealias UserContext = Spark_Connect_UserContext diff --git a/Tests/SparkConnectTests/SparkConnectClientTests.swift b/Tests/SparkConnectTests/SparkConnectClientTests.swift new file mode 100644 index 0000000..f50ae5d --- /dev/null +++ b/Tests/SparkConnectTests/SparkConnectClientTests.swift @@ -0,0 +1,49 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// + +import Foundation +import Testing + +@testable import SparkConnect + +/// A test suite for `SparkConnectClient` +@Suite(.serialized) +struct SparkConnectClientTests { + @Test + func createAndStop() async throws { + let client = SparkConnectClient(remote: "sc://localhost", user: "test") + await client.stop() + } + + @Test + func connectWithInvalidUUID() async throws { + let client = SparkConnectClient(remote: "sc://localhost", user: "test") + try await #require(throws: SparkConnectError.InvalidSessionIDException) { + let _ = try await client.connect("not-a-uuid-format") + } + await client.stop() + } + + @Test + func connect() async throws { + let client = SparkConnectClient(remote: "sc://localhost", user: "test") + let _ = try await client.connect(UUID().uuidString) + await client.stop() + } +}