diff --git a/Sources/SparkConnect/DataFrame.swift b/Sources/SparkConnect/DataFrame.swift index 44313f5..7ab4fc6 100644 --- a/Sources/SparkConnect/DataFrame.swift +++ b/Sources/SparkConnect/DataFrame.swift @@ -36,7 +36,7 @@ public actor DataFrame: Sendable { /// - Parameters: /// - spark: A ``SparkSession`` instance to use. /// - plan: A plan to execute. - init(spark: SparkSession, plan: Plan) async throws { + init(spark: SparkSession, plan: Plan) { self.spark = spark self.plan = plan } @@ -192,4 +192,38 @@ public actor DataFrame: Sendable { print(table.render()) } } + + /// Projects a set of expressions and returns a new ``DataFrame``. + /// - Parameter cols: Column names + /// - Returns: A ``DataFrame`` with subset of columns. + public func select(_ cols: String...) -> DataFrame { + return DataFrame(spark: self.spark, plan: SparkConnectClient.getProject(self.plan.root, cols)) + } + + /// Return a new ``DataFrame`` sorted by the specified column(s). + /// - Parameter cols: Column names. + /// - Returns: A sorted ``DataFrame`` + public func sort(_ cols: String...) -> DataFrame { + return DataFrame(spark: self.spark, plan: SparkConnectClient.getSort(self.plan.root, cols)) + } + + /// Return a new ``DataFrame`` sorted by the specified column(s). + /// - Parameter cols: Column names. + /// - Returns: A sorted ``DataFrame`` + public func orderBy(_ cols: String...) -> DataFrame { + return DataFrame(spark: self.spark, plan: SparkConnectClient.getSort(self.plan.root, cols)) + } + + /// Limits the result count to the number specified. + /// - Parameter n: Number of records to return. Will return this number of records or all records if the ``DataFrame`` contains less than this number of records. + /// - Returns: A subset of the records + public func limit(_ n: Int32) -> DataFrame { + return DataFrame(spark: self.spark, plan: SparkConnectClient.getLimit(self.plan.root, n)) + } + + /// Checks if the ``DataFrame`` is empty and returns a boolean value. + /// - Returns: `true` if the ``DataFrame`` is empty, `false` otherwise. + public func isEmpty() async throws -> Bool { + return try await select().limit(1).count() == 0 + } } diff --git a/Sources/SparkConnect/Extension.swift b/Sources/SparkConnect/Extension.swift index 4e3c627..3b2f839 100644 --- a/Sources/SparkConnect/Extension.swift +++ b/Sources/SparkConnect/Extension.swift @@ -45,6 +45,12 @@ extension String { keyValue.key = self return keyValue } + + var toUnresolvedAttribute: UnresolvedAttribute { + var attribute = UnresolvedAttribute() + attribute.unparsedIdentifier = self + return attribute + } } extension [String: String] { diff --git a/Sources/SparkConnect/SparkConnectClient.swift b/Sources/SparkConnect/SparkConnectClient.swift index 92cf56e..0ab1e0e 100644 --- a/Sources/SparkConnect/SparkConnectClient.swift +++ b/Sources/SparkConnect/SparkConnectClient.swift @@ -252,4 +252,47 @@ public actor SparkConnectClient { request.analyze = .schema(schema) return request } + + static func getProject(_ child: Relation, _ cols: [String]) -> Plan { + var project = Project() + project.input = child + let expressions: [Spark_Connect_Expression] = cols.map { + var expression = Spark_Connect_Expression() + expression.exprType = .unresolvedAttribute($0.toUnresolvedAttribute) + return expression + } + project.expressions = expressions + var relation = Relation() + relation.project = project + var plan = Plan() + plan.opType = .root(relation) + return plan + } + + static func getSort(_ child: Relation, _ cols: [String]) -> Plan { + var sort = Sort() + sort.input = child + let expressions: [Spark_Connect_Expression.SortOrder] = cols.map { + var expression = Spark_Connect_Expression.SortOrder() + expression.child.exprType = .unresolvedAttribute($0.toUnresolvedAttribute) + return expression + } + sort.order = expressions + var relation = Relation() + relation.sort = sort + var plan = Plan() + plan.opType = .root(relation) + return plan + } + + static func getLimit(_ child: Relation, _ n: Int32) -> Plan { + var limit = Limit() + limit.input = child + limit.limit = n + var relation = Relation() + relation.limit = limit + var plan = Plan() + plan.opType = .root(relation) + return plan + } } diff --git a/Sources/SparkConnect/TypeAliases.swift b/Sources/SparkConnect/TypeAliases.swift index ad0e898..8e68723 100644 --- a/Sources/SparkConnect/TypeAliases.swift +++ b/Sources/SparkConnect/TypeAliases.swift @@ -22,8 +22,12 @@ typealias ConfigRequest = Spark_Connect_ConfigRequest typealias DataType = Spark_Connect_DataType typealias ExecutePlanRequest = Spark_Connect_ExecutePlanRequest typealias Plan = Spark_Connect_Plan +typealias Project = Spark_Connect_Project typealias KeyValue = Spark_Connect_KeyValue +typealias Limit = Spark_Connect_Limit typealias Range = Spark_Connect_Range typealias Relation = Spark_Connect_Relation typealias SparkConnectService = Spark_Connect_SparkConnectService +typealias Sort = Spark_Connect_Sort typealias UserContext = Spark_Connect_UserContext +typealias UnresolvedAttribute = Spark_Connect_Expression.UnresolvedAttribute diff --git a/Tests/SparkConnectTests/DataFrameTests.swift b/Tests/SparkConnectTests/DataFrameTests.swift index e3f90a6..2b57d2d 100644 --- a/Tests/SparkConnectTests/DataFrameTests.swift +++ b/Tests/SparkConnectTests/DataFrameTests.swift @@ -68,6 +68,77 @@ struct DataFrameTests { await spark.stop() } + @Test + func selectNone() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let emptySchema = try await spark.range(1).select().schema() + #expect(emptySchema == #"{"struct":{}}"#) + await spark.stop() + } + + @Test + func select() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let schema = try await spark.range(1).select("id").schema() + #expect( + schema + == #"{"struct":{"fields":[{"name":"id","dataType":{"long":{}}}]}}"# + ) + await spark.stop() + } + + @Test + func selectMultipleColumns() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let schema = try await spark.sql("SELECT * FROM VALUES (1, 2)").select("col2", "col1").schema() + #expect( + schema + == #"{"struct":{"fields":[{"name":"col2","dataType":{"integer":{}}},{"name":"col1","dataType":{"integer":{}}}]}}"# + ) + await spark.stop() + } + + @Test + func selectInvalidColumn() async throws { + let spark = try await SparkSession.builder.getOrCreate() + try await #require(throws: Error.self) { + let _ = try await spark.range(1).select("invalid").schema() + } + await spark.stop() + } + + @Test + func limit() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.range(10).limit(0).count() == 0) + #expect(try await spark.range(10).limit(1).count() == 1) + #expect(try await spark.range(10).limit(2).count() == 2) + #expect(try await spark.range(10).limit(15).count() == 10) + await spark.stop() + } + + @Test + func isEmpty() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.range(0).isEmpty()) + #expect(!(try await spark.range(1).isEmpty())) + await spark.stop() + } + + @Test + func sort() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.range(10).sort("id").count() == 10) + await spark.stop() + } + + @Test + func orderBy() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.range(10).orderBy("id").count() == 10) + await spark.stop() + } + @Test func table() async throws { let spark = try await SparkSession.builder.getOrCreate()