Skip to content

[SPARK-51504] Support select/limit/sort/orderBy/isEmpty for DataFrame #16

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion Sources/SparkConnect/DataFrame.swift
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public actor DataFrame: Sendable {
/// - Parameters:
/// - spark: A ``SparkSession`` instance to use.
/// - plan: A plan to execute.
init(spark: SparkSession, plan: Plan) async throws {
init(spark: SparkSession, plan: Plan) {
self.spark = spark
self.plan = plan
}
Expand Down Expand Up @@ -192,4 +192,38 @@ public actor DataFrame: Sendable {
print(table.render())
}
}

/// Projects a set of expressions and returns a new ``DataFrame``.
/// - Parameter cols: Column names
/// - Returns: A ``DataFrame`` with subset of columns.
public func select(_ cols: String...) -> DataFrame {
return DataFrame(spark: self.spark, plan: SparkConnectClient.getProject(self.plan.root, cols))
}

/// Return a new ``DataFrame`` sorted by the specified column(s).
/// - Parameter cols: Column names.
/// - Returns: A sorted ``DataFrame``
public func sort(_ cols: String...) -> DataFrame {
return DataFrame(spark: self.spark, plan: SparkConnectClient.getSort(self.plan.root, cols))
}

/// <#Description#>
/// - Parameter cols: <#cols description#>
/// - Returns: <#description#>
public func orderBy(_ cols: String...) -> DataFrame {
return DataFrame(spark: self.spark, plan: SparkConnectClient.getSort(self.plan.root, cols))
}

/// Limits the result count to the number specified.
/// - Parameter n: Number of records to return. Will return this number of records or all records if the ``DataFrame`` contains less than this number of records.
/// - Returns: A subset of the records
public func limit(_ n: Int32) -> DataFrame {
return DataFrame(spark: self.spark, plan: SparkConnectClient.getLimit(self.plan.root, n))
}

/// Chec if the ``DataFrame`` is empty and returns a boolean value.
/// - Returns: `true` if the ``DataFrame`` is empty, `false` otherwise.
public func isEmpty() async throws -> Bool {
return try await select().limit(1).count() == 0
}
}
6 changes: 6 additions & 0 deletions Sources/SparkConnect/Extension.swift
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ extension String {
keyValue.key = self
return keyValue
}

var toUnresolvedAttribute: UnresolvedAttribute {
var attribute = UnresolvedAttribute()
attribute.unparsedIdentifier = self
return attribute
}
}

extension [String: String] {
Expand Down
43 changes: 43 additions & 0 deletions Sources/SparkConnect/SparkConnectClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -252,4 +252,47 @@ public actor SparkConnectClient {
request.analyze = .schema(schema)
return request
}

static func getProject(_ child: Relation, _ cols: [String]) -> Plan {
var project = Project()
project.input = child
let expressions: [Spark_Connect_Expression] = cols.map {
var expression = Spark_Connect_Expression()
expression.exprType = .unresolvedAttribute($0.toUnresolvedAttribute)
return expression
}
project.expressions = expressions
var relation = Relation()
relation.project = project
var plan = Plan()
plan.opType = .root(relation)
return plan
}

static func getSort(_ child: Relation, _ cols: [String]) -> Plan {
var sort = Sort()
sort.input = child
let expressions: [Spark_Connect_Expression.SortOrder] = cols.map {
var expression = Spark_Connect_Expression.SortOrder()
expression.child.exprType = .unresolvedAttribute($0.toUnresolvedAttribute)
return expression
}
sort.order = expressions
var relation = Relation()
relation.sort = sort
var plan = Plan()
plan.opType = .root(relation)
return plan
}

static func getLimit(_ child: Relation, _ n: Int32) -> Plan {
var limit = Limit()
limit.input = child
limit.limit = n
var relation = Relation()
relation.limit = limit
var plan = Plan()
plan.opType = .root(relation)
return plan
}
}
4 changes: 4 additions & 0 deletions Sources/SparkConnect/TypeAliases.swift
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,12 @@ typealias ConfigRequest = Spark_Connect_ConfigRequest
typealias DataType = Spark_Connect_DataType
typealias ExecutePlanRequest = Spark_Connect_ExecutePlanRequest
typealias Plan = Spark_Connect_Plan
typealias Project = Spark_Connect_Project
typealias KeyValue = Spark_Connect_KeyValue
typealias Limit = Spark_Connect_Limit
typealias Range = Spark_Connect_Range
typealias Relation = Spark_Connect_Relation
typealias SparkConnectService = Spark_Connect_SparkConnectService
typealias Sort = Spark_Connect_Sort
typealias UserContext = Spark_Connect_UserContext
typealias UnresolvedAttribute = Spark_Connect_Expression.UnresolvedAttribute
70 changes: 70 additions & 0 deletions Tests/SparkConnectTests/DataFrameTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,76 @@ struct DataFrameTests {
await spark.stop()
}

@Test
func selectNone() async throws {
let spark = try await SparkSession.builder.getOrCreate()
let emptySchema = try await spark.range(1).select().schema()
#expect(emptySchema == #"{"struct":{}}"#)
await spark.stop()
}

@Test
func select() async throws {
let spark = try await SparkSession.builder.getOrCreate()
let schema = try await spark.range(1).select("id").schema()
#expect(
schema
== #"{"struct":{"fields":[{"name":"id","dataType":{"long":{}}}]}}"#
)
await spark.stop()
}

@Test
func selectMultipleColumns() async throws {
let spark = try await SparkSession.builder.getOrCreate()
let schema = try await spark.sql("SELECT * FROM VALUES (1, 2)").select("col2", "col1").schema()
#expect(
schema
== #"{"struct":{"fields":[{"name":"col2","dataType":{"integer":{}}},{"name":"col1","dataType":{"integer":{}}}]}}"#
)
await spark.stop()
}

@Test
func selectInvalidColumn() async throws {
let spark = try await SparkSession.builder.getOrCreate()
try await #require(throws: Error.self) {
let _ = try await spark.range(1).select("invalid").schema()
}
await spark.stop()
}

@Test
func limit() async throws {
let spark = try await SparkSession.builder.getOrCreate()
#expect(try await spark.range(10).limit(0).count() == 0)
#expect(try await spark.range(10).limit(1).count() == 1)
#expect(try await spark.range(10).limit(2).count() == 2)
await spark.stop()
}

@Test
func isEmpty() async throws {
let spark = try await SparkSession.builder.getOrCreate()
#expect(try await spark.range(0).isEmpty())
#expect(!(try await spark.range(1).isEmpty()))
await spark.stop()
}

@Test
func sort() async throws {
let spark = try await SparkSession.builder.getOrCreate()
#expect(try await spark.range(10).sort("id").count() == 10)
await spark.stop()
}

@Test
func orderBy() async throws {
let spark = try await SparkSession.builder.getOrCreate()
#expect(try await spark.range(10).orderBy("id").count() == 10)
await spark.stop()
}

@Test
func table() async throws {
let spark = try await SparkSession.builder.getOrCreate()
Expand Down
Loading