Skip to content

Commit ac4d917

Browse files
committed
Create embedding module
1 parent 74b9421 commit ac4d917

File tree

2 files changed

+146
-2
lines changed

2 files changed

+146
-2
lines changed

Package.swift

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ let package = Package(
77
name: "swift-transformers",
88
platforms: [.iOS(.v16), .macOS(.v13)],
99
products: [
10-
.library(name: "Transformers", targets: ["Tokenizers", "Generation", "Models"]),
10+
.library(name: "Transformers", targets: ["Tokenizers", "Generation", "Models", "Embedding"]),
1111
.executable(name: "transformers", targets: ["TransformersCLI"]),
1212
.executable(name: "hub-cli", targets: ["HubCLI"]),
1313
],
@@ -26,11 +26,13 @@ let package = Package(
2626
.target(name: "TensorUtils"),
2727
.target(name: "Generation", dependencies: ["Tokenizers", "TensorUtils"]),
2828
.target(name: "Models", dependencies: ["Tokenizers", "Generation", "TensorUtils"]),
29+
.target(name: "Embedding", dependencies: ["Hub", "Tokenizers"]),
2930
.testTarget(name: "TokenizersTests", dependencies: ["Tokenizers", "Models", "Hub"], resources: [.process("Resources"), .process("Vocabs")]),
3031
.testTarget(name: "HubTests", dependencies: ["Hub"]),
3132
.testTarget(name: "PreTokenizerTests", dependencies: ["Tokenizers", "Hub"]),
3233
.testTarget(name: "TensorUtilsTests", dependencies: ["TensorUtils"]),
3334
.testTarget(name: "NormalizerTests", dependencies: ["Tokenizers", "Hub"]),
34-
.testTarget(name: "PostProcessorTests", dependencies: ["Tokenizers", "Hub"])
35+
.testTarget(name: "PostProcessorTests", dependencies: ["Tokenizers", "Hub"]),
36+
.testTarget(name: "EmbeddingTests", dependencies: ["Embedding", "Tokenizers", "Hub"])
3537
]
3638
)

