|
| 1 | +import Hub |
| 2 | +import Tokenizers |
| 3 | +import CoreML |
| 4 | +import Accelerate |
| 5 | + |
| 6 | + |
| 7 | +public protocol Embedding {} |
| 8 | + |
| 9 | +public struct AutoEmbedding {} // Otherwise AutoModel |
| 10 | + |
| 11 | +extension AutoEmbedding { |
| 12 | + public static func from(pretrained model: String, hubApi: HubApi = .shared) async throws -> Embedding { |
| 13 | + return try await BGEM3Model(repoName: model, hubApi: hubApi) |
| 14 | + } |
| 15 | +} |
| 16 | + |
| 17 | +class BERTEmbedding: Embedding { // Otherwise BERTModel |
| 18 | + private let wordEmbedding: BNNS.EmbeddingLayer |
| 19 | + private let positionEmbedding: BNNS.EmbeddingLayer |
| 20 | + private let tokenTypeEmbedding: BNNS.EmbeddingLayer |
| 21 | + private let normalization: BNNS.NormalizationLayer |
| 22 | + private let dropout: BNNS.DropoutLayer |
| 23 | + |
| 24 | + private let positionEmbeddingType = "absolute" |
| 25 | + |
| 26 | + init(repoName: String) { fatalError() } |
| 27 | + |
| 28 | + public func callAsFunction(inputIds: MLMultiArray? = nil, |
| 29 | + tokenTypeIDs: MLMultiArray? = nil, |
| 30 | + positionIDs: MLMultiArray? = nil, |
| 31 | + inputEmbeds: MLMultiArray? = nil, |
| 32 | + pastKeyValuesLength: Int = 0) -> MLMultiArray { |
| 33 | + fatalError() |
| 34 | + } |
| 35 | +} |
| 36 | + |
| 37 | +class BGEM3Model: Embedding { |
| 38 | + |
| 39 | + struct Output { |
| 40 | + let lastHidddenState: MLMultiArray // batchSize, sequenceLength, hiddenSize |
| 41 | + let hiddenStates: MLMultiArray? |
| 42 | + let attentions: MLMultiArray? |
| 43 | + |
| 44 | + let loss: MLMultiArray? |
| 45 | + let scores: MLMultiArray? |
| 46 | + let pReps: MLMultiArray? |
| 47 | + let qReps: MLMultiArray? |
| 48 | + } |
| 49 | + |
| 50 | + let withSparse = false |
| 51 | + let withDense = true |
| 52 | + let withColbert = false |
| 53 | + |
| 54 | + let shouldNormalize = false |
| 55 | +// let poolingMethod = "cls" |
| 56 | +// let negativesCrossDevice = false |
| 57 | +// let temperature = 1.0 |
| 58 | +// let enableSubBatch = true |
| 59 | +// let unifiedFinetuning = true |
| 60 | +// let useSelfDistill = false |
| 61 | +// let colbertDim: Int? = nil |
| 62 | +// let selfDistillStartStep: Int? = nil |
| 63 | + |
| 64 | + private let tokenizer: Tokenizer |
| 65 | + private let denseLayer: BNNS.FullyConnectedLayer |
| 66 | + private let sparseLayer: BNNS.FullyConnectedLayer |
| 67 | + private let colbertLayer: BNNS.FullyConnectedLayer |
| 68 | + |
| 69 | + init(repoName: String, hubApi: HubApi) async throws { |
| 70 | + let config = LanguageModelConfigurationFromHub(modelName: repoName) |
| 71 | + self.tokenizer = try await AutoTokenizer.from(pretrained: repoName, hubApi: hubApi) |
| 72 | + |
| 73 | + let hiddenSize = try await config.modelConfig.hiddenSize?.intValue ?? 384 |
| 74 | + let colbertDim: Int? = nil |
| 75 | + let denseInput = BNNSNDArrayDescriptor(dataType: .float16, shape: .vector(hiddenSize, stride: 2)) |
| 76 | + let denseOutput = BNNSNDArrayDescriptor(dataType: .float16, shape: .vector(colbertDim ?? hiddenSize, stride: 2)) |
| 77 | + let denseWeights = BNNSNDArrayDescriptor(dataType: .float16, shape: .vector(hiddenSize, stride: 2)) |
| 78 | + self.denseLayer = BNNS.FullyConnectedLayer(input: denseInput, output: denseOutput, weights: denseWeights, bias: nil, activation: .identity)! |
| 79 | + |
| 80 | + let sparseInput = BNNSNDArrayDescriptor(dataType: .float16, shape: .vector(hiddenSize, stride: 2)) |
| 81 | + let sparseOutput = BNNSNDArrayDescriptor(dataType: .float16, shape: .vector(1, stride: 2)) |
| 82 | + let sparseWeights = BNNSNDArrayDescriptor(dataType: .float16, shape: .vector(hiddenSize, stride: 2)) |
| 83 | + self.sparseLayer = BNNS.FullyConnectedLayer(input: sparseInput, output: sparseOutput, weights: sparseWeights, bias: nil, activation: .identity)! |
| 84 | + |
| 85 | + let colbertInput = BNNSNDArrayDescriptor(dataType: .float16, shape: .vector(hiddenSize, stride: 2)) |
| 86 | + let colbertOutput = BNNSNDArrayDescriptor(dataType: .float16, shape: .vector(1, stride: 2)) |
| 87 | + let colbertWeights = BNNSNDArrayDescriptor(dataType: .float16, shape: .vector(hiddenSize, stride: 2)) |
| 88 | + self.colbertLayer = BNNS.FullyConnectedLayer(input: colbertInput, output: colbertOutput, weights: colbertWeights, bias: nil, activation: .identity)! |
| 89 | + } |
| 90 | + |
| 91 | + public func callAsFunction(_ textInput: (indices: MLMultiArray, attentionMask: MLMultiArray)) -> Output { |
| 92 | + fatalError() |
| 93 | + } |
| 94 | + |
| 95 | + private func forward(textInput: (indices: MLMultiArray, attentionMask: MLMultiArray)) -> [String: MLMultiArray] { |
| 96 | + let lastHiddenState = self(textInput).lastHidddenState |
| 97 | + |
| 98 | + var output = [String: MLMultiArray]() |
| 99 | + if withDense { |
| 100 | + output["dense"] = self.dense(hiddenState: lastHiddenState, mask: textInput.attentionMask) |
| 101 | + } |
| 102 | + if withSparse { |
| 103 | + output["sparse"] = self.sparse(hiddenState: lastHiddenState, mask: textInput.attentionMask) |
| 104 | + } |
| 105 | + if withColbert { |
| 106 | + output["colbert"] = self.colbert(hiddenState: lastHiddenState, mask: textInput.attentionMask) |
| 107 | + } |
| 108 | + |
| 109 | + if shouldNormalize { |
| 110 | + if withDense { |
| 111 | + // TODO: Normalize output["dense"] = |
| 112 | + fatalError() |
| 113 | + } |
| 114 | + if withColbert { |
| 115 | + // TODO: Normalize output["colbert"] = |
| 116 | + fatalError() |
| 117 | + } |
| 118 | + } |
| 119 | + |
| 120 | + return output |
| 121 | + } |
| 122 | + |
| 123 | + private func dense(hiddenState: MLMultiArray, mask: MLMultiArray) -> MLMultiArray { |
| 124 | + assert(hiddenState.shape.count == 2) |
| 125 | + var data = [Float]() |
| 126 | + data.reserveCapacity(hiddenState.count) |
| 127 | + |
| 128 | + for index in 0..<hiddenState.count { |
| 129 | + data.append(hiddenState[index].floatValue) |
| 130 | + } |
| 131 | + |
| 132 | + return try! MLMultiArray(data) |
| 133 | + } |
| 134 | + |
| 135 | + private func sparse(hiddenState: MLMultiArray, mask: MLMultiArray) -> MLMultiArray { |
| 136 | + fatalError() |
| 137 | + } |
| 138 | + |
| 139 | + private func colbert(hiddenState: MLMultiArray, mask: MLMultiArray) -> MLMultiArray { |
| 140 | + fatalError() |
| 141 | + } |
| 142 | +} |
0 commit comments