Skip to content

[SPARK-52274] Update ArrowReader/Writer with GH-44910 #170

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 83 additions & 8 deletions Sources/SparkConnect/ArrowReader.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import FlatBuffers
import Foundation

let FILEMARKER = "ARROW1"
let CONTINUATIONMARKER = -1
let CONTINUATIONMARKER = UInt32(0xFFFF_FFFF)

/// @nodoc
public class ArrowReader { // swiftlint:disable:this type_body_length
Expand Down Expand Up @@ -240,7 +240,78 @@ public class ArrowReader { // swiftlint:disable:this type_body_length
return .success(RecordBatch(arrowSchema, columns: columns))
}

public func fromStream( // swiftlint:disable:this function_body_length
/*
This is for reading the Arrow streaming format. The Arrow streaming format
is slightly different from the Arrow File format as it doesn't contain a header
and footer.
*/
public func readStreaming( // swiftlint:disable:this function_body_length
_ input: Data,
useUnalignedBuffers: Bool = false
) -> Result<ArrowReaderResult, ArrowError> {
let result = ArrowReaderResult()
var offset: Int = 0
var length = getUInt32(input, offset: offset)
var streamData = input
var schemaMessage: org_apache_arrow_flatbuf_Schema?
while length != 0 {
if length == CONTINUATIONMARKER {
offset += Int(MemoryLayout<UInt32>.size)
length = getUInt32(input, offset: offset)
if length == 0 {
return .success(result)
}
}

offset += Int(MemoryLayout<UInt32>.size)
streamData = input[offset...]
let dataBuffer = ByteBuffer(
data: streamData,
allowReadingUnalignedBuffers: true)
let message = org_apache_arrow_flatbuf_Message.getRootAsMessage(bb: dataBuffer)
switch message.headerType {
case .recordbatch:
do {
let rbMessage = message.header(type: org_apache_arrow_flatbuf_RecordBatch.self)!
let recordBatch = try loadRecordBatch(
rbMessage,
schema: schemaMessage!,
arrowSchema: result.schema!,
data: input,
messageEndOffset: (Int64(offset) + Int64(length))
).get()
result.batches.append(recordBatch)
offset += Int(message.bodyLength + Int64(length))
length = getUInt32(input, offset: offset)
} catch let error as ArrowError {
return .failure(error)
} catch {
return .failure(.unknownError("Unexpected error: \(error)"))
}
case .schema:
schemaMessage = message.header(type: org_apache_arrow_flatbuf_Schema.self)!
let schemaResult = loadSchema(schemaMessage!)
switch schemaResult {
case .success(let schema):
result.schema = schema
case .failure(let error):
return .failure(error)
}
offset += Int(message.bodyLength + Int64(length))
length = getUInt32(input, offset: offset)
default:
return .failure(.unknownError("Unhandled header type: \(message.headerType)"))
}
}
return .success(result)
}

/*
This is for reading the Arrow file format. The Arrow file format supports
random accessing the data. The Arrow file format contains a header and
footer around the Arrow streaming format.
*/
public func readFile( // swiftlint:disable:this function_body_length
_ fileData: Data,
useUnalignedBuffers: Bool = false
) -> Result<ArrowReaderResult, ArrowError> {
Expand All @@ -266,7 +337,7 @@ public class ArrowReader { // swiftlint:disable:this type_body_length
for index in 0..<footer.recordBatchesCount {
let recordBatch = footer.recordBatches(at: index)!
var messageLength = fileData.withUnsafeBytes { rawBuffer in
rawBuffer.loadUnaligned(fromByteOffset: Int(recordBatch.offset), as: Int32.self)
rawBuffer.loadUnaligned(fromByteOffset: Int(recordBatch.offset), as: UInt32.self)
}

var messageOffset: Int64 = 1
Expand All @@ -275,7 +346,7 @@ public class ArrowReader { // swiftlint:disable:this type_body_length
messageLength = fileData.withUnsafeBytes { rawBuffer in
rawBuffer.loadUnaligned(
fromByteOffset: Int(recordBatch.offset + Int64(MemoryLayout<Int32>.size)),
as: Int32.self)
as: UInt32.self)
}
}

Expand All @@ -299,8 +370,10 @@ public class ArrowReader { // swiftlint:disable:this type_body_length
messageEndOffset: messageEndOffset
).get()
result.batches.append(recordBatch)
} catch let error {
} catch let error as ArrowError {
return .failure(error)
} catch {
return .failure(.unknownError("Unexpected error: \(error)"))
}
default:
return .failure(.unknownError("Unhandled header type: \(message.headerType)"))
Expand All @@ -320,7 +393,7 @@ public class ArrowReader { // swiftlint:disable:this type_body_length
let markerLength = FILEMARKER.utf8.count
let footerLengthEnd = Int(fileData.count - markerLength)
let data = fileData[..<(footerLengthEnd)]
return fromStream(data)
return readFile(data)
} catch {
return .failure(.unknownError("Error loading file: \(error)"))
}
Expand Down Expand Up @@ -360,13 +433,15 @@ public class ArrowReader { // swiftlint:disable:this type_body_length
).get()
result.batches.append(recordBatch)
return .success(())
} catch let error {
} catch let error as ArrowError {
return .failure(error)
} catch {
return .failure(.unknownError("Unexpected error: \(error)"))
}

default:
return .failure(.unknownError("Unhandled header type: \(message.headerType)"))
}
}

}
// swiftlint:disable:this file_length
7 changes: 7 additions & 0 deletions Sources/SparkConnect/ArrowReaderHelper.swift
Original file line number Diff line number Diff line change
Expand Up @@ -312,3 +312,10 @@ func validateFileData(_ data: Data) -> Bool {
let endString = String(decoding: data[(data.count - markerLength)...], as: UTF8.self)
return startString == FILEMARKER && endString == FILEMARKER
}

