Skip to content

[SPARK-52301] Support Decimal type #174

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions Sources/SparkConnect/ArrowArray.swift
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ public class ArrowArrayHolderImpl: ArrowArrayHolder {
return try ArrowArrayHolderImpl(FixedArray<Double>(with))
case .float:
return try ArrowArrayHolderImpl(FixedArray<Float>(with))
case .decimal128:
return try ArrowArrayHolderImpl(FixedArray<Decimal>(with))
case .date32:
return try ArrowArrayHolderImpl(Date32Array(with))
case .date64:
Expand Down Expand Up @@ -247,6 +249,25 @@ public class Time32Array: FixedArray<Time32> {}
/// @nodoc
public class Time64Array: FixedArray<Time64> {}

/// @nodoc
public class Decimal128Array: FixedArray<Decimal> {
public override subscript(_ index: UInt) -> Decimal? {
if self.arrowData.isNull(index) {
return nil
}
let scale: Int32 = switch self.arrowData.type.id {
case .decimal128(_, let scale):
scale
default:
18
}
let byteOffset = self.arrowData.stride * Int(index)
let value = self.arrowData.buffers[1].rawPointer.advanced(by: byteOffset).load(
as: UInt64.self)
return Decimal(value) / pow(10, Int(scale))
}
}

/// @nodoc
public class BinaryArray: ArrowArray<Data> {
public struct Options {
Expand Down
24 changes: 23 additions & 1 deletion Sources/SparkConnect/ArrowArrayBuilder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,12 @@ public class Time64ArrayBuilder: ArrowArrayBuilder<FixedBufferBuilder<Time64>, T
}
}

public class Decimal128ArrayBuilder: ArrowArrayBuilder<FixedBufferBuilder<Decimal>, Decimal128Array> {
fileprivate convenience init(precision: Int32, scale: Int32) throws {
try self.init(ArrowTypeDecimal128(precision: precision, scale: scale))
}
}

public class StructArrayBuilder: ArrowArrayBuilder<StructBufferBuilder, StructArray> {
let builders: [any ArrowArrayHolderBuilder]
let fields: [ArrowField]
Expand Down Expand Up @@ -202,6 +208,8 @@ public class ArrowArrayBuilders {
return try ArrowArrayBuilders.loadBoolArrayBuilder()
} else if builderType == Date.self || builderType == Date?.self {
return try ArrowArrayBuilders.loadDate64ArrayBuilder()
} else if builderType == Decimal.self || builderType == Decimal?.self {
return try ArrowArrayBuilders.loadDecimal128ArrayBuilder(38, 18)
} else {
throw ArrowError.invalid("Invalid type for builder: \(builderType)")
}
Expand All @@ -214,7 +222,7 @@ public class ArrowArrayBuilders {
|| type == Bool?.self || type == Bool.self || type == Int8.self || type == Int16.self
|| type == Int32.self || type == Int64.self || type == UInt8.self || type == UInt16.self
|| type == UInt32.self || type == UInt64.self || type == String.self || type == Double.self
|| type == Float.self || type == Date.self
|| type == Float.self || type == Date.self || type == Decimal.self || type == Decimal?.self
}

public static func loadStructArrayBuilderForType<T>(_ obj: T) throws -> StructArrayBuilder {
Expand Down Expand Up @@ -279,6 +287,11 @@ public class ArrowArrayBuilders {
throw ArrowError.invalid("Expected arrow type for \(arrowType.id) not found")
}
return try Time64ArrayBuilder(timeType.unit)
case .decimal128:
guard let decimalType = arrowType as? ArrowTypeDecimal128 else {
throw ArrowError.invalid("Expected ArrowTypeDecimal128 for decimal128 type")
}
return try Decimal128ArrayBuilder(precision: decimalType.precision, scale: decimalType.scale)
default:
throw ArrowError.unknownType("Builder not found for arrow type: \(arrowType.id)")
}
Expand Down Expand Up @@ -306,6 +319,8 @@ public class ArrowArrayBuilders {
return try NumberArrayBuilder<T>()
} else if type == Double.self {
return try NumberArrayBuilder<T>()
} else if type == Decimal.self {
return try NumberArrayBuilder<T>()
} else {
throw ArrowError.unknownType("Type is invalid for NumberArrayBuilder")
}
Expand Down Expand Up @@ -338,4 +353,11 @@ public class ArrowArrayBuilders {
public static func loadTime64ArrayBuilder(_ unit: ArrowTime64Unit) throws -> Time64ArrayBuilder {
return try Time64ArrayBuilder(unit)
}

public static func loadDecimal128ArrayBuilder(
_ precision: Int32 = 38,
_ scale: Int32 = 18
) throws -> Decimal128ArrayBuilder {
return try Decimal128ArrayBuilder(precision: precision, scale: scale)
}
}
2 changes: 2 additions & 0 deletions Sources/SparkConnect/ArrowBufferBuilder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ public class FixedBufferBuilder<T>: ValuesBufferBuilder<T>, ArrowBufferBuilder {
return Float(0) as! T // swiftlint:disable:this force_cast
} else if type == Double.self {
return Double(0) as! T // swiftlint:disable:this force_cast
} else if type == Decimal.self {
return Decimal(0) as! T // swiftlint:disable:this force_cast
}

throw ArrowError.unknownType("Unable to determine default value")
Expand Down
14 changes: 11 additions & 3 deletions Sources/SparkConnect/ArrowDecoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ private struct ArrowUnkeyedDecoding: UnkeyedDecodingContainer {
|| type == Bool?.self || type == Bool.self || type == Int8.self || type == Int16.self
|| type == Int32.self || type == Int64.self || type == UInt8.self || type == UInt16.self
|| type == UInt32.self || type == UInt64.self || type == String.self || type == Double.self
|| type == Float.self || type == Date.self
|| type == Float.self || type == Date.self || type == Decimal.self || type == Decimal?.self
{
defer { increment() }
return try self.decoder.doDecode(self.currentIndex)!
Expand Down Expand Up @@ -260,8 +260,12 @@ private struct ArrowKeyedDecoding<Key: CodingKey>: KeyedDecodingContainerProtoco
return try self.decoder.doDecode(key)!
}

func decode(_ type: Decimal.Type, forKey key: Key) throws -> Decimal {
return try self.decoder.doDecode(key)!
}

func decode<T>(_ type: T.Type, forKey key: Key) throws -> T where T: Decodable {
if ArrowArrayBuilders.isValidBuilderType(type) || type == Date.self {
if ArrowArrayBuilders.isValidBuilderType(type) || type == Date.self || type == Decimal.self {
return try self.decoder.doDecode(key)!
} else {
throw ArrowError.invalid("Type \(type) is currently not supported")
Expand Down Expand Up @@ -363,8 +367,12 @@ private struct ArrowSingleValueDecoding: SingleValueDecodingContainer {
return try self.decoder.doDecode(self.decoder.singleRBCol)!
}

func decode(_ type: Decimal.Type) throws -> Decimal {
return try self.decoder.doDecode(self.decoder.singleRBCol)!
}

func decode<T>(_ type: T.Type) throws -> T where T: Decodable {
if ArrowArrayBuilders.isValidBuilderType(type) || type == Date.self {
if ArrowArrayBuilders.isValidBuilderType(type) || type == Date.self || type == Decimal.self {
return try self.decoder.doDecode(self.decoder.singleRBCol)!
} else {
throw ArrowError.invalid("Type \(type) is currently not supported")
Expand Down
25 changes: 24 additions & 1 deletion Sources/SparkConnect/ArrowReaderHelper.swift
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,21 @@ private func makeStringHolder(
}
}

private func makeDecimalHolder(
_ field: ArrowField,
buffers: [ArrowBuffer],
nullCount: UInt
) -> Result<ArrowArrayHolder, ArrowError> {
do {
let arrowData = try ArrowData(field.type, buffers: buffers, nullCount: nullCount)
return .success(ArrowArrayHolderImpl(try Decimal128Array(arrowData)))
} catch let error as ArrowError {
return .failure(error)
} catch {
return .failure(.unknownError("\(error)"))
}
}

private func makeDateHolder(
_ field: ArrowField,
buffers: [ArrowBuffer],
Expand Down Expand Up @@ -183,6 +198,8 @@ func makeArrayHolder( // swiftlint:disable:this cyclomatic_complexity
return makeFixedHolder(Int64.self, field: field, buffers: buffers, nullCount: nullCount)
case .uint64:
return makeFixedHolder(UInt64.self, field: field, buffers: buffers, nullCount: nullCount)
case .decimal128:
return makeDecimalHolder(field, buffers: buffers, nullCount: nullCount)
case .boolean:
return makeBoolHolder(buffers, nullCount: nullCount)
case .float:
Expand Down Expand Up @@ -217,7 +234,7 @@ func makeBuffer(

func isFixedPrimitive(_ type: org_apache_arrow_flatbuf_Type_) -> Bool {
switch type {
case .int, .bool, .floatingpoint, .date, .time:
case .int, .bool, .floatingpoint, .date, .time, .decimal:
return true
default:
return false
Expand Down Expand Up @@ -266,6 +283,12 @@ func findArrowType( // swiftlint:disable:this cyclomatic_complexity function_bo
default:
return ArrowType(ArrowType.ArrowUnknown)
}
case .decimal:
let dataType = field.type(type: org_apache_arrow_flatbuf_Decimal.self)!
if dataType.bitWidth == 128 {
return ArrowType(ArrowType.ArrowDecimal128)
}
return ArrowType(ArrowType.ArrowUnknown)
case .utf8:
return ArrowType(ArrowType.ArrowString)
case .binary:
Expand Down
35 changes: 33 additions & 2 deletions Sources/SparkConnect/ArrowType.swift
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@ public enum ArrowError: Error {
case invalid(String)
}

public enum ArrowTypeId: Sendable {
public enum ArrowTypeId: Sendable, Equatable {
case binary
case boolean
case date32
case date64
case dateType
case decimal128
case decimal128(_ precision: Int32, _ scale: Int32)
case decimal256
case dictionary
case double
Expand Down Expand Up @@ -129,6 +129,23 @@ public class ArrowTypeTime64: ArrowType {
}
}

public class ArrowTypeDecimal128: ArrowType {
let precision: Int32
let scale: Int32

public init(precision: Int32, scale: Int32) {
self.precision = precision
self.scale = scale
super.init(ArrowType.ArrowDecimal128)
}

public override var cDataFormatId: String {
get throws {
return "d:\(precision),\(scale)"
}
}
}

/// @nodoc
public class ArrowNestedType: ArrowType {
let fields: [ArrowField]
Expand Down Expand Up @@ -156,6 +173,7 @@ public class ArrowType {
public static let ArrowBool = Info.primitiveInfo(ArrowTypeId.boolean)
public static let ArrowDate32 = Info.primitiveInfo(ArrowTypeId.date32)
public static let ArrowDate64 = Info.primitiveInfo(ArrowTypeId.date64)
public static let ArrowDecimal128 = Info.primitiveInfo(ArrowTypeId.decimal128(38, 18))
public static let ArrowBinary = Info.variableInfo(ArrowTypeId.binary)
public static let ArrowTime32 = Info.timeInfo(ArrowTypeId.time32)
public static let ArrowTime64 = Info.timeInfo(ArrowTypeId.time64)
Expand Down Expand Up @@ -216,6 +234,8 @@ public class ArrowType {
return ArrowType.ArrowFloat
} else if type == Double.self {
return ArrowType.ArrowDouble
} else if type == Decimal.self {
return ArrowType.ArrowDecimal128
} else {
return ArrowType.ArrowUnknown
}
Expand All @@ -242,6 +262,8 @@ public class ArrowType {
return ArrowType.ArrowFloat
} else if type == Double.self {
return ArrowType.ArrowDouble
} else if type == Decimal.self {
return ArrowType.ArrowDecimal128
} else {
return ArrowType.ArrowUnknown
}
Expand Down Expand Up @@ -271,6 +293,8 @@ public class ArrowType {
return MemoryLayout<Float>.stride
case .double:
return MemoryLayout<Double>.stride
case .decimal128:
return 16 // Decimal 128 (= 16 * 8) bits
case .boolean:
return MemoryLayout<Bool>.stride
case .date32:
Expand Down Expand Up @@ -315,6 +339,8 @@ public class ArrowType {
return "f"
case ArrowTypeId.double:
return "g"
case ArrowTypeId.decimal128(let precision, let scale):
return "d:\(precision),\(scale)"
case ArrowTypeId.boolean:
return "b"
case ArrowTypeId.date32:
Expand Down Expand Up @@ -344,6 +370,7 @@ public class ArrowType {
public static func fromCDataFormatId( // swiftlint:disable:this cyclomatic_complexity
_ from: String
) throws -> ArrowType {
let REGEX_DECIMAL_TYPE = /^d:(\d+),(\d+)$/
if from == "c" {
return ArrowType(ArrowType.ArrowInt8)
} else if from == "s" {
Expand All @@ -364,6 +391,10 @@ public class ArrowType {
return ArrowType(ArrowType.ArrowFloat)
} else if from == "g" {
return ArrowType(ArrowType.ArrowDouble)
} else if from.contains(REGEX_DECIMAL_TYPE) {
let match = from.firstMatch(of: REGEX_DECIMAL_TYPE)!
let decimalType = ArrowTypeId.decimal128(Int32(match.1)!, Int32(match.2)!)
return ArrowType(Info.primitiveInfo(decimalType))
} else if from == "b" {
return ArrowType(ArrowType.ArrowBool)
} else if from == "tdD" {
Expand Down
2 changes: 2 additions & 0 deletions Sources/SparkConnect/DataFrame.swift
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,8 @@ public actor DataFrame: Sendable {
values.append(array.asAny(i) as? Float)
case .primitiveInfo(.double):
values.append(array.asAny(i) as? Double)
case .primitiveInfo(.decimal128):
values.append(array.asAny(i) as? Decimal)
case .primitiveInfo(.date32):
values.append(array.asAny(i) as! Date)
case ArrowType.ArrowBinary:
Expand Down
9 changes: 9 additions & 0 deletions Sources/SparkConnect/ProtoUtil.swift
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,15 @@ func fromProto( // swiftlint:disable:this cyclomatic_complexity
} else if floatType.precision == .double {
arrowType = ArrowType(ArrowType.ArrowDouble)
}
case .decimal:
let decimalType = field.type(type: org_apache_arrow_flatbuf_Decimal.self)!
if decimalType.bitWidth == 128 && decimalType.precision <= 38 {
let arrowDecimal128 = ArrowTypeId.decimal128(decimalType.precision, decimalType.scale)
arrowType = ArrowType(ArrowType.Info.primitiveInfo(arrowDecimal128))
} else {
// Unsupport yet
arrowType = ArrowType(ArrowType.ArrowUnknown)
}
case .utf8:
arrowType = ArrowType(ArrowType.ArrowString)
case .binary:
Expand Down
2 changes: 2 additions & 0 deletions Sources/SparkConnect/Row.swift
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ public struct Row: Sendable, Equatable {
return a == b
} else if let a = x as? Double, let b = y as? Double {
return a == b
} else if let a = x as? Decimal, let b = y as? Decimal {
return a == b
} else if let a = x as? String, let b = y as? String {
return a == b
} else {
Expand Down
19 changes: 19 additions & 0 deletions Tests/SparkConnectTests/DataFrameTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -893,6 +893,25 @@ struct DataFrameTests {
}
await spark.stop()
}

@Test
func decimal() async throws {
let spark = try await SparkSession.builder.getOrCreate()
let df = try await spark.sql(
"""
SELECT * FROM VALUES
(1.0, 3.4, CAST(NULL AS DECIMAL), CAST(0 AS DECIMAL)),
(2.0, 34.56, CAST(0 AS DECIMAL), CAST(NULL AS DECIMAL))
""")
#expect(try await df.dtypes.map { $0.1 } ==
["decimal(2,1)", "decimal(4,2)", "decimal(10,0)", "decimal(10,0)"])
let expected = [
Row(Decimal(1.0), Decimal(3.40), nil, Decimal(0)),
Row(Decimal(2.0), Decimal(34.56), Decimal(0), nil)
]
#expect(try await df.collect() == expected)
await spark.stop()
}
#endif

@Test
Expand Down
3 changes: 3 additions & 0 deletions Tests/SparkConnectTests/Resources/queries/decimal.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
SELECT * FROM VALUES
(1.0, 3.4, NULL::decimal, 0::decimal),
(2.0, 34.56, 0::decimal, NULL::decimal)
6 changes: 6 additions & 0 deletions Tests/SparkConnectTests/Resources/queries/decimal.sql.answer
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
+----+-----+----+----+
|col1| col2|col3|col4|
+----+-----+----+----+
| 1.0| 3.40|NULL| 0|
| 2.0|34.56| 0|NULL|
+----+-----+----+----+
Loading
Loading