From ee7fca36eeb5bcf0ecb9d1467a767dd6ab80bfee Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 13 Mar 2025 10:45:18 -0700 Subject: [PATCH 1/7] [SPARK-51504] Support `DataFrame.select` --- Sources/SparkConnect/DataFrame.swift | 7 +++- Sources/SparkConnect/Extension.swift | 6 +++ Sources/SparkConnect/SparkConnectClient.swift | 16 ++++++++ Sources/SparkConnect/TypeAliases.swift | 2 + Tests/SparkConnectTests/DataFrameTests.swift | 39 +++++++++++++++++++ 5 files changed, 69 insertions(+), 1 deletion(-) diff --git a/Sources/SparkConnect/DataFrame.swift b/Sources/SparkConnect/DataFrame.swift index 44313f5..77e6322 100644 --- a/Sources/SparkConnect/DataFrame.swift +++ b/Sources/SparkConnect/DataFrame.swift @@ -36,7 +36,7 @@ public actor DataFrame: Sendable { /// - Parameters: /// - spark: A ``SparkSession`` instance to use. /// - plan: A plan to execute. - init(spark: SparkSession, plan: Plan) async throws { + init(spark: SparkSession, plan: Plan) { self.spark = spark self.plan = plan } @@ -192,4 +192,9 @@ public actor DataFrame: Sendable { print(table.render()) } } + + public func select(_ cols: String...) -> DataFrame { + let plan = SparkConnectClient.getProject(self.plan.root, cols) + return DataFrame(spark: self.spark, plan: plan) + } } diff --git a/Sources/SparkConnect/Extension.swift b/Sources/SparkConnect/Extension.swift index 4e3c627..3b2f839 100644 --- a/Sources/SparkConnect/Extension.swift +++ b/Sources/SparkConnect/Extension.swift @@ -45,6 +45,12 @@ extension String { keyValue.key = self return keyValue } + + var toUnresolvedAttribute: UnresolvedAttribute { + var attribute = UnresolvedAttribute() + attribute.unparsedIdentifier = self + return attribute + } } extension [String: String] { diff --git a/Sources/SparkConnect/SparkConnectClient.swift b/Sources/SparkConnect/SparkConnectClient.swift index 92cf56e..e6cdbf2 100644 --- a/Sources/SparkConnect/SparkConnectClient.swift +++ b/Sources/SparkConnect/SparkConnectClient.swift @@ -252,4 +252,20 @@ public actor SparkConnectClient { request.analyze = .schema(schema) return request } + + static func getProject(_ child: Relation, _ cols: [String]) -> Plan { + var project = Project() + project.input = child + let expressions: [Spark_Connect_Expression] = cols.map { + var expression = Spark_Connect_Expression() + expression.exprType = .unresolvedAttribute($0.toUnresolvedAttribute) + return expression + } + project.expressions = expressions + var relation = Relation() + relation.project = project + var plan = Plan() + plan.opType = .root(relation) + return plan + } } diff --git a/Sources/SparkConnect/TypeAliases.swift b/Sources/SparkConnect/TypeAliases.swift index ad0e898..743f545 100644 --- a/Sources/SparkConnect/TypeAliases.swift +++ b/Sources/SparkConnect/TypeAliases.swift @@ -22,8 +22,10 @@ typealias ConfigRequest = Spark_Connect_ConfigRequest typealias DataType = Spark_Connect_DataType typealias ExecutePlanRequest = Spark_Connect_ExecutePlanRequest typealias Plan = Spark_Connect_Plan +typealias Project = Spark_Connect_Project typealias KeyValue = Spark_Connect_KeyValue typealias Range = Spark_Connect_Range typealias Relation = Spark_Connect_Relation typealias SparkConnectService = Spark_Connect_SparkConnectService typealias UserContext = Spark_Connect_UserContext +typealias UnresolvedAttribute = Spark_Connect_Expression.UnresolvedAttribute diff --git a/Tests/SparkConnectTests/DataFrameTests.swift b/Tests/SparkConnectTests/DataFrameTests.swift index e3f90a6..83f0bbd 100644 --- a/Tests/SparkConnectTests/DataFrameTests.swift +++ b/Tests/SparkConnectTests/DataFrameTests.swift @@ -68,6 +68,45 @@ struct DataFrameTests { await spark.stop() } + @Test + func selectNone() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let emptySchema = try await spark.range(1).select().schema() + #expect(emptySchema == #"{"struct":{}}"#) + await spark.stop() + } + + @Test + func select() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let schema = try await spark.range(1).select("id").schema() + #expect( + schema + == #"{"struct":{"fields":[{"name":"id","dataType":{"long":{}}}]}}"# + ) + await spark.stop() + } + + @Test + func selectMultipleColumns() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let schema = try await spark.sql("SELECT * FROM VALUES (1, 2)").select("col2", "col1").schema() + #expect( + schema + == #"{"struct":{"fields":[{"name":"col2","dataType":{"integer":{}}},{"name":"col1","dataType":{"integer":{}}}]}}"# + ) + await spark.stop() + } + + @Test + func selectInvalidColumn() async throws { + let spark = try await SparkSession.builder.getOrCreate() + try await #require(throws: Error.self) { + let _ = try await spark.range(1).select("invalid").schema() + } + await spark.stop() + } + @Test func table() async throws { let spark = try await SparkSession.builder.getOrCreate() From 70f728078764c1b015c535d08ac744f7713f2da9 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 13 Mar 2025 13:39:21 -0700 Subject: [PATCH 2/7] Add limit --- Sources/SparkConnect/DataFrame.swift | 5 +++++ Sources/SparkConnect/SparkConnectClient.swift | 11 +++++++++++ Sources/SparkConnect/TypeAliases.swift | 1 + Tests/SparkConnectTests/DataFrameTests.swift | 9 +++++++++ 4 files changed, 26 insertions(+) diff --git a/Sources/SparkConnect/DataFrame.swift b/Sources/SparkConnect/DataFrame.swift index 77e6322..a2ac3ed 100644 --- a/Sources/SparkConnect/DataFrame.swift +++ b/Sources/SparkConnect/DataFrame.swift @@ -197,4 +197,9 @@ public actor DataFrame: Sendable { let plan = SparkConnectClient.getProject(self.plan.root, cols) return DataFrame(spark: self.spark, plan: plan) } + + public func limit(_ n: Int32) -> DataFrame { + let plan = SparkConnectClient.getLimit(self.plan.root, n) + return DataFrame(spark: self.spark, plan: plan) + } } diff --git a/Sources/SparkConnect/SparkConnectClient.swift b/Sources/SparkConnect/SparkConnectClient.swift index e6cdbf2..bc87259 100644 --- a/Sources/SparkConnect/SparkConnectClient.swift +++ b/Sources/SparkConnect/SparkConnectClient.swift @@ -268,4 +268,15 @@ public actor SparkConnectClient { plan.opType = .root(relation) return plan } + + static func getLimit(_ child: Relation, _ n: Int32) -> Plan { + var limit = Limit() + limit.input = child + limit.limit = n + var relation = Relation() + relation.limit = limit + var plan = Plan() + plan.opType = .root(relation) + return plan + } } diff --git a/Sources/SparkConnect/TypeAliases.swift b/Sources/SparkConnect/TypeAliases.swift index 743f545..5beeba8 100644 --- a/Sources/SparkConnect/TypeAliases.swift +++ b/Sources/SparkConnect/TypeAliases.swift @@ -24,6 +24,7 @@ typealias ExecutePlanRequest = Spark_Connect_ExecutePlanRequest typealias Plan = Spark_Connect_Plan typealias Project = Spark_Connect_Project typealias KeyValue = Spark_Connect_KeyValue +typealias Limit = Spark_Connect_Limit typealias Range = Spark_Connect_Range typealias Relation = Spark_Connect_Relation typealias SparkConnectService = Spark_Connect_SparkConnectService diff --git a/Tests/SparkConnectTests/DataFrameTests.swift b/Tests/SparkConnectTests/DataFrameTests.swift index 83f0bbd..7a10253 100644 --- a/Tests/SparkConnectTests/DataFrameTests.swift +++ b/Tests/SparkConnectTests/DataFrameTests.swift @@ -107,6 +107,15 @@ struct DataFrameTests { await spark.stop() } + @Test + func limit() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.range(10).limit(0).count() == 0) + #expect(try await spark.range(10).limit(1).count() == 1) + #expect(try await spark.range(10).limit(2).count() == 2) + await spark.stop() + } + @Test func table() async throws { let spark = try await SparkSession.builder.getOrCreate() From 8f10457609069d027843a3c90bed0f72107de8aa Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 13 Mar 2025 14:16:07 -0700 Subject: [PATCH 3/7] Add sort, orderBy --- Sources/SparkConnect/DataFrame.swift | 14 ++++++++++---- Sources/SparkConnect/SparkConnectClient.swift | 16 ++++++++++++++++ Sources/SparkConnect/TypeAliases.swift | 1 + Tests/SparkConnectTests/DataFrameTests.swift | 14 ++++++++++++++ 4 files changed, 41 insertions(+), 4 deletions(-) diff --git a/Sources/SparkConnect/DataFrame.swift b/Sources/SparkConnect/DataFrame.swift index a2ac3ed..0251c42 100644 --- a/Sources/SparkConnect/DataFrame.swift +++ b/Sources/SparkConnect/DataFrame.swift @@ -194,12 +194,18 @@ public actor DataFrame: Sendable { } public func select(_ cols: String...) -> DataFrame { - let plan = SparkConnectClient.getProject(self.plan.root, cols) - return DataFrame(spark: self.spark, plan: plan) + return DataFrame(spark: self.spark, plan: SparkConnectClient.getProject(self.plan.root, cols)) + } + + public func sort(_ cols: String...) -> DataFrame { + return DataFrame(spark: self.spark, plan: SparkConnectClient.getSort(self.plan.root, cols)) + } + + public func orderBy(_ cols: String...) -> DataFrame { + return DataFrame(spark: self.spark, plan: SparkConnectClient.getSort(self.plan.root, cols)) } public func limit(_ n: Int32) -> DataFrame { - let plan = SparkConnectClient.getLimit(self.plan.root, n) - return DataFrame(spark: self.spark, plan: plan) + return DataFrame(spark: self.spark, plan: SparkConnectClient.getLimit(self.plan.root, n)) } } diff --git a/Sources/SparkConnect/SparkConnectClient.swift b/Sources/SparkConnect/SparkConnectClient.swift index bc87259..0ab1e0e 100644 --- a/Sources/SparkConnect/SparkConnectClient.swift +++ b/Sources/SparkConnect/SparkConnectClient.swift @@ -269,6 +269,22 @@ public actor SparkConnectClient { return plan } + static func getSort(_ child: Relation, _ cols: [String]) -> Plan { + var sort = Sort() + sort.input = child + let expressions: [Spark_Connect_Expression.SortOrder] = cols.map { + var expression = Spark_Connect_Expression.SortOrder() + expression.child.exprType = .unresolvedAttribute($0.toUnresolvedAttribute) + return expression + } + sort.order = expressions + var relation = Relation() + relation.sort = sort + var plan = Plan() + plan.opType = .root(relation) + return plan + } + static func getLimit(_ child: Relation, _ n: Int32) -> Plan { var limit = Limit() limit.input = child diff --git a/Sources/SparkConnect/TypeAliases.swift b/Sources/SparkConnect/TypeAliases.swift index 5beeba8..8e68723 100644 --- a/Sources/SparkConnect/TypeAliases.swift +++ b/Sources/SparkConnect/TypeAliases.swift @@ -28,5 +28,6 @@ typealias Limit = Spark_Connect_Limit typealias Range = Spark_Connect_Range typealias Relation = Spark_Connect_Relation typealias SparkConnectService = Spark_Connect_SparkConnectService +typealias Sort = Spark_Connect_Sort typealias UserContext = Spark_Connect_UserContext typealias UnresolvedAttribute = Spark_Connect_Expression.UnresolvedAttribute diff --git a/Tests/SparkConnectTests/DataFrameTests.swift b/Tests/SparkConnectTests/DataFrameTests.swift index 7a10253..5a35f8d 100644 --- a/Tests/SparkConnectTests/DataFrameTests.swift +++ b/Tests/SparkConnectTests/DataFrameTests.swift @@ -116,6 +116,20 @@ struct DataFrameTests { await spark.stop() } + @Test + func sort() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.range(10).sort("id").count() == 10) + await spark.stop() + } + + @Test + func orderBy() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.range(10).orderBy("id").count() == 10) + await spark.stop() + } + @Test func table() async throws { let spark = try await SparkSession.builder.getOrCreate() From 08961869d47956627c9b543e90595b97bfe1facd Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 13 Mar 2025 14:51:35 -0700 Subject: [PATCH 4/7] Add isEmpty --- Sources/SparkConnect/DataFrame.swift | 18 ++++++++++++++++++ Tests/SparkConnectTests/DataFrameTests.swift | 8 ++++++++ 2 files changed, 26 insertions(+) diff --git a/Sources/SparkConnect/DataFrame.swift b/Sources/SparkConnect/DataFrame.swift index 0251c42..c3995db 100644 --- a/Sources/SparkConnect/DataFrame.swift +++ b/Sources/SparkConnect/DataFrame.swift @@ -193,19 +193,37 @@ public actor DataFrame: Sendable { } } + /// Projects a set of expressions and returns a new ``DataFrame``. + /// - Parameter cols: Column names + /// - Returns: A ``DataFrame`` with subset of columns. public func select(_ cols: String...) -> DataFrame { return DataFrame(spark: self.spark, plan: SparkConnectClient.getProject(self.plan.root, cols)) } + /// Return a new ``DataFrame`` sorted by the specified column(s). + /// - Parameter cols: Column names. + /// - Returns: A sorted ``DataFrame`` public func sort(_ cols: String...) -> DataFrame { return DataFrame(spark: self.spark, plan: SparkConnectClient.getSort(self.plan.root, cols)) } + /// <#Description#> + /// - Parameter cols: <#cols description#> + /// - Returns: <#description#> public func orderBy(_ cols: String...) -> DataFrame { return DataFrame(spark: self.spark, plan: SparkConnectClient.getSort(self.plan.root, cols)) } + /// Limits the result count to the number specified. + /// - Parameter n: Number of records to return. Will return this number of records or all records if the ``DataFrame`` contains less than this number of records. + /// - Returns: A subset of the records public func limit(_ n: Int32) -> DataFrame { return DataFrame(spark: self.spark, plan: SparkConnectClient.getLimit(self.plan.root, n)) } + + /// Chec if the ``DataFrame`` is empty and returns a boolean value. + /// - Returns: `true` if the ``DataFrame`` is empty, `false` otherwise. + public func isEmpty() async throws -> Bool { + return try await select().limit(1).count() == 0 + } } diff --git a/Tests/SparkConnectTests/DataFrameTests.swift b/Tests/SparkConnectTests/DataFrameTests.swift index 5a35f8d..95d7fde 100644 --- a/Tests/SparkConnectTests/DataFrameTests.swift +++ b/Tests/SparkConnectTests/DataFrameTests.swift @@ -116,6 +116,14 @@ struct DataFrameTests { await spark.stop() } + @Test + func isEmpty() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.range(0).isEmpty()) + #expect(!(try await spark.range(1).isEmpty())) + await spark.stop() + } + @Test func sort() async throws { let spark = try await SparkSession.builder.getOrCreate() From 8ba8dc2295819d179581f641a57e4f82efc0700e Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 13 Mar 2025 15:15:07 -0700 Subject: [PATCH 5/7] Update Sources/SparkConnect/DataFrame.swift Co-authored-by: Liang-Chi Hsieh --- Sources/SparkConnect/DataFrame.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/SparkConnect/DataFrame.swift b/Sources/SparkConnect/DataFrame.swift index c3995db..c76c91a 100644 --- a/Sources/SparkConnect/DataFrame.swift +++ b/Sources/SparkConnect/DataFrame.swift @@ -221,7 +221,7 @@ public actor DataFrame: Sendable { return DataFrame(spark: self.spark, plan: SparkConnectClient.getLimit(self.plan.root, n)) } - /// Chec if the ``DataFrame`` is empty and returns a boolean value. + /// Checks if the ``DataFrame`` is empty and returns a boolean value. /// - Returns: `true` if the ``DataFrame`` is empty, `false` otherwise. public func isEmpty() async throws -> Bool { return try await select().limit(1).count() == 0 From 453eccf814d80e419d87c3e57b682b54362743c3 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 13 Mar 2025 15:15:22 -0700 Subject: [PATCH 6/7] Update Tests/SparkConnectTests/DataFrameTests.swift Co-authored-by: Liang-Chi Hsieh --- Tests/SparkConnectTests/DataFrameTests.swift | 1 + 1 file changed, 1 insertion(+) diff --git a/Tests/SparkConnectTests/DataFrameTests.swift b/Tests/SparkConnectTests/DataFrameTests.swift index 95d7fde..2b57d2d 100644 --- a/Tests/SparkConnectTests/DataFrameTests.swift +++ b/Tests/SparkConnectTests/DataFrameTests.swift @@ -113,6 +113,7 @@ struct DataFrameTests { #expect(try await spark.range(10).limit(0).count() == 0) #expect(try await spark.range(10).limit(1).count() == 1) #expect(try await spark.range(10).limit(2).count() == 2) + #expect(try await spark.range(10).limit(15).count() == 10) await spark.stop() } From 68c02117fcbe9ab98ec0352aebc504558bb2b622 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 13 Mar 2025 15:16:55 -0700 Subject: [PATCH 7/7] Address comment --- Sources/SparkConnect/DataFrame.swift | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Sources/SparkConnect/DataFrame.swift b/Sources/SparkConnect/DataFrame.swift index c76c91a..7ab4fc6 100644 --- a/Sources/SparkConnect/DataFrame.swift +++ b/Sources/SparkConnect/DataFrame.swift @@ -207,9 +207,9 @@ public actor DataFrame: Sendable { return DataFrame(spark: self.spark, plan: SparkConnectClient.getSort(self.plan.root, cols)) } - /// <#Description#> - /// - Parameter cols: <#cols description#> - /// - Returns: <#description#> + /// Return a new ``DataFrame`` sorted by the specified column(s). + /// - Parameter cols: Column names. + /// - Returns: A sorted ``DataFrame`` public func orderBy(_ cols: String...) -> DataFrame { return DataFrame(spark: self.spark, plan: SparkConnectClient.getSort(self.plan.root, cols)) }