Skip to content

Commit 2f611bf

Browse files
authored
Add WordPieceDecoder tests that match Tokenizers (#139)
1 parent 5751308 commit 2f611bf

File tree

2 files changed

+30
-17
lines changed

2 files changed

+30
-17
lines changed

Sources/Tokenizers/Decoder.swift

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -57,34 +57,28 @@ class WordPieceDecoder: Decoder {
5757
let prefix: String
5858
let cleanup: Bool
5959

60+
// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/decoders/wordpiece.rs#L31
61+
private let re = try! NSRegularExpression(pattern: "\\s(\\.|\\?|\\!|\\,|'|n't|'m|'s|'ve|'re)", options: [])
62+
6063
required public init(config: Config) {
6164
guard let prefix = config.prefix?.stringValue else { fatalError("Missing `prefix` configuration for WordPieceDecoder.") }
6265
self.prefix = prefix
6366
self.cleanup = config.cleanup?.boolValue ?? false
6467
}
6568

6669
func decode(tokens: [String]) -> [String] {
67-
var newTokens = [String]()
68-
newTokens.reserveCapacity(tokens.count)
69-
for (index, token) in tokens.enumerated() {
70-
var decodedToken = token
71-
if index != 0 {
72-
if decodedToken.hasPrefix(prefix) {
73-
decodedToken = String(decodedToken.dropFirst(prefix.count))
74-
} else {
75-
decodedToken = " \(decodedToken)"
76-
}
77-
}
78-
if cleanup {
79-
decodedToken = cleanUpTokenization(decodedToken)
80-
}
81-
newTokens.append(decodedToken)
70+
let firstToken = cleanup ? cleanUpTokenization(tokens.first!) : tokens.first!
71+
return [firstToken] + tokens.dropFirst().map { token in
72+
let token = token.hasPrefix(prefix) ? token.replacingCharacters(in: token.range(of: prefix)!, with: "") : " \(token)"
73+
return cleanup ? cleanUpTokenization(token) : token
8274
}
83-
return newTokens
8475
}
8576

77+
// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/decoders/wordpiece.rs#L40
8678
private func cleanUpTokenization(_ token: String) -> String {
87-
return token.trimmingCharacters(in: .whitespacesAndNewlines)
79+
let range = NSRange(location: 0, length: token.utf16.count)
80+
return re.stringByReplacingMatches(in: token, options: [], range: range, withTemplate: "$1")
81+
.replacingOccurrences(of: " do not", with: " don't")
8882
}
8983
}
9084

Tests/TokenizersTests/DecoderTests.swift

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,23 @@ class DecoderTests: XCTestCase {
2424
["Hey", " my", " friend", " ", " <s>", " how", " are", " you"]
2525
)
2626
}
27+
28+
func testWordPieceDecoder() {
29+
let config = Config(["prefix": "##", "cleanup": true])
30+
let decoder = WordPieceDecoder(config: config)
31+
32+
let testCases: [([String], String)] = [
33+
(["##inter", "##national", "##ization"], "##internationalization"),
34+
(["##auto", "##mat", "##ic", "transmission"], "##automatic transmission"),
35+
(["who", "do", "##n't", "does", "n't", "can't"], "who don't doesn't can't"),
36+
(["##un", "##believ", "##able", "##fa", "##ntastic"], "##unbelievablefantastic"),
37+
(["this", "is", "un", "##believ", "##able", "fa", "##ntastic"], "this is unbelievable fantastic"),
38+
(["The", "##quick", "##brown", "fox"], "Thequickbrown fox"),
39+
]
40+
41+
for (tokens, expected) in testCases {
42+
let output = decoder.decode(tokens: tokens)
43+
XCTAssertEqual(output.joined(), expected)
44+
}
45+
}
2746
}

0 commit comments

Comments
 (0)