Skip to content

Commit 0bb1a9d

Browse files
committed
Merge branch 'main' into feature/swift-format
2 parents 1df5c74 + 0b07561 commit 0bb1a9d

File tree

2 files changed

+39
-4
lines changed

2 files changed

+39
-4
lines changed

Sources/Tokenizers/BertTokenizer.swift

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,13 @@ public class BertTokenizer {
4545
self.fuseUnknownTokens = fuseUnknownTokens
4646
}
4747

48-
public required convenience init(tokenizerConfig: Config, tokenizerData: Config, addedTokens: [String: Int]) throws {
49-
guard let vocab = tokenizerData.model?.vocab?.dictionary as? [String: Int] else {
50-
throw TokenizerError.missingVocab
48+
public required convenience init(tokenizerConfig: Config, tokenizerData: Config, addedTokens: [String : Int]) throws {
49+
guard var vocab = tokenizerData.model?.vocab?.dictionary as? [String: Int] else { throw TokenizerError.missingVocab }
50+
if let addedTokens = tokenizerData.added_tokens?.dictionary["value"] as? [[String: Any]],
51+
let pairs = addedTokens.compactMap({ ($0["content"] as? String, $0["id"] as? Int) }) as? [(String, Int)] {
52+
vocab.merge(pairs, uniquingKeysWith: {$1})
5153
}
54+
vocab.merge(addedTokens, uniquingKeysWith: {$1})
5255
let merges = tokenizerData.model?.merges?.value as? [String]
5356
let tokenizeChineseChars = tokenizerConfig.handleChineseChars?.boolValue ?? true
5457
let eosToken = tokenizerConfig.eosToken?.stringValue

Tests/TokenizersTests/BertTokenizerTests.swift

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
//
88

99
@testable import Tokenizers
10-
import XCTest
10+
@testable import Hub
11+
1112

1213
class BertTokenizerTests: XCTestCase {
1314
override func setUp() {
@@ -175,4 +176,35 @@ class BertTokenizerTests: XCTestCase {
175176
XCTAssertEqual(decoded, String(expected))
176177
}
177178
}
179+
180+
func testBertTokenizerAddedTokensRecognized() async throws {
181+
let base: URL = FileManager.default.urls(for: .cachesDirectory, in: .userDomainMask).first!.appending(component: "huggingface-tests")
182+
let hubApi = HubApi(downloadBase: base)
183+
let configuration = LanguageModelConfigurationFromHub(modelName: "google-bert/bert-base-uncased", hubApi: hubApi)
184+
guard let tokenizerConfig = try await configuration.tokenizerConfig else { fatalError("missing tokenizer config") }
185+
let tokenizerData = try await configuration.tokenizerData
186+
let addedTokens = [
187+
"[ROAD]": 60_001,
188+
"[RIVER]": 60_002,
189+
"[BUILDING]": 60_003,
190+
"[PARK]": 60_004,
191+
"[BUFFER]": 60_005,
192+
"[INTERSECT]": 60_006,
193+
"[UNION]": 60_007,
194+
]
195+
let tokenizer = try BertTokenizer(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData, addedTokens: addedTokens)
196+
for (token, idx) in addedTokens {
197+
XCTAssertEqual(tokenizer.convertTokenToId(token), idx)
198+
}
199+
for (token, idx) in addedTokens {
200+
XCTAssertEqual(tokenizer.convertIdToToken(idx), token)
201+
}
202+
203+
// Reading added_tokens from tokenizer.json
204+
XCTAssertEqual(tokenizer.convertTokenToId("[PAD]"), 0)
205+
XCTAssertEqual(tokenizer.convertTokenToId("[UNK]"), 100)
206+
XCTAssertEqual(tokenizer.convertTokenToId("[CLS]"), 101)
207+
XCTAssertEqual(tokenizer.convertTokenToId("[SEP]"), 102)
208+
XCTAssertEqual(tokenizer.convertTokenToId("[MASK]"), 103)
209+
}
178210
}

0 commit comments

Comments
 (0)