Skip to content

Commit 4c8cf07

Browse files
authored
Tokenizer fixes (#113)
* Bring over hf token envvar from preview branch * Add tests for Gemma, including edge cases Edge cases also added for other BPE tokenizers, but not for T5 yet. * Sort added tokens by length (descending) to avoid early partial matches Similar to huggingface/transformers.js@c305c38 * Store vocab as NSString to allow multiple tokens with the same Unicode canonical representation. * Remove comments * Go back to making vocab dictionaries private * Use ungated copy of Gemma tokenizer * Use NSString in UnigramTokenizer * Switch test to microsoft tokenizer, verify in Python
1 parent e72d032 commit 4c8cf07

File tree

10 files changed

+129
-40
lines changed

10 files changed

+129
-40
lines changed

Sources/Hub/Hub.swift

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ public extension Hub {
3838

3939
@dynamicMemberLookup
4040
public struct Config {
41-
public private(set) var dictionary: [String: Any]
41+
public private(set) var dictionary: [NSString: Any]
4242

43-
public init(_ dictionary: [String: Any]) {
43+
public init(_ dictionary: [NSString: Any]) {
4444
self.dictionary = dictionary
4545
}
4646

@@ -76,8 +76,8 @@ public struct Config {
7676

7777

7878
public subscript(dynamicMember member: String) -> Config? {
79-
let key = dictionary[member] != nil ? member : uncamelCase(member)
80-
if let value = dictionary[key] as? [String: Any] {
79+
let key = (dictionary[member as NSString] != nil ? member : uncamelCase(member)) as NSString
80+
if let value = dictionary[key] as? [NSString: Any] {
8181
return Config(value)
8282
} else if let value = dictionary[key] {
8383
return Config(["value": value])
@@ -96,7 +96,7 @@ public struct Config {
9696
// Instead of doing this we could provide custom classes and decode to them
9797
public var arrayValue: [Config]? {
9898
guard let list = value as? [Any] else { return nil }
99-
return list.map { Config($0 as! [String : Any]) }
99+
return list.map { Config($0 as! [NSString : Any]) }
100100
}
101101

102102
/// Tuple of token identifier and string value
@@ -206,7 +206,7 @@ public class LanguageModelConfigurationFromHub {
206206
do {
207207
let data = try Data(contentsOf: url)
208208
let parsed = try JSONSerialization.jsonObject(with: data, options: [])
209-
guard let dictionary = parsed as? [String: Any] else { return nil }
209+
guard let dictionary = parsed as? [NSString: Any] else { return nil }
210210
return Config(dictionary)
211211
} catch {
212212
return nil

Sources/Hub/HubApi.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ public struct HubApi {
1717
public typealias Repo = Hub.Repo
1818

1919
public init(downloadBase: URL? = nil, hfToken: String? = nil, endpoint: String = "https://huggingface.co", useBackgroundSession: Bool = false) {
20-
self.hfToken = hfToken
20+
self.hfToken = hfToken ?? ProcessInfo.processInfo.environment["HUGGING_FACE_HUB_TOKEN"]
2121
if let downloadBase {
2222
self.downloadBase = downloadBase
2323
} else {
@@ -102,7 +102,7 @@ public extension HubApi {
102102
func configuration(fileURL: URL) throws -> Config {
103103
let data = try Data(contentsOf: fileURL)
104104
let parsed = try JSONSerialization.jsonObject(with: data, options: [])
105-
guard let dictionary = parsed as? [String: Any] else { throw Hub.HubClientError.parse }
105+
guard let dictionary = parsed as? [NSString: Any] else { throw Hub.HubClientError.parse }
106106
return Config(dictionary)
107107
}
108108
}
@@ -116,7 +116,7 @@ public extension HubApi {
116116
let (data, _) = try await httpGet(for: url)
117117

118118
let parsed = try JSONSerialization.jsonObject(with: data, options: [])
119-
guard let dictionary = parsed as? [String: Any] else { throw Hub.HubClientError.parse }
119+
guard let dictionary = parsed as? [NSString: Any] else { throw Hub.HubClientError.parse }
120120
return Config(dictionary)
121121
}
122122
}

Sources/Tokenizers/BPETokenizer.swift

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,11 @@ struct BytePair: Hashable {
3333

3434
class BPETokenizer: PreTrainedTokenizerModel {
3535
let bpeRanks: Dictionary<BytePair, Int>
36-
private let tokensToIds: [String: Int]
37-
private let idsToTokens: [Int: String]
38-
36+
private let tokensToIds: [NSString: Int]
37+
private let idsToTokens: [Int: NSString]
38+
39+
var vocabCount: Int { tokensToIds.count }
40+
3941
public let bosToken: String?
4042
public let bosTokenId: Int?
4143
public let eosToken: String?
@@ -45,7 +47,7 @@ class BPETokenizer: PreTrainedTokenizerModel {
4547

4648
required init(tokenizerConfig: Config, tokenizerData: Config, addedTokens: [String : Int]) throws {
4749
guard let merges = tokenizerData.model?.merges?.value as? [String] else { fatalError("BPETokenizer requires merges") }
48-
guard let vocab = tokenizerData.model?.vocab?.dictionary as? [String: Int] else {
50+
guard let vocab = tokenizerData.model?.vocab?.dictionary as? [NSString: Int] else {
4951
throw TokenizerError.missingVocab
5052
}
5153
var bpeRanks: Dictionary<BytePair, Int> = [:]
@@ -56,31 +58,31 @@ class BPETokenizer: PreTrainedTokenizerModel {
5658
}
5759
self.bpeRanks = bpeRanks
5860

59-
self.tokensToIds = vocab.merging(addedTokens) { $1 }
61+
self.tokensToIds = vocab.merging(addedTokens as [NSString : Int]) { $1 }
6062
self.idsToTokens = Utils.invert(self.tokensToIds)
6163

6264
// Populate tokens
6365
if let unknownToken = TokenizerModel.unknownToken(from: tokenizerConfig) {
6466
self.unknownToken = unknownToken
65-
self.unknownTokenId = self.tokensToIds[unknownToken]
67+
self.unknownTokenId = self.tokensToIds[unknownToken as NSString]
6668
} else {
6769
self.unknownToken = nil
6870
self.unknownTokenId = nil
6971
}
7072

7173
eosToken = tokenizerConfig.eosToken?.stringValue
72-
eosTokenId = eosToken == nil ? nil : tokensToIds[eosToken!]
74+
eosTokenId = eosToken == nil ? nil : tokensToIds[eosToken! as NSString]
7375

7476
bosToken = tokenizerConfig.bosToken?.stringValue
75-
bosTokenId = bosToken == nil ? nil : tokensToIds[bosToken!]
77+
bosTokenId = bosToken == nil ? nil : tokensToIds[bosToken! as NSString]
7678
}
7779

7880
func convertTokenToId(_ token: String) -> Int? {
79-
return tokensToIds[token] ?? self.unknownTokenId
81+
return tokensToIds[token as NSString] ?? self.unknownTokenId
8082
}
8183

8284
func convertIdToToken(_ id: Int) -> String? {
83-
return idsToTokens[id]
85+
return idsToTokens[id] as String?
8486
}
8587

8688
func byteEncode(text: String) -> [String] {
@@ -162,7 +164,7 @@ class BPETokenizer: PreTrainedTokenizerModel {
162164
var tokens: [String] = []
163165
let bpeTokens = self.bpe(token: text).split(separator: " ").map { String($0) }
164166
for token in bpeTokens {
165-
if let _ = tokensToIds[token] {
167+
if convertTokenToId(token) != unknownTokenId {
166168
tokens.append(token)
167169
} else {
168170
// TODO: if config.byte_fallback is False, append the unknown token instead

Sources/Tokenizers/Tokenizer.swift

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -163,12 +163,23 @@ public class PreTrainedTokenizer: Tokenizer {
163163
}
164164
}
165165

166-
let addedTokensRegexString = (tokenizerData.addedTokens?.arrayValue ?? []).compactMap { addedToken in
167-
guard let content = addedToken.content?.stringValue else { return nil }
168-
let prefix = (addedToken.lstrip?.boolValue ?? false ? #"\s*"# : "")
169-
let suffix = (addedToken.rstrip?.boolValue ?? false ? #"\s*"# : "")
170-
let token = NSRegularExpression.escapedPattern(for: content)
171-
return "\(prefix)(\(token))\(suffix)"
166+
// Convert to tuples for easier access, then sort by length (descending) to avoid early partial matches
167+
// (https://github.com/xenova/transformers.js/commit/c305c3824f628f1f02806a6310bd3b18b0f7f8f5)
168+
let unwrappedAddedTokens : [(content: String, prefix: Bool, suffix: Bool)] = (tokenizerData.addedTokens?.arrayValue ?? []).compactMap { addedToken in
169+
guard let content = addedToken.content?.stringValue else { return nil }
170+
let prefix = addedToken.lstrip?.boolValue ?? false
171+
let suffix = addedToken.rstrip?.boolValue ?? false
172+
return (content: content, prefix: prefix, suffix: suffix)
173+
}.sorted {
174+
$0.content.count > $1.content.count
175+
}
176+
177+
// then concatenate into regular expression
178+
let addedTokensRegexString = unwrappedAddedTokens.map {
179+
let token = NSRegularExpression.escapedPattern(for: $0.content)
180+
let prefix = $0.prefix ? #"\s*"# : ""
181+
let suffix = $0.suffix ? #"\s*"# : ""
182+
return "\(prefix)(\(token))\(suffix)"
172183
}.joined(separator: "|")
173184
addedTokensRegex = try? NSRegularExpression(pattern: addedTokensRegexString, options: [])
174185

Sources/Tokenizers/UnigramTokenizer.swift

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ class UnigramTokenizer: PreTrainedTokenizerModel {
2323
public var unknownToken: String? { unknownPiece.token }
2424

2525
let minScore: Float
26-
let tokensToIds: [String: Int]
27-
26+
let tokensToIds: [NSString: Int]
27+
2828
let bosToken: String? = " "
2929
let bosTokenId: Int?
3030
let eosToken: String?
@@ -63,20 +63,20 @@ class UnigramTokenizer: PreTrainedTokenizerModel {
6363
self.unknownTokenId = unknownTokenId
6464
self.unknownPiece = SentencePieceToken(token: vocab[unknownTokenId].token, score: minScore - 10)
6565

66-
tokensToIds = Dictionary(uniqueKeysWithValues: vocab.map { $0.token }.enumerated().map { ($1, $0) })
67-
bosTokenId = tokensToIds[bosToken!] // May be nil
68-
66+
tokensToIds = Dictionary(uniqueKeysWithValues: vocab.map { $0.token as NSString }.enumerated().map { ($1, $0) })
67+
bosTokenId = tokensToIds[bosToken! as NSString] // May be nil
68+
6969
eosToken = tokenizerConfig.eosToken?.stringValue
70-
eosTokenId = eosToken == nil ? nil : tokensToIds[eosToken!]
71-
70+
eosTokenId = eosToken == nil ? nil : tokensToIds[eosToken! as NSString]
71+
7272
trie = Trie()
7373
trie.append(contentsOf: vocab.map { $0.token })
7474

7575
// TODO: set fuse_unk to true
7676
}
7777

7878
func convertTokenToId(_ token: String) -> Int? {
79-
return tokensToIds[token] ?? self.unknownTokenId
79+
return tokensToIds[token as NSString] ?? self.unknownTokenId
8080
}
8181

8282
func convertIdToToken(_ id: Int) -> String? {
@@ -95,7 +95,7 @@ class UnigramTokenizer: PreTrainedTokenizerModel {
9595

9696
let beginIndex = sentence.index(sentence.startIndex, offsetBy: beginPos)
9797
for token in trie.commonPrefixSearchIterator(sentence[beginIndex...]).map({ String($0) }) {
98-
guard let tokenId = tokensToIds[token] else { fatalError("Token not in vocab: \(token)") }
98+
guard let tokenId = tokensToIds[token as NSString] else { fatalError("Token not in vocab: \(token)") }
9999
let tokenScore = vocab[tokenId].score
100100
lattice.insert(startOffset: beginPos, length: token.count, score: tokenScore, tokenId: tokenId)
101101
if !hasSingleNode && token.count == mblen {

Tests/HubTests/HubTests.swift

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,4 +97,25 @@ class HubTests: XCTestCase {
9797
XCTFail("Cannot download test configuration from the Hub: \(error)")
9898
}
9999
}
100+
101+
func testConfigUnicode() {
102+
// These are two different characters
103+
let json = "{\"vocab\": {\"\": 1, \"à\": 2}}"
104+
let data = json.data(using: .utf8)
105+
let dict = try! JSONSerialization.jsonObject(with: data!, options: []) as! [NSString: Any]
106+
let config = Config(dict)
107+
108+
let vocab_nsdict = config.dictionary["vocab"] as! NSDictionary
109+
let vocab_nsstring = config.dictionary["vocab"] as! [NSString: Int]
110+
let vocab = config.vocab!.dictionary
111+
112+
XCTAssertEqual(vocab_nsdict.count, 2)
113+
XCTAssertEqual(vocab_nsstring.count, 2)
114+
XCTAssertEqual(vocab.count, 2)
115+
116+
// This is expected because, unlike with NSString, String comparison uses the canonical Unicode representation
117+
// https://developer.apple.com/documentation/swift/string#Modifying-and-Comparing-Strings
118+
let vocab_dict = config.dictionary["vocab"] as! [String: Int]
119+
XCTAssertNotEqual(vocab_dict.count, 2)
120+
}
100121
}

Tests/TokenizersTests/AddedTokensTests.swift

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,21 @@ import Hub
1111

1212
class AddedTokensTests: XCTestCase {
1313
func testPhiAddedTokens() async throws {
14-
let tokenizer = try await AutoTokenizer.from(pretrained: "mlx-community/Phi-3-mini-128k-instruct-4bit")
14+
let tokenizer = try await AutoTokenizer.from(pretrained: "microsoft/Phi-3-mini-128k-instruct")
1515
let inputIds = tokenizer("This is the <|end|>. My only friend, the <|end|>")
16-
XCTAssertEqual(inputIds, [1, 910, 338, 278, 29871, 32007, 29889, 1619, 871, 5121, 29892, 278, 29871, 32007])
16+
XCTAssertEqual(inputIds, [910, 338, 278, 29871, 32007, 29889, 1619, 871, 5121, 29892, 278, 29871, 32007])
1717

1818
let decoded = tokenizer.decode(tokens: inputIds)
19-
XCTAssertEqual(decoded, "<s> This is the <|end|>. My only friend, the <|end|>")
19+
XCTAssertEqual(decoded, "This is the <|end|>. My only friend, the <|end|>")
20+
}
21+
22+
func testGemmaAddedTokens() async throws {
23+
let tokenizer = try await AutoTokenizer.from(pretrained: "pcuenq/gemma-tokenizer")
24+
let inputIds = tokenizer("This\n\nis\na\ntest.")
25+
XCTAssertEqual(inputIds, [2, 1596, 109, 502, 108, 235250, 108, 2195, 235265])
26+
27+
let decoded = tokenizer.decode(tokens: inputIds)
28+
XCTAssertEqual(decoded, "<bos>This\n\nis\na\ntest.")
2029
}
2130

2231
func testSplitWithCaptureGroups() {
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"text": "Fatouville-Grestain est une commune du Nord-Ouest du d\u00e9partement de l'Eure situ\u00e9e au \nbord de l'estuaire de la Seine et \u00e0 proximit\u00e9 du d\u00e9partement du Calvados. Selon l'atlas des paysages \nde Haute-Normandie, elle appartient \u00e0 la r\u00e9gion naturelle du Lieuvin. Toutefois, l'Agreste, le service \nde la statistique et de la prospective du minist\u00e8re de l'Agriculture, de l'Agroalimentaire et de la For\u00eat, \nla classe au sein du pays d'Auge (en tant que r\u00e9gion agricole).La commune est \u00e0 moins de dix kilom\u00e8tres \u00e0 \nl'est de Honfleur, \u00e0 autant de Beuzeville et \u00e0 environ dix-sept kilom\u00e8tres de Pont-Audemer.", "bpe_tokens": ["Fat", "ou", "ville", "-", "G", "rest", "ain", "\u2581est", "\u2581une", "\u2581commune", "\u2581du", "\u2581Nord", "-", "Ouest", "\u2581du", "\u2581d\u00e9partement", "\u2581de", "\u2581l", "'", "Eure", "\u2581situ\u00e9e", "\u2581au", "\u2581", "\n", "bord", "\u2581de", "\u2581l", "'", "est", "uaire", "\u2581de", "\u2581la", "\u2581Seine", "\u2581et", "\u2581\u00e0", "\u2581proximit\u00e9", "\u2581du", "\u2581d\u00e9partement", "\u2581du", "\u2581Cal", "vados", ".", "\u2581Selon", "\u2581l", "'", "atlas", "\u2581des", "\u2581paysages", "\u2581", "\n", "de", "\u2581Haute", "-", "Norman", "die", ",", "\u2581elle", "\u2581appartient", "\u2581\u00e0", "\u2581la", "\u2581r\u00e9gion", "\u2581naturelle", "\u2581du", "\u2581Lieu", "vin", ".", "\u2581Toutefois", ",", "\u2581l", "'", "Ag", "reste", ",", "\u2581le", "\u2581service", "\u2581", "\n", "de", "\u2581la", "\u2581statistique", "\u2581et", "\u2581de", "\u2581la", "\u2581prospective", "\u2581du", "\u2581minist\u00e8re", "\u2581de", "\u2581l", "'", "Agriculture", ",", "\u2581de", "\u2581l", "'", "Agro", "alimenta", "ire", "\u2581et", "\u2581de", "\u2581la", "\u2581For", "\u00eat", ",", "\u2581", "\n", "la", "\u2581classe", "\u2581au", "\u2581sein", "\u2581du", "\u2581pays", "\u2581d", "'", "Au", "ge", "\u2581(", "en", "\u2581tant", "\u2581que", "\u2581r\u00e9gion", "\u2581agricole", ").", "La", "\u2581commune", "\u2581est", "\u2581\u00e0", "\u2581moins", "\u2581de", "\u2581dix", "\u2581kilom\u00e8tres", "\u2581\u00e0", "\u2581", "\n", "l", "'", "est", "\u2581de", "\u2581Hon", "fleur", ",", "\u2581\u00e0", "\u2581autant", "\u2581de", "\u2581Be", "uze", "ville", "\u2581et", "\u2581\u00e0", "\u2581environ", "\u2581dix", "-", "sept", "\u2581kilom\u00e8tres", "\u2581de", "\u2581Pont", "-", "Au", "de", "mer", "."], "token_ids": [2, 33690, 507, 5259, 235290, 235319, 4803, 985, 1455, 2360, 34960, 1344, 14852, 235290, 101323, 1344, 57781, 581, 533, 235303, 128985, 80493, 992, 235248, 108, 51123, 581, 533, 235303, 644, 106910, 581, 683, 53876, 1008, 1305, 72883, 1344, 57781, 1344, 2659, 119613, 235265, 86721, 533, 235303, 64117, 848, 141362, 235248, 108, 495, 70628, 235290, 74906, 3917, 235269, 11340, 133635, 1305, 683, 33927, 72277, 1344, 174959, 2964, 235265, 145673, 235269, 533, 235303, 6665, 62423, 235269, 709, 2566, 235248, 108, 495, 683, 160719, 1008, 581, 683, 40675, 1344, 85986, 581, 533, 235303, 79742, 235269, 581, 533, 235303, 166317, 104544, 844, 1008, 581, 683, 1699, 19941, 235269, 235248, 108, 522, 30739, 992, 8399, 1344, 11928, 499, 235303, 2159, 541, 591, 479, 21482, 907, 33927, 113917, 846, 2841, 34960, 1455, 1305, 15006, 581, 51102, 118516, 1305, 235248, 108, 235257, 235303, 644, 581, 9073, 129564, 235269, 1305, 54409, 581, 2065, 52172, 5259, 1008, 1305, 15265, 51102, 235290, 91012, 118516, 581, 52291, 235290, 2159, 495, 977, 235265], "decoded_text": "<bos>Fatouville-Grestain est une commune du Nord-Ouest du d\u00e9partement de l'Eure situ\u00e9e au \nbord de l'estuaire de la Seine et \u00e0 proximit\u00e9 du d\u00e9partement du Calvados. Selon l'atlas des paysages \nde Haute-Normandie, elle appartient \u00e0 la r\u00e9gion naturelle du Lieuvin. Toutefois, l'Agreste, le service \nde la statistique et de la prospective du minist\u00e8re de l'Agriculture, de l'Agroalimentaire et de la For\u00eat, \nla classe au sein du pays d'Auge (en tant que r\u00e9gion agricole).La commune est \u00e0 moins de dix kilom\u00e8tres \u00e0 \nl'est de Honfleur, \u00e0 autant de Beuzeville et \u00e0 environ dix-sept kilom\u00e8tres de Pont-Audemer."}

Tests/TokenizersTests/Resources/tokenizer_tests.json

Lines changed: 1 addition & 1 deletion
Large diffs are not rendered by default.

Tests/TokenizersTests/TokenizerTests.swift

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,13 @@ class LlamaTokenizerTests: TokenizerTests {
2828
override class var hubModelName: String? { "coreml-projects/Llama-2-7b-chat-coreml" }
2929
override class var encodedSamplesFilename: String? { "llama_encoded" }
3030
override class var unknownTokenId: Int? { 0 }
31+
32+
func testHexaEncode() async {
33+
if let tester = Self._tester {
34+
let tokenized = await tester.tokenizer?.tokenize(text: "\n")
35+
XCTAssertEqual(tokenized, ["", "<0x0A>"])
36+
}
37+
}
3138
}
3239

3340
class WhisperLargeTokenizerTests: TokenizerTests {
@@ -48,6 +55,41 @@ class T5TokenizerTests: TokenizerTests {
4855
override class var unknownTokenId: Int? { 2 }
4956
}
5057

58+
class GemmaTokenizerTests: TokenizerTests {
59+
override class var hubModelName: String? { "pcuenq/gemma-tokenizer" }
60+
override class var encodedSamplesFilename: String? { "gemma_encoded" }
61+
override class var unknownTokenId: Int? { 3 }
62+
63+
func testUnicodeEdgeCase() async {
64+
guard let tester = Self._tester else {
65+
XCTFail()
66+
return
67+
}
68+
69+
// These are two different characters
70+
let cases = ["" /* 0x61 0x300 */, "à" /* 0xe0 */]
71+
let expected = [217138, 1305]
72+
73+
// These are different characters
74+
for (s, expected) in zip(cases, expected) {
75+
let encoded = await tester.tokenizer?.encode(text: " " + s)
76+
XCTAssertEqual(encoded, [2, expected])
77+
}
78+
}
79+
}
80+
81+
class GemmaUnicodeTests: XCTestCase {
82+
func testGemmaVocab() async throws {
83+
guard let tokenizer = try await AutoTokenizer.from(pretrained: "pcuenq/gemma-tokenizer") as? PreTrainedTokenizer else {
84+
XCTFail()
85+
return
86+
}
87+
88+
// FIXME: This should be 256_000, I believe
89+
XCTAssertEqual((tokenizer.model as? BPETokenizer)?.vocabCount, 255994)
90+
}
91+
}
92+
5193

5294
struct EncodedTokenizerSamplesDataset: Decodable {
5395
let text: String
@@ -156,7 +198,10 @@ class TokenizerTester {
156198

157199
/// Test encode and decode for a few edge cases
158200
func testEdgeCases() async {
159-
guard let edgeCases = edgeCases else { return }
201+
guard let edgeCases = edgeCases else {
202+
print("Edge cases test ignored")
203+
return
204+
}
160205
guard let tokenizer = await tokenizer else { return }
161206
for edgeCase in edgeCases {
162207
print("Testing \(edgeCase.input)")

0 commit comments

Comments
 (0)