Sources/Embedding/Embedding.swift

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
import Hub
2+
import Tokenizers
3+
import CoreML
4+
import Accelerate
5+
6+
7+
public protocol Embedding {}
8+
9+
public struct AutoEmbedding {} // Otherwise AutoModel
10+
11+
extension AutoEmbedding {
12+
public static func from(pretrained model: String, hubApi: HubApi = .shared) async throws -> Embedding {
13+
return try await BGEM3Model(repoName: model, hubApi: hubApi)
14+
}
15+
}
16+
17+
class BERTEmbedding: Embedding { // Otherwise BERTModel
18+
private let wordEmbedding: BNNS.EmbeddingLayer
19+
private let positionEmbedding: BNNS.EmbeddingLayer
20+
private let tokenTypeEmbedding: BNNS.EmbeddingLayer
21+
private let normalization: BNNS.NormalizationLayer
22+
private let dropout: BNNS.DropoutLayer
23+
24+
private let positionEmbeddingType = "absolute"
25+
26+
init(repoName: String) { fatalError() }
27+
28+
public func callAsFunction(inputIds: MLMultiArray? = nil,
29+
tokenTypeIDs: MLMultiArray? = nil,
30+
positionIDs: MLMultiArray? = nil,
31+
inputEmbeds: MLMultiArray? = nil,
32+
pastKeyValuesLength: Int = 0) -> MLMultiArray {
33+
fatalError()
34+
}
35+
}
36+
37+
class BGEM3Model: Embedding {
38+
39+
struct Output {
40+
let lastHidddenState: MLMultiArray // batchSize, sequenceLength, hiddenSize
41+
let hiddenStates: MLMultiArray?
42+
let attentions: MLMultiArray?
43+
44+
let loss: MLMultiArray?
45+
let scores: MLMultiArray?
46+
let pReps: MLMultiArray?
47+
let qReps: MLMultiArray?
48+
}
49+
50+
let withSparse = false
51+
let withDense = true
52+
let withColbert = false
53+
54+
let shouldNormalize = false
55+
// let poolingMethod = "cls"
56+
// let negativesCrossDevice = false
57+
// let temperature = 1.0
58+
// let enableSubBatch = true
59+
// let unifiedFinetuning = true
60+
// let useSelfDistill = false
61+
// let colbertDim: Int? = nil
62+
// let selfDistillStartStep: Int? = nil
63+
64+
private let tokenizer: Tokenizer
65+
private let denseLayer: BNNS.FullyConnectedLayer
66+
private let sparseLayer: BNNS.FullyConnectedLayer
67+
private let colbertLayer: BNNS.FullyConnectedLayer
68+
69+
init(repoName: String, hubApi: HubApi) async throws {
70+
let config = LanguageModelConfigurationFromHub(modelName: repoName)
71+
self.tokenizer = try await AutoTokenizer.from(pretrained: repoName, hubApi: hubApi)
72+
73+
let hiddenSize = try await config.modelConfig.hiddenSize?.intValue ?? 384
74+
let colbertDim: Int? = nil
75+
let denseInput = BNNSNDArrayDescriptor(dataType: .float16, shape: .vector(hiddenSize, stride: 2))
76+
let denseOutput = BNNSNDArrayDescriptor(dataType: .float16, shape: .vector(colbertDim ?? hiddenSize, stride: 2))
77+
let denseWeights = BNNSNDArrayDescriptor(dataType: .float16, shape: .vector(hiddenSize, stride: 2))
78+
self.denseLayer = BNNS.FullyConnectedLayer(input: denseInput, output: denseOutput, weights: denseWeights, bias: nil, activation: .identity)!
79+
80+
let sparseInput = BNNSNDArrayDescriptor(dataType: .float16, shape: .vector(hiddenSize, stride: 2))
81+
let sparseOutput = BNNSNDArrayDescriptor(dataType: .float16, shape: .vector(1, stride: 2))
82+
let sparseWeights = BNNSNDArrayDescriptor(dataType: .float16, shape: .vector(hiddenSize, stride: 2))
83+
self.sparseLayer = BNNS.FullyConnectedLayer(input: sparseInput, output: sparseOutput, weights: sparseWeights, bias: nil, activation: .identity)!
84+
85+
let colbertInput = BNNSNDArrayDescriptor(dataType: .float16, shape: .vector(hiddenSize, stride: 2))
86+
let colbertOutput = BNNSNDArrayDescriptor(dataType: .float16, shape: .vector(1, stride: 2))
87+
let colbertWeights = BNNSNDArrayDescriptor(dataType: .float16, shape: .vector(hiddenSize, stride: 2))
88+
self.colbertLayer = BNNS.FullyConnectedLayer(input: colbertInput, output: colbertOutput, weights: colbertWeights, bias: nil, activation: .identity)!
89+
}
90+
91+
public func callAsFunction(_ textInput: (indices: MLMultiArray, attentionMask: MLMultiArray)) -> Output {
92+
fatalError()
93+
}
94+
95+
private func forward(textInput: (indices: MLMultiArray, attentionMask: MLMultiArray)) -> [String: MLMultiArray] {
96+
let lastHiddenState = self(textInput).lastHidddenState
97+
98+
var output = [String: MLMultiArray]()
99+
if withDense {
100+
output["dense"] = self.dense(hiddenState: lastHiddenState, mask: textInput.attentionMask)
101+
}
102+
if withSparse {
103+
output["sparse"] = self.sparse(hiddenState: lastHiddenState, mask: textInput.attentionMask)
104+
}
105+
if withColbert {
106+
output["colbert"] = self.colbert(hiddenState: lastHiddenState, mask: textInput.attentionMask)
107+
}
108+
109+
if shouldNormalize {
110+
if withDense {
111+
// TODO: Normalize output["dense"] =
112+
fatalError()
113+
}
114+
if withColbert {
115+
// TODO: Normalize output["colbert"] =
116+
fatalError()
117+
}
118+
}
119+
120+
return output
121+
}
122+
123+
private func dense(hiddenState: MLMultiArray, mask: MLMultiArray) -> MLMultiArray {
124+
assert(hiddenState.shape.count == 2)
125+
var data = [Float]()
126+
data.reserveCapacity(hiddenState.count)
127+
128+
for index in 0..<hiddenState.count {
129+
data.append(hiddenState[index].floatValue)
130+
}
131+
132+
return try! MLMultiArray(data)
133+
}
134+
135+
private func sparse(hiddenState: MLMultiArray, mask: MLMultiArray) -> MLMultiArray {
136+
fatalError()
137+
}
138+
139+
private func colbert(hiddenState: MLMultiArray, mask: MLMultiArray) -> MLMultiArray {
140+
fatalError()
141+
}
142+
}

0 commit comments

Comments
 (0)