Skip to content

Commit 44e2c04

Browse files
authored
Add skipSpecialTokens option to Tokenizer.decode (#148)
1 parent fc95ce1 commit 44e2c04

File tree

2 files changed

+19
-3
lines changed

2 files changed

+19
-3
lines changed

Sources/Tokenizers/Tokenizer.swift

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ public protocol Tokenizer {
112112

113113
/// Decode
114114
func decode(tokens: [Int]) -> String
115+
func decode(tokens: [Int], skipSpecialTokens: Bool) -> String
115116

116117
func convertTokenToId(_ token: String) -> Int?
117118
func convertTokensToIds(_ tokens: [String]) -> [Int?]
@@ -150,6 +151,10 @@ public extension Tokenizer {
150151
func callAsFunction(_ text: String, addSpecialTokens: Bool = true) -> [Int] {
151152
encode(text: text, addSpecialTokens: addSpecialTokens)
152153
}
154+
155+
func decode(tokens: [Int]) -> String {
156+
decode(tokens: tokens, skipSpecialTokens: false)
157+
}
153158

154159
func convertTokensToIds(_ tokens: [String]) -> [Int?] {
155160
return tokens.map { convertTokenToId($0) }
@@ -315,10 +320,17 @@ public class PreTrainedTokenizer: Tokenizer {
315320
return encode(text: text, addSpecialTokens: true)
316321
}
317322

318-
/// Decode
319-
public func decode(tokens: [Int]) -> String {
323+
public func decode(tokens: [Int], skipSpecialTokens: Bool = false) -> String {
320324
// IDs to tokens
321-
let tokenStrings = tokens.compactMap { model.convertIdToToken($0) }
325+
let tokenStrings: [String]
326+
if skipSpecialTokens {
327+
let specialTokenIDs = Set(specialTokens.values)
328+
tokenStrings = tokens
329+
.filter { !specialTokenIDs.contains($0) }
330+
.compactMap { model.convertIdToToken($0) }
331+
} else {
332+
tokenStrings = tokens.compactMap { model.convertIdToToken($0) }
333+
}
322334
let decoded = decodeTokens(tokenStrings)
323335
// At this point we should have a single String
324336
return cleanUp(text: decoded.joined(separator: ""))

Tests/TokenizersTests/TokenizerTests.swift

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,10 @@ class TokenizerTester {
218218
tokenizer.decode(tokens: edgeCase.encoded.input_ids),
219219
edgeCase.decoded_with_special
220220
)
221+
XCTAssertEqual(
222+
tokenizer.decode(tokens: edgeCase.encoded.input_ids, skipSpecialTokens: true),
223+
edgeCase.decoded_without_special
224+
)
221225
}
222226
}
223227

0 commit comments

Comments
 (0)