From b9ea9283ecdf7a75edfa77f753a71ecfa80ed53e Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 15 May 2025 15:17:06 -0700 Subject: [PATCH] [SPARK-52167] Support `hint` for `DataFrame` --- Sources/SparkConnect/DataFrame.swift | 12 +++++++ Sources/SparkConnect/SparkConnectClient.swift | 35 +++++++++++++++++++ Tests/SparkConnectTests/DataFrameTests.swift | 22 ++++++++++++ 3 files changed, 69 insertions(+) diff --git a/Sources/SparkConnect/DataFrame.swift b/Sources/SparkConnect/DataFrame.swift index f5acc02..a696158 100644 --- a/Sources/SparkConnect/DataFrame.swift +++ b/Sources/SparkConnect/DataFrame.swift @@ -115,6 +115,7 @@ import Synchronization /// - ``melt(_:_:_:_:)`` /// - ``transpose()`` /// - ``transpose(_:)`` +/// - ``hint(_:_:)`` /// /// ### Join Operations /// - ``join(_:)`` @@ -1349,6 +1350,17 @@ public actor DataFrame: Sendable { return GroupedData(self, GroupType.cube, cols) } + /// Specifies some hint on the current Dataset. + /// - Parameters: + /// - name: The hint name. + /// - parameters: The parameters of the hint + /// - Returns: A ``DataFrame``. + @discardableResult + public func hint(_ name: String, _ parameters: Sendable...) -> DataFrame { + let plan = SparkConnectClient.getHint(self.plan.root, name, parameters) + return DataFrame(spark: self.spark, plan: plan) + } + /// Creates a local temporary view using the given name. The lifetime of this temporary view is /// tied to the `SparkSession` that was used to create this ``DataFrame``. /// - Parameter viewName: A view name. diff --git a/Sources/SparkConnect/SparkConnectClient.swift b/Sources/SparkConnect/SparkConnectClient.swift index 5c48c55..743a41a 100644 --- a/Sources/SparkConnect/SparkConnectClient.swift +++ b/Sources/SparkConnect/SparkConnectClient.swift @@ -950,6 +950,41 @@ public actor SparkConnectClient { return plan } + static func getHint(_ child: Relation, _ name: String, _ parameters: [Sendable]) -> Plan { + var hint = Spark_Connect_Hint() + hint.input = child + hint.name = name + hint.parameters = parameters.map { + var literal = ExpressionLiteral() + switch $0 { + case let value as Bool: + literal.boolean = value + case let value as Int8: + literal.byte = Int32(value) + case let value as Int16: + literal.short = Int32(value) + case let value as Int32: + literal.integer = value + case let value as Int64: // Hint parameter raises exceptions for Int64 + literal.integer = Int32(value) + case let value as Int: + literal.integer = Int32(value) + case let value as String: + literal.string = value + default: + literal.string = $0 as! String + } + var expr = Spark_Connect_Expression() + expr.literal = literal + return expr + } + var relation = Relation() + relation.hint = hint + var plan = Plan() + plan.opType = .root(relation) + return plan + } + func createTempView( _ child: Relation, _ viewName: String, replace: Bool, isGlobal: Bool ) async throws { diff --git a/Tests/SparkConnectTests/DataFrameTests.swift b/Tests/SparkConnectTests/DataFrameTests.swift index 7b6c4c2..4ee6ae9 100644 --- a/Tests/SparkConnectTests/DataFrameTests.swift +++ b/Tests/SparkConnectTests/DataFrameTests.swift @@ -852,4 +852,26 @@ struct DataFrameTests { await spark.stop() } + + @Test + func hint() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let df1 = try await spark.range(1) + let df2 = try await spark.range(1) + + try await df1.join(df2.hint("broadcast")).count() + try await df1.join(df2.hint("coalesce", 10)).count() + try await df1.join(df2.hint("rebalance", 10)).count() + try await df1.join(df2.hint("rebalance", 10, "id")).count() + try await df1.join(df2.hint("repartition", 10)).count() + try await df1.join(df2.hint("repartition", 10, "id")).count() + try await df1.join(df2.hint("repartition", "id")).count() + try await df1.join(df2.hint("repartition_by_range")).count() + try await df1.join(df2.hint("merge")).count() + try await df1.join(df2.hint("shuffle_hash")).count() + try await df1.join(df2.hint("shuffle_replicate_nl")).count() + try await df1.join(df2.hint("shuffle_merge")).count() + + await spark.stop() + } }