7
7
8
8
import Hub
9
9
import Foundation
10
+ import Jinja
10
11
11
12
enum TokenizerError : Error {
12
13
case missingConfig
@@ -98,7 +99,8 @@ public protocol Tokenizer {
98
99
99
100
/// Main entry point
100
101
func encode( text: String ) -> [ Int ]
101
- func callAsFunction( _ text: String ) -> [ Int ]
102
+ func encode( text: String , addSpecialTokens: Bool ) -> [ Int ]
103
+ func callAsFunction( _ text: String , addSpecialTokens: Bool ) -> [ Int ]
102
104
103
105
/// Decode
104
106
func decode( tokens: [ Int ] ) -> String
@@ -115,11 +117,21 @@ public protocol Tokenizer {
115
117
var eosTokenId : Int ? { get }
116
118
var unknownToken : String ? { get }
117
119
var unknownTokenId : Int ? { get }
120
+
121
+ func applyChatTemplate( messages: [ [ String : String ] ] ) throws -> [ Int ]
122
+
123
+ func applyChatTemplate(
124
+ messages: [ [ String : String ] ] ,
125
+ chatTemplate: String ? ,
126
+ addGenerationPrompt: Bool ,
127
+ truncation: Bool ,
128
+ maxLength: Int ?
129
+ ) throws -> [ Int ]
118
130
}
119
131
120
132
public extension Tokenizer {
121
- func callAsFunction( _ text: String ) -> [ Int ] {
122
- encode ( text: text)
133
+ func callAsFunction( _ text: String , addSpecialTokens : Bool = true ) -> [ Int ] {
134
+ encode ( text: text, addSpecialTokens : addSpecialTokens )
123
135
}
124
136
125
137
func convertTokensToIds( _ tokens: [ String ] ) -> [ Int ? ] {
@@ -131,6 +143,17 @@ public extension Tokenizer {
131
143
}
132
144
}
133
145
146
+ let specialTokenAttributes : [ String ] = [
147
+ " bos_token " ,
148
+ " eos_token " ,
149
+ " unk_token " ,
150
+ " sep_token " ,
151
+ " pad_token " ,
152
+ " cls_token " ,
153
+ " mask_token " ,
154
+ " additional_special_tokens "
155
+ ]
156
+
134
157
public class PreTrainedTokenizer : Tokenizer {
135
158
let model : TokenizingModel
136
159
@@ -150,8 +173,11 @@ public class PreTrainedTokenizer: Tokenizer {
150
173
private let normalizer : Normalizer ?
151
174
private let postProcessor : PostProcessor ?
152
175
private let decoder : Decoder ?
176
+ private let tokenizerConfig : Config
153
177
154
178
private let cleanUpTokenizationSpaces : Bool
179
+
180
+ private let defaultChatTemplate : String = " {% for message in messages %}{{'<|im_start|>' + message['role'] + ' \n ' + message['content'] + '<|im_end|>' + ' \n '}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant \n ' }}{% endif %} "
155
181
156
182
required public init ( tokenizerConfig: Config , tokenizerData: Config ) throws {
157
183
var addedTokens : [ String : Int ] = [ : ]
@@ -195,7 +221,8 @@ public class PreTrainedTokenizer: Tokenizer {
195
221
self . postProcessor = PostProcessorFactory . fromConfig ( config: tokenizerData. postProcessor)
196
222
self . decoder = DecoderFactory . fromConfig ( config: tokenizerData. decoder)
197
223
self . cleanUpTokenizationSpaces = tokenizerConfig. cleanUpTokenizationSpaces? . boolValue ?? true
198
-
224
+ self . tokenizerConfig = tokenizerConfig
225
+
199
226
model = try TokenizerModel . from ( tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData, addedTokens: addedTokens)
200
227
}
201
228
@@ -209,9 +236,9 @@ public class PreTrainedTokenizer: Tokenizer {
209
236
return normalizer ( text: text)
210
237
}
211
238
212
- func postProcess( _ tokens: [ String ] ) -> [ String ] {
239
+ func postProcess( _ tokens: [ String ] , addSpecialTokens : Bool = true ) -> [ String ] {
213
240
guard let postProcessor = postProcessor else { return tokens }
214
- return postProcessor ( tokens: tokens)
241
+ return postProcessor ( tokens: tokens, addSpecialTokens : addSpecialTokens )
215
242
}
216
243
217
244
func decodeTokens( _ tokens: [ String ] ) -> [ String ] {
@@ -265,8 +292,12 @@ public class PreTrainedTokenizer: Tokenizer {
265
292
}
266
293
267
294
/// Main entry point
295
+ public func encode( text: String , addSpecialTokens: Bool = true ) -> [ Int ] {
296
+ return postProcess ( tokenize ( text: text) , addSpecialTokens: addSpecialTokens) . map { model. convertTokenToId ( $0) ! }
297
+ }
298
+
268
299
public func encode( text: String ) -> [ Int ] {
269
- return postProcess ( tokenize ( text: text) ) . map { model . convertTokenToId ( $0 ) ! }
300
+ return encode ( text: text, addSpecialTokens : true )
270
301
}
271
302
272
303
/// Decode
@@ -285,6 +316,43 @@ public class PreTrainedTokenizer: Tokenizer {
285
316
public func convertIdToToken( _ id: Int ) -> String ? {
286
317
model. convertIdToToken ( id)
287
318
}
319
+
320
+ public func applyChatTemplate( messages: [ [ String : String ] ] ) throws -> [ Int ] {
321
+ try applyChatTemplate ( messages: messages, chatTemplate: nil , addGenerationPrompt: true , maxLength: nil )
322
+ }
323
+
324
+ public func applyChatTemplate(
325
+ messages: [ [ String : String ] ] ,
326
+ chatTemplate: String ? ,
327
+ addGenerationPrompt: Bool = false ,
328
+ truncation: Bool = false ,
329
+ maxLength: Int ?
330
+ ) throws -> [ Int ] {
331
+ let template = try Template ( chatTemplate ?? tokenizerConfig. chatTemplate? . stringValue ?? defaultChatTemplate)
332
+ var context : [ String : Any ] = [
333
+ " messages " : messages,
334
+ " add_generation_prompt " : addGenerationPrompt
335
+ ]
336
+
337
+ // TODO: maybe keep NSString here
338
+ for (key, value) in tokenizerConfig. dictionary as [ String : Any ] {
339
+ if specialTokenAttributes. contains ( key) , !( value is NSNull ) {
340
+ context [ key] = value
341
+ }
342
+ }
343
+
344
+ let rendered = try template. render ( context)
345
+ var encodedTokens = encode ( text: rendered, addSpecialTokens: false )
346
+ var maxLength = maxLength ?? encodedTokens. count
347
+ maxLength = min ( maxLength, tokenizerConfig. modelMaxLength? . intValue ?? maxLength)
348
+ if encodedTokens. count > maxLength {
349
+ if truncation {
350
+ encodedTokens = Array ( encodedTokens. prefix ( maxLength) )
351
+ }
352
+ }
353
+
354
+ return encodedTokens
355
+ }
288
356
}
289
357
290
358
// MARK: - Building
0 commit comments