Skip to content

[SPARK-52340] Update ArrowWriter(Helper)? and ProtoUtil with GH-43170 #182

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 94 additions & 31 deletions Sources/SparkConnect/ArrowWriter.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,11 @@
import FlatBuffers
import Foundation

/// @nodoc
public protocol DataWriter {
var count: Int { get }
func append(_ data: Data)
}

/// @nodoc
public class ArrowWriter { // swiftlint:disable:this type_body_length
public class InMemDataWriter: DataWriter {
public private(set) var data: Data
Expand Down Expand Up @@ -77,11 +75,30 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
private func writeField(_ fbb: inout FlatBufferBuilder, field: ArrowField) -> Result<
Offset, ArrowError
> {
var fieldsOffset: Offset?
if let nestedField = field.type as? ArrowNestedType {
var offsets = [Offset]()
for field in nestedField.fields {
switch writeField(&fbb, field: field) {
case .success(let offset):
offsets.append(offset)
case .failure(let error):
return .failure(error)
}
}

fieldsOffset = fbb.createVector(ofOffsets: offsets)
}

let nameOffset = fbb.create(string: field.name)
let fieldTypeOffsetResult = toFBType(&fbb, arrowType: field.type)
let startOffset = org_apache_arrow_flatbuf_Field.startField(&fbb)
org_apache_arrow_flatbuf_Field.add(name: nameOffset, &fbb)
org_apache_arrow_flatbuf_Field.add(nullable: field.isNullable, &fbb)
if let childrenOffset = fieldsOffset {
org_apache_arrow_flatbuf_Field.addVectorOf(children: childrenOffset, &fbb)
}

switch toFBTypeEnum(field.type) {
case .success(let type):
org_apache_arrow_flatbuf_Field.add(typeType: type, &fbb)
Expand Down Expand Up @@ -109,7 +126,6 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
case .failure(let error):
return .failure(error)
}

}

let fieldsOffset: Offset = fbb.createVector(ofOffsets: fieldOffsets)
Expand All @@ -135,7 +151,7 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
withUnsafeBytes(of: CONTINUATIONMARKER.littleEndian) { writer.append(Data($0)) }
withUnsafeBytes(of: rbResult.1.o.littleEndian) { writer.append(Data($0)) }
writer.append(rbResult.0)
switch writeRecordBatchData(&writer, batch: batch) {
switch writeRecordBatchData(&writer, fields: batch.schema.fields, columns: batch.columns) {
case .success:
rbBlocks.append(
org_apache_arrow_flatbuf_Block(
Expand All @@ -153,40 +169,69 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
return .success(rbBlocks)
}

private func writeRecordBatch(batch: RecordBatch) -> Result<(Data, Offset), ArrowError> {
let schema = batch.schema
var fbb = FlatBufferBuilder()

// write out field nodes
var fieldNodeOffsets = [Offset]()
fbb.startVector(
schema.fields.count, elementSize: MemoryLayout<org_apache_arrow_flatbuf_FieldNode>.size)
for index in (0..<schema.fields.count).reversed() {
let column = batch.column(index)
private func writeFieldNodes(
_ fields: [ArrowField], columns: [ArrowArrayHolder], offsets: inout [Offset],
fbb: inout FlatBufferBuilder
) {
for index in (0..<fields.count).reversed() {
let column = columns[index]
let fieldNode =
org_apache_arrow_flatbuf_FieldNode(
length: Int64(column.length),
nullCount: Int64(column.nullCount))
fieldNodeOffsets.append(fbb.create(struct: fieldNode))
offsets.append(fbb.create(struct: fieldNode))
if let nestedType = column.type as? ArrowNestedType {
let structArray = column.array as? StructArray
writeFieldNodes(
nestedType.fields, columns: structArray!.arrowFields!, offsets: &offsets, fbb: &fbb)
}
}
}

let nodeOffset = fbb.endVector(len: schema.fields.count)

// write out buffers
var buffers = [org_apache_arrow_flatbuf_Buffer]()
var bufferOffset = Int(0)
for index in 0..<batch.schema.fields.count {
let column = batch.column(index)
private func writeBufferInfo(
_ fields: [ArrowField],
columns: [ArrowArrayHolder],
bufferOffset: inout Int,
buffers: inout [org_apache_arrow_flatbuf_Buffer],
fbb: inout FlatBufferBuilder
) {
for index in 0..<fields.count {
let column = columns[index]
let colBufferDataSizes = column.getBufferDataSizes()
for var bufferDataSize in colBufferDataSizes {
bufferDataSize = getPadForAlignment(bufferDataSize)
let buffer = org_apache_arrow_flatbuf_Buffer(
offset: Int64(bufferOffset), length: Int64(bufferDataSize))
buffers.append(buffer)
bufferOffset += bufferDataSize
if let nestedType = column.type as? ArrowNestedType {
let structArray = column.array as? StructArray
writeBufferInfo(
nestedType.fields, columns: structArray!.arrowFields!,
bufferOffset: &bufferOffset, buffers: &buffers, fbb: &fbb)
}
}
}
}

private func writeRecordBatch(batch: RecordBatch) -> Result<(Data, Offset), ArrowError> {
let schema = batch.schema
var fbb = FlatBufferBuilder()

// write out field nodes
var fieldNodeOffsets = [Offset]()
fbb.startVector(
schema.fields.count, elementSize: MemoryLayout<org_apache_arrow_flatbuf_FieldNode>.size)
writeFieldNodes(schema.fields, columns: batch.columns, offsets: &fieldNodeOffsets, fbb: &fbb)
let nodeOffset = fbb.endVector(len: fieldNodeOffsets.count)

// write out buffers
var buffers = [org_apache_arrow_flatbuf_Buffer]()
var bufferOffset = Int(0)
writeBufferInfo(
schema.fields, columns: batch.columns,
bufferOffset: &bufferOffset, buffers: &buffers,
fbb: &fbb)
org_apache_arrow_flatbuf_RecordBatch.startVectorOfBuffers(batch.schema.fields.count, in: &fbb)
for buffer in buffers.reversed() {
fbb.create(struct: buffer)
Expand All @@ -210,15 +255,32 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
return .success((fbb.data, Offset(offset: UInt32(fbb.data.count))))
}

private func writeRecordBatchData(_ writer: inout DataWriter, batch: RecordBatch) -> Result<
Bool, ArrowError
> {
for index in 0..<batch.schema.fields.count {
let column = batch.column(index)
private func writeRecordBatchData(
_ writer: inout DataWriter, fields: [ArrowField],
columns: [ArrowArrayHolder]
)
-> Result<Bool, ArrowError>
{
for index in 0..<fields.count {
let column = columns[index]
let colBufferData = column.getBufferData()
for var bufferData in colBufferData {
addPadForAlignment(&bufferData)
writer.append(bufferData)
if let nestedType = column.type as? ArrowNestedType {
guard let structArray = column.array as? StructArray else {
return .failure(.invalid("Struct type array expected for nested type"))
}

switch writeRecordBatchData(
&writer, fields: nestedType.fields, columns: structArray.arrowFields!)
{
case .success:
continue
case .failure(let error):
return .failure(error)
}
}
}
}

Expand All @@ -244,11 +306,10 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
org_apache_arrow_flatbuf_Footer.addVectorOf(recordBatches: rbBlkEnd, &fbb)
let footerOffset = org_apache_arrow_flatbuf_Footer.endFooter(&fbb, start: footerStartOffset)
fbb.finish(offset: footerOffset)
return .success(fbb.data)
case .failure(let error):
return .failure(error)
}

return .success(fbb.data)
}

private func writeFile(_ writer: inout DataWriter, info: ArrowWriter.Info) -> Result<
Expand Down Expand Up @@ -285,7 +346,7 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
return .success(true)
}

public func writeSteaming(_ info: ArrowWriter.Info) -> Result<Data, ArrowError> {
public func writeStreaming(_ info: ArrowWriter.Info) -> Result<Data, ArrowError> {
let writer: any DataWriter = InMemDataWriter()
switch toMessage(info.schema) {
case .success(let schemaData):
Expand Down Expand Up @@ -363,7 +424,8 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
writer.append(message.0)
addPadForAlignment(&writer)
var dataWriter: any DataWriter = InMemDataWriter()
switch writeRecordBatchData(&dataWriter, batch: batch) {
switch writeRecordBatchData(&dataWriter, fields: batch.schema.fields, columns: batch.columns)
{
case .success:
return .success([
(writer as! InMemDataWriter).data, // swiftlint:disable:this force_cast
Expand Down Expand Up @@ -397,3 +459,4 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
return .success(fbb.data)
}
}
// swiftlint:disable:this file_length
62 changes: 33 additions & 29 deletions Sources/SparkConnect/ArrowWriterHelper.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,77 +25,78 @@ extension Data {
}

func toFBTypeEnum(_ arrowType: ArrowType) -> Result<org_apache_arrow_flatbuf_Type_, ArrowError> {
let infoType = arrowType.info
if infoType == ArrowType.ArrowInt8 || infoType == ArrowType.ArrowInt16
|| infoType == ArrowType.ArrowInt64 || infoType == ArrowType.ArrowUInt8
|| infoType == ArrowType.ArrowUInt16 || infoType == ArrowType.ArrowUInt32
|| infoType == ArrowType.ArrowUInt64 || infoType == ArrowType.ArrowInt32
{
let typeId = arrowType.id
switch typeId {
case .int8, .int16, .int32, .int64, .uint8, .uint16, .uint32, .uint64:
return .success(org_apache_arrow_flatbuf_Type_.int)
} else if infoType == ArrowType.ArrowFloat || infoType == ArrowType.ArrowDouble {
case .float, .double:
return .success(org_apache_arrow_flatbuf_Type_.floatingpoint)
} else if infoType == ArrowType.ArrowString {
case .string:
return .success(org_apache_arrow_flatbuf_Type_.utf8)
} else if infoType == ArrowType.ArrowBinary {
case .binary:
return .success(org_apache_arrow_flatbuf_Type_.binary)
} else if infoType == ArrowType.ArrowBool {
case .boolean:
return .success(org_apache_arrow_flatbuf_Type_.bool)
} else if infoType == ArrowType.ArrowDate32 || infoType == ArrowType.ArrowDate64 {
case .date32, .date64:
return .success(org_apache_arrow_flatbuf_Type_.date)
} else if infoType == ArrowType.ArrowTime32 || infoType == ArrowType.ArrowTime64 {
case .time32, .time64:
return .success(org_apache_arrow_flatbuf_Type_.time)
case .strct:
return .success(org_apache_arrow_flatbuf_Type_.struct_)
default:
return .failure(.unknownType("Unable to find flatbuf type for Arrow type: \(typeId)"))
}
return .failure(.unknownType("Unable to find flatbuf type for Arrow type: \(infoType)"))
}

func toFBType( // swiftlint:disable:this cyclomatic_complexity
func toFBType( // swiftlint:disable:this cyclomatic_complexity function_body_length
_ fbb: inout FlatBufferBuilder,
arrowType: ArrowType
) -> Result<Offset, ArrowError> {
let infoType = arrowType.info
if infoType == ArrowType.ArrowInt8 || infoType == ArrowType.ArrowUInt8 {
switch arrowType.id {
case .int8, .uint8:
return .success(
org_apache_arrow_flatbuf_Int.createInt(
&fbb, bitWidth: 8, isSigned: infoType == ArrowType.ArrowInt8))
} else if infoType == ArrowType.ArrowInt16 || infoType == ArrowType.ArrowUInt16 {
case .int16, .uint16:
return .success(
org_apache_arrow_flatbuf_Int.createInt(
&fbb, bitWidth: 16, isSigned: infoType == ArrowType.ArrowInt16))
} else if infoType == ArrowType.ArrowInt32 || infoType == ArrowType.ArrowUInt32 {
case .int32, .uint32:
return .success(
org_apache_arrow_flatbuf_Int.createInt(
&fbb, bitWidth: 32, isSigned: infoType == ArrowType.ArrowInt32))
} else if infoType == ArrowType.ArrowInt64 || infoType == ArrowType.ArrowUInt64 {
case .int64, .uint64:
return .success(
org_apache_arrow_flatbuf_Int.createInt(
&fbb, bitWidth: 64, isSigned: infoType == ArrowType.ArrowInt64))
} else if infoType == ArrowType.ArrowFloat {
case .float:
return .success(
org_apache_arrow_flatbuf_FloatingPoint.createFloatingPoint(&fbb, precision: .single))
} else if infoType == ArrowType.ArrowDouble {
case .double:
return .success(
org_apache_arrow_flatbuf_FloatingPoint.createFloatingPoint(&fbb, precision: .double))
} else if infoType == ArrowType.ArrowString {
case .string:
return .success(
org_apache_arrow_flatbuf_Utf8.endUtf8(
&fbb, start: org_apache_arrow_flatbuf_Utf8.startUtf8(&fbb)))
} else if infoType == ArrowType.ArrowBinary {
case .binary:
return .success(
org_apache_arrow_flatbuf_Binary.endBinary(
&fbb, start: org_apache_arrow_flatbuf_Binary.startBinary(&fbb)))
} else if infoType == ArrowType.ArrowBool {
case .boolean:
return .success(
org_apache_arrow_flatbuf_Bool.endBool(
&fbb, start: org_apache_arrow_flatbuf_Bool.startBool(&fbb)))
} else if infoType == ArrowType.ArrowDate32 {
case .date32:
let startOffset = org_apache_arrow_flatbuf_Date.startDate(&fbb)
org_apache_arrow_flatbuf_Date.add(unit: .day, &fbb)
return .success(org_apache_arrow_flatbuf_Date.endDate(&fbb, start: startOffset))
} else if infoType == ArrowType.ArrowDate64 {
case .date64:
let startOffset = org_apache_arrow_flatbuf_Date.startDate(&fbb)
org_apache_arrow_flatbuf_Date.add(unit: .millisecond, &fbb)
return .success(org_apache_arrow_flatbuf_Date.endDate(&fbb, start: startOffset))
} else if infoType == ArrowType.ArrowTime32 {
case .time32:
let startOffset = org_apache_arrow_flatbuf_Time.startTime(&fbb)
if let timeType = arrowType as? ArrowTypeTime32 {
org_apache_arrow_flatbuf_Time.add(
Expand All @@ -104,7 +105,7 @@ func toFBType( // swiftlint:disable:this cyclomatic_complexity
}

return .failure(.invalid("Unable to case to Time32"))
} else if infoType == ArrowType.ArrowTime64 {
case .time64:
let startOffset = org_apache_arrow_flatbuf_Time.startTime(&fbb)
if let timeType = arrowType as? ArrowTypeTime64 {
org_apache_arrow_flatbuf_Time.add(
Expand All @@ -113,9 +114,12 @@ func toFBType( // swiftlint:disable:this cyclomatic_complexity
}

return .failure(.invalid("Unable to case to Time64"))
case .strct:
let startOffset = org_apache_arrow_flatbuf_Struct_.startStruct_(&fbb)
return .success(org_apache_arrow_flatbuf_Struct_.endStruct_(&fbb, start: startOffset))
default:
return .failure(.unknownType("Unable to add flatbuf type for Arrow type: \(infoType)"))
}

return .failure(.unknownType("Unable to add flatbuf type for Arrow type: \(infoType)"))
}

func addPadForAlignment(_ data: inout Data, alignment: Int = 8) {
Expand Down
10 changes: 8 additions & 2 deletions Sources/SparkConnect/ProtoUtil.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import Foundation

func fromProto( // swiftlint:disable:this cyclomatic_complexity
func fromProto( // swiftlint:disable:this cyclomatic_complexity function_body_length
field: org_apache_arrow_flatbuf_Field
) -> ArrowField {
let type = field.typeType
Expand Down Expand Up @@ -74,7 +74,13 @@ func fromProto( // swiftlint:disable:this cyclomatic_complexity
arrowType = ArrowTypeTime64(arrowUnit)
}
case .struct_:
arrowType = ArrowType(ArrowType.ArrowStruct)
var children = [ArrowField]()
for index in 0..<field.childrenCount {
let childField = field.children(at: index)!
children.append(fromProto(field: childField))
}

arrowType = ArrowNestedType(ArrowType.ArrowStruct, fields: children)
default:
arrowType = ArrowType(ArrowType.ArrowUnknown)
}
Expand Down
Loading