From b5dcc6e58f0be506ca948fc70b8dd9f0db526d76 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 15 May 2025 10:48:11 -0700 Subject: [PATCH 1/3] [SPARK-52164] Support `MergeIntoWriter` --- Sources/SparkConnect/DataFrame.swift | 9 + Sources/SparkConnect/MergeIntoWriter.swift | 254 ++++++++++++++++++ Sources/SparkConnect/TypeAliases.swift | 3 + Tests/SparkConnectTests/IcebergTests.swift | 25 +- .../MergeIntoWriterTests.swift | 75 ++++++ 5 files changed, 365 insertions(+), 1 deletion(-) create mode 100644 Sources/SparkConnect/MergeIntoWriter.swift create mode 100644 Tests/SparkConnectTests/MergeIntoWriterTests.swift diff --git a/Sources/SparkConnect/DataFrame.swift b/Sources/SparkConnect/DataFrame.swift index 1c7b08d..34cceb7 100644 --- a/Sources/SparkConnect/DataFrame.swift +++ b/Sources/SparkConnect/DataFrame.swift @@ -1395,6 +1395,15 @@ public actor DataFrame: Sendable { return DataFrameWriterV2(table, self) } + /// Merges a set of updates, insertions, and deletions based on a source table into a target table. + /// - Parameters: + /// - table: A target table name. + /// - condition: A condition expression. + /// - Returns: A ``MergeIntoWriter`` instance. + public func mergeInto(_ table: String, _ condition: String) async -> MergeIntoWriter { + return await MergeIntoWriter(table, self, condition) + } + /// Returns a ``DataStreamWriter`` that can be used to write streaming data. public var writeStream: DataStreamWriter { get { diff --git a/Sources/SparkConnect/MergeIntoWriter.swift b/Sources/SparkConnect/MergeIntoWriter.swift new file mode 100644 index 0000000..4d825cd --- /dev/null +++ b/Sources/SparkConnect/MergeIntoWriter.swift @@ -0,0 +1,254 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// + +/// A struct for defining actions to be taken when matching rows in a ``DataFrame`` +/// during a merge operation. +public struct WhenMatched: Sendable { + let mergeIntoWriter: MergeIntoWriter + let condition: String? + + init(_ mergeIntoWriter: MergeIntoWriter, _ condition: String? = nil) { + self.mergeIntoWriter = mergeIntoWriter + self.condition = condition + } + + /// Specifies an action to update all matched rows in the ``DataFrame``. + /// - Returns: The ``MergeIntoWriter`` instance with the update all action configured. + public func updateAll() async -> MergeIntoWriter { + await mergeIntoWriter.updateAll(condition, false) + } + + /// Specifies an action to update matched rows in the ``DataFrame`` with the provided column + /// assignments. + /// - Parameter map: A dictionary from column names to expressions representing the updates to be applied. + /// - Returns: The ``MergeIntoWriter`` instance with the update action configured. + public func update(map: [String: String]) async -> MergeIntoWriter { + await mergeIntoWriter.update(condition, map, false) + } + + /// Specifies an action to delete matched rows from the DataFrame. + /// - Returns: The ``MergeIntoWriter`` instance with the delete action configured. + public func delete() async -> MergeIntoWriter { + await mergeIntoWriter.delete(condition, false) + } +} + +/// A struct for defining actions to be taken when no matching rows are found in a ``DataFrame`` +/// during a merge operation. +public struct WhenNotMatched: Sendable { + let mergeIntoWriter: MergeIntoWriter + let condition: String? + + init(_ mergeIntoWriter: MergeIntoWriter, _ condition: String? = nil) { + self.mergeIntoWriter = mergeIntoWriter + self.condition = condition + } + + /// Specifies an action to insert all non-matched rows into the ``DataFrame``. + /// - Returns: The`` MergeIntoWriter`` instance with the insert all action configured. + public func insertAll() async -> MergeIntoWriter { + await mergeIntoWriter.insertAll(condition) + } + + /// Specifies an action to insert non-matched rows into the ``DataFrame`` + /// with the provided column assignments. + /// - Parameter map: A dictionary of column names to expressions representing the values to be inserted. + /// - Returns: The ``MergeIntoWriter`` instance with the insert action configured. + public func insert(_ map: [String: String]) async -> MergeIntoWriter { + await mergeIntoWriter.insert(condition, map) + } +} + +public struct WhenNotMatchedBySource: Sendable { + let mergeIntoWriter: MergeIntoWriter + let condition: String? + + init(_ mergeIntoWriter: MergeIntoWriter, _ condition: String? = nil) { + self.mergeIntoWriter = mergeIntoWriter + self.condition = condition + } + + public func updateAll() async -> MergeIntoWriter { + await mergeIntoWriter.updateAll(condition, true) + } + + public func update(map: [String: String]) async -> MergeIntoWriter { + await mergeIntoWriter.update(condition, map, true) + } + + public func delete() async -> MergeIntoWriter { + await mergeIntoWriter.delete(condition, true) + } +} + +/// `MergeIntoWriter` provides methods to define and execute merge actions based on specified +/// conditions. +public actor MergeIntoWriter { + var schemaEvolution: Bool = false + + let table: String + + let df: DataFrame + + let condition: String + + var mergeIntoTableCommand = MergeIntoTableCommand() + + init(_ table: String, _ df: DataFrame, _ condition: String) async { + self.table = table + self.df = df + self.condition = condition + + mergeIntoTableCommand.targetTableName = table + mergeIntoTableCommand.sourceTablePlan = await (df.getPlan() as! Plan).root + mergeIntoTableCommand.mergeCondition.expressionString = condition.toExpressionString + } + + public var schemaEvolutionEnabled: Bool { + schemaEvolution + } + + /// Enable automatic schema evolution for this merge operation. + /// - Returns: ``MergeIntoWriter`` instance + public func withSchemaEvolution() -> MergeIntoWriter { + self.schemaEvolution = true + return self + } + + /// Initialize a `WhenMatched` action without any condition. + /// - Returns: A `WhenMatched` instance. + public func whenMatched() -> WhenMatched { + WhenMatched(self) + } + + /// Initialize a `WhenMatched` action with a condition. + /// - Parameter condition: <#condition description#> + /// - Returns: A `WhenMatched` instance configured with the specified condition. + public func whenMatched(_ condition: String) -> WhenMatched { + WhenMatched(self, condition) + } + + /// Initialize a `WhenNotMatched` action without any condition. + /// - Returns: A `WhenNotMatched` instance. + public func whenNotMatched() -> WhenNotMatched { + WhenNotMatched(self) + } + + /// Initialize a `WhenNotMatched` action with a condition. + /// - Parameter condition: The condition to be evaluated for the action. + /// - Returns: A `WhenNotMatched` instance configured with the specified condition. + public func whenNotMatched(_ condition: String) -> WhenNotMatched { + WhenNotMatched(self, condition) + } + + /// Initialize a `WhenNotMatchedBySource` action without any condition. + /// - Returns: A `WhenNotMatchedBySource` instance. + public func whenNotMatchedBySource() -> WhenNotMatchedBySource { + WhenNotMatchedBySource(self) + } + + /// Initialize a `WhenNotMatchedBySource` action with a condition + /// - Parameter condition: The condition to be evaluated for the action. + /// - Returns: A `WhenNotMatchedBySource` instance configured with the specified condition. + public func whenNotMatchedBySource(_ condition: String) -> WhenNotMatchedBySource { + WhenNotMatchedBySource(self, condition) + } + + /// Executes the merge operation. + public func merge() async throws { + if self.mergeIntoTableCommand.matchActions.count == 0 + && self.mergeIntoTableCommand.notMatchedActions.count == 0 + && self.mergeIntoTableCommand.notMatchedBySourceActions.count == 0 + { + throw SparkConnectError.InvalidArgumentException + } + self.mergeIntoTableCommand.withSchemaEvolution = self.schemaEvolution + + var command = Spark_Connect_Command() + command.mergeIntoTableCommand = self.mergeIntoTableCommand + _ = try await df.spark.client.execute(df.spark.sessionID, command) + } + + public func insertAll(_ condition: String?) -> MergeIntoWriter { + let expression = buildMergeAction(ActionType.insertStar, condition) + self.mergeIntoTableCommand.notMatchedActions.append(expression) + return self + } + + public func insert(_ condition: String?, _ map: [String: String]) -> MergeIntoWriter { + let expression = buildMergeAction(ActionType.insert, condition, map) + self.mergeIntoTableCommand.notMatchedActions.append(expression) + return self + } + + public func updateAll(_ condition: String?, _ notMatchedBySource: Bool) -> MergeIntoWriter { + appendUpdateDeleteAction(buildMergeAction(ActionType.updateStar, condition), notMatchedBySource) + } + + public func update(_ condition: String?, _ map: [String: String], _ notMatchedBySource: Bool) + -> MergeIntoWriter + { + appendUpdateDeleteAction(buildMergeAction(ActionType.update, condition), notMatchedBySource) + } + + public func delete(_ condition: String?, _ notMatchedBySource: Bool) -> MergeIntoWriter { + appendUpdateDeleteAction(buildMergeAction(ActionType.delete, condition), notMatchedBySource) + } + + private func appendUpdateDeleteAction( + _ action: Spark_Connect_Expression, + _ notMatchedBySource: Bool + ) -> MergeIntoWriter { + if notMatchedBySource { + self.mergeIntoTableCommand.notMatchedBySourceActions.append(action) + } else { + self.mergeIntoTableCommand.matchActions.append(action) + } + return self + } + + private func buildMergeAction( + _ actionType: ActionType, + _ condition: String?, + _ assignments: [String: String] = [:] + ) -> Spark_Connect_Expression { + var mergeAction = Spark_Connect_MergeAction() + mergeAction.actionType = actionType + if let condition { + var expression = Spark_Connect_Expression() + expression.expressionString = condition.toExpressionString + mergeAction.condition = expression + } + mergeAction.assignments = assignments.map { key, value in + var keyExpr = Spark_Connect_Expression() + var valueExpr = Spark_Connect_Expression() + + keyExpr.expressionString = key.toExpressionString + valueExpr.expressionString = value.toExpressionString + + var assignment = MergeAction.Assignment() + assignment.key = keyExpr + assignment.value = valueExpr + return assignment + } + var expression = Spark_Connect_Expression() + expression.mergeAction = mergeAction + return expression + } +} diff --git a/Sources/SparkConnect/TypeAliases.swift b/Sources/SparkConnect/TypeAliases.swift index 537ea42..b061f32 100644 --- a/Sources/SparkConnect/TypeAliases.swift +++ b/Sources/SparkConnect/TypeAliases.swift @@ -16,6 +16,7 @@ // specific language governing permissions and limitations // under the License. +typealias ActionType = Spark_Connect_MergeAction.ActionType typealias Aggregate = Spark_Connect_Aggregate typealias AnalyzePlanRequest = Spark_Connect_AnalyzePlanRequest typealias AnalyzePlanResponse = Spark_Connect_AnalyzePlanResponse @@ -38,6 +39,8 @@ typealias KeyValue = Spark_Connect_KeyValue typealias LateralJoin = Spark_Connect_LateralJoin typealias Limit = Spark_Connect_Limit typealias MapType = Spark_Connect_DataType.Map +typealias MergeAction = Spark_Connect_MergeAction +typealias MergeIntoTableCommand = Spark_Connect_MergeIntoTableCommand typealias NamedTable = Spark_Connect_Read.NamedTable typealias OneOf_Analyze = AnalyzePlanRequest.OneOf_Analyze typealias OneOf_CatType = Spark_Connect_Catalog.OneOf_CatType diff --git a/Tests/SparkConnectTests/IcebergTests.swift b/Tests/SparkConnectTests/IcebergTests.swift index 0483f5c..94c6a8a 100644 --- a/Tests/SparkConnectTests/IcebergTests.swift +++ b/Tests/SparkConnectTests/IcebergTests.swift @@ -83,8 +83,31 @@ struct IcebergTests { try await spark.table(t1).writeTo(t2).overwrite("id = 1") #expect(try await spark.table(t2).count() == 3) - }) + try await spark.sql( + """ + MERGE INTO \(t2) t + USING (SELECT * + FROM VALUES + (1, 'delete', null), + (2, 'update', 'updated'), + (4, null, 'new') AS T(id, op, data)) s + ON t.id = s.id + WHEN MATCHED AND s.op = 'delete' THEN DELETE + WHEN MATCHED AND s.op = 'update' THEN UPDATE SET t.data = s.data + WHEN NOT MATCHED THEN INSERT * + WHEN NOT MATCHED BY SOURCE THEN UPDATE SET data = 'invalid' + """ + ).count() + #if !os(Linux) + let expected = [ + Row(2, "updated"), + Row(3, "invalid"), + Row(4, "new"), + ] + #expect(try await spark.table(t2).collect() == expected) + #endif + }) await spark.stop() } diff --git a/Tests/SparkConnectTests/MergeIntoWriterTests.swift b/Tests/SparkConnectTests/MergeIntoWriterTests.swift new file mode 100644 index 0000000..d7d693d --- /dev/null +++ b/Tests/SparkConnectTests/MergeIntoWriterTests.swift @@ -0,0 +1,75 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// + +import Foundation +import SparkConnect +import Testing + +/// A test suite for `MergeIntoWriter` +/// Since this requires Apache Spark 4 with Iceberg support (SPARK-48794), this suite only tests syntaxes. +@Suite(.serialized) +struct MergeIntoWriterTests { + @Test + func whenMatched() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") + try await SQLHelper.withTable(spark, tableName)({ + let mergeInto = try await spark.range(1).mergeInto(tableName, "true") + try await #require(throws: Error.self) { + try await mergeInto.whenMatched().delete().merge() + } + try await #require(throws: Error.self) { + try await mergeInto.whenMatched("true").delete().merge() + } + }) + await spark.stop() + } + + @Test + func whenNotMatched() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") + try await SQLHelper.withTable(spark, tableName)({ + let mergeInto = try await spark.range(1).mergeInto(tableName, "true") + try await #require(throws: Error.self) { + try await mergeInto.whenNotMatched().insertAll().merge() + } + try await #require(throws: Error.self) { + try await mergeInto.whenNotMatched("true").insertAll().merge() + } + }) + await spark.stop() + } + + @Test + func whenNotMatchedBySource() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") + try await SQLHelper.withTable(spark, tableName)({ + let mergeInto = try await spark.range(1).mergeInto(tableName, "true") + try await #require(throws: Error.self) { + try await mergeInto.whenNotMatchedBySource().delete().merge() + } + try await #require(throws: Error.self) { + try await mergeInto.whenNotMatchedBySource("true").delete().merge() + } + }) + await spark.stop() + } +} From 334bf5e14339cae0287588f3c453116bf84c3e13 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 15 May 2025 10:55:30 -0700 Subject: [PATCH 2/3] Add comments --- Sources/SparkConnect/MergeIntoWriter.swift | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/Sources/SparkConnect/MergeIntoWriter.swift b/Sources/SparkConnect/MergeIntoWriter.swift index 4d825cd..923f5eb 100644 --- a/Sources/SparkConnect/MergeIntoWriter.swift +++ b/Sources/SparkConnect/MergeIntoWriter.swift @@ -75,6 +75,8 @@ public struct WhenNotMatched: Sendable { } } +/// A struct for defining actions to be performed when there is no match by source during a merge +/// operation in a ``MergeIntoWriter``. public struct WhenNotMatchedBySource: Sendable { let mergeIntoWriter: MergeIntoWriter let condition: String? @@ -84,14 +86,24 @@ public struct WhenNotMatchedBySource: Sendable { self.condition = condition } + /// Specifies an action to update all non-matched rows in the target ``DataFrame`` + /// when not matched by the source. + /// - Returns: A ``MergeIntoWriter`` instance. public func updateAll() async -> MergeIntoWriter { await mergeIntoWriter.updateAll(condition, true) } + /// Specifies an action to update non-matched rows in the target ``DataFrame`` + /// with the provided column assignments when not matched by the source. + /// - Parameter map: A dictionary from column names to expressions representing the updates to be applied + /// - Returns: A ``MergeIntoWriter`` instance. public func update(map: [String: String]) async -> MergeIntoWriter { await mergeIntoWriter.update(condition, map, true) } + /// Specifies an action to delete non-matched rows from the target ``DataFrame`` + /// when not matched by the source. + /// - Returns: A ``MergeIntoWriter`` instance. public func delete() async -> MergeIntoWriter { await mergeIntoWriter.delete(condition, true) } From 11e03a61a5f17183610b9c6bae1d48f48c376235 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 15 May 2025 12:50:54 -0700 Subject: [PATCH 3/3] fix Swift 6.0 compilation --- Sources/SparkConnect/DataFrame.swift | 2 +- Sources/SparkConnect/MergeIntoWriter.swift | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/Sources/SparkConnect/DataFrame.swift b/Sources/SparkConnect/DataFrame.swift index 34cceb7..f5acc02 100644 --- a/Sources/SparkConnect/DataFrame.swift +++ b/Sources/SparkConnect/DataFrame.swift @@ -1401,7 +1401,7 @@ public actor DataFrame: Sendable { /// - condition: A condition expression. /// - Returns: A ``MergeIntoWriter`` instance. public func mergeInto(_ table: String, _ condition: String) async -> MergeIntoWriter { - return await MergeIntoWriter(table, self, condition) + return MergeIntoWriter(table, self, condition) } /// Returns a ``DataStreamWriter`` that can be used to write streaming data. diff --git a/Sources/SparkConnect/MergeIntoWriter.swift b/Sources/SparkConnect/MergeIntoWriter.swift index 923f5eb..de15892 100644 --- a/Sources/SparkConnect/MergeIntoWriter.swift +++ b/Sources/SparkConnect/MergeIntoWriter.swift @@ -122,14 +122,13 @@ public actor MergeIntoWriter { var mergeIntoTableCommand = MergeIntoTableCommand() - init(_ table: String, _ df: DataFrame, _ condition: String) async { + init(_ table: String, _ df: DataFrame, _ condition: String) { self.table = table self.df = df self.condition = condition - mergeIntoTableCommand.targetTableName = table - mergeIntoTableCommand.sourceTablePlan = await (df.getPlan() as! Plan).root - mergeIntoTableCommand.mergeCondition.expressionString = condition.toExpressionString + self.mergeIntoTableCommand.targetTableName = table + self.mergeIntoTableCommand.mergeCondition.expressionString = condition.toExpressionString } public var schemaEvolutionEnabled: Bool { @@ -150,7 +149,7 @@ public actor MergeIntoWriter { } /// Initialize a `WhenMatched` action with a condition. - /// - Parameter condition: <#condition description#> + /// - Parameter condition: The condition to be evaluated for the action. /// - Returns: A `WhenMatched` instance configured with the specified condition. public func whenMatched(_ condition: String) -> WhenMatched { WhenMatched(self, condition) @@ -190,6 +189,7 @@ public actor MergeIntoWriter { { throw SparkConnectError.InvalidArgumentException } + self.mergeIntoTableCommand.sourceTablePlan = await (self.df.getPlan() as! Plan).root self.mergeIntoTableCommand.withSchemaEvolution = self.schemaEvolution var command = Spark_Connect_Command()