func getUInt32(_ data: Data, offset: Int) -> UInt32 {
let token = data.withUnsafeBytes { rawBuffer in
rawBuffer.loadUnaligned(fromByteOffset: offset, as: UInt32.self)
}
return token
}
41 changes: 37 additions & 4 deletions Sources/SparkConnect/ArrowWriter.swift
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
let startIndex = writer.count
switch writeRecordBatch(batch: batch) {
case .success(let rbResult):
withUnsafeBytes(of: CONTINUATIONMARKER.littleEndian) { writer.append(Data($0)) }
withUnsafeBytes(of: rbResult.1.o.littleEndian) { writer.append(Data($0)) }
writer.append(rbResult.0)
switch writeRecordBatchData(&writer, batch: batch) {
Expand Down Expand Up @@ -250,7 +251,7 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
return .success(fbb.data)
}

private func writeStream(_ writer: inout DataWriter, info: ArrowWriter.Info) -> Result<
private func writeFile(_ writer: inout DataWriter, info: ArrowWriter.Info) -> Result<
Bool, ArrowError
> {
var fbb: FlatBufferBuilder = FlatBufferBuilder()
Expand Down Expand Up @@ -284,9 +285,41 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length
return .success(true)
}

public func toStream(_ info: ArrowWriter.Info) -> Result<Data, ArrowError> {
public func writeSteaming(_ info: ArrowWriter.Info) -> Result<Data, ArrowError> {
let writer: any DataWriter = InMemDataWriter()
switch toMessage(info.schema) {
case .success(let schemaData):
withUnsafeBytes(of: CONTINUATIONMARKER.littleEndian) { writer.append(Data($0)) }
withUnsafeBytes(of: UInt32(schemaData.count).littleEndian) { writer.append(Data($0)) }
writer.append(schemaData)
case .failure(let error):
return .failure(error)
}

for batch in info.batches {
switch toMessage(batch) {
case .success(let batchData):
withUnsafeBytes(of: CONTINUATIONMARKER.littleEndian) { writer.append(Data($0)) }
withUnsafeBytes(of: UInt32(batchData[0].count).littleEndian) { writer.append(Data($0)) }
writer.append(batchData[0])
writer.append(batchData[1])
case .failure(let error):
return .failure(error)
}
}

withUnsafeBytes(of: CONTINUATIONMARKER.littleEndian) { writer.append(Data($0)) }
withUnsafeBytes(of: UInt32(0).littleEndian) { writer.append(Data($0)) }
if let memWriter = writer as? InMemDataWriter {
return .success(memWriter.data)
} else {
return .failure(.invalid("Unable to cast writer"))
}
}

public func writeFile(_ info: ArrowWriter.Info) -> Result<Data, ArrowError> {
var writer: any DataWriter = InMemDataWriter()
switch writeStream(&writer, info: info) {
switch writeFile(&writer, info: info) {
case .success:
if let memWriter = writer as? InMemDataWriter {
return .success(memWriter.data)
Expand All @@ -313,7 +346,7 @@ public class ArrowWriter { // swiftlint:disable:this type_body_length

var writer: any DataWriter = FileDataWriter(fileHandle)
writer.append(FILEMARKER.data(using: .utf8)!)
switch writeStream(&writer, info: info) {
switch writeFile(&writer, info: info) {
case .success:
writer.append(FILEMARKER.data(using: .utf8)!)
case .failure(let error):
Expand Down
Loading