Skip to content

[SPARK-52523] Update arrow-swift code for Timestamp #201

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions Sources/SparkConnect/ArrowArray.swift
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ public class ArrowArrayHolderImpl: ArrowArrayHolder {
return try ArrowArrayHolderImpl(Time32Array(with))
case .time64:
return try ArrowArrayHolderImpl(Time64Array(with))
case .timestamp:
return try ArrowArrayHolderImpl(TimestampArray(with))
case .string:
return try ArrowArrayHolderImpl(StringArray(with))
case .boolean:
Expand Down Expand Up @@ -269,6 +271,86 @@ public class Decimal128Array: FixedArray<Decimal> {
}
}

public class TimestampArray: FixedArray<Timestamp> {

public struct FormattingOptions: Equatable {
public var dateFormat: String = "yyyy-MM-dd HH:mm:ss.SSS"
public var locale: Locale = .current
public var includeTimezone: Bool = true
public var fallbackToRaw: Bool = true

public init(
dateFormat: String = "yyyy-MM-dd HH:mm:ss.SSS",
locale: Locale = .current,
includeTimezone: Bool = true,
fallbackToRaw: Bool = true
) {
self.dateFormat = dateFormat
self.locale = locale
self.includeTimezone = includeTimezone
self.fallbackToRaw = fallbackToRaw
}

public static func == (lhs: FormattingOptions, rhs: FormattingOptions) -> Bool {
return lhs.dateFormat == rhs.dateFormat && lhs.locale.identifier == rhs.locale.identifier
&& lhs.includeTimezone == rhs.includeTimezone && lhs.fallbackToRaw == rhs.fallbackToRaw
}
}

private var cachedFormatter: DateFormatter?
private var cachedOptions: FormattingOptions?

public func formattedDate(at index: UInt, options: FormattingOptions = FormattingOptions())
-> String?
{
guard let timestamp = self[index] else { return nil }

guard let timestampType = self.arrowData.type as? ArrowTypeTimestamp else {
return options.fallbackToRaw ? "\(timestamp)" : nil
}

let date = dateFromTimestamp(timestamp, unit: timestampType.unit)

if cachedFormatter == nil || cachedOptions != options {
let formatter = DateFormatter()
formatter.dateFormat = options.dateFormat
formatter.locale = options.locale
if options.includeTimezone, let timezone = timestampType.timezone {
formatter.timeZone = TimeZone(identifier: timezone)
}
cachedFormatter = formatter
cachedOptions = options
}

return cachedFormatter?.string(from: date)
}

private func dateFromTimestamp(_ timestamp: Int64, unit: ArrowTimestampUnit) -> Date {
let timeInterval: TimeInterval

switch unit {
case .seconds:
timeInterval = TimeInterval(timestamp)
case .milliseconds:
timeInterval = TimeInterval(timestamp) / 1_000
case .microseconds:
timeInterval = TimeInterval(timestamp) / 1_000_000
case .nanoseconds:
timeInterval = TimeInterval(timestamp) / 1_000_000_000
}

return Date(timeIntervalSince1970: timeInterval)
}

public override func asString(_ index: UInt) -> String {
if let formatted = formattedDate(at: index) {
return formatted
}

return super.asString(index)
}
}

/// @nodoc
public class BinaryArray: ArrowArray<Data> {
public struct Options {
Expand Down
17 changes: 17 additions & 0 deletions Sources/SparkConnect/ArrowArrayBuilder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,12 @@ public class Decimal128ArrayBuilder: ArrowArrayBuilder<FixedBufferBuilder<Decima
}
}

public class TimestampArrayBuilder: ArrowArrayBuilder<FixedBufferBuilder<Int64>, TimestampArray> {
fileprivate convenience init(_ unit: ArrowTimestampUnit, timezone: String? = nil) throws {
try self.init(ArrowTypeTimestamp(unit, timezone: timezone))
}
}

public class StructArrayBuilder: ArrowArrayBuilder<StructBufferBuilder, StructArray> {
let builders: [any ArrowArrayHolderBuilder]
let fields: [ArrowField]
Expand Down Expand Up @@ -293,6 +299,11 @@ public class ArrowArrayBuilders {
throw ArrowError.invalid("Expected ArrowTypeDecimal128 for decimal128 type")
}
return try Decimal128ArrayBuilder(precision: decimalType.precision, scale: decimalType.scale)
case .timestamp:
guard let timestampType = arrowType as? ArrowTypeTimestamp else {
throw ArrowError.invalid("Expected arrow type for \(arrowType.id) not found")
}
return try TimestampArrayBuilder(timestampType.unit)
default:
throw ArrowError.unknownType("Builder not found for arrow type: \(arrowType.id)")
}
Expand Down Expand Up @@ -355,6 +366,12 @@ public class ArrowArrayBuilders {
return try Time64ArrayBuilder(unit)
}

public static func loadTimestampArrayBuilder(_ unit: ArrowTimestampUnit, timezone: String? = nil)
throws -> TimestampArrayBuilder
{
return try TimestampArrayBuilder(unit, timezone: timezone)
}

