@@ -9,13 +9,13 @@ import Hub
9
9
import Foundation
10
10
import Jinja
11
11
12
- enum TokenizerError : Error {
12
+ enum TokenizerError : Error {
13
13
case missingConfig
14
14
case missingTokenizerClassInConfig
15
15
case unsupportedTokenizer( String )
16
16
case missingVocab
17
17
case malformedVocab
18
-
18
+ case chatTemplate ( String )
19
19
case tooLong( String )
20
20
}
21
21
@@ -94,6 +94,13 @@ struct TokenizerModel {
94
94
}
95
95
}
96
96
97
+ public enum ChatTemplateArgument {
98
+ /// A Jinja template to use for the conversation. Normally it is not necessary to provide a template, since it will be read from the tokenizer config.
99
+ case literal( String )
100
+ /// For models whose tokenizer config includes multiple chat templates, the template can be specified by name. Normally this is not necessary.
101
+ case name( String )
102
+ }
103
+
97
104
public protocol Tokenizer {
98
105
func tokenize( text: String ) -> [ String ]
99
106
@@ -117,15 +124,24 @@ public protocol Tokenizer {
117
124
var eosTokenId : Int ? { get }
118
125
var unknownToken : String ? { get }
119
126
var unknownTokenId : Int ? { get }
120
-
127
+
128
+ /// The appropriate chat template is selected from the tokenizer config
121
129
func applyChatTemplate( messages: [ [ String : String ] ] ) throws -> [ Int ]
122
-
130
+
131
+ /// The chat template is provided as a string literal or specified by name
132
+ func applyChatTemplate( messages: [ [ String : String ] ] , chatTemplate: ChatTemplateArgument ) throws -> [ Int ]
133
+
134
+ /// The chat template is provided as a string literal
135
+ func applyChatTemplate( messages: [ [ String : String ] ] , chatTemplate: String ) throws -> [ Int ]
136
+
123
137
func applyChatTemplate(
124
138
messages: [ [ String : String ] ] ,
125
- chatTemplate: String ? ,
139
+ /// A chat template can optionally be provided or specified by name when several templates are included in the tokenizer config. Normally this is not necessary.
140
+ chatTemplate: ChatTemplateArgument ? ,
126
141
addGenerationPrompt: Bool ,
127
142
truncation: Bool ,
128
- maxLength: Int ?
143
+ maxLength: Int ? ,
144
+ tools: [ [ String : Any ] ] ?
129
145
) throws -> [ Int ]
130
146
}
131
147
@@ -176,8 +192,6 @@ public class PreTrainedTokenizer: Tokenizer {
176
192
private let tokenizerConfig : Config
177
193
178
194
private let cleanUpTokenizationSpaces : Bool
179
-
180
- private let defaultChatTemplate : String = " {% for message in messages %}{{'<|im_start|>' + message['role'] + ' \n ' + message['content'] + '<|im_end|>' + ' \n '}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant \n ' }}{% endif %} "
181
195
182
196
required public init ( tokenizerConfig: Config , tokenizerData: Config ) throws {
183
197
var addedTokens : [ String : Int ] = [ : ]
@@ -222,7 +236,7 @@ public class PreTrainedTokenizer: Tokenizer {
222
236
self . decoder = DecoderFactory . fromConfig ( config: tokenizerData. decoder)
223
237
self . cleanUpTokenizationSpaces = tokenizerConfig. cleanUpTokenizationSpaces? . boolValue ?? true
224
238
self . tokenizerConfig = tokenizerConfig
225
-
239
+
226
240
model = try TokenizerModel . from ( tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData, addedTokens: addedTokens)
227
241
}
228
242
@@ -316,22 +330,76 @@ public class PreTrainedTokenizer: Tokenizer {
316
330
public func convertIdToToken( _ id: Int ) -> String ? {
317
331
model. convertIdToToken ( id)
318
332
}
319
-
333
+
320
334
public func applyChatTemplate( messages: [ [ String : String ] ] ) throws -> [ Int ] {
321
- try applyChatTemplate ( messages: messages, chatTemplate: nil , addGenerationPrompt: true , maxLength: nil )
335
+ try applyChatTemplate ( messages: messages, addGenerationPrompt: true )
336
+ }
337
+
338
+ public func applyChatTemplate( messages: [ [ String : String ] ] , chatTemplate: ChatTemplateArgument ) throws -> [ Int ] {
339
+ try applyChatTemplate ( messages: messages, chatTemplate: chatTemplate, addGenerationPrompt: true )
322
340
}
323
-
341
+
342
+ public func applyChatTemplate( messages: [ [ String : String ] ] , chatTemplate: String ) throws -> [ Int ] {
343
+ try applyChatTemplate ( messages: messages, chatTemplate: . literal( chatTemplate) , addGenerationPrompt: true )
344
+ }
345
+
324
346
public func applyChatTemplate(
325
347
messages: [ [ String : String ] ] ,
326
- chatTemplate: String ? ,
348
+ chatTemplate: ChatTemplateArgument ? = nil ,
327
349
addGenerationPrompt: Bool = false ,
328
350
truncation: Bool = false ,
329
- maxLength: Int ?
351
+ maxLength: Int ? = nil ,
352
+ /// A list of tools (callable functions) that will be accessible to the model. If the template does not
353
+ /// support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema,
354
+ /// giving the name, description and argument types for the tool. See the
355
+ /// [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use)
356
+ /// for more information.
357
+ /// Note: tool calling is not supported yet, it will be available in a future update.
358
+ tools: [ [ String : Any ] ] ? = nil
330
359
) throws -> [ Int ] {
331
- let template = try Template ( chatTemplate ?? tokenizerConfig. chatTemplate? . stringValue ?? defaultChatTemplate)
360
+ var selectedChatTemplate : String ?
361
+ if let chatTemplate, case . literal( let template) = chatTemplate {
362
+ // Use chat template from argument
363
+ selectedChatTemplate = template
364
+ } else if let valueFromConfig = tokenizerConfig. chatTemplate {
365
+ if let arrayValue = valueFromConfig. arrayValue {
366
+ // If the config specifies a list of chat templates, convert them to a dictionary
367
+ let templateDict = Dictionary < String , String > ( uniqueKeysWithValues: arrayValue. compactMap { item in
368
+ guard let name = item. name? . stringValue, let template = item. template? . stringValue else {
369
+ return nil
370
+ }
371
+ return ( name, template)
372
+ } )
373
+ if let chatTemplate, case . name( let name) = chatTemplate {
374
+ // Select chat template from config by name
375
+ if let matchingDictEntry = templateDict [ name] {
376
+ selectedChatTemplate = matchingDictEntry
377
+ } else {
378
+ throw TokenizerError . chatTemplate ( " No chat template named \" \( name) \" was found in the tokenizer config " )
379
+ }
380
+ } else if let tools, !tools. isEmpty, let toolUseTemplate = templateDict [ " tool_use " ] {
381
+ // Use tool use chat template from config
382
+ selectedChatTemplate = toolUseTemplate
383
+ } else if let defaultChatTemplate = templateDict [ " default " ] {
384
+ // Use default chat template from config
385
+ selectedChatTemplate = defaultChatTemplate
386
+ }
387
+ } else if let stringValue = valueFromConfig. stringValue {
388
+ // Use chat template from config
389
+ selectedChatTemplate = stringValue
390
+ }
391
+ }
392
+
393
+ guard let selectedChatTemplate else {
394
+ throw TokenizerError . chatTemplate ( " No chat template was specified " )
395
+ }
396
+
397
+ let template = try Template ( selectedChatTemplate)
332
398
var context : [ String : Any ] = [
333
399
" messages " : messages,
334
400
" add_generation_prompt " : addGenerationPrompt
401
+ // TODO: Add `tools` entry when support is added in Jinja
402
+ // "tools": tools
335
403
]
336
404
337
405
// TODO: maybe keep NSString here
@@ -397,15 +465,15 @@ extension AutoTokenizer {
397
465
398
466
return try AutoTokenizer . from ( tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
399
467
}
400
-
468
+
401
469
public static func from(
402
470
modelFolder: URL ,
403
471
hubApi: HubApi = . shared
404
472
) async throws -> Tokenizer {
405
473
let config = LanguageModelConfigurationFromHub ( modelFolder: modelFolder, hubApi: hubApi)
406
474
guard let tokenizerConfig = try await config. tokenizerConfig else { throw TokenizerError . missingConfig }
407
475
let tokenizerData = try await config. tokenizerData
408
-
476
+
409
477
return try PreTrainedTokenizer ( tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
410
478
}
411
479
}
0 commit comments