@@ -112,6 +112,7 @@ public protocol Tokenizer {
112
112
113
113
/// Decode
114
114
func decode( tokens: [ Int ] ) -> String
115
+ func decode( tokens: [ Int ] , skipSpecialTokens: Bool ) -> String
115
116
116
117
func convertTokenToId( _ token: String ) -> Int ?
117
118
func convertTokensToIds( _ tokens: [ String ] ) -> [ Int ? ]
@@ -150,6 +151,10 @@ public extension Tokenizer {
150
151
func callAsFunction( _ text: String , addSpecialTokens: Bool = true ) -> [ Int ] {
151
152
encode ( text: text, addSpecialTokens: addSpecialTokens)
152
153
}
154
+
155
+ func decode( tokens: [ Int ] ) -> String {
156
+ decode ( tokens: tokens, skipSpecialTokens: false )
157
+ }
153
158
154
159
func convertTokensToIds( _ tokens: [ String ] ) -> [ Int ? ] {
155
160
return tokens. map { convertTokenToId ( $0) }
@@ -315,10 +320,17 @@ public class PreTrainedTokenizer: Tokenizer {
315
320
return encode ( text: text, addSpecialTokens: true )
316
321
}
317
322
318
- /// Decode
319
- public func decode( tokens: [ Int ] ) -> String {
323
+ public func decode( tokens: [ Int ] , skipSpecialTokens: Bool = false ) -> String {
320
324
// IDs to tokens
321
- let tokenStrings = tokens. compactMap { model. convertIdToToken ( $0) }
325
+ let tokenStrings : [ String ]
326
+ if skipSpecialTokens {
327
+ let specialTokenIDs = Set ( specialTokens. values)
328
+ tokenStrings = tokens
329
+ . filter { !specialTokenIDs. contains ( $0) }
330
+ . compactMap { model. convertIdToToken ( $0) }
331
+ } else {
332
+ tokenStrings = tokens. compactMap { model. convertIdToToken ( $0) }
333
+ }
322
334
let decoded = decodeTokens ( tokenStrings)
323
335
// At this point we should have a single String
324
336
return cleanUp ( text: decoded. joined ( separator: " " ) )
0 commit comments