diff --git a/Sources/SparkConnect/SparkConnectClient.swift b/Sources/SparkConnect/SparkConnectClient.swift index 03a6ffe..c1c9bd1 100644 --- a/Sources/SparkConnect/SparkConnectClient.swift +++ b/Sources/SparkConnect/SparkConnectClient.swift @@ -735,6 +735,39 @@ public actor SparkConnectClient { return plan } + func addArtifact(_ url: URL) async throws { + guard url.lastPathComponent.hasSuffix(".jar") else { + throw SparkConnectError.InvalidArgument + } + + let JAR_PREFIX = "jars" + let name = "\(JAR_PREFIX)/" + url.lastPathComponent + + try await withGPRC { client in + let service = SparkConnectService.Client(wrapping: client) + + var chunk = Spark_Connect_AddArtifactsRequest.ArtifactChunk() + chunk.data = try Data(contentsOf: url) + chunk.crc = Int64(CRC32.checksum(data: chunk.data)) + + var singleChunk = Spark_Connect_AddArtifactsRequest.SingleChunkArtifact() + singleChunk.name = name + singleChunk.data = chunk + var batch = Spark_Connect_AddArtifactsRequest.Batch() + batch.artifacts.append(singleChunk) + + var addArtifactsRequest = Spark_Connect_AddArtifactsRequest() + addArtifactsRequest.sessionID = self.sessionID! + addArtifactsRequest.userContext = self.userContext + addArtifactsRequest.clientType = self.clientType + addArtifactsRequest.batch = batch + let request = addArtifactsRequest + _ = try await service.addArtifacts(request: StreamingClientRequest { x in + try await x.write(contentsOf: [request]) + }) + } + } + /// Add a tag to be assigned to all the operations started by this thread in this session. /// - Parameter tag: The tag to be added. Cannot contain ',' (comma) character or be an empty string. public func addTag(tag: String) throws { diff --git a/Sources/SparkConnect/SparkSession.swift b/Sources/SparkConnect/SparkSession.swift index c7a8a27..7e7326c 100644 --- a/Sources/SparkConnect/SparkSession.swift +++ b/Sources/SparkConnect/SparkSession.swift @@ -267,6 +267,36 @@ public actor SparkSession { return await read.table(tableName) } + /// Add a single artifact to the current session. + /// Currently only local files with extensions .jar supported. + /// - Parameter url: A url to the artifact + public func addArtifact(_ url: URL) async throws { + try await self.client.addArtifact(url) + } + + /// Add a single artifact to the current session. + /// Currently only local files with extensions .jar are supported. + /// - Parameter path: A path to the file. + public func addArtifact(_ path: String) async throws { + try await self.client.addArtifact(URL(fileURLWithPath: path)) + } + + /// Add one or more artifacts to the session. + /// - Parameter url: One or more URLs + public func addArtifacts(_ url: URL...) async throws { + for u in url { + try await self.client.addArtifact(u) + } + } + + /// Execute an arbitrary string command inside an external execution engine rather than Spark. + /// This could be useful when user wants to execute some commands out of Spark. For example, + /// executing custom DDL/DML command for JDBC, creating index for ElasticSearch, creating cores + /// for Solr and so on. + /// - Parameters: + /// - runner: The class name of the runner that implements `ExternalCommandRunner`. + /// - command: The target command to be executed + /// - options: The options for the runner. public func executeCommand(_ runner: String, _ command: String, _ options: [String: String]) async throws -> DataFrame { diff --git a/Tests/SparkConnectTests/SparkSessionTests.swift b/Tests/SparkConnectTests/SparkSessionTests.swift index deece09..1b4a658 100644 --- a/Tests/SparkConnectTests/SparkSessionTests.swift +++ b/Tests/SparkConnectTests/SparkSessionTests.swift @@ -142,6 +142,49 @@ struct SparkSessionTests { await spark.stop() } + @Test + func addInvalidArtifact() async throws { + await SparkSession.builder.clear() + let spark = try await SparkSession.builder.getOrCreate() + await #expect(throws: SparkConnectError.InvalidArgument) { + try await spark.addArtifact("x.txt") + } + await spark.stop() + } + + @Test + func addArtifact() async throws { + let fm = FileManager() + let path = "my.jar" + let url = URL(fileURLWithPath: path) + + await SparkSession.builder.clear() + let spark = try await SparkSession.builder.getOrCreate() + #expect(fm.createFile(atPath: path, contents: "abc".data(using: .utf8))) + if await spark.version.starts(with: "4.") { + try await spark.addArtifact(path) + try await spark.addArtifact(url) + } + try fm.removeItem(atPath: path) + await spark.stop() + } + + @Test + func addArtifacts() async throws { + let fm = FileManager() + let path = "my.jar" + let url = URL(fileURLWithPath: path) + + await SparkSession.builder.clear() + let spark = try await SparkSession.builder.getOrCreate() + #expect(fm.createFile(atPath: path, contents: "abc".data(using: .utf8))) + if await spark.version.starts(with: "4.") { + try await spark.addArtifacts(url, url) + } + try fm.removeItem(atPath: path) + await spark.stop() + } + @Test func executeCommand() async throws { await SparkSession.builder.clear()