diff --git a/Package.swift b/Package.swift index de98aa6..5050e92 100644 --- a/Package.swift +++ b/Package.swift @@ -7,7 +7,7 @@ let package = Package( name: "swift-transformers", platforms: [.iOS(.v16), .macOS(.v13)], products: [ - .library(name: "Transformers", targets: ["Tokenizers", "Generation", "Models"]), + .library(name: "Transformers", targets: ["Tokenizers", "Generation", "Models", "Embedding"]), .executable(name: "transformers", targets: ["TransformersCLI"]), .executable(name: "hub-cli", targets: ["HubCLI"]), ], @@ -26,11 +26,13 @@ let package = Package( .target(name: "TensorUtils"), .target(name: "Generation", dependencies: ["Tokenizers", "TensorUtils"]), .target(name: "Models", dependencies: ["Tokenizers", "Generation", "TensorUtils"]), + .target(name: "Embedding", dependencies: ["Hub", "Tokenizers"]), .testTarget(name: "TokenizersTests", dependencies: ["Tokenizers", "Models", "Hub"], resources: [.process("Resources"), .process("Vocabs")]), .testTarget(name: "HubTests", dependencies: ["Hub"]), .testTarget(name: "PreTokenizerTests", dependencies: ["Tokenizers", "Hub"]), .testTarget(name: "TensorUtilsTests", dependencies: ["TensorUtils"]), .testTarget(name: "NormalizerTests", dependencies: ["Tokenizers", "Hub"]), - .testTarget(name: "PostProcessorTests", dependencies: ["Tokenizers", "Hub"]) + .testTarget(name: "PostProcessorTests", dependencies: ["Tokenizers", "Hub"]), + .testTarget(name: "EmbeddingTests", dependencies: ["Embedding", "Tokenizers", "Hub", "TensorUtils"], resources: [.process("Resources"), .process("Vocabs")]) ] ) diff --git a/Sources/Embedding/Embedding.swift b/Sources/Embedding/Embedding.swift new file mode 100644 index 0000000..e4ef3e0 --- /dev/null +++ b/Sources/Embedding/Embedding.swift @@ -0,0 +1,154 @@ +import Hub +import Tokenizers +import CoreML +import Accelerate + + +class BERTEmbedding { + + typealias Weights = [String: MLMultiArray] + + var shape: [NSNumber] {[ + NSNumber(value: maxPositionEmbeddings), + NSNumber(value: hiddenSize), + ]} + + private let weights: Weights + + private let positionEmbeddingType: String + private let hiddenSize: Int + private let vocabSize: Int + private let maxPositionEmbeddings: Int + private let typeVocabSize: Int + private let padTokenID: Int + private let normalizationEpsilon: Float + private let dropoutRate: Float = 1e-1 + private let hiddenActivation: BNNS.ActivationFunction = .geluApproximation2(alpha: 1e-1, beta: 1e-1) + + private var allocations: [BNNSNDArrayDescriptor] = [] + + private lazy var wordEmbedding: BNNS.EmbeddingLayer = { + let input = BNNSNDArrayDescriptor.allocateUninitialized(scalarType: Int64.self, shape: .vector(maxPositionEmbeddings)) + allocations.append(input) + let dictData: [Float32] = weights["bert.embeddings.word_embeddings.weight"]!.toArray() + let dict = BNNSNDArrayDescriptor.allocate(initializingFrom: dictData, shape: .matrixColumnMajor(hiddenSize, vocabSize)) + allocations.append(dict) + let output = BNNSNDArrayDescriptor.allocateUninitialized(scalarType: Float32.self, shape: .matrixColumnMajor(hiddenSize, maxPositionEmbeddings)) + allocations.append(output) + + return BNNS.EmbeddingLayer(input: input, output: output, dictionary: dict, paddingIndex: 0, maximumNorm: 0, normType: .l2, scalesGradientByFrequency: false)! + }() + + private lazy var positionEmbedding: BNNS.EmbeddingLayer = { + let input = BNNSNDArrayDescriptor.allocateUninitialized(scalarType: Int64.self, shape: .vector(maxPositionEmbeddings)) + allocations.append(input) + let dictData: [Float32] = weights["bert.embeddings.position_embeddings.weight"]!.toArray() + let dict = BNNSNDArrayDescriptor.allocate(initializingFrom: dictData, shape: .matrixColumnMajor(hiddenSize, maxPositionEmbeddings)) + allocations.append(dict) + let output = BNNSNDArrayDescriptor.allocateUninitialized(scalarType: Float32.self, shape: .matrixColumnMajor(hiddenSize, maxPositionEmbeddings)) + allocations.append(output) + + return BNNS.EmbeddingLayer(input: input, output: output, dictionary: dict, paddingIndex: -1, maximumNorm: 0, normType: .l2, scalesGradientByFrequency: true)! + }() + + private lazy var tokenTypeEmbedding: BNNS.EmbeddingLayer = { + let input = BNNSNDArrayDescriptor.allocateUninitialized(scalarType: Int64.self, shape: .vector(maxPositionEmbeddings)) + allocations.append(input) + let dictData: [Float32] = weights["bert.embeddings.token_type_embeddings.weight"]!.toArray() + let dict = BNNSNDArrayDescriptor.allocate(initializingFrom: dictData, shape: .matrixColumnMajor(hiddenSize, typeVocabSize)) + allocations.append(dict) + let output = BNNSNDArrayDescriptor.allocateUninitialized(scalarType: Float32.self, shape: .matrixColumnMajor(hiddenSize, maxPositionEmbeddings)) + allocations.append(output) + + return BNNS.EmbeddingLayer(input: input, output: output, dictionary: dict, paddingIndex: -1, maximumNorm: 0, normType: .l2, scalesGradientByFrequency: true)! + }() + + private lazy var normalization: BNNS.NormalizationLayer = { + let input = BNNSNDArrayDescriptor.allocateUninitialized(scalarType: Float32.self, shape: .matrixRowMajor(maxPositionEmbeddings, hiddenSize)) + allocations.append(input) + let output = BNNSNDArrayDescriptor.allocateUninitialized(scalarType: Float32.self, shape: .matrixRowMajor(maxPositionEmbeddings, hiddenSize)) + allocations.append(output) + + let betaWA: MLMultiArray! = weights["bert.embeddings.LayerNorm.beta"] ?? weights["bert.embeddings.LayerNorm.bias"] + let beta = BNNSNDArrayDescriptor.allocate(initializingFrom: betaWA.toArray() as [Float32], shape: .matrixColumnMajor(hiddenSize, maxPositionEmbeddings)) + allocations.append(beta) + + let gammaWA: MLMultiArray! = weights["bert.embeddings.LayerNorm.gamma"] ?? weights["bert.embeddings.LayerNorm.weight"] + let gamma = BNNSNDArrayDescriptor.allocate(initializingFrom: gammaWA.toArray() as [Float32], shape: .matrixColumnMajor(hiddenSize, maxPositionEmbeddings)) + allocations.append(gamma) + + return BNNS.NormalizationLayer(type: .batch(movingMean: nil, movingVariance: nil), input: input, output: output, beta: beta, gamma: gamma, epsilon: normalizationEpsilon, activation: hiddenActivation)! + }() + + private lazy var dropout: BNNS.DropoutLayer = { + let input = BNNSNDArrayDescriptor.allocateUninitialized(scalarType: Float32.self, shape: .matrixColumnMajor(hiddenSize, maxPositionEmbeddings)) + allocations.append(input) + let output = BNNSNDArrayDescriptor.allocateUninitialized(scalarType: Float32.self, shape: .matrixColumnMajor(hiddenSize, maxPositionEmbeddings)) + allocations.append(output) + + return BNNS.DropoutLayer(input: input, output: output, rate: dropoutRate, seed: 0, control: 0)! + }() + + deinit { + allocations.forEach({ $0.deallocate() }) + } + + init(config: Config, weights: Weights = [:]) { + assert(config.model_type!.stringValue == "bert") + for key in [ + "bert.embeddings.word_embeddings.weight", + "bert.embeddings.position_embeddings.weight", + "bert.embeddings.token_type_embeddings.weight", + ] { assert(weights.keys.contains(where: { $0 == key })) } + assert(weights.keys.contains(where: { $0 == "bert.embeddings.LayerNorm.beta" || $0 == "bert.embeddings.LayerNorm.bias" })) + assert(weights.keys.contains(where: { $0 == "bert.embeddings.LayerNorm.gamma" || $0 == "bert.embeddings.LayerNorm.weight" })) + assert(config.hidden_act!.stringValue == "gelu") + assert("absolute" == config.position_embedding_type!.stringValue!) + self.positionEmbeddingType = config.position_embedding_type!.stringValue! + self.hiddenSize = config.hidden_size!.intValue! + self.vocabSize = config.vocab_size!.intValue! + self.maxPositionEmbeddings = config.max_position_embeddings!.intValue! + self.typeVocabSize = config.type_vocab_size!.intValue! + self.padTokenID = config.pad_token_id!.intValue! + self.normalizationEpsilon = Float(config.layer_norm_eps!.doubleValue!) + self.weights = weights + } + + public func callAsFunction(inputIDs: [Int64], + tokenTypeIDs: [Int64]? = nil, + positionIDs: [Int64]? = nil) -> MLMultiArray { + let inputLength = inputIDs.count + let inputIDs: [Int64] = inputIDs.padded(length: maxPositionEmbeddings) + let wordInput = BNNSNDArrayDescriptor.allocate(initializingFrom: inputIDs, shape: .vector(inputIDs.count)) + let wordOutput = BNNSNDArrayDescriptor.allocateUninitialized(scalarType: Float32.self, shape: .matrixColumnMajor(hiddenSize, inputIDs.count)) + defer { + wordInput.deallocate() + wordOutput.deallocate() + } + try! wordEmbedding.apply(batchSize: 1, input: wordInput, output: wordOutput) + + let positionIDs = positionIDs ?? Array(stride(from: 0, through: Int64(inputLength - 1), by: 1)) + let positionInput = BNNSNDArrayDescriptor.allocate(initializingFrom: positionIDs.padded(length: maxPositionEmbeddings), shape: .vector(maxPositionEmbeddings)) + let positionOutput = BNNSNDArrayDescriptor.allocateUninitialized(scalarType: Float32.self, shape: .matrixColumnMajor(hiddenSize, maxPositionEmbeddings)) + defer { + positionInput.deallocate() + positionOutput.deallocate() + } + try! self.positionEmbedding.apply(batchSize: 1, input: positionInput, output: positionOutput) + + let tokenTypeIDs: [Int64] = tokenTypeIDs ?? Array(repeating: 0, count: maxPositionEmbeddings) + let typeInput = BNNSNDArrayDescriptor.allocate(initializingFrom: tokenTypeIDs, shape: .vector(maxPositionEmbeddings)) + let typeOutput = BNNSNDArrayDescriptor.allocateUninitialized(scalarType: Float32.self, shape: .matrixColumnMajor(hiddenSize, maxPositionEmbeddings)) + defer { + typeInput.deallocate() + typeOutput.deallocate() + } + try! self.tokenTypeEmbedding.apply(batchSize: 1, input: typeInput, output: typeOutput) + + let multiWord = try! wordOutput.makeMultiArray(of: Float32.self, shape: shape) + let multiPosition = try! positionOutput.makeMultiArray(of: Float32.self, shape: shape) + let multiType = try! typeOutput.makeMultiArray(of: Float32.self, shape: shape) + + return multiWord + multiPosition + multiType + } +} diff --git a/Sources/Hub/Hub.swift b/Sources/Hub/Hub.swift index 8deeee8..d5aa649 100644 --- a/Sources/Hub/Hub.swift +++ b/Sources/Hub/Hub.swift @@ -90,6 +90,7 @@ public struct Config { } public var intValue: Int? { value as? Int } + public var doubleValue: Double? { value as? Double } public var boolValue: Bool? { value as? Bool } public var stringValue: String? { value as? String } diff --git a/Sources/TensorUtils/Array+Utils.swift b/Sources/TensorUtils/Array+Utils.swift new file mode 100644 index 0000000..38c32a9 --- /dev/null +++ b/Sources/TensorUtils/Array+Utils.swift @@ -0,0 +1,8 @@ +import Foundation + + +public extension Array where Element: Numeric { + func padded(length maxLength: Int) -> Array { + self + Array(repeating: 0, count: Swift.max(maxLength - count, 0)) + } +} diff --git a/Sources/TensorUtils/BNNS+Utils.swift b/Sources/TensorUtils/BNNS+Utils.swift new file mode 100644 index 0000000..f5b70ae --- /dev/null +++ b/Sources/TensorUtils/BNNS+Utils.swift @@ -0,0 +1,14 @@ +import Accelerate +import CoreML.MLMultiArray + + +public extension BNNSNDArrayDescriptor { + func makeMultiArray(of numericType: T.Type, shape: [NSNumber]) throws -> MLMultiArray { + assert(numericType == Float32.self) + let strides = shape.dropFirst().reversed().reduce(into: [1]) { acc, a in + acc.insert(acc[0].intValue * a.intValue as NSNumber, at: 0) + } + + return try MLMultiArray(dataPointer: self.data!, shape: shape, dataType: .float32, strides: strides) + } +} diff --git a/Sources/TensorUtils/MLMultiArray+Utils.swift b/Sources/TensorUtils/MLMultiArray+Utils.swift index ddb2760..623bc87 100644 --- a/Sources/TensorUtils/MLMultiArray+Utils.swift +++ b/Sources/TensorUtils/MLMultiArray+Utils.swift @@ -8,6 +8,7 @@ import Foundation import CoreML +import Accelerate public extension MLMultiArray { /// All values will be stored in the last dimension of the MLMultiArray (default is dims=1) @@ -198,3 +199,54 @@ extension MLMultiArray { return s + "]" } } + +public extension MLMultiArray { + func toArray() -> Array { + let stride = MemoryLayout.stride + let allocated = UnsafeMutableRawBufferPointer.allocate(byteCount: self.count * stride, alignment: MemoryLayout.alignment) + return self.withUnsafeBytes { ptr in + memcpy(allocated.baseAddress!, ptr.baseAddress!, self.count * stride) + let start = allocated.bindMemory(to: T.self).baseAddress! + return Array(UnsafeBufferPointer(start: start, count: self.count)) + } + } +} + +public extension MLMultiArray { + static func +(lhs: MLMultiArray, rhs: MLMultiArray) -> MLMultiArray { + assert(lhs.dataType == rhs.dataType && lhs.dataType == .float32) + assert(lhs.shape.count == rhs.shape.count && lhs.shape[1].intValue == rhs.shape[1].intValue) + + let outShape: [NSNumber] + let outLength: Int + var ptr0: UnsafeMutablePointer + var ptr1: UnsafeMutablePointer + if lhs.shape[0].intValue >= rhs.shape[0].intValue { + assert(rhs.shape[0].intValue == 1 || lhs.shape == rhs.shape) // A[m, n], B[1, n] || B[m, n] + outShape = lhs.shape + outLength = lhs.count + ptr0 = UnsafeMutablePointer(OpaquePointer(lhs.withUnsafeMutableBytes({ ptr, _ in ptr.baseAddress! }))) + ptr1 = UnsafeMutablePointer(OpaquePointer(rhs.withUnsafeMutableBytes({ ptr, _ in ptr.baseAddress! }))) + } else { + assert(lhs.shape[0].intValue == 1) // Swap when A[1, n], B[m, n] + outShape = rhs.shape + outLength = rhs.count + ptr0 = UnsafeMutablePointer(OpaquePointer(rhs.withUnsafeMutableBytes({ ptr, _ in ptr.baseAddress! }))) + ptr1 = UnsafeMutablePointer(OpaquePointer(lhs.withUnsafeMutableBytes({ ptr, _ in ptr.baseAddress! }))) + } + + let output = try! MLMultiArray(shape: outShape, dataType: .float32) + var ptrOutput = UnsafeMutablePointer(OpaquePointer(output.withUnsafeMutableBytes({ ptr, _ in ptr.baseAddress! }))) + vDSP_vadd(ptr0, 1, ptr1, 1, ptrOutput, 1, vDSP_Length(outLength)) + + if lhs.shape[0].intValue != rhs.shape[0].intValue { + for _ in 1...stride + let allocated = UnsafeMutableRawBufferPointer.allocate(byteCount: array.count * stride, alignment: MemoryLayout.alignment) + defer { allocated.deallocate() } + _ = array.withUnsafeBufferPointer { ptr in + memcpy(allocated.baseAddress!, ptr.baseAddress!, array.count * stride) + } + let multiArray = try MLMultiArray(dataPointer: allocated.baseAddress!, shape: [4, 10], dataType: .float32, strides: [10, 1]) + let output = multiArray + multiArray + XCTAssertEqual(output.count, array.count) + XCTAssertEqual(output.count, multiArray.count) + + for index in 0...stride + let allocA = UnsafeMutableRawBufferPointer.allocate(byteCount: array.count * stride, alignment: MemoryLayout.alignment) + defer { allocA.deallocate() } + let allocB = UnsafeMutableRawBufferPointer.allocate(byteCount: 10 * stride, alignment: MemoryLayout.alignment) + defer { allocB.deallocate() } + + _ = array.withUnsafeBufferPointer { ptr in + memcpy(allocA.baseAddress!, ptr.baseAddress!, array.count * stride) + } + _ = Array(repeating: 10, count: 10).withUnsafeBufferPointer { ptr in + memcpy(allocB.baseAddress!, ptr.baseAddress!, 10 * stride) + } + + let A = try MLMultiArray(dataPointer: allocA.baseAddress!, shape: [4, 10], dataType: .float32, strides: [10, 1]) + XCTAssertEqual(A.count, 40) + let B = try MLMultiArray(dataPointer: allocB.baseAddress!, shape: [1, 10], dataType: .float32, strides: [10, 1]) + XCTAssertEqual(B.count, 10) + + _ = A + B + _ = A + B + B + _ = A + B + B + let expectedArray: [Float32] = [ + 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, + 21, 22, 23, 24, 25, 26, 27 ,28 ,29, 30, + 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, + 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, + ] + let output = A + B + XCTAssertEqual(output.floats, expectedArray) + } + + func testAdditionRowReverseOrder() throws { + let array: [Float32] = [ + 01, 02, 03, 04, 05, 06, 07, 08, 09, 10, + 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, + 21, 22, 23, 24, 25, 26, 27 ,28 ,29, 30, + 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, + ] + let stride = MemoryLayout.stride + let allocA = UnsafeMutableRawBufferPointer.allocate(byteCount: array.count * stride, alignment: MemoryLayout.alignment) + defer { allocA.deallocate() } + let allocB = UnsafeMutableRawBufferPointer.allocate(byteCount: 10 * stride, alignment: MemoryLayout.alignment) + defer { allocB.deallocate() } + + _ = array.withUnsafeBufferPointer { ptr in + memcpy(allocA.baseAddress!, ptr.baseAddress!, array.count * stride) + } + _ = Array(repeating: 10, count: 10).withUnsafeBufferPointer { ptr in + memcpy(allocB.baseAddress!, ptr.baseAddress!, 10 * stride) + } + + let A = try MLMultiArray(dataPointer: allocA.baseAddress!, shape: [4, 10], dataType: .float32, strides: [10, 1]) + XCTAssertEqual(A.count, 40) + let B = try MLMultiArray(dataPointer: allocB.baseAddress!, shape: [1, 10], dataType: .float32, strides: [10, 1]) + XCTAssertEqual(B.count, 10) + XCTAssertEqual(B + A, A + B) + _ = A + B + _ = A + B + B + _ = A + B + B + let expectedArray: [Float32] = [ + 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, + 21, 22, 23, 24, 25, 26, 27 ,28 ,29, 30, + 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, + 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, + ] + let output = B + A + XCTAssertEqual(output.floats, expectedArray) + } +}