diff --git a/Sources/SparkConnect/ArrowArray.swift b/Sources/SparkConnect/ArrowArray.swift index e3d61bb..a767b6e 100644 --- a/Sources/SparkConnect/ArrowArray.swift +++ b/Sources/SparkConnect/ArrowArray.swift @@ -101,6 +101,8 @@ public class ArrowArrayHolderImpl: ArrowArrayHolder { return try ArrowArrayHolderImpl(FixedArray(with)) case .float: return try ArrowArrayHolderImpl(FixedArray(with)) + case .decimal128: + return try ArrowArrayHolderImpl(FixedArray(with)) case .date32: return try ArrowArrayHolderImpl(Date32Array(with)) case .date64: @@ -247,6 +249,25 @@ public class Time32Array: FixedArray {} /// @nodoc public class Time64Array: FixedArray {} +/// @nodoc +public class Decimal128Array: FixedArray { + public override subscript(_ index: UInt) -> Decimal? { + if self.arrowData.isNull(index) { + return nil + } + let scale: Int32 = switch self.arrowData.type.id { + case .decimal128(_, let scale): + scale + default: + 18 + } + let byteOffset = self.arrowData.stride * Int(index) + let value = self.arrowData.buffers[1].rawPointer.advanced(by: byteOffset).load( + as: UInt64.self) + return Decimal(value) / pow(10, Int(scale)) + } +} + /// @nodoc public class BinaryArray: ArrowArray { public struct Options { diff --git a/Sources/SparkConnect/ArrowArrayBuilder.swift b/Sources/SparkConnect/ArrowArrayBuilder.swift index 5062127..20b3f27 100644 --- a/Sources/SparkConnect/ArrowArrayBuilder.swift +++ b/Sources/SparkConnect/ArrowArrayBuilder.swift @@ -122,6 +122,12 @@ public class Time64ArrayBuilder: ArrowArrayBuilder, T } } +public class Decimal128ArrayBuilder: ArrowArrayBuilder, Decimal128Array> { + fileprivate convenience init(precision: Int32, scale: Int32) throws { + try self.init(ArrowTypeDecimal128(precision: precision, scale: scale)) + } +} + public class StructArrayBuilder: ArrowArrayBuilder { let builders: [any ArrowArrayHolderBuilder] let fields: [ArrowField] @@ -202,6 +208,8 @@ public class ArrowArrayBuilders { return try ArrowArrayBuilders.loadBoolArrayBuilder() } else if builderType == Date.self || builderType == Date?.self { return try ArrowArrayBuilders.loadDate64ArrayBuilder() + } else if builderType == Decimal.self || builderType == Decimal?.self { + return try ArrowArrayBuilders.loadDecimal128ArrayBuilder(38, 18) } else { throw ArrowError.invalid("Invalid type for builder: \(builderType)") } @@ -214,7 +222,7 @@ public class ArrowArrayBuilders { || type == Bool?.self || type == Bool.self || type == Int8.self || type == Int16.self || type == Int32.self || type == Int64.self || type == UInt8.self || type == UInt16.self || type == UInt32.self || type == UInt64.self || type == String.self || type == Double.self - || type == Float.self || type == Date.self + || type == Float.self || type == Date.self || type == Decimal.self || type == Decimal?.self } public static func loadStructArrayBuilderForType(_ obj: T) throws -> StructArrayBuilder { @@ -279,6 +287,11 @@ public class ArrowArrayBuilders { throw ArrowError.invalid("Expected arrow type for \(arrowType.id) not found") } return try Time64ArrayBuilder(timeType.unit) + case .decimal128: + guard let decimalType = arrowType as? ArrowTypeDecimal128 else { + throw ArrowError.invalid("Expected ArrowTypeDecimal128 for decimal128 type") + } + return try Decimal128ArrayBuilder(precision: decimalType.precision, scale: decimalType.scale) default: throw ArrowError.unknownType("Builder not found for arrow type: \(arrowType.id)") } @@ -306,6 +319,8 @@ public class ArrowArrayBuilders { return try NumberArrayBuilder() } else if type == Double.self { return try NumberArrayBuilder() + } else if type == Decimal.self { + return try NumberArrayBuilder() } else { throw ArrowError.unknownType("Type is invalid for NumberArrayBuilder") } @@ -338,4 +353,11 @@ public class ArrowArrayBuilders { public static func loadTime64ArrayBuilder(_ unit: ArrowTime64Unit) throws -> Time64ArrayBuilder { return try Time64ArrayBuilder(unit) } + + public static func loadDecimal128ArrayBuilder( + _ precision: Int32 = 38, + _ scale: Int32 = 18 + ) throws -> Decimal128ArrayBuilder { + return try Decimal128ArrayBuilder(precision: precision, scale: scale) + } } diff --git a/Sources/SparkConnect/ArrowBufferBuilder.swift b/Sources/SparkConnect/ArrowBufferBuilder.swift index 3c38c1d..b20e964 100644 --- a/Sources/SparkConnect/ArrowBufferBuilder.swift +++ b/Sources/SparkConnect/ArrowBufferBuilder.swift @@ -142,6 +142,8 @@ public class FixedBufferBuilder: ValuesBufferBuilder, ArrowBufferBuilder { return Float(0) as! T // swiftlint:disable:this force_cast } else if type == Double.self { return Double(0) as! T // swiftlint:disable:this force_cast + } else if type == Decimal.self { + return Decimal(0) as! T // swiftlint:disable:this force_cast } throw ArrowError.unknownType("Unable to determine default value") diff --git a/Sources/SparkConnect/ArrowDecoder.swift b/Sources/SparkConnect/ArrowDecoder.swift index 1f12b8b..e4875f6 100644 --- a/Sources/SparkConnect/ArrowDecoder.swift +++ b/Sources/SparkConnect/ArrowDecoder.swift @@ -160,7 +160,7 @@ private struct ArrowUnkeyedDecoding: UnkeyedDecodingContainer { || type == Bool?.self || type == Bool.self || type == Int8.self || type == Int16.self || type == Int32.self || type == Int64.self || type == UInt8.self || type == UInt16.self || type == UInt32.self || type == UInt64.self || type == String.self || type == Double.self - || type == Float.self || type == Date.self + || type == Float.self || type == Date.self || type == Decimal.self || type == Decimal?.self { defer { increment() } return try self.decoder.doDecode(self.currentIndex)! @@ -260,8 +260,12 @@ private struct ArrowKeyedDecoding: KeyedDecodingContainerProtoco return try self.decoder.doDecode(key)! } + func decode(_ type: Decimal.Type, forKey key: Key) throws -> Decimal { + return try self.decoder.doDecode(key)! + } + func decode(_ type: T.Type, forKey key: Key) throws -> T where T: Decodable { - if ArrowArrayBuilders.isValidBuilderType(type) || type == Date.self { + if ArrowArrayBuilders.isValidBuilderType(type) || type == Date.self || type == Decimal.self { return try self.decoder.doDecode(key)! } else { throw ArrowError.invalid("Type \(type) is currently not supported") @@ -363,8 +367,12 @@ private struct ArrowSingleValueDecoding: SingleValueDecodingContainer { return try self.decoder.doDecode(self.decoder.singleRBCol)! } + func decode(_ type: Decimal.Type) throws -> Decimal { + return try self.decoder.doDecode(self.decoder.singleRBCol)! + } + func decode(_ type: T.Type) throws -> T where T: Decodable { - if ArrowArrayBuilders.isValidBuilderType(type) || type == Date.self { + if ArrowArrayBuilders.isValidBuilderType(type) || type == Date.self || type == Decimal.self { return try self.decoder.doDecode(self.decoder.singleRBCol)! } else { throw ArrowError.invalid("Type \(type) is currently not supported") diff --git a/Sources/SparkConnect/ArrowReaderHelper.swift b/Sources/SparkConnect/ArrowReaderHelper.swift index baa4e93..c955d16 100644 --- a/Sources/SparkConnect/ArrowReaderHelper.swift +++ b/Sources/SparkConnect/ArrowReaderHelper.swift @@ -49,6 +49,21 @@ private func makeStringHolder( } } +private func makeDecimalHolder( + _ field: ArrowField, + buffers: [ArrowBuffer], + nullCount: UInt +) -> Result { + do { + let arrowData = try ArrowData(field.type, buffers: buffers, nullCount: nullCount) + return .success(ArrowArrayHolderImpl(try Decimal128Array(arrowData))) + } catch let error as ArrowError { + return .failure(error) + } catch { + return .failure(.unknownError("\(error)")) + } +} + private func makeDateHolder( _ field: ArrowField, buffers: [ArrowBuffer], @@ -183,6 +198,8 @@ func makeArrayHolder( // swiftlint:disable:this cyclomatic_complexity return makeFixedHolder(Int64.self, field: field, buffers: buffers, nullCount: nullCount) case .uint64: return makeFixedHolder(UInt64.self, field: field, buffers: buffers, nullCount: nullCount) + case .decimal128: + return makeDecimalHolder(field, buffers: buffers, nullCount: nullCount) case .boolean: return makeBoolHolder(buffers, nullCount: nullCount) case .float: @@ -217,7 +234,7 @@ func makeBuffer( func isFixedPrimitive(_ type: org_apache_arrow_flatbuf_Type_) -> Bool { switch type { - case .int, .bool, .floatingpoint, .date, .time: + case .int, .bool, .floatingpoint, .date, .time, .decimal: return true default: return false @@ -266,6 +283,12 @@ func findArrowType( // swiftlint:disable:this cyclomatic_complexity function_bo default: return ArrowType(ArrowType.ArrowUnknown) } + case .decimal: + let dataType = field.type(type: org_apache_arrow_flatbuf_Decimal.self)! + if dataType.bitWidth == 128 { + return ArrowType(ArrowType.ArrowDecimal128) + } + return ArrowType(ArrowType.ArrowUnknown) case .utf8: return ArrowType(ArrowType.ArrowString) case .binary: diff --git a/Sources/SparkConnect/ArrowType.swift b/Sources/SparkConnect/ArrowType.swift index cdf9772..39555f3 100644 --- a/Sources/SparkConnect/ArrowType.swift +++ b/Sources/SparkConnect/ArrowType.swift @@ -42,13 +42,13 @@ public enum ArrowError: Error { case invalid(String) } -public enum ArrowTypeId: Sendable { +public enum ArrowTypeId: Sendable, Equatable { case binary case boolean case date32 case date64 case dateType - case decimal128 + case decimal128(_ precision: Int32, _ scale: Int32) case decimal256 case dictionary case double @@ -129,6 +129,23 @@ public class ArrowTypeTime64: ArrowType { } } +public class ArrowTypeDecimal128: ArrowType { + let precision: Int32 + let scale: Int32 + + public init(precision: Int32, scale: Int32) { + self.precision = precision + self.scale = scale + super.init(ArrowType.ArrowDecimal128) + } + + public override var cDataFormatId: String { + get throws { + return "d:\(precision),\(scale)" + } + } +} + /// @nodoc public class ArrowNestedType: ArrowType { let fields: [ArrowField] @@ -156,6 +173,7 @@ public class ArrowType { public static let ArrowBool = Info.primitiveInfo(ArrowTypeId.boolean) public static let ArrowDate32 = Info.primitiveInfo(ArrowTypeId.date32) public static let ArrowDate64 = Info.primitiveInfo(ArrowTypeId.date64) + public static let ArrowDecimal128 = Info.primitiveInfo(ArrowTypeId.decimal128(38, 18)) public static let ArrowBinary = Info.variableInfo(ArrowTypeId.binary) public static let ArrowTime32 = Info.timeInfo(ArrowTypeId.time32) public static let ArrowTime64 = Info.timeInfo(ArrowTypeId.time64) @@ -216,6 +234,8 @@ public class ArrowType { return ArrowType.ArrowFloat } else if type == Double.self { return ArrowType.ArrowDouble + } else if type == Decimal.self { + return ArrowType.ArrowDecimal128 } else { return ArrowType.ArrowUnknown } @@ -242,6 +262,8 @@ public class ArrowType { return ArrowType.ArrowFloat } else if type == Double.self { return ArrowType.ArrowDouble + } else if type == Decimal.self { + return ArrowType.ArrowDecimal128 } else { return ArrowType.ArrowUnknown } @@ -271,6 +293,8 @@ public class ArrowType { return MemoryLayout.stride case .double: return MemoryLayout.stride + case .decimal128: + return 16 // Decimal 128 (= 16 * 8) bits case .boolean: return MemoryLayout.stride case .date32: @@ -315,6 +339,8 @@ public class ArrowType { return "f" case ArrowTypeId.double: return "g" + case ArrowTypeId.decimal128(let precision, let scale): + return "d:\(precision),\(scale)" case ArrowTypeId.boolean: return "b" case ArrowTypeId.date32: @@ -344,6 +370,7 @@ public class ArrowType { public static func fromCDataFormatId( // swiftlint:disable:this cyclomatic_complexity _ from: String ) throws -> ArrowType { + let REGEX_DECIMAL_TYPE = /^d:(\d+),(\d+)$/ if from == "c" { return ArrowType(ArrowType.ArrowInt8) } else if from == "s" { @@ -364,6 +391,10 @@ public class ArrowType { return ArrowType(ArrowType.ArrowFloat) } else if from == "g" { return ArrowType(ArrowType.ArrowDouble) + } else if from.contains(REGEX_DECIMAL_TYPE) { + let match = from.firstMatch(of: REGEX_DECIMAL_TYPE)! + let decimalType = ArrowTypeId.decimal128(Int32(match.1)!, Int32(match.2)!) + return ArrowType(Info.primitiveInfo(decimalType)) } else if from == "b" { return ArrowType(ArrowType.ArrowBool) } else if from == "tdD" { diff --git a/Sources/SparkConnect/DataFrame.swift b/Sources/SparkConnect/DataFrame.swift index 945651e..760ece3 100644 --- a/Sources/SparkConnect/DataFrame.swift +++ b/Sources/SparkConnect/DataFrame.swift @@ -408,6 +408,8 @@ public actor DataFrame: Sendable { values.append(array.asAny(i) as? Float) case .primitiveInfo(.double): values.append(array.asAny(i) as? Double) + case .primitiveInfo(.decimal128): + values.append(array.asAny(i) as? Decimal) case .primitiveInfo(.date32): values.append(array.asAny(i) as! Date) case ArrowType.ArrowBinary: diff --git a/Sources/SparkConnect/ProtoUtil.swift b/Sources/SparkConnect/ProtoUtil.swift index 5a16269..4f27b20 100644 --- a/Sources/SparkConnect/ProtoUtil.swift +++ b/Sources/SparkConnect/ProtoUtil.swift @@ -44,6 +44,15 @@ func fromProto( // swiftlint:disable:this cyclomatic_complexity } else if floatType.precision == .double { arrowType = ArrowType(ArrowType.ArrowDouble) } + case .decimal: + let decimalType = field.type(type: org_apache_arrow_flatbuf_Decimal.self)! + if decimalType.bitWidth == 128 && decimalType.precision <= 38 { + let arrowDecimal128 = ArrowTypeId.decimal128(decimalType.precision, decimalType.scale) + arrowType = ArrowType(ArrowType.Info.primitiveInfo(arrowDecimal128)) + } else { + // Unsupport yet + arrowType = ArrowType(ArrowType.ArrowUnknown) + } case .utf8: arrowType = ArrowType(ArrowType.ArrowString) case .binary: diff --git a/Sources/SparkConnect/Row.swift b/Sources/SparkConnect/Row.swift index 67cfcfd..79ade14 100644 --- a/Sources/SparkConnect/Row.swift +++ b/Sources/SparkConnect/Row.swift @@ -69,6 +69,8 @@ public struct Row: Sendable, Equatable { return a == b } else if let a = x as? Double, let b = y as? Double { return a == b + } else if let a = x as? Decimal, let b = y as? Decimal { + return a == b } else if let a = x as? String, let b = y as? String { return a == b } else { diff --git a/Tests/SparkConnectTests/DataFrameTests.swift b/Tests/SparkConnectTests/DataFrameTests.swift index 09aab96..d85c607 100644 --- a/Tests/SparkConnectTests/DataFrameTests.swift +++ b/Tests/SparkConnectTests/DataFrameTests.swift @@ -893,6 +893,25 @@ struct DataFrameTests { } await spark.stop() } + + @Test + func decimal() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let df = try await spark.sql( + """ + SELECT * FROM VALUES + (1.0, 3.4, CAST(NULL AS DECIMAL), CAST(0 AS DECIMAL)), + (2.0, 34.56, CAST(0 AS DECIMAL), CAST(NULL AS DECIMAL)) + """) + #expect(try await df.dtypes.map { $0.1 } == + ["decimal(2,1)", "decimal(4,2)", "decimal(10,0)", "decimal(10,0)"]) + let expected = [ + Row(Decimal(1.0), Decimal(3.40), nil, Decimal(0)), + Row(Decimal(2.0), Decimal(34.56), Decimal(0), nil) + ] + #expect(try await df.collect() == expected) + await spark.stop() + } #endif @Test diff --git a/Tests/SparkConnectTests/Resources/queries/decimal.sql b/Tests/SparkConnectTests/Resources/queries/decimal.sql new file mode 100644 index 0000000..19cf2a5 --- /dev/null +++ b/Tests/SparkConnectTests/Resources/queries/decimal.sql @@ -0,0 +1,3 @@ +SELECT * FROM VALUES +(1.0, 3.4, NULL::decimal, 0::decimal), +(2.0, 34.56, 0::decimal, NULL::decimal) diff --git a/Tests/SparkConnectTests/Resources/queries/decimal.sql.answer b/Tests/SparkConnectTests/Resources/queries/decimal.sql.answer new file mode 100644 index 0000000..0c22049 --- /dev/null +++ b/Tests/SparkConnectTests/Resources/queries/decimal.sql.answer @@ -0,0 +1,6 @@ ++----+-----+----+----+ +|col1| col2|col3|col4| ++----+-----+----+----+ +| 1.0| 3.40|NULL| 0| +| 2.0|34.56| 0|NULL| ++----+-----+----+----+ \ No newline at end of file diff --git a/Tests/SparkConnectTests/RowTests.swift b/Tests/SparkConnectTests/RowTests.swift index 8472b94..02da050 100644 --- a/Tests/SparkConnectTests/RowTests.swift +++ b/Tests/SparkConnectTests/RowTests.swift @@ -38,6 +38,7 @@ struct RowTests { #expect(Row(nil).size == 1) #expect(Row(1).size == 1) #expect(Row(1.1).size == 1) + #expect(Row(Decimal(1.1)).size == 1) #expect(Row("a").size == 1) #expect(Row(nil, 1, 1.1, "a", true).size == 5) #expect(Row(valueArray: [nil, 1, 1.1, "a", true]).size == 5) @@ -50,11 +51,12 @@ struct RowTests { @Test func get() throws { - let row = Row(1, 1.1, "a", true) + let row = Row(1, 1.1, "a", true, Decimal(1.2)) #expect(try row.get(0) as! Int == 1) #expect(try row.get(1) as! Double == 1.1) #expect(try row.get(2) as! String == "a") #expect(try row.get(3) as! Bool == true) + #expect(try row.get(4) as! Decimal == Decimal(1.2)) #expect(throws: SparkConnectError.InvalidArgumentException) { try Row.empty.get(-1) } @@ -73,6 +75,9 @@ struct RowTests { #expect(Row(1.0) == Row(1.0)) #expect(Row(1.0) != Row(2.0)) + #expect(Row(Decimal(1.0)) == Row(Decimal(1.0))) + #expect(Row(Decimal(1.0)) != Row(Decimal(2.0))) + #expect(Row("a") == Row("a")) #expect(Row("a") != Row("b")) diff --git a/Tests/SparkConnectTests/SQLTests.swift b/Tests/SparkConnectTests/SQLTests.swift index 7ac589c..5c5efb2 100644 --- a/Tests/SparkConnectTests/SQLTests.swift +++ b/Tests/SparkConnectTests/SQLTests.swift @@ -84,6 +84,7 @@ struct SQLTests { "create_scala_function.sql", "create_table_function.sql", "cast.sql", + "decimal.sql", "pipesyntax.sql", "explain.sql", "variant.sql",