diff --git a/Sources/SparkConnect/ArrowWriter.swift b/Sources/SparkConnect/ArrowWriter.swift index ca3601d..9d44f9e 100644 --- a/Sources/SparkConnect/ArrowWriter.swift +++ b/Sources/SparkConnect/ArrowWriter.swift @@ -18,13 +18,11 @@ import FlatBuffers import Foundation -/// @nodoc public protocol DataWriter { var count: Int { get } func append(_ data: Data) } -/// @nodoc public class ArrowWriter { // swiftlint:disable:this type_body_length public class InMemDataWriter: DataWriter { public private(set) var data: Data @@ -77,11 +75,30 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length private func writeField(_ fbb: inout FlatBufferBuilder, field: ArrowField) -> Result< Offset, ArrowError > { + var fieldsOffset: Offset? + if let nestedField = field.type as? ArrowNestedType { + var offsets = [Offset]() + for field in nestedField.fields { + switch writeField(&fbb, field: field) { + case .success(let offset): + offsets.append(offset) + case .failure(let error): + return .failure(error) + } + } + + fieldsOffset = fbb.createVector(ofOffsets: offsets) + } + let nameOffset = fbb.create(string: field.name) let fieldTypeOffsetResult = toFBType(&fbb, arrowType: field.type) let startOffset = org_apache_arrow_flatbuf_Field.startField(&fbb) org_apache_arrow_flatbuf_Field.add(name: nameOffset, &fbb) org_apache_arrow_flatbuf_Field.add(nullable: field.isNullable, &fbb) + if let childrenOffset = fieldsOffset { + org_apache_arrow_flatbuf_Field.addVectorOf(children: childrenOffset, &fbb) + } + switch toFBTypeEnum(field.type) { case .success(let type): org_apache_arrow_flatbuf_Field.add(typeType: type, &fbb) @@ -109,7 +126,6 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length case .failure(let error): return .failure(error) } - } let fieldsOffset: Offset = fbb.createVector(ofOffsets: fieldOffsets) @@ -135,7 +151,7 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length withUnsafeBytes(of: CONTINUATIONMARKER.littleEndian) { writer.append(Data($0)) } withUnsafeBytes(of: rbResult.1.o.littleEndian) { writer.append(Data($0)) } writer.append(rbResult.0) - switch writeRecordBatchData(&writer, batch: batch) { + switch writeRecordBatchData(&writer, fields: batch.schema.fields, columns: batch.columns) { case .success: rbBlocks.append( org_apache_arrow_flatbuf_Block( @@ -153,30 +169,34 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length return .success(rbBlocks) } - private func writeRecordBatch(batch: RecordBatch) -> Result<(Data, Offset), ArrowError> { - let schema = batch.schema - var fbb = FlatBufferBuilder() - - // write out field nodes - var fieldNodeOffsets = [Offset]() - fbb.startVector( - schema.fields.count, elementSize: MemoryLayout.size) - for index in (0.. Result<(Data, Offset), ArrowError> { + let schema = batch.schema + var fbb = FlatBufferBuilder() + + // write out field nodes + var fieldNodeOffsets = [Offset]() + fbb.startVector( + schema.fields.count, elementSize: MemoryLayout.size) + writeFieldNodes(schema.fields, columns: batch.columns, offsets: &fieldNodeOffsets, fbb: &fbb) + let nodeOffset = fbb.endVector(len: fieldNodeOffsets.count) + + // write out buffers + var buffers = [org_apache_arrow_flatbuf_Buffer]() + var bufferOffset = Int(0) + writeBufferInfo( + schema.fields, columns: batch.columns, + bufferOffset: &bufferOffset, buffers: &buffers, + fbb: &fbb) org_apache_arrow_flatbuf_RecordBatch.startVectorOfBuffers(batch.schema.fields.count, in: &fbb) for buffer in buffers.reversed() { fbb.create(struct: buffer) @@ -210,15 +255,32 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length return .success((fbb.data, Offset(offset: UInt32(fbb.data.count)))) } - private func writeRecordBatchData(_ writer: inout DataWriter, batch: RecordBatch) -> Result< - Bool, ArrowError - > { - for index in 0.. Result + { + for index in 0.. Result< @@ -285,7 +346,7 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length return .success(true) } - public func writeSteaming(_ info: ArrowWriter.Info) -> Result { + public func writeStreaming(_ info: ArrowWriter.Info) -> Result { let writer: any DataWriter = InMemDataWriter() switch toMessage(info.schema) { case .success(let schemaData): @@ -363,7 +424,8 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length writer.append(message.0) addPadForAlignment(&writer) var dataWriter: any DataWriter = InMemDataWriter() - switch writeRecordBatchData(&dataWriter, batch: batch) { + switch writeRecordBatchData(&dataWriter, fields: batch.schema.fields, columns: batch.columns) + { case .success: return .success([ (writer as! InMemDataWriter).data, // swiftlint:disable:this force_cast @@ -397,3 +459,4 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length return .success(fbb.data) } } +// swiftlint:disable:this file_length diff --git a/Sources/SparkConnect/ArrowWriterHelper.swift b/Sources/SparkConnect/ArrowWriterHelper.swift index 7702e48..13856af 100644 --- a/Sources/SparkConnect/ArrowWriterHelper.swift +++ b/Sources/SparkConnect/ArrowWriterHelper.swift @@ -25,77 +25,78 @@ extension Data { } func toFBTypeEnum(_ arrowType: ArrowType) -> Result { - let infoType = arrowType.info - if infoType == ArrowType.ArrowInt8 || infoType == ArrowType.ArrowInt16 - || infoType == ArrowType.ArrowInt64 || infoType == ArrowType.ArrowUInt8 - || infoType == ArrowType.ArrowUInt16 || infoType == ArrowType.ArrowUInt32 - || infoType == ArrowType.ArrowUInt64 || infoType == ArrowType.ArrowInt32 - { + let typeId = arrowType.id + switch typeId { + case .int8, .int16, .int32, .int64, .uint8, .uint16, .uint32, .uint64: return .success(org_apache_arrow_flatbuf_Type_.int) - } else if infoType == ArrowType.ArrowFloat || infoType == ArrowType.ArrowDouble { + case .float, .double: return .success(org_apache_arrow_flatbuf_Type_.floatingpoint) - } else if infoType == ArrowType.ArrowString { + case .string: return .success(org_apache_arrow_flatbuf_Type_.utf8) - } else if infoType == ArrowType.ArrowBinary { + case .binary: return .success(org_apache_arrow_flatbuf_Type_.binary) - } else if infoType == ArrowType.ArrowBool { + case .boolean: return .success(org_apache_arrow_flatbuf_Type_.bool) - } else if infoType == ArrowType.ArrowDate32 || infoType == ArrowType.ArrowDate64 { + case .date32, .date64: return .success(org_apache_arrow_flatbuf_Type_.date) - } else if infoType == ArrowType.ArrowTime32 || infoType == ArrowType.ArrowTime64 { + case .time32, .time64: return .success(org_apache_arrow_flatbuf_Type_.time) + case .strct: + return .success(org_apache_arrow_flatbuf_Type_.struct_) + default: + return .failure(.unknownType("Unable to find flatbuf type for Arrow type: \(typeId)")) } - return .failure(.unknownType("Unable to find flatbuf type for Arrow type: \(infoType)")) } -func toFBType( // swiftlint:disable:this cyclomatic_complexity +func toFBType( // swiftlint:disable:this cyclomatic_complexity function_body_length _ fbb: inout FlatBufferBuilder, arrowType: ArrowType ) -> Result { let infoType = arrowType.info - if infoType == ArrowType.ArrowInt8 || infoType == ArrowType.ArrowUInt8 { + switch arrowType.id { + case .int8, .uint8: return .success( org_apache_arrow_flatbuf_Int.createInt( &fbb, bitWidth: 8, isSigned: infoType == ArrowType.ArrowInt8)) - } else if infoType == ArrowType.ArrowInt16 || infoType == ArrowType.ArrowUInt16 { + case .int16, .uint16: return .success( org_apache_arrow_flatbuf_Int.createInt( &fbb, bitWidth: 16, isSigned: infoType == ArrowType.ArrowInt16)) - } else if infoType == ArrowType.ArrowInt32 || infoType == ArrowType.ArrowUInt32 { + case .int32, .uint32: return .success( org_apache_arrow_flatbuf_Int.createInt( &fbb, bitWidth: 32, isSigned: infoType == ArrowType.ArrowInt32)) - } else if infoType == ArrowType.ArrowInt64 || infoType == ArrowType.ArrowUInt64 { + case .int64, .uint64: return .success( org_apache_arrow_flatbuf_Int.createInt( &fbb, bitWidth: 64, isSigned: infoType == ArrowType.ArrowInt64)) - } else if infoType == ArrowType.ArrowFloat { + case .float: return .success( org_apache_arrow_flatbuf_FloatingPoint.createFloatingPoint(&fbb, precision: .single)) - } else if infoType == ArrowType.ArrowDouble { + case .double: return .success( org_apache_arrow_flatbuf_FloatingPoint.createFloatingPoint(&fbb, precision: .double)) - } else if infoType == ArrowType.ArrowString { + case .string: return .success( org_apache_arrow_flatbuf_Utf8.endUtf8( &fbb, start: org_apache_arrow_flatbuf_Utf8.startUtf8(&fbb))) - } else if infoType == ArrowType.ArrowBinary { + case .binary: return .success( org_apache_arrow_flatbuf_Binary.endBinary( &fbb, start: org_apache_arrow_flatbuf_Binary.startBinary(&fbb))) - } else if infoType == ArrowType.ArrowBool { + case .boolean: return .success( org_apache_arrow_flatbuf_Bool.endBool( &fbb, start: org_apache_arrow_flatbuf_Bool.startBool(&fbb))) - } else if infoType == ArrowType.ArrowDate32 { + case .date32: let startOffset = org_apache_arrow_flatbuf_Date.startDate(&fbb) org_apache_arrow_flatbuf_Date.add(unit: .day, &fbb) return .success(org_apache_arrow_flatbuf_Date.endDate(&fbb, start: startOffset)) - } else if infoType == ArrowType.ArrowDate64 { + case .date64: let startOffset = org_apache_arrow_flatbuf_Date.startDate(&fbb) org_apache_arrow_flatbuf_Date.add(unit: .millisecond, &fbb) return .success(org_apache_arrow_flatbuf_Date.endDate(&fbb, start: startOffset)) - } else if infoType == ArrowType.ArrowTime32 { + case .time32: let startOffset = org_apache_arrow_flatbuf_Time.startTime(&fbb) if let timeType = arrowType as? ArrowTypeTime32 { org_apache_arrow_flatbuf_Time.add( @@ -104,7 +105,7 @@ func toFBType( // swiftlint:disable:this cyclomatic_complexity } return .failure(.invalid("Unable to case to Time32")) - } else if infoType == ArrowType.ArrowTime64 { + case .time64: let startOffset = org_apache_arrow_flatbuf_Time.startTime(&fbb) if let timeType = arrowType as? ArrowTypeTime64 { org_apache_arrow_flatbuf_Time.add( @@ -113,9 +114,12 @@ func toFBType( // swiftlint:disable:this cyclomatic_complexity } return .failure(.invalid("Unable to case to Time64")) + case .strct: + let startOffset = org_apache_arrow_flatbuf_Struct_.startStruct_(&fbb) + return .success(org_apache_arrow_flatbuf_Struct_.endStruct_(&fbb, start: startOffset)) + default: + return .failure(.unknownType("Unable to add flatbuf type for Arrow type: \(infoType)")) } - - return .failure(.unknownType("Unable to add flatbuf type for Arrow type: \(infoType)")) } func addPadForAlignment(_ data: inout Data, alignment: Int = 8) { diff --git a/Sources/SparkConnect/ProtoUtil.swift b/Sources/SparkConnect/ProtoUtil.swift index 4f27b20..ee5f179 100644 --- a/Sources/SparkConnect/ProtoUtil.swift +++ b/Sources/SparkConnect/ProtoUtil.swift @@ -17,7 +17,7 @@ import Foundation -func fromProto( // swiftlint:disable:this cyclomatic_complexity +func fromProto( // swiftlint:disable:this cyclomatic_complexity function_body_length field: org_apache_arrow_flatbuf_Field ) -> ArrowField { let type = field.typeType @@ -74,7 +74,13 @@ func fromProto( // swiftlint:disable:this cyclomatic_complexity arrowType = ArrowTypeTime64(arrowUnit) } case .struct_: - arrowType = ArrowType(ArrowType.ArrowStruct) + var children = [ArrowField]() + for index in 0..