public static func loadDecimal128ArrayBuilder(
_ precision: Int32 = 38,
_ scale: Int32 = 18
Expand Down
39 changes: 38 additions & 1 deletion Sources/SparkConnect/ArrowReaderHelper.swift
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,25 @@ private func makeTimeHolder(
}
}

private func makeTimestampHolder(
_ field: ArrowField,
buffers: [ArrowBuffer],
nullCount: UInt
) -> Result<ArrowArrayHolder, ArrowError> {
do {
if let arrowType = field.type as? ArrowTypeTimestamp {
let arrowData = try ArrowData(arrowType, buffers: buffers, nullCount: nullCount)
return .success(ArrowArrayHolderImpl(try TimestampArray(arrowData)))
} else {
return .failure(.invalid("Incorrect field type for timestamp: \(field.type)"))
}
} catch let error as ArrowError {
return .failure(error)
} catch {
return .failure(.unknownError("\(error)"))
}
}

private func makeBoolHolder(
_ buffers: [ArrowBuffer],
nullCount: UInt
Expand Down Expand Up @@ -214,6 +233,8 @@ func makeArrayHolder( // swiftlint:disable:this cyclomatic_complexity
return makeDateHolder(field, buffers: buffers, nullCount: nullCount)
case .time32, .time64:
return makeTimeHolder(field, buffers: buffers, nullCount: nullCount)
case .timestamp:
return makeTimestampHolder(field, buffers: buffers, nullCount: nullCount)
case .strct:
return makeStructHolder(
field, buffers: buffers, nullCount: nullCount, children: children!, rbLength: rbLength)
Expand All @@ -234,7 +255,7 @@ func makeBuffer(

func isFixedPrimitive(_ type: org_apache_arrow_flatbuf_Type_) -> Bool {
switch type {
case .int, .bool, .floatingpoint, .date, .time, .decimal:
case .int, .bool, .floatingpoint, .date, .time, .timestamp, .decimal:
return true
default:
return false
Expand Down Expand Up @@ -307,6 +328,22 @@ func findArrowType( // swiftlint:disable:this cyclomatic_complexity function_bo
}

return ArrowTypeTime64(timeType.unit == .microsecond ? .microseconds : .nanoseconds)
case .timestamp:
let timestampType = field.type(type: org_apache_arrow_flatbuf_Timestamp.self)!
let arrowUnit: ArrowTimestampUnit
switch timestampType.unit {
case .second:
arrowUnit = .seconds
case .millisecond:
arrowUnit = .milliseconds
case .microsecond:
arrowUnit = .microseconds
case .nanosecond:
arrowUnit = .nanoseconds
}

let timezone = timestampType.timezone
return ArrowTypeTimestamp(arrowUnit, timezone: timezone)
case .struct_:
_ = field.type(type: org_apache_arrow_flatbuf_Struct_.self)!
var fields = [ArrowField]()
Expand Down
72 changes: 72 additions & 0 deletions Sources/SparkConnect/ArrowType.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ public typealias Time64 = Int64
public typealias Date32 = Int32
/// @nodoc
public typealias Date64 = Int64
public typealias Timestamp = Int64

func FlatBuffersVersion_23_1_4() { // swiftlint:disable:this identifier_name
}
Expand Down Expand Up @@ -70,6 +71,7 @@ public enum ArrowTypeId: Sendable, Equatable {
case strct
case time32
case time64
case timestamp
case time
case uint16
case uint32
Expand Down Expand Up @@ -146,6 +148,47 @@ public class ArrowTypeDecimal128: ArrowType {
}
}

public enum ArrowTimestampUnit {
case seconds
case milliseconds
case microseconds
case nanoseconds
}

public class ArrowTypeTimestamp: ArrowType {
let unit: ArrowTimestampUnit
let timezone: String?

public init(_ unit: ArrowTimestampUnit, timezone: String? = nil) {
self.unit = unit
self.timezone = timezone

super.init(ArrowType.ArrowTimestamp)
}

public convenience init(type: ArrowTypeId) {
self.init(.milliseconds, timezone: nil)
}

public override var cDataFormatId: String {
get throws {
let unitChar: String
switch self.unit {
case .seconds: unitChar = "s"
case .milliseconds: unitChar = "m"
case .microseconds: unitChar = "u"
case .nanoseconds: unitChar = "n"
}

if let timezone = self.timezone {
return "ts\(unitChar):\(timezone)"
} else {
return "ts\(unitChar)"
}
}
}
}

/// @nodoc
public class ArrowNestedType: ArrowType {
let fields: [ArrowField]
Expand Down Expand Up @@ -177,6 +220,7 @@ public class ArrowType {
public static let ArrowBinary = Info.variableInfo(ArrowTypeId.binary)
public static let ArrowTime32 = Info.timeInfo(ArrowTypeId.time32)
public static let ArrowTime64 = Info.timeInfo(ArrowTypeId.time64)
public static let ArrowTimestamp = Info.timeInfo(ArrowTypeId.timestamp)
public static let ArrowStruct = Info.complexInfo(ArrowTypeId.strct)

public init(_ info: ArrowType.Info) {
Expand Down Expand Up @@ -305,6 +349,8 @@ public class ArrowType {
return MemoryLayout<Time32>.stride
case .time64:
return MemoryLayout<Time64>.stride
case .timestamp:
return MemoryLayout<Timestamp>.stride
case .binary:
return MemoryLayout<Int8>.stride
case .string:
Expand Down Expand Up @@ -357,6 +403,11 @@ public class ArrowType {
return try time64.cDataFormatId
}
return "ttu"
case ArrowTypeId.timestamp:
if let timestamp = self as? ArrowTypeTimestamp {
return try timestamp.cDataFormatId
}
return "tsu"
case ArrowTypeId.binary:
return "z"
case ArrowTypeId.string:
Expand Down Expand Up @@ -409,6 +460,27 @@ public class ArrowType {
return ArrowTypeTime64(.microseconds)
} else if from == "ttn" {
return ArrowTypeTime64(.nanoseconds)
} else if from.starts(with: "ts") {
let components = from.split(separator: ":", maxSplits: 1)
guard let unitPart = components.first, unitPart.count == 3 else {
throw ArrowError.invalid(
"Invalid timestamp format '\(from)'. Expected format 'ts[s|m|u|n][:timezone]'")
}

let unitChar = unitPart.suffix(1)
let unit: ArrowTimestampUnit
switch unitChar {
case "s": unit = .seconds
case "m": unit = .milliseconds
case "u": unit = .microseconds
case "n": unit = .nanoseconds
default:
throw ArrowError.invalid(
"Unrecognized timestamp unit '\(unitChar)'. Expected 's', 'm', 'u', or 'n'.")
}

let timezone = components.count > 1 ? String(components[1]) : nil
return ArrowTypeTimestamp(unit, timezone: timezone)
} else if from == "z" {
return ArrowType(ArrowType.ArrowBinary)
} else if from == "u" {
Expand Down
28 changes: 28 additions & 0 deletions Sources/SparkConnect/ArrowWriterHelper.swift
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ func toFBTypeEnum(_ arrowType: ArrowType) -> Result<org_apache_arrow_flatbuf_Typ
return .success(org_apache_arrow_flatbuf_Type_.date)
case .time32, .time64:
return .success(org_apache_arrow_flatbuf_Type_.time)
case .timestamp:
return .success(org_apache_arrow_flatbuf_Type_.timestamp)
case .strct:
return .success(org_apache_arrow_flatbuf_Type_.struct_)
default:
Expand Down Expand Up @@ -114,6 +116,32 @@ func toFBType( // swiftlint:disable:this cyclomatic_complexity function_body_le
}

return .failure(.invalid("Unable to case to Time64"))
case .timestamp:
if let timestampType = arrowType as? ArrowTypeTimestamp {
let startOffset = org_apache_arrow_flatbuf_Timestamp.startTimestamp(&fbb)

let fbUnit: org_apache_arrow_flatbuf_TimeUnit
switch timestampType.unit {
case .seconds:
fbUnit = .second
case .milliseconds:
fbUnit = .millisecond
case .microseconds:
fbUnit = .microsecond
case .nanoseconds:
fbUnit = .nanosecond
}
org_apache_arrow_flatbuf_Timestamp.add(unit: fbUnit, &fbb)

if let timezone = timestampType.timezone {
let timezoneOffset = fbb.create(string: timezone)
org_apache_arrow_flatbuf_Timestamp.add(timezone: timezoneOffset, &fbb)
}

return .success(org_apache_arrow_flatbuf_Timestamp.endTimestamp(&fbb, start: startOffset))
}

return .failure(.invalid("Unable to cast to Timestamp"))
case .strct:
let startOffset = org_apache_arrow_flatbuf_Struct_.startStruct_(&fbb)
return .success(org_apache_arrow_flatbuf_Struct_.endStruct_(&fbb, start: startOffset))
Expand Down
16 changes: 16 additions & 0 deletions Sources/SparkConnect/ProtoUtil.swift
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,22 @@ func fromProto( // swiftlint:disable:this cyclomatic_complexity function_body_l
let arrowUnit: ArrowTime64Unit = timeType.unit == .microsecond ? .microseconds : .nanoseconds
arrowType = ArrowTypeTime64(arrowUnit)
}
case .timestamp:
let timestampType = field.type(type: org_apache_arrow_flatbuf_Timestamp.self)!
let arrowUnit: ArrowTimestampUnit
switch timestampType.unit {
case .second:
arrowUnit = .seconds
case .millisecond:
arrowUnit = .milliseconds
case .microsecond:
arrowUnit = .microseconds
case .nanosecond:
arrowUnit = .nanoseconds
}

let timezone = timestampType.timezone
arrowType = ArrowTypeTimestamp(arrowUnit, timezone: timezone?.isEmpty == true ? nil : timezone)
case .struct_:
var children = [ArrowField]()
for index in 0..<field.childrenCount {
Expand Down
Loading