diff --git a/Sources/SparkConnect/ArrowArray.swift b/Sources/SparkConnect/ArrowArray.swift index a767b6e..cc348ed 100644 --- a/Sources/SparkConnect/ArrowArray.swift +++ b/Sources/SparkConnect/ArrowArray.swift @@ -255,12 +255,13 @@ public class Decimal128Array: FixedArray { if self.arrowData.isNull(index) { return nil } - let scale: Int32 = switch self.arrowData.type.id { - case .decimal128(_, let scale): - scale - default: - 18 - } + let scale: Int32 = + switch self.arrowData.type.id { + case .decimal128(_, let scale): + scale + default: + 18 + } let byteOffset = self.arrowData.stride * Int(index) let value = self.arrowData.buffers[1].rawPointer.advanced(by: byteOffset).load( as: UInt64.self) diff --git a/Sources/SparkConnect/ArrowArrayBuilder.swift b/Sources/SparkConnect/ArrowArrayBuilder.swift index 20b3f27..da7074c 100644 --- a/Sources/SparkConnect/ArrowArrayBuilder.swift +++ b/Sources/SparkConnect/ArrowArrayBuilder.swift @@ -122,7 +122,8 @@ public class Time64ArrayBuilder: ArrowArrayBuilder, T } } -public class Decimal128ArrayBuilder: ArrowArrayBuilder, Decimal128Array> { +public class Decimal128ArrayBuilder: ArrowArrayBuilder, Decimal128Array> +{ fileprivate convenience init(precision: Int32, scale: Int32) throws { try self.init(ArrowTypeDecimal128(precision: precision, scale: scale)) } diff --git a/Sources/SparkConnect/ArrowType.swift b/Sources/SparkConnect/ArrowType.swift index 39555f3..a617b3a 100644 --- a/Sources/SparkConnect/ArrowType.swift +++ b/Sources/SparkConnect/ArrowType.swift @@ -294,7 +294,7 @@ public class ArrowType { case .double: return MemoryLayout.stride case .decimal128: - return 16 // Decimal 128 (= 16 * 8) bits + return 16 // Decimal 128 (= 16 * 8) bits case .boolean: return MemoryLayout.stride case .date32: @@ -429,7 +429,7 @@ extension ArrowType.Info: Equatable { case (.timeInfo(let lhsId), .timeInfo(let rhsId)): return lhsId == rhsId case (.complexInfo(let lhsId), .complexInfo(let rhsId)): - return lhsId == rhsId + return lhsId == rhsId default: return false } diff --git a/Sources/SparkConnect/Catalog.swift b/Sources/SparkConnect/Catalog.swift index 4f3c917..c1b23d4 100644 --- a/Sources/SparkConnect/Catalog.swift +++ b/Sources/SparkConnect/Catalog.swift @@ -40,15 +40,13 @@ public struct SparkTable: Sendable, Equatable { public var tableType: String public var isTemporary: Bool public var database: String? { - get { - guard let namespace else { - return nil - } - if namespace.count == 1 { - return namespace[0] - } else { - return nil - } + guard let namespace else { + return nil + } + if namespace.count == 1 { + return namespace[0] + } else { + return nil } } } @@ -173,7 +171,9 @@ public actor Catalog: Sendable { return catalog }) return try await df.collect().map { - try Database(name: $0[0] as! String, catalog: $0[1] as? String, description: $0[2] as? String, locationUri: $0[3] as! String) + try Database( + name: $0[0] as! String, catalog: $0[1] as? String, description: $0[2] as? String, + locationUri: $0[3] as! String) } } @@ -189,7 +189,9 @@ public actor Catalog: Sendable { return catalog }) return try await df.collect().map { - try Database(name: $0[0] as! String, catalog: $0[1] as? String, description: $0[2] as? String, locationUri: $0[3] as! String) + try Database( + name: $0[0] as! String, catalog: $0[1] as? String, description: $0[2] as? String, + locationUri: $0[3] as! String) }.first! } diff --git a/Sources/SparkConnect/DataFrame.swift b/Sources/SparkConnect/DataFrame.swift index db720ba..2f590de 100644 --- a/Sources/SparkConnect/DataFrame.swift +++ b/Sources/SparkConnect/DataFrame.swift @@ -494,11 +494,12 @@ public actor DataFrame: Sendable { /// - Parameter cols: Column names /// - Returns: A ``DataFrame`` with subset of columns. public func toDF(_ cols: String...) -> DataFrame { - let df = if cols.isEmpty { - DataFrame(spark: self.spark, plan: self.plan) - } else { - DataFrame(spark: self.spark, plan: SparkConnectClient.getProject(self.plan.root, cols)) - } + let df = + if cols.isEmpty { + DataFrame(spark: self.spark, plan: self.plan) + } else { + DataFrame(spark: self.spark, plan: SparkConnectClient.getProject(self.plan.root, cols)) + } return df } @@ -507,7 +508,8 @@ public actor DataFrame: Sendable { /// - Returns: A ``DataFrame`` with the given schema. public func to(_ schema: String) async throws -> DataFrame { let dataType = try await sparkSession.client.ddlParse(schema) - return DataFrame(spark: self.spark, plan: SparkConnectClient.getToSchema(self.plan.root, dataType)) + return DataFrame( + spark: self.spark, plan: SparkConnectClient.getToSchema(self.plan.root, dataType)) } /// Returns the content of the Dataset as a Dataset of JSON strings. @@ -520,7 +522,8 @@ public actor DataFrame: Sendable { /// - Parameter exprs: Expression strings /// - Returns: A ``DataFrame`` with subset of columns. public func selectExpr(_ exprs: String...) -> DataFrame { - return DataFrame(spark: self.spark, plan: SparkConnectClient.getProjectExprs(self.plan.root, exprs)) + return DataFrame( + spark: self.spark, plan: SparkConnectClient.getProjectExprs(self.plan.root, exprs)) } /// Returns a new Dataset with a column dropped. This is a no-op if schema doesn't contain column name. @@ -564,7 +567,8 @@ public actor DataFrame: Sendable { /// - Parameter statistics: Statistics names. /// - Returns: A ``DataFrame`` containing specified statistics. public func summary(_ statistics: String...) -> DataFrame { - return DataFrame(spark: self.spark, plan: SparkConnectClient.getSummary(self.plan.root, statistics)) + return DataFrame( + spark: self.spark, plan: SparkConnectClient.getSummary(self.plan.root, statistics)) } /// Returns a new Dataset with a column renamed. This is a no-op if schema doesn't contain existingName. @@ -583,14 +587,16 @@ public actor DataFrame: Sendable { /// - Returns: A ``DataFrame`` with the renamed columns. public func withColumnRenamed(_ colNames: [String], _ newColNames: [String]) -> DataFrame { let dic = Dictionary(uniqueKeysWithValues: zip(colNames, newColNames)) - return DataFrame(spark: self.spark, plan: SparkConnectClient.getWithColumnRenamed(self.plan.root, dic)) + return DataFrame( + spark: self.spark, plan: SparkConnectClient.getWithColumnRenamed(self.plan.root, dic)) } /// Returns a new Dataset with columns renamed. This is a no-op if schema doesn't contain existingName. /// - Parameter colsMap: A dictionary of existing column name and new column name. /// - Returns: A ``DataFrame`` with the renamed columns. public func withColumnRenamed(_ colsMap: [String: String]) -> DataFrame { - return DataFrame(spark: self.spark, plan: SparkConnectClient.getWithColumnRenamed(self.plan.root, colsMap)) + return DataFrame( + spark: self.spark, plan: SparkConnectClient.getWithColumnRenamed(self.plan.root, colsMap)) } /// Filters rows using the given condition. @@ -611,7 +617,8 @@ public actor DataFrame: Sendable { /// - Parameter conditionExpr: A SQL expression string for filtering /// - Returns: A new DataFrame containing only rows that match the condition public func filter(_ conditionExpr: String) -> DataFrame { - return DataFrame(spark: self.spark, plan: SparkConnectClient.getFilter(self.plan.root, conditionExpr)) + return DataFrame( + spark: self.spark, plan: SparkConnectClient.getFilter(self.plan.root, conditionExpr)) } /// Filters rows using the given condition (alias for filter). @@ -691,7 +698,9 @@ public actor DataFrame: Sendable { /// - seed: Seed for sampling. /// - Returns: A subset of the records. public func sample(_ withReplacement: Bool, _ fraction: Double, _ seed: Int64) -> DataFrame { - return DataFrame(spark: self.spark, plan: SparkConnectClient.getSample(self.plan.root, withReplacement, fraction, seed)) + return DataFrame( + spark: self.spark, + plan: SparkConnectClient.getSample(self.plan.root, withReplacement, fraction, seed)) } /// Returns a new ``Dataset`` by sampling a fraction of rows, using a random seed. @@ -765,7 +774,7 @@ public actor DataFrame: Sendable { /// - Parameter n: The number of rows. /// - Returns: ``[Row]`` public func tail(_ n: Int32) async throws -> [Row] { - let lastN = DataFrame(spark:spark, plan: SparkConnectClient.getTail(self.plan.root, n)) + let lastN = DataFrame(spark: spark, plan: SparkConnectClient.getTail(self.plan.root, n)) return try await lastN.collect() } @@ -786,7 +795,8 @@ public actor DataFrame: Sendable { public func isStreaming() async throws -> Bool { try await withGPRC { client in let service = Spark_Connect_SparkConnectService.Client(wrapping: client) - let response = try await service.analyzePlan(spark.client.getIsStreaming(spark.sessionID, plan)) + let response = try await service.analyzePlan( + spark.client.getIsStreaming(spark.sessionID, plan)) return response.isStreaming.isStreaming } } @@ -850,8 +860,10 @@ public actor DataFrame: Sendable { get async throws { try await withGPRC { client in let service = Spark_Connect_SparkConnectService.Client(wrapping: client) - return try await service - .analyzePlan(spark.client.getStorageLevel(spark.sessionID, plan)).getStorageLevel.storageLevel.toStorageLevel + return + try await service + .analyzePlan(spark.client.getStorageLevel(spark.sessionID, plan)).getStorageLevel + .storageLevel.toStorageLevel } } } @@ -878,7 +890,7 @@ public actor DataFrame: Sendable { /// Prints the plans (logical and physical) to the console for debugging purposes. /// - Parameter extended: If `false`, prints only the physical plan. public func explain(_ extended: Bool) async throws { - if (extended) { + if extended { try await explain("extended") } else { try await explain("simple") @@ -891,7 +903,8 @@ public actor DataFrame: Sendable { public func explain(_ mode: String) async throws { try await withGPRC { client in let service = Spark_Connect_SparkConnectService.Client(wrapping: client) - let response = try await service.analyzePlan(spark.client.getExplain(spark.sessionID, plan, mode)) + let response = try await service.analyzePlan( + spark.client.getExplain(spark.sessionID, plan, mode)) print(response.explain.explainString) } } @@ -903,7 +916,8 @@ public actor DataFrame: Sendable { public func inputFiles() async throws -> [String] { try await withGPRC { client in let service = Spark_Connect_SparkConnectService.Client(wrapping: client) - let response = try await service.analyzePlan(spark.client.getInputFiles(spark.sessionID, plan)) + let response = try await service.analyzePlan( + spark.client.getInputFiles(spark.sessionID, plan)) return response.inputFiles.files } } @@ -918,7 +932,8 @@ public actor DataFrame: Sendable { public func printSchema(_ level: Int32) async throws { try await withGPRC { client in let service = Spark_Connect_SparkConnectService.Client(wrapping: client) - let response = try await service.analyzePlan(spark.client.getTreeString(spark.sessionID, plan, level)) + let response = try await service.analyzePlan( + spark.client.getTreeString(spark.sessionID, plan, level)) print(response.treeString.treeString) } } @@ -964,7 +979,9 @@ public actor DataFrame: Sendable { /// - usingColumn: Column name that exists in both DataFrames /// - joinType: Type of join (default: "inner") /// - Returns: A new DataFrame with the join result - public func join(_ right: DataFrame, _ usingColumn: String, _ joinType: String = "inner") async -> DataFrame { + public func join(_ right: DataFrame, _ usingColumn: String, _ joinType: String = "inner") async + -> DataFrame + { await join(right, [usingColumn], joinType) } @@ -974,7 +991,9 @@ public actor DataFrame: Sendable { /// - usingColumn: Names of the columns to join on. These columns must exist on both sides. /// - joinType: A join type name. /// - Returns: A `DataFrame`. - public func join(_ other: DataFrame, _ usingColumns: [String], _ joinType: String = "inner") async -> DataFrame { + public func join(_ other: DataFrame, _ usingColumns: [String], _ joinType: String = "inner") async + -> DataFrame + { let right = await (other.getPlan() as! Plan).root let plan = SparkConnectClient.getJoin( self.plan.root, @@ -1112,7 +1131,8 @@ public actor DataFrame: Sendable { /// - Returns: A `DataFrame`. public func exceptAll(_ other: DataFrame) async -> DataFrame { let right = await (other.getPlan() as! Plan).root - let plan = SparkConnectClient.getSetOperation(self.plan.root, right, SetOpType.except, isAll: true) + let plan = SparkConnectClient.getSetOperation( + self.plan.root, right, SetOpType.except, isAll: true) return DataFrame(spark: self.spark, plan: plan) } @@ -1132,7 +1152,8 @@ public actor DataFrame: Sendable { /// - Returns: A `DataFrame`. public func intersectAll(_ other: DataFrame) async -> DataFrame { let right = await (other.getPlan() as! Plan).root - let plan = SparkConnectClient.getSetOperation(self.plan.root, right, SetOpType.intersect, isAll: true) + let plan = SparkConnectClient.getSetOperation( + self.plan.root, right, SetOpType.intersect, isAll: true) return DataFrame(spark: self.spark, plan: plan) } @@ -1144,7 +1165,8 @@ public actor DataFrame: Sendable { /// - Returns: A `DataFrame`. public func union(_ other: DataFrame) async -> DataFrame { let right = await (other.getPlan() as! Plan).root - let plan = SparkConnectClient.getSetOperation(self.plan.root, right, SetOpType.union, isAll: true) + let plan = SparkConnectClient.getSetOperation( + self.plan.root, right, SetOpType.union, isAll: true) return DataFrame(spark: self.spark, plan: plan) } @@ -1164,7 +1186,9 @@ public actor DataFrame: Sendable { /// of this `DataFrame` will be added at the end in the schema of the union result /// - Parameter other: A `DataFrame` to union with. /// - Returns: A `DataFrame`. - public func unionByName(_ other: DataFrame, _ allowMissingColumns: Bool = false) async -> DataFrame { + public func unionByName(_ other: DataFrame, _ allowMissingColumns: Bool = false) async + -> DataFrame + { let right = await (other.getPlan() as! Plan).root let plan = SparkConnectClient.getSetOperation( self.plan.root, @@ -1182,8 +1206,11 @@ public actor DataFrame: Sendable { return DataFrame(spark: self.spark, plan: plan) } - private func buildRepartitionByExpression(numPartitions: Int32?, partitionExprs: [String]) -> DataFrame { - let plan = SparkConnectClient.getRepartitionByExpression(self.plan.root, partitionExprs, numPartitions) + private func buildRepartitionByExpression(numPartitions: Int32?, partitionExprs: [String]) + -> DataFrame + { + let plan = SparkConnectClient.getRepartitionByExpression( + self.plan.root, partitionExprs, numPartitions) return DataFrame(spark: self.spark, plan: plan) } @@ -1211,7 +1238,8 @@ public actor DataFrame: Sendable { /// - partitionExprs: The partition expression strings. /// - Returns: A `DataFrame`. public func repartition(_ numPartitions: Int32, _ partitionExprs: String...) -> DataFrame { - return buildRepartitionByExpression(numPartitions: numPartitions, partitionExprs: partitionExprs) + return buildRepartitionByExpression( + numPartitions: numPartitions, partitionExprs: partitionExprs) } /// Returns a new ``DataFrame`` partitioned by the given partitioning expressions, using @@ -1219,8 +1247,11 @@ public actor DataFrame: Sendable { /// partitioned. /// - Parameter partitionExprs: The partition expression strings. /// - Returns: A `DataFrame`. - public func repartitionByExpression(_ numPartitions: Int32?, _ partitionExprs: String...) -> DataFrame { - return buildRepartitionByExpression(numPartitions: numPartitions, partitionExprs: partitionExprs) + public func repartitionByExpression(_ numPartitions: Int32?, _ partitionExprs: String...) + -> DataFrame + { + return buildRepartitionByExpression( + numPartitions: numPartitions, partitionExprs: partitionExprs) } /// Returns a new ``DataFrame`` that has exactly `numPartitions` partitions, when the fewer partitions @@ -1322,7 +1353,8 @@ public actor DataFrame: Sendable { _ variableColumnName: String, _ valueColumnName: String ) -> DataFrame { - let plan = SparkConnectClient.getUnpivot(self.plan.root, ids, values, variableColumnName, valueColumnName) + let plan = SparkConnectClient.getUnpivot( + self.plan.root, ids, values, variableColumnName, valueColumnName) return DataFrame(spark: self.spark, plan: plan) } @@ -1421,7 +1453,8 @@ public actor DataFrame: Sendable { } func createTempView(_ viewName: String, replace: Bool, global: Bool) async throws { - try await spark.client.createTempView(self.plan.root, viewName, replace: replace, isGlobal: global) + try await spark.client.createTempView( + self.plan.root, viewName, replace: replace, isGlobal: global) } /// Eagerly checkpoint a ``DataFrame`` and return the new ``DataFrame``. @@ -1439,7 +1472,8 @@ public actor DataFrame: Sendable { _ reliableCheckpoint: Bool = true, _ storageLevel: StorageLevel? = nil ) async throws -> DataFrame { - let plan = try await spark.client.getCheckpoint(self.plan.root, eager, reliableCheckpoint, storageLevel) + let plan = try await spark.client.getCheckpoint( + self.plan.root, eager, reliableCheckpoint, storageLevel) return DataFrame(spark: self.spark, plan: plan) } @@ -1474,9 +1508,7 @@ public actor DataFrame: Sendable { /// Returns a ``DataFrameWriter`` that can be used to write non-streaming data. public var write: DataFrameWriter { - get { - DataFrameWriter(df: self) - } + DataFrameWriter(df: self) } /// Create a write configuration builder for v2 sources. @@ -1485,7 +1517,7 @@ public actor DataFrame: Sendable { public func writeTo(_ table: String) -> DataFrameWriterV2 { return DataFrameWriterV2(table, self) } - + /// Merges a set of updates, insertions, and deletions based on a source table into a target table. /// - Parameters: /// - table: A target table name. @@ -1497,8 +1529,6 @@ public actor DataFrame: Sendable { /// Returns a ``DataStreamWriter`` that can be used to write streaming data. public var writeStream: DataStreamWriter { - get { - DataStreamWriter(df: self) - } + DataStreamWriter(df: self) } } diff --git a/Sources/SparkConnect/DataFrameReader.swift b/Sources/SparkConnect/DataFrameReader.swift index 274efdf..9c2076e 100644 --- a/Sources/SparkConnect/DataFrameReader.swift +++ b/Sources/SparkConnect/DataFrameReader.swift @@ -261,7 +261,9 @@ public actor DataFrameReader: Sendable { /// - table: The JDBC table that should be read from or written into. /// - properties: A string-string dictionary for connection properties. /// - Returns: A `DataFrame`. - public func jdbc(_ url: String, _ table: String, _ properties: [String: String] = [:]) -> DataFrame { + public func jdbc(_ url: String, _ table: String, _ properties: [String: String] = [:]) + -> DataFrame + { for (key, value) in properties { self.extraOptions[key] = value } diff --git a/Sources/SparkConnect/DataFrameWriter.swift b/Sources/SparkConnect/DataFrameWriter.swift index 11a5fa8..38492ac 100644 --- a/Sources/SparkConnect/DataFrameWriter.swift +++ b/Sources/SparkConnect/DataFrameWriter.swift @@ -236,7 +236,9 @@ public actor DataFrameWriter: Sendable { /// - url: The JDBC URL of the form `jdbc:subprotocol:subname` to connect to. /// - table: Name of the table in the external database. /// - properties:JDBC database connection arguments, a list of arbitrary string tag/value. - public func jdbc(_ url: String, _ table: String, _ properties: [String: String] = [:]) async throws { + public func jdbc(_ url: String, _ table: String, _ properties: [String: String] = [:]) + async throws + { for (key, value) in properties { self.extraOptions[key] = value } diff --git a/Sources/SparkConnect/Extension.swift b/Sources/SparkConnect/Extension.swift index f7d869e..4307a94 100644 --- a/Sources/SparkConnect/Extension.swift +++ b/Sources/SparkConnect/Extension.swift @@ -135,14 +135,15 @@ extension String { } var toExplainMode: ExplainMode { - let mode = switch self { - case "codegen": ExplainMode.codegen - case "cost": ExplainMode.cost - case "extended": ExplainMode.extended - case "formatted": ExplainMode.formatted - case "simple": ExplainMode.simple - default: ExplainMode.simple - } + let mode = + switch self { + case "codegen": ExplainMode.codegen + case "cost": ExplainMode.cost + case "extended": ExplainMode.extended + case "formatted": ExplainMode.formatted + case "simple": ExplainMode.simple + default: ExplainMode.simple + } return mode } @@ -220,13 +221,14 @@ extension YearMonthInterval { func toString() throws -> String { let startFieldName = try fieldToString(self.startField) let endFieldName = try fieldToString(self.endField) - let interval = if startFieldName == endFieldName { - "interval \(startFieldName)" - } else if startFieldName < endFieldName { - "interval \(startFieldName) to \(endFieldName)" - } else { - throw SparkConnectError.InvalidType - } + let interval = + if startFieldName == endFieldName { + "interval \(startFieldName)" + } else if startFieldName < endFieldName { + "interval \(startFieldName) to \(endFieldName)" + } else { + throw SparkConnectError.InvalidType + } return interval } } @@ -246,13 +248,14 @@ extension DayTimeInterval { func toString() throws -> String { let startFieldName = try fieldToString(self.startField) let endFieldName = try fieldToString(self.endField) - let interval = if startFieldName == endFieldName { - "interval \(startFieldName)" - } else if startFieldName < endFieldName { - "interval \(startFieldName) to \(endFieldName)" - } else { - throw SparkConnectError.InvalidType - } + let interval = + if startFieldName == endFieldName { + "interval \(startFieldName)" + } else if startFieldName < endFieldName { + "interval \(startFieldName) to \(endFieldName)" + } else { + throw SparkConnectError.InvalidType + } return interval } } diff --git a/Sources/SparkConnect/SparkConnectClient.swift b/Sources/SparkConnect/SparkConnectClient.swift index c1c9bd1..208601f 100644 --- a/Sources/SparkConnect/SparkConnectClient.swift +++ b/Sources/SparkConnect/SparkConnectClient.swift @@ -44,11 +44,11 @@ public actor SparkConnectClient { self.port = self.url.port ?? 15002 var token: String? = nil let processInfo = ProcessInfo.processInfo -#if os(macOS) || os(Linux) - var userName = processInfo.environment["SPARK_USER"] ?? processInfo.userName -#else - var userName = processInfo.environment["SPARK_USER"] ?? "" -#endif + #if os(macOS) || os(Linux) + var userName = processInfo.environment["SPARK_USER"] ?? processInfo.userName + #else + var userName = processInfo.environment["SPARK_USER"] ?? "" + #endif for param in self.url.path.split(separator: ";").dropFirst().filter({ !$0.isEmpty }) { let kv = param.split(separator: "=") switch String(kv[0]).lowercased() { @@ -109,9 +109,11 @@ public actor SparkConnectClient { self.sessionID = sessionID let service = SparkConnectService.Client(wrapping: client) - let request = analyze(self.sessionID!, { - return OneOf_Analyze.sparkVersion(AnalyzePlanRequest.SparkVersion()) - }) + let request = analyze( + self.sessionID!, + { + return OneOf_Analyze.sparkVersion(AnalyzePlanRequest.SparkVersion()) + }) let response = try await service.analyzePlan(request) return response } @@ -193,7 +195,7 @@ public actor SparkConnectClient { request.operation.opType = .unset(unset) return request } - + /// Request the server to unset keys /// - Parameter keys: An array of keys /// - Returns: Always return true @@ -263,11 +265,12 @@ public actor SparkConnectClient { request.userContext = userContext request.sessionID = self.sessionID! let response = try await service.config(request) - let result = if response.pairs[0].hasValue { - response.pairs[0].value - } else { - value - } + let result = + if response.pairs[0].hasValue { + response.pairs[0].value + } else { + value + } return result } } @@ -295,11 +298,12 @@ public actor SparkConnectClient { request.userContext = userContext request.sessionID = self.sessionID! let response = try await service.config(request) - let result: String? = if response.pairs[0].hasValue { - response.pairs[0].value - } else { - nil - } + let result: String? = + if response.pairs[0].hasValue { + response.pairs[0].value + } else { + nil + } return result } } @@ -414,11 +418,13 @@ public actor SparkConnectClient { func getAnalyzePlanRequest(_ sessionID: String, _ plan: Plan) async -> AnalyzePlanRequest { - return analyze(sessionID, { - var schema = AnalyzePlanRequest.Schema() - schema.plan = plan - return OneOf_Analyze.schema(schema) - }) + return analyze( + sessionID, + { + var schema = AnalyzePlanRequest.Schema() + schema.plan = plan + return OneOf_Analyze.schema(schema) + }) } private func analyze(_ sessionID: String, _ f: () -> OneOf_Analyze) -> AnalyzePlanRequest { @@ -456,8 +462,7 @@ public actor SparkConnectClient { }) } - func getStorageLevel(_ sessionID: String, _ plan: Plan) async -> AnalyzePlanRequest - { + func getStorageLevel(_ sessionID: String, _ plan: Plan) async -> AnalyzePlanRequest { return analyze( sessionID, { @@ -467,8 +472,7 @@ public actor SparkConnectClient { }) } - func getExplain(_ sessionID: String, _ plan: Plan, _ mode: String) async -> AnalyzePlanRequest - { + func getExplain(_ sessionID: String, _ plan: Plan, _ mode: String) async -> AnalyzePlanRequest { return analyze( sessionID, { @@ -479,8 +483,7 @@ public actor SparkConnectClient { }) } - func getInputFiles(_ sessionID: String, _ plan: Plan) async -> AnalyzePlanRequest - { + func getInputFiles(_ sessionID: String, _ plan: Plan) async -> AnalyzePlanRequest { return analyze( sessionID, { @@ -670,7 +673,9 @@ public actor SparkConnectClient { return plan } - static func getSample(_ child: Relation, _ withReplacement: Bool, _ fraction: Double, _ seed: Int64) -> Plan { + static func getSample( + _ child: Relation, _ withReplacement: Bool, _ fraction: Double, _ seed: Int64 + ) -> Plan { var sample = Sample() sample.input = child sample.withReplacement = withReplacement @@ -762,9 +767,10 @@ public actor SparkConnectClient { addArtifactsRequest.clientType = self.clientType addArtifactsRequest.batch = batch let request = addArtifactsRequest - _ = try await service.addArtifacts(request: StreamingClientRequest { x in - try await x.write(contentsOf: [request]) - }) + _ = try await service.addArtifacts( + request: StreamingClientRequest { x in + try await x.write(contentsOf: [request]) + }) } } @@ -846,11 +852,13 @@ public actor SparkConnectClient { func ddlParse(_ ddlString: String) async throws -> Spark_Connect_DataType { try await withGPRC { client in let service = SparkConnectService.Client(wrapping: client) - let request = analyze(self.sessionID!, { - var ddlParse = AnalyzePlanRequest.DDLParse() - ddlParse.ddlString = ddlString - return OneOf_Analyze.ddlParse(ddlParse) - }) + let request = analyze( + self.sessionID!, + { + var ddlParse = AnalyzePlanRequest.DDLParse() + ddlParse.ddlString = ddlString + return OneOf_Analyze.ddlParse(ddlParse) + }) do { let response = try await service.analyzePlan(request) return response.ddlParse.parsed @@ -871,11 +879,13 @@ public actor SparkConnectClient { func jsonToDdl(_ jsonString: String) async throws -> String { try await withGPRC { client in let service = SparkConnectService.Client(wrapping: client) - let request = analyze(self.sessionID!, { - var jsonToDDL = AnalyzePlanRequest.JsonToDDL() - jsonToDDL.jsonString = jsonString - return OneOf_Analyze.jsonToDdl(jsonToDDL) - }) + let request = analyze( + self.sessionID!, + { + var jsonToDDL = AnalyzePlanRequest.JsonToDDL() + jsonToDDL.jsonString = jsonString + return OneOf_Analyze.jsonToDdl(jsonToDDL) + }) let response = try await service.analyzePlan(request) return response.jsonToDdl.ddlString } @@ -884,12 +894,14 @@ public actor SparkConnectClient { func sameSemantics(_ plan: Plan, _ otherPlan: Plan) async throws -> Bool { try await withGPRC { client in let service = SparkConnectService.Client(wrapping: client) - let request = analyze(self.sessionID!, { - var sameSemantics = AnalyzePlanRequest.SameSemantics() - sameSemantics.targetPlan = plan - sameSemantics.otherPlan = otherPlan - return OneOf_Analyze.sameSemantics(sameSemantics) - }) + let request = analyze( + self.sessionID!, + { + var sameSemantics = AnalyzePlanRequest.SameSemantics() + sameSemantics.targetPlan = plan + sameSemantics.otherPlan = otherPlan + return OneOf_Analyze.sameSemantics(sameSemantics) + }) let response = try await service.analyzePlan(request) return response.sameSemantics.result } @@ -898,11 +910,13 @@ public actor SparkConnectClient { func semanticHash(_ plan: Plan) async throws -> Int32 { try await withGPRC { client in let service = SparkConnectService.Client(wrapping: client) - let request = analyze(self.sessionID!, { - var semanticHash = AnalyzePlanRequest.SemanticHash() - semanticHash.plan = plan - return OneOf_Analyze.semanticHash(semanticHash) - }) + let request = analyze( + self.sessionID!, + { + var semanticHash = AnalyzePlanRequest.SemanticHash() + semanticHash.plan = plan + return OneOf_Analyze.semanticHash(semanticHash) + }) let response = try await service.analyzePlan(request) return response.semanticHash.result } @@ -986,7 +1000,9 @@ public actor SparkConnectClient { }) } - static func getRepartition(_ child: Relation, _ numPartitions: Int32, _ shuffle: Bool = false) -> Plan { + static func getRepartition(_ child: Relation, _ numPartitions: Int32, _ shuffle: Bool = false) + -> Plan + { var repartition = Repartition() repartition.input = child repartition.numPartitions = numPartitions @@ -1064,7 +1080,7 @@ public actor SparkConnectClient { literal.short = Int32(value) case let value as Int32: literal.integer = value - case let value as Int64: // Hint parameter raises exceptions for Int64 + case let value as Int64: // Hint parameter raises exceptions for Int64 literal.integer = Int32(value) case let value as Int: literal.integer = Int32(value) diff --git a/Sources/SparkConnect/SparkSession.swift b/Sources/SparkConnect/SparkSession.swift index 7e7326c..5203cba 100644 --- a/Sources/SparkConnect/SparkSession.swift +++ b/Sources/SparkConnect/SparkSession.swift @@ -89,16 +89,14 @@ public actor SparkSession { /// Interface through which the user may create, drop, alter or query underlying databases, tables, functions etc. public var catalog: Catalog { - get { - return Catalog(spark: self) - } + return Catalog(spark: self) } /// Stop the current client. public func stop() async { await client.stop() } - + /// Returns a ``DataFrame`` with no rows or columns. public var emptyDataFrame: DataFrame { get async { @@ -222,9 +220,7 @@ public actor SparkSession { /// /// - Returns: A ``DataFrameReader`` instance configured for this session public var read: DataFrameReader { - get { - DataFrameReader(sparkSession: self) - } + DataFrameReader(sparkSession: self) } /// Returns a ``DataStreamReader`` that can be used to read streaming data in as a ``DataFrame``. @@ -239,9 +235,7 @@ public actor SparkSession { /// /// - Returns: A ``DataFrameReader`` instance configured for this session public var readStream: DataStreamReader { - get { - DataStreamReader(sparkSession: self) - } + DataStreamReader(sparkSession: self) } /// Returns a ``DataFrame`` representing the specified table or view. @@ -337,11 +331,11 @@ public actor SparkSession { /// ```swift /// // Add a tag for a specific operation /// try await spark.addTag("etl_job_2024") - /// + /// /// // Perform operations that will be tagged /// let df = try await spark.sql("SELECT * FROM source_table") /// try await df.write.saveAsTable("processed_table") - /// + /// /// // Remove the tag when done /// try await spark.removeTag("etl_job_2024") /// ``` @@ -422,9 +416,7 @@ public actor SparkSession { /// Returns a `StreamingQueryManager` that allows managing all the `StreamingQuery`s active on /// `this`. public var streams: StreamingQueryManager { - get { - StreamingQueryManager(self) - } + StreamingQueryManager(self) } /// This is defined as the return type of `SparkSession.sparkContext` method. diff --git a/Tests/SparkConnectTests/CatalogTests.swift b/Tests/SparkConnectTests/CatalogTests.swift index 0888fdd..24ae1f6 100644 --- a/Tests/SparkConnectTests/CatalogTests.swift +++ b/Tests/SparkConnectTests/CatalogTests.swift @@ -25,288 +25,297 @@ import Testing /// A test suite for `Catalog` @Suite(.serialized) struct CatalogTests { -#if !os(Linux) - @Test - func currentCatalog() async throws { - let spark = try await SparkSession.builder.getOrCreate() - #expect(try await spark.catalog.currentCatalog() == "spark_catalog") - await spark.stop() - } + #if !os(Linux) + @Test + func currentCatalog() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.catalog.currentCatalog() == "spark_catalog") + await spark.stop() + } - @Test - func setCurrentCatalog() async throws { - let spark = try await SparkSession.builder.getOrCreate() - try await spark.catalog.setCurrentCatalog("spark_catalog") - if await spark.version >= "4.0.0" { - try await #require(throws: SparkConnectError.CatalogNotFound) { - try await spark.catalog.setCurrentCatalog("not_exist_catalog") - } - } else { - try await #require(throws: Error.self) { - try await spark.catalog.setCurrentCatalog("not_exist_catalog") + @Test + func setCurrentCatalog() async throws { + let spark = try await SparkSession.builder.getOrCreate() + try await spark.catalog.setCurrentCatalog("spark_catalog") + if await spark.version >= "4.0.0" { + try await #require(throws: SparkConnectError.CatalogNotFound) { + try await spark.catalog.setCurrentCatalog("not_exist_catalog") + } + } else { + try await #require(throws: Error.self) { + try await spark.catalog.setCurrentCatalog("not_exist_catalog") + } } + await spark.stop() } - await spark.stop() - } - - @Test - func listCatalogs() async throws { - let spark = try await SparkSession.builder.getOrCreate() - #expect(try await spark.catalog.listCatalogs() == [CatalogMetadata(name: "spark_catalog")]) - #expect(try await spark.catalog.listCatalogs(pattern: "*") == [CatalogMetadata(name: "spark_catalog")]) - #expect(try await spark.catalog.listCatalogs(pattern: "non_exist").count == 0) - await spark.stop() - } - @Test - func currentDatabase() async throws { - let spark = try await SparkSession.builder.getOrCreate() - #expect(try await spark.catalog.currentDatabase() == "default") - await spark.stop() - } + @Test + func listCatalogs() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.catalog.listCatalogs() == [CatalogMetadata(name: "spark_catalog")]) + #expect( + try await spark.catalog.listCatalogs(pattern: "*") == [ + CatalogMetadata(name: "spark_catalog") + ]) + #expect(try await spark.catalog.listCatalogs(pattern: "non_exist").count == 0) + await spark.stop() + } - @Test - func setCurrentDatabase() async throws { - let spark = try await SparkSession.builder.getOrCreate() - try await spark.catalog.setCurrentDatabase("default") - try await #require(throws: SparkConnectError.SchemaNotFound) { - try await spark.catalog.setCurrentDatabase("not_exist_database") + @Test + func currentDatabase() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.catalog.currentDatabase() == "default") + await spark.stop() } - await spark.stop() - } - @Test - func listDatabases() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let dbs = try await spark.catalog.listDatabases() - #expect(dbs.count == 1) - #expect(dbs[0].name == "default") - #expect(dbs[0].catalog == "spark_catalog") - #expect(dbs[0].description == "default database") - #expect(dbs[0].locationUri.hasSuffix("spark-warehouse")) - #expect(try await spark.catalog.listDatabases(pattern: "*") == dbs) - #expect(try await spark.catalog.listDatabases(pattern: "non_exist").count == 0) - await spark.stop() - } + @Test + func setCurrentDatabase() async throws { + let spark = try await SparkSession.builder.getOrCreate() + try await spark.catalog.setCurrentDatabase("default") + try await #require(throws: SparkConnectError.SchemaNotFound) { + try await spark.catalog.setCurrentDatabase("not_exist_database") + } + await spark.stop() + } - @Test - func getDatabase() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let db = try await spark.catalog.getDatabase("default") - #expect(db.name == "default") - #expect(db.catalog == "spark_catalog") - #expect(db.description == "default database") - #expect(db.locationUri.hasSuffix("spark-warehouse")) - try await #require(throws: SparkConnectError.SchemaNotFound) { - try await spark.catalog.getDatabase("not_exist_database") + @Test + func listDatabases() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let dbs = try await spark.catalog.listDatabases() + #expect(dbs.count == 1) + #expect(dbs[0].name == "default") + #expect(dbs[0].catalog == "spark_catalog") + #expect(dbs[0].description == "default database") + #expect(dbs[0].locationUri.hasSuffix("spark-warehouse")) + #expect(try await spark.catalog.listDatabases(pattern: "*") == dbs) + #expect(try await spark.catalog.listDatabases(pattern: "non_exist").count == 0) + await spark.stop() } - await spark.stop() - } - @Test - func databaseExists() async throws { - let spark = try await SparkSession.builder.getOrCreate() - #expect(try await spark.catalog.databaseExists("default")) + @Test + func getDatabase() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let db = try await spark.catalog.getDatabase("default") + #expect(db.name == "default") + #expect(db.catalog == "spark_catalog") + #expect(db.description == "default database") + #expect(db.locationUri.hasSuffix("spark-warehouse")) + try await #require(throws: SparkConnectError.SchemaNotFound) { + try await spark.catalog.getDatabase("not_exist_database") + } + await spark.stop() + } - let dbName = "DB_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") - #expect(try await spark.catalog.databaseExists(dbName) == false) - try await SQLHelper.withDatabase(spark, dbName) ({ - try await spark.sql("CREATE DATABASE \(dbName)").count() - #expect(try await spark.catalog.databaseExists(dbName)) - }) - #expect(try await spark.catalog.databaseExists(dbName) == false) - await spark.stop() - } + @Test + func databaseExists() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.catalog.databaseExists("default")) + + let dbName = "DB_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") + #expect(try await spark.catalog.databaseExists(dbName) == false) + try await SQLHelper.withDatabase(spark, dbName)({ + try await spark.sql("CREATE DATABASE \(dbName)").count() + #expect(try await spark.catalog.databaseExists(dbName)) + }) + #expect(try await spark.catalog.databaseExists(dbName) == false) + await spark.stop() + } - @Test - func createTable() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") - try await SQLHelper.withTable(spark, tableName)({ - try await spark.range(1).write.orc("/tmp/\(tableName)") - #expect(try await spark.catalog.createTable(tableName, "/tmp/\(tableName)", source: "orc").count() == 1) - #expect(try await spark.catalog.tableExists(tableName)) - }) - await spark.stop() - } + @Test + func createTable() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") + try await SQLHelper.withTable(spark, tableName)({ + try await spark.range(1).write.orc("/tmp/\(tableName)") + #expect( + try await spark.catalog.createTable(tableName, "/tmp/\(tableName)", source: "orc").count() + == 1) + #expect(try await spark.catalog.tableExists(tableName)) + }) + await spark.stop() + } - @Test - func tableExists() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") - try await SQLHelper.withTable(spark, tableName)({ - try await spark.range(1).write.parquet("/tmp/\(tableName)") + @Test + func tableExists() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") + try await SQLHelper.withTable(spark, tableName)({ + try await spark.range(1).write.parquet("/tmp/\(tableName)") + #expect(try await spark.catalog.tableExists(tableName) == false) + #expect(try await spark.catalog.createTable(tableName, "/tmp/\(tableName)").count() == 1) + #expect(try await spark.catalog.tableExists(tableName)) + #expect(try await spark.catalog.tableExists("default", tableName)) + #expect(try await spark.catalog.tableExists("default2", tableName) == false) + }) #expect(try await spark.catalog.tableExists(tableName) == false) - #expect(try await spark.catalog.createTable(tableName, "/tmp/\(tableName)").count() == 1) - #expect(try await spark.catalog.tableExists(tableName)) - #expect(try await spark.catalog.tableExists("default", tableName)) - #expect(try await spark.catalog.tableExists("default2", tableName) == false) - }) - #expect(try await spark.catalog.tableExists(tableName) == false) - try await #require(throws: SparkConnectError.ParseSyntaxError) { - try await spark.catalog.tableExists("invalid table name") + try await #require(throws: SparkConnectError.ParseSyntaxError) { + try await spark.catalog.tableExists("invalid table name") + } + await spark.stop() } - await spark.stop() - } - @Test - func listColumns() async throws { - let spark = try await SparkSession.builder.getOrCreate() + @Test + func listColumns() async throws { + let spark = try await SparkSession.builder.getOrCreate() + + // Table + let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") + let path = "/tmp/\(tableName)" + try await SQLHelper.withTable(spark, tableName)({ + try await spark.range(2).write.orc(path) + let expected = + if await spark.version.starts(with: "4.") { + [Row("id", nil, "bigint", true, false, false, false)] + } else { + [Row("id", nil, "bigint", true, false, false)] + } + #expect(try await spark.catalog.createTable(tableName, path, source: "orc").count() == 2) + #expect(try await spark.catalog.listColumns(tableName).collect() == expected) + #expect(try await spark.catalog.listColumns("default.\(tableName)").collect() == expected) + }) + + // View + let viewName = "VIEW_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") + try await SQLHelper.withTempView(spark, viewName)({ + try await spark.range(1).createTempView(viewName) + let expected = + if await spark.version.starts(with: "4.") { + [Row("id", nil, "bigint", false, false, false, false)] + } else { + [Row("id", nil, "bigint", false, false, false)] + } + #expect(try await spark.catalog.listColumns(viewName).collect() == expected) + }) + + await spark.stop() + } - // Table - let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") - let path = "/tmp/\(tableName)" - try await SQLHelper.withTable(spark, tableName)({ - try await spark.range(2).write.orc(path) - let expected = if await spark.version.starts(with: "4.") { - [Row("id", nil, "bigint", true, false, false, false)] - } else { - [Row("id", nil, "bigint", true, false, false)] - } - #expect(try await spark.catalog.createTable(tableName, path, source: "orc").count() == 2) - #expect(try await spark.catalog.listColumns(tableName).collect() == expected) - #expect(try await spark.catalog.listColumns("default.\(tableName)").collect() == expected) - }) + @Test + func functionExists() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.catalog.functionExists("base64")) + #expect(try await spark.catalog.functionExists("non_exist_function") == false) - // View - let viewName = "VIEW_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") - try await SQLHelper.withTempView(spark, viewName)({ - try await spark.range(1).createTempView(viewName) - let expected = if await spark.version.starts(with: "4.") { - [Row("id", nil, "bigint", false, false, false, false)] - } else { - [Row("id", nil, "bigint", false, false, false)] + try await #require(throws: SparkConnectError.ParseSyntaxError) { + try await spark.catalog.functionExists("invalid function name") } - #expect(try await spark.catalog.listColumns(viewName).collect() == expected) - }) - - await spark.stop() - } - - @Test - func functionExists() async throws { - let spark = try await SparkSession.builder.getOrCreate() - #expect(try await spark.catalog.functionExists("base64")) - #expect(try await spark.catalog.functionExists("non_exist_function") == false) - - try await #require(throws: SparkConnectError.ParseSyntaxError) { - try await spark.catalog.functionExists("invalid function name") + await spark.stop() } - await spark.stop() - } - @Test - func createTempView() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let viewName = "VIEW_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") - try await SQLHelper.withTempView(spark, viewName)({ - #expect(try await spark.catalog.tableExists(viewName) == false) - try await spark.range(1).createTempView(viewName) - #expect(try await spark.catalog.tableExists(viewName)) - - try await #require(throws: SparkConnectError.TableOrViewAlreadyExists) { + @Test + func createTempView() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let viewName = "VIEW_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") + try await SQLHelper.withTempView(spark, viewName)({ + #expect(try await spark.catalog.tableExists(viewName) == false) try await spark.range(1).createTempView(viewName) + #expect(try await spark.catalog.tableExists(viewName)) + + try await #require(throws: SparkConnectError.TableOrViewAlreadyExists) { + try await spark.range(1).createTempView(viewName) + } + }) + + try await #require(throws: SparkConnectError.InvalidViewName) { + try await spark.range(1).createTempView("invalid view name") } - }) - try await #require(throws: SparkConnectError.InvalidViewName) { - try await spark.range(1).createTempView("invalid view name") + await spark.stop() } - await spark.stop() - } - - @Test - func createOrReplaceTempView() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let viewName = "VIEW_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") - try await SQLHelper.withTempView(spark, viewName)({ - #expect(try await spark.catalog.tableExists(viewName) == false) - try await spark.range(1).createOrReplaceTempView(viewName) - #expect(try await spark.catalog.tableExists(viewName)) - try await spark.range(1).createOrReplaceTempView(viewName) - }) + @Test + func createOrReplaceTempView() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let viewName = "VIEW_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") + try await SQLHelper.withTempView(spark, viewName)({ + #expect(try await spark.catalog.tableExists(viewName) == false) + try await spark.range(1).createOrReplaceTempView(viewName) + #expect(try await spark.catalog.tableExists(viewName)) + try await spark.range(1).createOrReplaceTempView(viewName) + }) + + try await #require(throws: SparkConnectError.InvalidViewName) { + try await spark.range(1).createOrReplaceTempView("invalid view name") + } - try await #require(throws: SparkConnectError.InvalidViewName) { - try await spark.range(1).createOrReplaceTempView("invalid view name") + await spark.stop() } - await spark.stop() - } + @Test + func createGlobalTempView() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let viewName = "VIEW_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") + try await SQLHelper.withGlobalTempView(spark, viewName)({ + #expect(try await spark.catalog.tableExists("global_temp.\(viewName)") == false) + try await spark.range(1).createGlobalTempView(viewName) + #expect(try await spark.catalog.tableExists("global_temp.\(viewName)")) - @Test - func createGlobalTempView() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let viewName = "VIEW_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") - try await SQLHelper.withGlobalTempView(spark, viewName)({ + try await #require(throws: SparkConnectError.TableOrViewAlreadyExists) { + try await spark.range(1).createGlobalTempView(viewName) + } + }) #expect(try await spark.catalog.tableExists("global_temp.\(viewName)") == false) - try await spark.range(1).createGlobalTempView(viewName) - #expect(try await spark.catalog.tableExists("global_temp.\(viewName)")) - try await #require(throws: SparkConnectError.TableOrViewAlreadyExists) { - try await spark.range(1).createGlobalTempView(viewName) + try await #require(throws: SparkConnectError.InvalidViewName) { + try await spark.range(1).createGlobalTempView("invalid view name") } - }) - #expect(try await spark.catalog.tableExists("global_temp.\(viewName)") == false) - try await #require(throws: SparkConnectError.InvalidViewName) { - try await spark.range(1).createGlobalTempView("invalid view name") + await spark.stop() } - await spark.stop() - } - - @Test - func createOrReplaceGlobalTempView() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let viewName = "VIEW_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") - try await SQLHelper.withGlobalTempView(spark, viewName)({ + @Test + func createOrReplaceGlobalTempView() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let viewName = "VIEW_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") + try await SQLHelper.withGlobalTempView(spark, viewName)({ + #expect(try await spark.catalog.tableExists("global_temp.\(viewName)") == false) + try await spark.range(1).createOrReplaceGlobalTempView(viewName) + #expect(try await spark.catalog.tableExists("global_temp.\(viewName)")) + try await spark.range(1).createOrReplaceGlobalTempView(viewName) + }) #expect(try await spark.catalog.tableExists("global_temp.\(viewName)") == false) - try await spark.range(1).createOrReplaceGlobalTempView(viewName) - #expect(try await spark.catalog.tableExists("global_temp.\(viewName)")) - try await spark.range(1).createOrReplaceGlobalTempView(viewName) - }) - #expect(try await spark.catalog.tableExists("global_temp.\(viewName)") == false) - - try await #require(throws: SparkConnectError.InvalidViewName) { - try await spark.range(1).createOrReplaceGlobalTempView("invalid view name") - } - await spark.stop() - } + try await #require(throws: SparkConnectError.InvalidViewName) { + try await spark.range(1).createOrReplaceGlobalTempView("invalid view name") + } - @Test - func dropTempView() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let viewName = "VIEW_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") - try await SQLHelper.withTempView(spark, viewName)({ #expect(try await spark.catalog.tableExists(viewName) == false) - try await spark.range(1).createTempView(viewName) - try await spark.catalog.dropTempView(viewName) - #expect(try await spark.catalog.tableExists(viewName) == false) - }) + await spark.stop() + } - #expect(try await spark.catalog.dropTempView("non_exist_view") == false) - #expect(try await spark.catalog.dropTempView("invalid view name") == false) - await spark.stop() - } + @Test + func dropTempView() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let viewName = "VIEW_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") + try await SQLHelper.withTempView(spark, viewName)({ + #expect(try await spark.catalog.tableExists(viewName) == false) + try await spark.range(1).createTempView(viewName) + try await spark.catalog.dropTempView(viewName) + #expect(try await spark.catalog.tableExists(viewName) == false) + }) - @Test - func dropGlobalTempView() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let viewName = "VIEW_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") - try await SQLHelper.withTempView(spark, viewName)({ #expect(try await spark.catalog.tableExists(viewName) == false) - try await spark.range(1).createGlobalTempView(viewName) - #expect(try await spark.catalog.tableExists("global_temp.\(viewName)")) - try await spark.catalog.dropGlobalTempView(viewName) - #expect(try await spark.catalog.tableExists("global_temp.\(viewName)") == false) - }) + #expect(try await spark.catalog.dropTempView("non_exist_view") == false) + #expect(try await spark.catalog.dropTempView("invalid view name") == false) + await spark.stop() + } - #expect(try await spark.catalog.dropGlobalTempView("non_exist_view") == false) - #expect(try await spark.catalog.dropGlobalTempView("invalid view name") == false) - await spark.stop() - } -#endif + @Test + func dropGlobalTempView() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let viewName = "VIEW_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") + try await SQLHelper.withTempView(spark, viewName)({ + #expect(try await spark.catalog.tableExists(viewName) == false) + try await spark.range(1).createGlobalTempView(viewName) + #expect(try await spark.catalog.tableExists("global_temp.\(viewName)")) + try await spark.catalog.dropGlobalTempView(viewName) + #expect(try await spark.catalog.tableExists("global_temp.\(viewName)") == false) + }) + + #expect(try await spark.catalog.dropGlobalTempView("non_exist_view") == false) + #expect(try await spark.catalog.dropGlobalTempView("invalid view name") == false) + await spark.stop() + } + #endif @Test func cacheTable() async throws { diff --git a/Tests/SparkConnectTests/DataFrameInternalTests.swift b/Tests/SparkConnectTests/DataFrameInternalTests.swift index 96e8fc2..6c843c3 100644 --- a/Tests/SparkConnectTests/DataFrameInternalTests.swift +++ b/Tests/SparkConnectTests/DataFrameInternalTests.swift @@ -25,63 +25,63 @@ import Testing @Suite(.serialized) struct DataFrameInternalTests { -#if !os(Linux) - @Test - func showString() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let rows = try await spark.range(10).showString(2, 0, false).collect() - #expect(rows.count == 1) - #expect(rows[0].length == 1) - #expect( - try (rows[0].get(0) as! String).trimmingCharacters(in: .whitespacesAndNewlines) == """ - +---+ - |id | - +---+ - |0 | - |1 | - +---+ - only showing top 2 rows - """) - await spark.stop() - } + #if !os(Linux) + @Test + func showString() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let rows = try await spark.range(10).showString(2, 0, false).collect() + #expect(rows.count == 1) + #expect(rows[0].length == 1) + #expect( + try (rows[0].get(0) as! String).trimmingCharacters(in: .whitespacesAndNewlines) == """ + +---+ + |id | + +---+ + |0 | + |1 | + +---+ + only showing top 2 rows + """) + await spark.stop() + } - @Test - func showStringTruncate() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let rows = try await spark.sql("SELECT * FROM VALUES ('abc', 'def'), ('ghi', 'jkl')") - .showString(2, 2, false).collect() - #expect(rows.count == 1) - #expect(rows[0].length == 1) - print(try rows[0].get(0) as! String) - #expect( - try rows[0].get(0) as! String == """ - +----+----+ - |col1|col2| - +----+----+ - | ab| de| - | gh| jk| - +----+----+ + @Test + func showStringTruncate() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let rows = try await spark.sql("SELECT * FROM VALUES ('abc', 'def'), ('ghi', 'jkl')") + .showString(2, 2, false).collect() + #expect(rows.count == 1) + #expect(rows[0].length == 1) + print(try rows[0].get(0) as! String) + #expect( + try rows[0].get(0) as! String == """ + +----+----+ + |col1|col2| + +----+----+ + | ab| de| + | gh| jk| + +----+----+ - """) - await spark.stop() - } + """) + await spark.stop() + } - @Test - func showStringVertical() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let rows = try await spark.range(10).showString(2, 0, true).collect() - #expect(rows.count == 1) - #expect(rows[0].length == 1) - print(try rows[0].get(0) as! String) - #expect( - try (rows[0].get(0) as! String).trimmingCharacters(in: .whitespacesAndNewlines) == """ - -RECORD 0-- - id | 0 - -RECORD 1-- - id | 1 - only showing top 2 rows - """) - await spark.stop() - } -#endif + @Test + func showStringVertical() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let rows = try await spark.range(10).showString(2, 0, true).collect() + #expect(rows.count == 1) + #expect(rows[0].length == 1) + print(try rows[0].get(0) as! String) + #expect( + try (rows[0].get(0) as! String).trimmingCharacters(in: .whitespacesAndNewlines) == """ + -RECORD 0-- + id | 0 + -RECORD 1-- + id | 1 + only showing top 2 rows + """) + await spark.stop() + } + #endif } diff --git a/Tests/SparkConnectTests/DataFrameReaderTests.swift b/Tests/SparkConnectTests/DataFrameReaderTests.swift index 0dfd04b..bcee038 100644 --- a/Tests/SparkConnectTests/DataFrameReaderTests.swift +++ b/Tests/SparkConnectTests/DataFrameReaderTests.swift @@ -18,9 +18,8 @@ // import Foundation -import Testing - import SparkConnect +import Testing /// A test suite for `DataFrameReader` @Suite(.serialized) @@ -95,8 +94,14 @@ struct DataFrameReaderTests { let path = "../examples/src/main/resources/people.json" #expect(try await spark.read.schema("age SHORT").json(path).dtypes.count == 1) #expect(try await spark.read.schema("age SHORT").json(path).dtypes[0] == ("age", "smallint")) - #expect(try await spark.read.schema("age SHORT, name STRING").json(path).dtypes[0] == ("age", "smallint")) - #expect(try await spark.read.schema("age SHORT, name STRING").json(path).dtypes[1] == ("name", "string")) + #expect( + try await spark.read.schema("age SHORT, name STRING").json(path).dtypes[0] == ( + "age", "smallint" + )) + #expect( + try await spark.read.schema("age SHORT, name STRING").json(path).dtypes[1] == ( + "name", "string" + )) await spark.stop() } diff --git a/Tests/SparkConnectTests/DataFrameTests.swift b/Tests/SparkConnectTests/DataFrameTests.swift index f5c6eeb..2edd5f8 100644 --- a/Tests/SparkConnectTests/DataFrameTests.swift +++ b/Tests/SparkConnectTests/DataFrameTests.swift @@ -18,9 +18,8 @@ // import Foundation -import Testing - import SparkConnect +import Testing /// A test suite for `DataFrame` @Suite(.serialized) @@ -70,19 +69,21 @@ struct DataFrameTests { let spark = try await SparkSession.builder.getOrCreate() let schema1 = try await spark.sql("SELECT 'a' as col1").schema - let answer1 = if await spark.version.starts(with: "4.") { - #"{"struct":{"fields":[{"name":"col1","dataType":{"string":{"collation":"UTF8_BINARY"}}}]}}"# - } else { - #"{"struct":{"fields":[{"name":"col1","dataType":{"string":{}}}]}}"# - } + let answer1 = + if await spark.version.starts(with: "4.") { + #"{"struct":{"fields":[{"name":"col1","dataType":{"string":{"collation":"UTF8_BINARY"}}}]}}"# + } else { + #"{"struct":{"fields":[{"name":"col1","dataType":{"string":{}}}]}}"# + } #expect(schema1 == answer1) let schema2 = try await spark.sql("SELECT 'a' as col1, 'b' as col2").schema - let answer2 = if await spark.version.starts(with: "4.") { - #"{"struct":{"fields":[{"name":"col1","dataType":{"string":{"collation":"UTF8_BINARY"}}},{"name":"col2","dataType":{"string":{"collation":"UTF8_BINARY"}}}]}}"# - } else { - #"{"struct":{"fields":[{"name":"col1","dataType":{"string":{}}},{"name":"col2","dataType":{"string":{}}}]}}"# - } + let answer2 = + if await spark.version.starts(with: "4.") { + #"{"struct":{"fields":[{"name":"col1","dataType":{"string":{"collation":"UTF8_BINARY"}}},{"name":"col2","dataType":{"string":{"collation":"UTF8_BINARY"}}}]}}"# + } else { + #"{"struct":{"fields":[{"name":"col1","dataType":{"string":{}}},{"name":"col2","dataType":{"string":{}}}]}}"# + } #expect(schema2 == answer2) let emptySchema = try await spark.sql("DROP TABLE IF EXISTS nonexistent").schema @@ -208,14 +209,14 @@ struct DataFrameTests { let schema1 = try await spark.range(1).to("shortID SHORT").schema #expect( schema1 - == #"{"struct":{"fields":[{"name":"shortID","dataType":{"short":{}},"nullable":true}]}}"# + == #"{"struct":{"fields":[{"name":"shortID","dataType":{"short":{}},"nullable":true}]}}"# ) let schema2 = try await spark.sql("SELECT '1'").to("id INT").schema print(schema2) #expect( schema2 - == #"{"struct":{"fields":[{"name":"id","dataType":{"integer":{}},"nullable":true}]}}"# + == #"{"struct":{"fields":[{"name":"id","dataType":{"integer":{}},"nullable":true}]}}"# ) await spark.stop() @@ -344,23 +345,23 @@ struct DataFrameTests { await spark.stop() } -#if !os(Linux) - @Test - func sort() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let expected = Array((1...10).map{ Row($0) }) - #expect(try await spark.range(10, 0, -1).sort("id").collect() == expected) - await spark.stop() - } + #if !os(Linux) + @Test + func sort() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let expected = Array((1...10).map { Row($0) }) + #expect(try await spark.range(10, 0, -1).sort("id").collect() == expected) + await spark.stop() + } - @Test - func orderBy() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let expected = Array((1...10).map{ Row($0) }) - #expect(try await spark.range(10, 0, -1).orderBy("id").collect() == expected) - await spark.stop() - } -#endif + @Test + func orderBy() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let expected = Array((1...10).map { Row($0) }) + #expect(try await spark.range(10, 0, -1).orderBy("id").collect() == expected) + await spark.stop() + } + #endif @Test func table() async throws { @@ -376,204 +377,167 @@ struct DataFrameTests { await spark.stop() } -#if !os(Linux) - @Test - func collect() async throws { - let spark = try await SparkSession.builder.getOrCreate() - #expect(try await spark.range(0).collect().isEmpty) - #expect( - try await spark.sql( - "SELECT * FROM VALUES (1, true, 'abc'), (null, null, null), (3, false, 'def')" - ).collect() == [Row(1, true, "abc"), Row(nil, nil, nil), Row(3, false, "def")]) - await spark.stop() - } - - @Test - func collectMultiple() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let df = try await spark.range(1) - #expect(try await df.collect().count == 1) - #expect(try await df.collect().count == 1) - await spark.stop() - } + #if !os(Linux) + @Test + func collect() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.range(0).collect().isEmpty) + #expect( + try await spark.sql( + "SELECT * FROM VALUES (1, true, 'abc'), (null, null, null), (3, false, 'def')" + ).collect() == [Row(1, true, "abc"), Row(nil, nil, nil), Row(3, false, "def")]) + await spark.stop() + } - @Test - func first() async throws { - let spark = try await SparkSession.builder.getOrCreate() - #expect(try await spark.range(2).sort("id").first() == Row(0)) - #expect(try await spark.range(2).sort("id").head() == Row(0)) - await spark.stop() - } + @Test + func collectMultiple() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let df = try await spark.range(1) + #expect(try await df.collect().count == 1) + #expect(try await df.collect().count == 1) + await spark.stop() + } - @Test - func head() async throws { - let spark = try await SparkSession.builder.getOrCreate() - #expect(try await spark.range(0).head(1).isEmpty) - #expect(try await spark.range(2).sort("id").head() == Row(0)) - #expect(try await spark.range(2).sort("id").head(1) == [Row(0)]) - #expect(try await spark.range(2).sort("id").head(2) == [Row(0), Row(1)]) - #expect(try await spark.range(2).sort("id").head(3) == [Row(0), Row(1)]) - await spark.stop() - } + @Test + func first() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.range(2).sort("id").first() == Row(0)) + #expect(try await spark.range(2).sort("id").head() == Row(0)) + await spark.stop() + } - @Test - func take() async throws { - let spark = try await SparkSession.builder.getOrCreate() - #expect(try await spark.range(0).take(1).isEmpty) - #expect(try await spark.range(2).sort("id").take(1) == [Row(0)]) - #expect(try await spark.range(2).sort("id").take(2) == [Row(0), Row(1)]) - #expect(try await spark.range(2).sort("id").take(3) == [Row(0), Row(1)]) - await spark.stop() - } + @Test + func head() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.range(0).head(1).isEmpty) + #expect(try await spark.range(2).sort("id").head() == Row(0)) + #expect(try await spark.range(2).sort("id").head(1) == [Row(0)]) + #expect(try await spark.range(2).sort("id").head(2) == [Row(0), Row(1)]) + #expect(try await spark.range(2).sort("id").head(3) == [Row(0), Row(1)]) + await spark.stop() + } - @Test - func tail() async throws { - let spark = try await SparkSession.builder.getOrCreate() - #expect(try await spark.range(0).tail(1).isEmpty) - #expect(try await spark.range(2).sort("id").tail(1) == [Row(1)]) - #expect(try await spark.range(2).sort("id").tail(2) == [Row(0), Row(1)]) - #expect(try await spark.range(2).sort("id").tail(3) == [Row(0), Row(1)]) - await spark.stop() - } + @Test + func take() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.range(0).take(1).isEmpty) + #expect(try await spark.range(2).sort("id").take(1) == [Row(0)]) + #expect(try await spark.range(2).sort("id").take(2) == [Row(0), Row(1)]) + #expect(try await spark.range(2).sort("id").take(3) == [Row(0), Row(1)]) + await spark.stop() + } - @Test - func show() async throws { - let spark = try await SparkSession.builder.getOrCreate() - try await spark.sql("SHOW TABLES").show() - try await spark.sql("SELECT * FROM VALUES (true, false)").show() - try await spark.sql("SELECT * FROM VALUES (1, 2)").show() - try await spark.sql("SELECT * FROM VALUES ('abc', 'def'), ('ghi', 'jkl')").show() - - // Check all signatures - try await spark.range(1000).show() - try await spark.range(1000).show(1) - try await spark.range(1000).show(true) - try await spark.range(1000).show(false) - try await spark.range(1000).show(1, true) - try await spark.range(1000).show(1, false) - try await spark.range(1000).show(1, 20) - try await spark.range(1000).show(1, 20, true) - try await spark.range(1000).show(1, 20, false) + @Test + func tail() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.range(0).tail(1).isEmpty) + #expect(try await spark.range(2).sort("id").tail(1) == [Row(1)]) + #expect(try await spark.range(2).sort("id").tail(2) == [Row(0), Row(1)]) + #expect(try await spark.range(2).sort("id").tail(3) == [Row(0), Row(1)]) + await spark.stop() + } - await spark.stop() - } + @Test + func show() async throws { + let spark = try await SparkSession.builder.getOrCreate() + try await spark.sql("SHOW TABLES").show() + try await spark.sql("SELECT * FROM VALUES (true, false)").show() + try await spark.sql("SELECT * FROM VALUES (1, 2)").show() + try await spark.sql("SELECT * FROM VALUES ('abc', 'def'), ('ghi', 'jkl')").show() + + // Check all signatures + try await spark.range(1000).show() + try await spark.range(1000).show(1) + try await spark.range(1000).show(true) + try await spark.range(1000).show(false) + try await spark.range(1000).show(1, true) + try await spark.range(1000).show(1, false) + try await spark.range(1000).show(1, 20) + try await spark.range(1000).show(1, 20, true) + try await spark.range(1000).show(1, 20, false) + + await spark.stop() + } - @Test - func showNull() async throws { - let spark = try await SparkSession.builder.getOrCreate() - try await spark.sql( - "SELECT * FROM VALUES (1, true, 'abc'), (null, null, null), (3, false, 'def')" - ).show() - await spark.stop() - } + @Test + func showNull() async throws { + let spark = try await SparkSession.builder.getOrCreate() + try await spark.sql( + "SELECT * FROM VALUES (1, true, 'abc'), (null, null, null), (3, false, 'def')" + ).show() + await spark.stop() + } - @Test - func showCommand() async throws { - let spark = try await SparkSession.builder.getOrCreate() - try await spark.sql("DROP TABLE IF EXISTS t").show() - await spark.stop() - } + @Test + func showCommand() async throws { + let spark = try await SparkSession.builder.getOrCreate() + try await spark.sql("DROP TABLE IF EXISTS t").show() + await spark.stop() + } - @Test - func cache() async throws { - let spark = try await SparkSession.builder.getOrCreate() - #expect(try await spark.range(10).cache().count() == 10) - await spark.stop() - } + @Test + func cache() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.range(10).cache().count() == 10) + await spark.stop() + } - @Test - func checkpoint() async throws { - let spark = try await SparkSession.builder.getOrCreate() - if await spark.version >= "4.0.0" { - // By default, reliable checkpoint location is required. - try await #require(throws: Error.self) { - try await spark.range(10).checkpoint() + @Test + func checkpoint() async throws { + let spark = try await SparkSession.builder.getOrCreate() + if await spark.version >= "4.0.0" { + // By default, reliable checkpoint location is required. + try await #require(throws: Error.self) { + try await spark.range(10).checkpoint() + } + // Checkpointing with unreliable checkpoint + let df = try await spark.range(10).checkpoint(true, false) + #expect(try await df.count() == 10) } - // Checkpointing with unreliable checkpoint - let df = try await spark.range(10).checkpoint(true, false) - #expect(try await df.count() == 10) + await spark.stop() } - await spark.stop() - } - @Test - func localCheckpoint() async throws { - let spark = try await SparkSession.builder.getOrCreate() - if await spark.version >= "4.0.0" { - #expect(try await spark.range(10).localCheckpoint().count() == 10) + @Test + func localCheckpoint() async throws { + let spark = try await SparkSession.builder.getOrCreate() + if await spark.version >= "4.0.0" { + #expect(try await spark.range(10).localCheckpoint().count() == 10) + } + await spark.stop() } - await spark.stop() - } - - @Test - func persist() async throws { - let spark = try await SparkSession.builder.getOrCreate() - #expect(try await spark.range(20).persist().count() == 20) - #expect(try await spark.range(21).persist(storageLevel: StorageLevel.MEMORY_ONLY).count() == 21) - await spark.stop() - } - @Test - func persistInvalidStorageLevel() async throws { - let spark = try await SparkSession.builder.getOrCreate() - try await #require(throws: Error.self) { - var invalidLevel = StorageLevel.DISK_ONLY - invalidLevel.replication = 0 - try await spark.range(9999).persist(storageLevel: invalidLevel).count() + @Test + func persist() async throws { + let spark = try await SparkSession.builder.getOrCreate() + #expect(try await spark.range(20).persist().count() == 20) + #expect( + try await spark.range(21).persist(storageLevel: StorageLevel.MEMORY_ONLY).count() == 21) + await spark.stop() } - await spark.stop() - } - @Test - func unpersist() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let df = try await spark.range(30) - #expect(try await df.persist().count() == 30) - #expect(try await df.unpersist().count() == 30) - await spark.stop() - } - - @Test - func join() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let df1 = try await spark.sql("SELECT * FROM VALUES ('a', 1), ('b', 2) AS T(a, b)") - let df2 = try await spark.sql("SELECT * FROM VALUES ('c', 2), ('d', 3) AS S(c, b)") - let expectedCross = [ - Row("a", 1, "c", 2), - Row("a", 1, "d", 3), - Row("b", 2, "c", 2), - Row("b", 2, "d", 3), - ] - #expect(try await df1.join(df2).collect() == expectedCross) - #expect(try await df1.crossJoin(df2).collect() == expectedCross) - - #expect(try await df1.join(df2, "b").collect() == [Row(2, "b", "c")]) - #expect(try await df1.join(df2, ["b"]).collect() == [Row(2, "b", "c")]) - - #expect(try await df1.join(df2, "b", "left").collect() == [Row(1, "a", nil), Row(2, "b", "c")]) - #expect(try await df1.join(df2, "b", "right").collect() == [Row(2, "b", "c"), Row(3, nil, "d")]) - #expect(try await df1.join(df2, "b", "semi").collect() == [Row(2, "b")]) - #expect(try await df1.join(df2, "b", "anti").collect() == [Row(1, "a")]) - - let expectedOuter = [ - Row(1, "a", nil), - Row(2, "b", "c"), - Row(3, nil, "d"), - ] - #expect(try await df1.join(df2, "b", "outer").collect() == expectedOuter) - #expect(try await df1.join(df2, "b", "full").collect() == expectedOuter) - #expect(try await df1.join(df2, ["b"], "full").collect() == expectedOuter) + @Test + func persistInvalidStorageLevel() async throws { + let spark = try await SparkSession.builder.getOrCreate() + try await #require(throws: Error.self) { + var invalidLevel = StorageLevel.DISK_ONLY + invalidLevel.replication = 0 + try await spark.range(9999).persist(storageLevel: invalidLevel).count() + } + await spark.stop() + } - let expected = [Row("b", 2, "c", 2)] - #expect(try await df1.join(df2, joinExprs: "T.b = S.b").collect() == expected) - #expect(try await df1.join(df2, joinExprs: "T.b = S.b", joinType: "inner").collect() == expected) - await spark.stop() - } + @Test + func unpersist() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let df = try await spark.range(30) + #expect(try await df.persist().count() == 30) + #expect(try await df.unpersist().count() == 30) + await spark.stop() + } - @Test - func lateralJoin() async throws { - let spark = try await SparkSession.builder.getOrCreate() - if await spark.version.starts(with: "4.") { + @Test + func join() async throws { + let spark = try await SparkSession.builder.getOrCreate() let df1 = try await spark.sql("SELECT * FROM VALUES ('a', 1), ('b', 2) AS T(a, b)") let df2 = try await spark.sql("SELECT * FROM VALUES ('c', 2), ('d', 3) AS S(c, b)") let expectedCross = [ @@ -582,337 +546,393 @@ struct DataFrameTests { Row("b", 2, "c", 2), Row("b", 2, "d", 3), ] - #expect(try await df1.lateralJoin(df2).collect() == expectedCross) - #expect(try await df1.lateralJoin(df2, joinType: "inner").collect() == expectedCross) + #expect(try await df1.join(df2).collect() == expectedCross) + #expect(try await df1.crossJoin(df2).collect() == expectedCross) + + #expect(try await df1.join(df2, "b").collect() == [Row(2, "b", "c")]) + #expect(try await df1.join(df2, ["b"]).collect() == [Row(2, "b", "c")]) + + #expect( + try await df1.join(df2, "b", "left").collect() == [Row(1, "a", nil), Row(2, "b", "c")]) + #expect( + try await df1.join(df2, "b", "right").collect() == [Row(2, "b", "c"), Row(3, nil, "d")]) + #expect(try await df1.join(df2, "b", "semi").collect() == [Row(2, "b")]) + #expect(try await df1.join(df2, "b", "anti").collect() == [Row(1, "a")]) + + let expectedOuter = [ + Row(1, "a", nil), + Row(2, "b", "c"), + Row(3, nil, "d"), + ] + #expect(try await df1.join(df2, "b", "outer").collect() == expectedOuter) + #expect(try await df1.join(df2, "b", "full").collect() == expectedOuter) + #expect(try await df1.join(df2, ["b"], "full").collect() == expectedOuter) let expected = [Row("b", 2, "c", 2)] - #expect(try await df1.lateralJoin(df2, joinExprs: "T.b = S.b").collect() == expected) - #expect(try await df1.lateralJoin(df2, joinExprs: "T.b = S.b", joinType: "inner").collect() == expected) + #expect(try await df1.join(df2, joinExprs: "T.b = S.b").collect() == expected) + #expect( + try await df1.join(df2, joinExprs: "T.b = S.b", joinType: "inner").collect() == expected) + await spark.stop() } - await spark.stop() - } - @Test - func except() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let df = try await spark.range(1, 3) - #expect(try await df.except(spark.range(1, 5)).collect() == []) - #expect(try await df.except(spark.range(2, 5)).collect() == [Row(1)]) - #expect(try await df.except(spark.range(3, 5)).collect() == [Row(1), Row(2)]) - #expect(try await spark.sql("SELECT * FROM VALUES 1, 1").except(df).count() == 0) - await spark.stop() - } + @Test + func lateralJoin() async throws { + let spark = try await SparkSession.builder.getOrCreate() + if await spark.version.starts(with: "4.") { + let df1 = try await spark.sql("SELECT * FROM VALUES ('a', 1), ('b', 2) AS T(a, b)") + let df2 = try await spark.sql("SELECT * FROM VALUES ('c', 2), ('d', 3) AS S(c, b)") + let expectedCross = [ + Row("a", 1, "c", 2), + Row("a", 1, "d", 3), + Row("b", 2, "c", 2), + Row("b", 2, "d", 3), + ] + #expect(try await df1.lateralJoin(df2).collect() == expectedCross) + #expect(try await df1.lateralJoin(df2, joinType: "inner").collect() == expectedCross) + + let expected = [Row("b", 2, "c", 2)] + #expect(try await df1.lateralJoin(df2, joinExprs: "T.b = S.b").collect() == expected) + #expect( + try await df1.lateralJoin(df2, joinExprs: "T.b = S.b", joinType: "inner").collect() + == expected) + } + await spark.stop() + } - @Test - func exceptAll() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let df = try await spark.range(1, 3) - #expect(try await df.exceptAll(spark.range(1, 5)).collect() == []) - #expect(try await df.exceptAll(spark.range(2, 5)).collect() == [Row(1)]) - #expect(try await df.exceptAll(spark.range(3, 5)).collect() == [Row(1), Row(2)]) - #expect(try await spark.sql("SELECT * FROM VALUES 1, 1").exceptAll(df).count() == 1) - await spark.stop() - } + @Test + func except() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let df = try await spark.range(1, 3) + #expect(try await df.except(spark.range(1, 5)).collect() == []) + #expect(try await df.except(spark.range(2, 5)).collect() == [Row(1)]) + #expect(try await df.except(spark.range(3, 5)).collect() == [Row(1), Row(2)]) + #expect(try await spark.sql("SELECT * FROM VALUES 1, 1").except(df).count() == 0) + await spark.stop() + } - @Test - func intersect() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let df = try await spark.range(1, 3) - #expect(try await df.intersect(spark.range(1, 5)).collect() == [Row(1), Row(2)]) - #expect(try await df.intersect(spark.range(2, 5)).collect() == [Row(2)]) - #expect(try await df.intersect(spark.range(3, 5)).collect() == []) - let df2 = try await spark.sql("SELECT * FROM VALUES 1, 1") - #expect(try await df2.intersect(df2).count() == 1) - await spark.stop() - } + @Test + func exceptAll() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let df = try await spark.range(1, 3) + #expect(try await df.exceptAll(spark.range(1, 5)).collect() == []) + #expect(try await df.exceptAll(spark.range(2, 5)).collect() == [Row(1)]) + #expect(try await df.exceptAll(spark.range(3, 5)).collect() == [Row(1), Row(2)]) + #expect(try await spark.sql("SELECT * FROM VALUES 1, 1").exceptAll(df).count() == 1) + await spark.stop() + } - @Test - func intersectAll() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let df = try await spark.range(1, 3) - #expect(try await df.intersectAll(spark.range(1, 5)).collect() == [Row(1), Row(2)]) - #expect(try await df.intersectAll(spark.range(2, 5)).collect() == [Row(2)]) - #expect(try await df.intersectAll(spark.range(3, 5)).collect() == []) - let df2 = try await spark.sql("SELECT * FROM VALUES 1, 1") - #expect(try await df2.intersectAll(df2).count() == 2) - await spark.stop() - } + @Test + func intersect() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let df = try await spark.range(1, 3) + #expect(try await df.intersect(spark.range(1, 5)).collect() == [Row(1), Row(2)]) + #expect(try await df.intersect(spark.range(2, 5)).collect() == [Row(2)]) + #expect(try await df.intersect(spark.range(3, 5)).collect() == []) + let df2 = try await spark.sql("SELECT * FROM VALUES 1, 1") + #expect(try await df2.intersect(df2).count() == 1) + await spark.stop() + } - @Test - func union() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let df = try await spark.range(1, 2) - #expect(try await df.union(spark.range(1, 3)).collect() == [Row(1), Row(1), Row(2)]) - #expect(try await df.union(spark.range(2, 3)).collect() == [Row(1), Row(2)]) - let df2 = try await spark.sql("SELECT * FROM VALUES 1, 1") - #expect(try await df2.union(df2).count() == 4) - await spark.stop() - } + @Test + func intersectAll() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let df = try await spark.range(1, 3) + #expect(try await df.intersectAll(spark.range(1, 5)).collect() == [Row(1), Row(2)]) + #expect(try await df.intersectAll(spark.range(2, 5)).collect() == [Row(2)]) + #expect(try await df.intersectAll(spark.range(3, 5)).collect() == []) + let df2 = try await spark.sql("SELECT * FROM VALUES 1, 1") + #expect(try await df2.intersectAll(df2).count() == 2) + await spark.stop() + } - @Test - func unionAll() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let df = try await spark.range(1, 2) - #expect(try await df.unionAll(spark.range(1, 3)).collect() == [Row(1), Row(1), Row(2)]) - #expect(try await df.unionAll(spark.range(2, 3)).collect() == [Row(1), Row(2)]) - let df2 = try await spark.sql("SELECT * FROM VALUES 1, 1") - #expect(try await df2.unionAll(df2).count() == 4) - await spark.stop() - } + @Test + func union() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let df = try await spark.range(1, 2) + #expect(try await df.union(spark.range(1, 3)).collect() == [Row(1), Row(1), Row(2)]) + #expect(try await df.union(spark.range(2, 3)).collect() == [Row(1), Row(2)]) + let df2 = try await spark.sql("SELECT * FROM VALUES 1, 1") + #expect(try await df2.union(df2).count() == 4) + await spark.stop() + } - @Test - func unionByName() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let df1 = try await spark.sql("SELECT 1 a, 2 b") - let df2 = try await spark.sql("SELECT 4 b, 3 a") - #expect(try await df1.unionByName(df2).collect() == [Row(1, 2), Row(3, 4)]) - #expect(try await df1.union(df2).collect() == [Row(1, 2), Row(4, 3)]) - let df3 = try await spark.sql("SELECT * FROM VALUES 1, 1") - #expect(try await df3.unionByName(df3).count() == 4) - await spark.stop() - } + @Test + func unionAll() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let df = try await spark.range(1, 2) + #expect(try await df.unionAll(spark.range(1, 3)).collect() == [Row(1), Row(1), Row(2)]) + #expect(try await df.unionAll(spark.range(2, 3)).collect() == [Row(1), Row(2)]) + let df2 = try await spark.sql("SELECT * FROM VALUES 1, 1") + #expect(try await df2.unionAll(df2).count() == 4) + await spark.stop() + } - @Test - func repartition() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let tmpDir = "/tmp/" + UUID().uuidString - let df = try await spark.range(2025) - for n in [1, 3, 5] as [Int32] { - try await df.repartition(n).write.mode("overwrite").orc(tmpDir) - #expect(try await spark.read.orc(tmpDir).inputFiles().count == n) - } - try await spark.range(1).repartition(10).write.mode("overwrite").orc(tmpDir) - #expect(try await spark.read.orc(tmpDir).inputFiles().count < 10) - await spark.stop() - } + @Test + func unionByName() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let df1 = try await spark.sql("SELECT 1 a, 2 b") + let df2 = try await spark.sql("SELECT 4 b, 3 a") + #expect(try await df1.unionByName(df2).collect() == [Row(1, 2), Row(3, 4)]) + #expect(try await df1.union(df2).collect() == [Row(1, 2), Row(4, 3)]) + let df3 = try await spark.sql("SELECT * FROM VALUES 1, 1") + #expect(try await df3.unionByName(df3).count() == 4) + await spark.stop() + } - @Test - func repartitionByExpression() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let tmpDir = "/tmp/" + UUID().uuidString - let df = try await spark.range(2025) - for n in [1, 3, 5] as [Int32] { - try await df.repartition(n, "id").write.mode("overwrite").orc(tmpDir) - #expect(try await spark.read.orc(tmpDir).inputFiles().count == n) - try await df.repartitionByExpression(n, "id").write.mode("overwrite").orc(tmpDir) - #expect(try await spark.read.orc(tmpDir).inputFiles().count == n) - } - try await spark.range(1).repartition(10, "id").write.mode("overwrite").orc(tmpDir) - #expect(try await spark.read.orc(tmpDir).inputFiles().count < 10) - try await spark.range(1).repartition("id").write.mode("overwrite").orc(tmpDir) - #expect(try await spark.read.orc(tmpDir).inputFiles().count < 10) - await spark.stop() - } + @Test + func repartition() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let tmpDir = "/tmp/" + UUID().uuidString + let df = try await spark.range(2025) + for n in [1, 3, 5] as [Int32] { + try await df.repartition(n).write.mode("overwrite").orc(tmpDir) + #expect(try await spark.read.orc(tmpDir).inputFiles().count == n) + } + try await spark.range(1).repartition(10).write.mode("overwrite").orc(tmpDir) + #expect(try await spark.read.orc(tmpDir).inputFiles().count < 10) + await spark.stop() + } - @Test - func coalesce() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let tmpDir = "/tmp/" + UUID().uuidString - let df = try await spark.range(2025) - for n in [1, 2, 3] as [Int32] { - try await df.coalesce(n).write.mode("overwrite").orc(tmpDir) - #expect(try await spark.read.orc(tmpDir).inputFiles().count == n) - } - try await spark.range(1).coalesce(10).write.mode("overwrite").orc(tmpDir) - #expect(try await spark.read.orc(tmpDir).inputFiles().count < 10) - await spark.stop() - } + @Test + func repartitionByExpression() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let tmpDir = "/tmp/" + UUID().uuidString + let df = try await spark.range(2025) + for n in [1, 3, 5] as [Int32] { + try await df.repartition(n, "id").write.mode("overwrite").orc(tmpDir) + #expect(try await spark.read.orc(tmpDir).inputFiles().count == n) + try await df.repartitionByExpression(n, "id").write.mode("overwrite").orc(tmpDir) + #expect(try await spark.read.orc(tmpDir).inputFiles().count == n) + } + try await spark.range(1).repartition(10, "id").write.mode("overwrite").orc(tmpDir) + #expect(try await spark.read.orc(tmpDir).inputFiles().count < 10) + try await spark.range(1).repartition("id").write.mode("overwrite").orc(tmpDir) + #expect(try await spark.read.orc(tmpDir).inputFiles().count < 10) + await spark.stop() + } - @Test - func distinct() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let df = try await spark.sql("SELECT * FROM VALUES (1), (2), (3), (1), (3) T(a)") - #expect(try await df.distinct().count() == 3) - await spark.stop() - } + @Test + func coalesce() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let tmpDir = "/tmp/" + UUID().uuidString + let df = try await spark.range(2025) + for n in [1, 2, 3] as [Int32] { + try await df.coalesce(n).write.mode("overwrite").orc(tmpDir) + #expect(try await spark.read.orc(tmpDir).inputFiles().count == n) + } + try await spark.range(1).coalesce(10).write.mode("overwrite").orc(tmpDir) + #expect(try await spark.read.orc(tmpDir).inputFiles().count < 10) + await spark.stop() + } - @Test - func dropDuplicates() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let df = try await spark.sql("SELECT * FROM VALUES (1), (2), (3), (1), (3) T(a)") - #expect(try await df.dropDuplicates().count() == 3) - #expect(try await df.dropDuplicates("a").count() == 3) - await spark.stop() - } + @Test + func distinct() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let df = try await spark.sql("SELECT * FROM VALUES (1), (2), (3), (1), (3) T(a)") + #expect(try await df.distinct().count() == 3) + await spark.stop() + } - @Test - func dropDuplicatesWithinWatermark() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let df = try await spark.sql("SELECT * FROM VALUES (1), (2), (3), (1), (3) T(a)") - #expect(try await df.dropDuplicatesWithinWatermark().count() == 3) - #expect(try await df.dropDuplicatesWithinWatermark("a").count() == 3) - await spark.stop() - } + @Test + func dropDuplicates() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let df = try await spark.sql("SELECT * FROM VALUES (1), (2), (3), (1), (3) T(a)") + #expect(try await df.dropDuplicates().count() == 3) + #expect(try await df.dropDuplicates("a").count() == 3) + await spark.stop() + } - @Test - func withWatermark() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let df = try await spark - .sql(""" - SELECT * FROM VALUES - (1, now()), - (1, now() - INTERVAL 1 HOUR), - (1, now() - INTERVAL 2 HOUR) - T(data, eventTime) - """) - .withWatermark("eventTime", "1 minute") // This tests only API for now - #expect(try await df.dropDuplicatesWithinWatermark("data").count() == 1) - await spark.stop() - } + @Test + func dropDuplicatesWithinWatermark() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let df = try await spark.sql("SELECT * FROM VALUES (1), (2), (3), (1), (3) T(a)") + #expect(try await df.dropDuplicatesWithinWatermark().count() == 3) + #expect(try await df.dropDuplicatesWithinWatermark("a").count() == 3) + await spark.stop() + } - @Test - func describe() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let df = try await spark.range(10) - let expected = [Row("10"), Row("4.5"), Row("3.0276503540974917"), Row("0"), Row("9")] - #expect(try await df.describe().select("id").collect() == expected) - #expect(try await df.describe("id").select("id").collect() == expected) - await spark.stop() - } + @Test + func withWatermark() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let df = + try await spark + .sql( + """ + SELECT * FROM VALUES + (1, now()), + (1, now() - INTERVAL 1 HOUR), + (1, now() - INTERVAL 2 HOUR) + T(data, eventTime) + """ + ) + .withWatermark("eventTime", "1 minute") // This tests only API for now + #expect(try await df.dropDuplicatesWithinWatermark("data").count() == 1) + await spark.stop() + } - @Test - func summary() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let expected = [ - Row("10"), Row("4.5"), Row("3.0276503540974917"), - Row("0"), Row("2"), Row("4"), Row("7"), Row("9") - ] - #expect(try await spark.range(10).summary().select("id").collect() == expected) - #expect(try await spark.range(10).summary("min", "max").select("id").collect() == [Row("0"), Row("9")]) - await spark.stop() - } + @Test + func describe() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let df = try await spark.range(10) + let expected = [Row("10"), Row("4.5"), Row("3.0276503540974917"), Row("0"), Row("9")] + #expect(try await df.describe().select("id").collect() == expected) + #expect(try await df.describe("id").select("id").collect() == expected) + await spark.stop() + } - @Test - func groupBy() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let rows = try await spark.range(3).groupBy("id").agg("count(*)", "sum(*)", "avg(*)").collect() - #expect(rows == [Row(0, 1, 0, 0.0), Row(1, 1, 1, 1.0), Row(2, 1, 2, 2.0)]) - await spark.stop() - } + @Test + func summary() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let expected = [ + Row("10"), Row("4.5"), Row("3.0276503540974917"), + Row("0"), Row("2"), Row("4"), Row("7"), Row("9"), + ] + #expect(try await spark.range(10).summary().select("id").collect() == expected) + #expect( + try await spark.range(10).summary("min", "max").select("id").collect() == [ + Row("0"), Row("9"), + ]) + await spark.stop() + } - @Test - func rollup() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let rows = try await spark.sql(DEALER_TABLE).rollup("city", "car_model") - .agg("sum(quantity) sum").orderBy("city", "car_model").collect() - #expect(rows == [ - Row("Dublin", "Honda Accord", 10), - Row("Dublin", "Honda CRV", 3), - Row("Dublin", "Honda Civic", 20), - Row("Dublin", nil, 33), - Row("Fremont", "Honda Accord", 15), - Row("Fremont", "Honda CRV", 7), - Row("Fremont", "Honda Civic", 10), - Row("Fremont", nil, 32), - Row("San Jose", "Honda Accord", 8), - Row("San Jose", "Honda Civic", 5), - Row("San Jose", nil, 13), - Row(nil, nil, 78), - ]) - await spark.stop() - } + @Test + func groupBy() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let rows = try await spark.range(3).groupBy("id").agg("count(*)", "sum(*)", "avg(*)") + .collect() + #expect(rows == [Row(0, 1, 0, 0.0), Row(1, 1, 1, 1.0), Row(2, 1, 2, 2.0)]) + await spark.stop() + } - @Test - func cube() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let rows = try await spark.sql(DEALER_TABLE).cube("city", "car_model") - .agg("sum(quantity) sum").orderBy("city", "car_model").collect() - #expect(rows == [ - Row("Dublin", "Honda Accord", 10), - Row("Dublin", "Honda CRV", 3), - Row("Dublin", "Honda Civic", 20), - Row("Dublin", nil, 33), - Row("Fremont", "Honda Accord", 15), - Row("Fremont", "Honda CRV", 7), - Row("Fremont", "Honda Civic", 10), - Row("Fremont", nil, 32), - Row("San Jose", "Honda Accord", 8), - Row("San Jose", "Honda Civic", 5), - Row("San Jose", nil, 13), - Row(nil, "Honda Accord", 33), - Row(nil, "Honda CRV", 10), - Row(nil, "Honda Civic", 35), - Row(nil, nil, 78), - ]) - await spark.stop() - } + @Test + func rollup() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let rows = try await spark.sql(DEALER_TABLE).rollup("city", "car_model") + .agg("sum(quantity) sum").orderBy("city", "car_model").collect() + #expect( + rows == [ + Row("Dublin", "Honda Accord", 10), + Row("Dublin", "Honda CRV", 3), + Row("Dublin", "Honda Civic", 20), + Row("Dublin", nil, 33), + Row("Fremont", "Honda Accord", 15), + Row("Fremont", "Honda CRV", 7), + Row("Fremont", "Honda Civic", 10), + Row("Fremont", nil, 32), + Row("San Jose", "Honda Accord", 8), + Row("San Jose", "Honda Civic", 5), + Row("San Jose", nil, 13), + Row(nil, nil, 78), + ]) + await spark.stop() + } - @Test - func toJSON() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let df = try await spark.range(2).toJSON() - #expect(try await df.columns == ["to_json(struct(id))"]) - #expect(try await df.collect() == [Row("{\"id\":0}"), Row("{\"id\":1}")]) + @Test + func cube() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let rows = try await spark.sql(DEALER_TABLE).cube("city", "car_model") + .agg("sum(quantity) sum").orderBy("city", "car_model").collect() + #expect( + rows == [ + Row("Dublin", "Honda Accord", 10), + Row("Dublin", "Honda CRV", 3), + Row("Dublin", "Honda Civic", 20), + Row("Dublin", nil, 33), + Row("Fremont", "Honda Accord", 15), + Row("Fremont", "Honda CRV", 7), + Row("Fremont", "Honda Civic", 10), + Row("Fremont", nil, 32), + Row("San Jose", "Honda Accord", 8), + Row("San Jose", "Honda Civic", 5), + Row("San Jose", nil, 13), + Row(nil, "Honda Accord", 33), + Row(nil, "Honda CRV", 10), + Row(nil, "Honda Civic", 35), + Row(nil, nil, 78), + ]) + await spark.stop() + } - let expected = [Row("{\"a\":1,\"b\":2,\"c\":3}")] - #expect(try await spark.sql("SELECT 1 a, 2 b, 3 c").toJSON().collect() == expected) - await spark.stop() - } + @Test + func toJSON() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let df = try await spark.range(2).toJSON() + #expect(try await df.columns == ["to_json(struct(id))"]) + #expect(try await df.collect() == [Row("{\"id\":0}"), Row("{\"id\":1}")]) - @Test - func unpivot() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let df = try await spark.sql( - """ - SELECT * FROM - VALUES (1, 11, 12L), - (2, 21, 22L) - T(id, int, long) - """) - let expected = [ - Row(1, "int", 11), - Row(1, "long", 12), - Row(2, "int", 21), - Row(2, "long", 22), - ] - #expect(try await df.unpivot(["id"], ["int", "long"], "variable", "value").collect() == expected) - #expect(try await df.melt(["id"], ["int", "long"], "variable", "value").collect() == expected) - await spark.stop() - } + let expected = [Row("{\"a\":1,\"b\":2,\"c\":3}")] + #expect(try await spark.sql("SELECT 1 a, 2 b, 3 c").toJSON().collect() == expected) + await spark.stop() + } - @Test - func transpose() async throws { - let spark = try await SparkSession.builder.getOrCreate() - if await spark.version.starts(with: "4.") { - #expect(try await spark.range(1).transpose().columns == ["key", "0"]) - #expect(try await spark.range(1).transpose().count() == 0) - + @Test + func unpivot() async throws { + let spark = try await SparkSession.builder.getOrCreate() let df = try await spark.sql( - """ - SELECT * FROM - VALUES ('A', 1, 2), - ('B', 3, 4) - T(id, val1, val2) - """) + """ + SELECT * FROM + VALUES (1, 11, 12L), + (2, 21, 22L) + T(id, int, long) + """) let expected = [ - Row("val1", 1, 3), - Row("val2", 2, 4), + Row(1, "int", 11), + Row(1, "long", 12), + Row(2, "int", 21), + Row(2, "long", 22), ] - #expect(try await df.transpose().collect() == expected) - #expect(try await df.transpose("id").collect() == expected) + #expect( + try await df.unpivot(["id"], ["int", "long"], "variable", "value").collect() == expected) + #expect( + try await df.melt(["id"], ["int", "long"], "variable", "value").collect() == expected) + await spark.stop() } - await spark.stop() - } - @Test - func decimal() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let df = try await spark.sql( - """ - SELECT * FROM VALUES - (1.0, 3.4, CAST(NULL AS DECIMAL), CAST(0 AS DECIMAL)), - (2.0, 34.56, CAST(0 AS DECIMAL), CAST(NULL AS DECIMAL)) - """) - #expect(try await df.dtypes.map { $0.1 } == - ["decimal(2,1)", "decimal(4,2)", "decimal(10,0)", "decimal(10,0)"]) - let expected = [ - Row(Decimal(1.0), Decimal(3.40), nil, Decimal(0)), - Row(Decimal(2.0), Decimal(34.56), Decimal(0), nil) - ] - #expect(try await df.collect() == expected) - await spark.stop() - } -#endif + @Test + func transpose() async throws { + let spark = try await SparkSession.builder.getOrCreate() + if await spark.version.starts(with: "4.") { + #expect(try await spark.range(1).transpose().columns == ["key", "0"]) + #expect(try await spark.range(1).transpose().count() == 0) + + let df = try await spark.sql( + """ + SELECT * FROM + VALUES ('A', 1, 2), + ('B', 3, 4) + T(id, val1, val2) + """) + let expected = [ + Row("val1", 1, 3), + Row("val2", 2, 4), + ] + #expect(try await df.transpose().collect() == expected) + #expect(try await df.transpose("id").collect() == expected) + } + await spark.stop() + } + + @Test + func decimal() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let df = try await spark.sql( + """ + SELECT * FROM VALUES + (1.0, 3.4, CAST(NULL AS DECIMAL), CAST(0 AS DECIMAL)), + (2.0, 34.56, CAST(0 AS DECIMAL), CAST(NULL AS DECIMAL)) + """) + #expect( + try await df.dtypes.map { $0.1 } == [ + "decimal(2,1)", "decimal(4,2)", "decimal(10,0)", "decimal(10,0)", + ]) + let expected = [ + Row(Decimal(1.0), Decimal(3.40), nil, Decimal(0)), + Row(Decimal(2.0), Decimal(34.56), Decimal(0), nil), + ] + #expect(try await df.collect() == expected) + await spark.stop() + } + #endif @Test func storageLevel() async throws { diff --git a/Tests/SparkConnectTests/DataFrameWriterTests.swift b/Tests/SparkConnectTests/DataFrameWriterTests.swift index 5228667..7e91a30 100644 --- a/Tests/SparkConnectTests/DataFrameWriterTests.swift +++ b/Tests/SparkConnectTests/DataFrameWriterTests.swift @@ -18,9 +18,8 @@ // import Foundation -import Testing - import SparkConnect +import Testing /// A test suite for `DataFrameWriter` @Suite(.serialized) diff --git a/Tests/SparkConnectTests/SQLTests.swift b/Tests/SparkConnectTests/SQLTests.swift index 5c5efb2..808c27b 100644 --- a/Tests/SparkConnectTests/SQLTests.swift +++ b/Tests/SparkConnectTests/SQLTests.swift @@ -27,7 +27,8 @@ import Testing struct SQLTests { let fm = FileManager.default let path = Bundle.module.path(forResource: "queries", ofType: "")! - let regenerateGoldenFiles = ProcessInfo.processInfo.environment["SPARK_GENERATE_GOLDEN_FILES"] == "1" + let regenerateGoldenFiles = + ProcessInfo.processInfo.environment["SPARK_GENERATE_GOLDEN_FILES"] == "1" let regexID = /#\d+L?/ let regexPlanId = /plan_id=\d+/ @@ -90,35 +91,39 @@ struct SQLTests { "variant.sql", ] -#if !os(Linux) - @Test - func runAll() async throws { - let spark = try await SparkSession.builder.getOrCreate() - let MAX = Int32.max - for name in try! fm.contentsOfDirectory(atPath: path).sorted() { - guard name.hasSuffix(".sql") else { continue } - print(name) - if await !spark.version.starts(with: "4.") && queriesForSpark4Only.contains(name) { - print("Skip query \(name) due to the difference between Spark 3 and 4.") - continue - } + #if !os(Linux) + @Test + func runAll() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let MAX = Int32.max + for name in try! fm.contentsOfDirectory(atPath: path).sorted() { + guard name.hasSuffix(".sql") else { continue } + print(name) + if await !spark.version.starts(with: "4.") && queriesForSpark4Only.contains(name) { + print("Skip query \(name) due to the difference between Spark 3 and 4.") + continue + } - let sql = try String(contentsOf: URL(fileURLWithPath: "\(path)/\(name)"), encoding: .utf8) - let result = try await spark.sql(sql).showString(MAX, MAX, false).collect()[0].get(0) as! String - let answer = cleanUp(result.trimmingCharacters(in: .whitespacesAndNewlines)) - if (regenerateGoldenFiles) { - let path = "\(FileManager.default.currentDirectoryPath)/Tests/SparkConnectTests/Resources/queries/\(name).answer" - fm.createFile(atPath: path, contents: answer.data(using: .utf8)!, attributes: nil) - } else { - let expected = cleanUp(try String(contentsOf: URL(fileURLWithPath: "\(path)/\(name).answer"), encoding: .utf8)) + let sql = try String(contentsOf: URL(fileURLWithPath: "\(path)/\(name)"), encoding: .utf8) + let result = + try await spark.sql(sql).showString(MAX, MAX, false).collect()[0].get(0) as! String + let answer = cleanUp(result.trimmingCharacters(in: .whitespacesAndNewlines)) + if regenerateGoldenFiles { + let path = + "\(FileManager.default.currentDirectoryPath)/Tests/SparkConnectTests/Resources/queries/\(name).answer" + fm.createFile(atPath: path, contents: answer.data(using: .utf8)!, attributes: nil) + } else { + let expected = cleanUp( + try String(contentsOf: URL(fileURLWithPath: "\(path)/\(name).answer"), encoding: .utf8) + ) .trimmingCharacters(in: .whitespacesAndNewlines) - if (answer != expected) { - print("Try to compare normalized result.") - #expect(normalize(answer) == normalize(expected)) + if answer != expected { + print("Try to compare normalized result.") + #expect(normalize(answer) == normalize(expected)) + } } } + await spark.stop() } - await spark.stop() - } -#endif + #endif } diff --git a/Tests/SparkConnectTests/SparkConnectClientTests.swift b/Tests/SparkConnectTests/SparkConnectClientTests.swift index e47eab6..cd57905 100644 --- a/Tests/SparkConnectTests/SparkConnectClientTests.swift +++ b/Tests/SparkConnectTests/SparkConnectClientTests.swift @@ -35,7 +35,8 @@ struct SparkConnectClientTests { @Test func parameters() async throws { - let client = SparkConnectClient(remote: "sc://host1:123/;tOkeN=abcd;user_ID=test;USER_agent=myagent") + let client = SparkConnectClient( + remote: "sc://host1:123/;tOkeN=abcd;user_ID=test;USER_agent=myagent") #expect(await client.token == "abcd") #expect(await client.userContext.userID == "test") #expect(await client.clientType == "myagent") diff --git a/Tests/SparkConnectTests/SparkSessionTests.swift b/Tests/SparkConnectTests/SparkSessionTests.swift index 1b4a658..326f37d 100644 --- a/Tests/SparkConnectTests/SparkSessionTests.swift +++ b/Tests/SparkConnectTests/SparkSessionTests.swift @@ -58,7 +58,8 @@ struct SparkSessionTests { await SparkSession.builder.clear() let spark1 = try await SparkSession.builder.getOrCreate() let remote = ProcessInfo.processInfo.environment["SPARK_REMOTE"] ?? "sc://localhost" - let spark2 = try await SparkSession.builder.remote("\(remote)/;session_id=\(spark1.sessionID)").getOrCreate() + let spark2 = try await SparkSession.builder.remote("\(remote)/;session_id=\(spark1.sessionID)") + .getOrCreate() await spark2.stop() #expect(spark1.sessionID == spark2.sessionID) #expect(spark1 == spark2) @@ -81,11 +82,11 @@ struct SparkSessionTests { @Test func userContext() async throws { await SparkSession.builder.clear() let spark = try await SparkSession.builder.getOrCreate() -#if os(macOS) || os(Linux) - let defaultUserContext = ProcessInfo.processInfo.userName.toUserContext -#else - let defaultUserContext = "".toUserContext -#endif + #if os(macOS) || os(Linux) + let defaultUserContext = ProcessInfo.processInfo.userName.toUserContext + #else + let defaultUserContext = "".toUserContext + #endif #expect(await spark.client.userContext == defaultUserContext) await spark.stop() } @@ -129,74 +130,76 @@ struct SparkSessionTests { await spark.stop() } -#if !os(Linux) - @Test - func sql() async throws { - await SparkSession.builder.clear() - let spark = try await SparkSession.builder.getOrCreate() - let expected = [Row(true, 1, "a")] - if await spark.version.starts(with: "4.") { - #expect(try await spark.sql("SELECT ?, ?, ?", true, 1, "a").collect() == expected) - #expect(try await spark.sql("SELECT :x, :y, :z", args: ["x": true, "y": 1, "z": "a"]).collect() == expected) + #if !os(Linux) + @Test + func sql() async throws { + await SparkSession.builder.clear() + let spark = try await SparkSession.builder.getOrCreate() + let expected = [Row(true, 1, "a")] + if await spark.version.starts(with: "4.") { + #expect(try await spark.sql("SELECT ?, ?, ?", true, 1, "a").collect() == expected) + #expect( + try await spark.sql("SELECT :x, :y, :z", args: ["x": true, "y": 1, "z": "a"]).collect() + == expected) + } + await spark.stop() } - await spark.stop() - } - @Test - func addInvalidArtifact() async throws { - await SparkSession.builder.clear() - let spark = try await SparkSession.builder.getOrCreate() - await #expect(throws: SparkConnectError.InvalidArgument) { - try await spark.addArtifact("x.txt") + @Test + func addInvalidArtifact() async throws { + await SparkSession.builder.clear() + let spark = try await SparkSession.builder.getOrCreate() + await #expect(throws: SparkConnectError.InvalidArgument) { + try await spark.addArtifact("x.txt") + } + await spark.stop() } - await spark.stop() - } - @Test - func addArtifact() async throws { - let fm = FileManager() - let path = "my.jar" - let url = URL(fileURLWithPath: path) + @Test + func addArtifact() async throws { + let fm = FileManager() + let path = "my.jar" + let url = URL(fileURLWithPath: path) - await SparkSession.builder.clear() - let spark = try await SparkSession.builder.getOrCreate() - #expect(fm.createFile(atPath: path, contents: "abc".data(using: .utf8))) - if await spark.version.starts(with: "4.") { - try await spark.addArtifact(path) - try await spark.addArtifact(url) + await SparkSession.builder.clear() + let spark = try await SparkSession.builder.getOrCreate() + #expect(fm.createFile(atPath: path, contents: "abc".data(using: .utf8))) + if await spark.version.starts(with: "4.") { + try await spark.addArtifact(path) + try await spark.addArtifact(url) + } + try fm.removeItem(atPath: path) + await spark.stop() } - try fm.removeItem(atPath: path) - await spark.stop() - } - @Test - func addArtifacts() async throws { - let fm = FileManager() - let path = "my.jar" - let url = URL(fileURLWithPath: path) + @Test + func addArtifacts() async throws { + let fm = FileManager() + let path = "my.jar" + let url = URL(fileURLWithPath: path) - await SparkSession.builder.clear() - let spark = try await SparkSession.builder.getOrCreate() - #expect(fm.createFile(atPath: path, contents: "abc".data(using: .utf8))) - if await spark.version.starts(with: "4.") { - try await spark.addArtifacts(url, url) + await SparkSession.builder.clear() + let spark = try await SparkSession.builder.getOrCreate() + #expect(fm.createFile(atPath: path, contents: "abc".data(using: .utf8))) + if await spark.version.starts(with: "4.") { + try await spark.addArtifacts(url, url) + } + try fm.removeItem(atPath: path) + await spark.stop() } - try fm.removeItem(atPath: path) - await spark.stop() - } - @Test - func executeCommand() async throws { - await SparkSession.builder.clear() - let spark = try await SparkSession.builder.getOrCreate() - if await spark.version.starts(with: "4.") { - await #expect(throws: SparkConnectError.DataSourceNotFound) { - try await spark.executeCommand("runner", "command", [:]).show() + @Test + func executeCommand() async throws { + await SparkSession.builder.clear() + let spark = try await SparkSession.builder.getOrCreate() + if await spark.version.starts(with: "4.") { + await #expect(throws: SparkConnectError.DataSourceNotFound) { + try await spark.executeCommand("runner", "command", [:]).show() + } } + await spark.stop() } - await spark.stop() - } -#endif + #endif @Test func table() async throws { @@ -215,10 +218,10 @@ struct SparkSessionTests { await SparkSession.builder.clear() let spark = try await SparkSession.builder.getOrCreate() #expect(try await spark.time(spark.range(1000).count) == 1000) -#if !os(Linux) - #expect(try await spark.time(spark.range(1).collect) == [Row(0)]) - try await spark.time(spark.range(10).show) -#endif + #if !os(Linux) + #expect(try await spark.time(spark.range(1).collect) == [Row(0)]) + try await spark.time(spark.range(10).show) + #endif await spark.stop() }