@@ -9,6 +9,9 @@ import Hub
9
9
import Foundation
10
10
import Jinja
11
11
12
+ public typealias Message = [ String : Any ]
13
+ public typealias ToolSpec = [ String : Any ]
14
+
12
15
enum TokenizerError : Error {
13
16
case missingConfig
14
17
case missingTokenizerClassInConfig
@@ -142,23 +145,57 @@ public protocol Tokenizer {
142
145
var unknownTokenId : Int ? { get }
143
146
144
147
/// The appropriate chat template is selected from the tokenizer config
145
- func applyChatTemplate( messages: [ [ String : String ] ] ) throws -> [ Int ]
148
+ func applyChatTemplate( messages: [ Message ] ) throws -> [ Int ]
149
+
150
+ /// The appropriate chat template is selected from the tokenizer config
151
+ func applyChatTemplate( messages: [ Message ] , tools: [ ToolSpec ] ) throws -> [ Int ]
146
152
147
153
/// The chat template is provided as a string literal or specified by name
148
- func applyChatTemplate( messages: [ [ String : String ] ] , chatTemplate: ChatTemplateArgument ) throws -> [ Int ]
154
+ func applyChatTemplate( messages: [ Message ] , chatTemplate: ChatTemplateArgument ) throws -> [ Int ]
149
155
150
156
/// The chat template is provided as a string literal
151
- func applyChatTemplate( messages: [ [ String : String ] ] , chatTemplate: String ) throws -> [ Int ]
157
+ func applyChatTemplate( messages: [ Message ] , chatTemplate: String ) throws -> [ Int ]
152
158
153
159
func applyChatTemplate(
154
- messages: [ [ String : String ] ] ,
160
+ messages: [ Message ] ,
155
161
/// A chat template can optionally be provided or specified by name when several templates are included in the tokenizer config. Normally this is not necessary.
156
162
chatTemplate: ChatTemplateArgument ? ,
157
163
addGenerationPrompt: Bool ,
158
164
truncation: Bool ,
159
165
maxLength: Int ? ,
160
- tools: [ [ String : Any ] ] ?
166
+ tools: [ ToolSpec ] ?
161
167
) throws -> [ Int ]
168
+
169
+ func applyChatTemplate(
170
+ messages: [ Message ] ,
171
+ /// A chat template can optionally be provided or specified by name when several templates are included in the tokenizer config. Normally this is not necessary.
172
+ chatTemplate: ChatTemplateArgument ? ,
173
+ addGenerationPrompt: Bool ,
174
+ truncation: Bool ,
175
+ maxLength: Int ? ,
176
+ tools: [ ToolSpec ] ? ,
177
+ additionalContext: [ String : Any ] ?
178
+ ) throws -> [ Int ]
179
+ }
180
+
181
+ extension Tokenizer {
182
+ /// Call previous signature for backwards compatibility
183
+ func applyChatTemplate(
184
+ messages: [ Message ] ,
185
+ /// A chat template can optionally be provided or specified by name when several templates are included in the tokenizer config. Normally this is not necessary.
186
+ chatTemplate: ChatTemplateArgument ? ,
187
+ addGenerationPrompt: Bool ,
188
+ truncation: Bool ,
189
+ maxLength: Int ? ,
190
+ tools: [ ToolSpec ] ? ,
191
+ additionalContext: [ String : Any ] ?
192
+ ) throws -> [ Int ] {
193
+ if additionalContext == nil {
194
+ try applyChatTemplate ( messages: messages, chatTemplate: chatTemplate, addGenerationPrompt: addGenerationPrompt, truncation: truncation, maxLength: maxLength, tools: tools)
195
+ } else {
196
+ throw TokenizerError . chatTemplate ( " Not implemented " )
197
+ }
198
+ }
162
199
}
163
200
164
201
public extension Tokenizer {
@@ -359,20 +396,46 @@ public class PreTrainedTokenizer: Tokenizer {
359
396
model. convertIdToToken ( id)
360
397
}
361
398
362
- public func applyChatTemplate( messages: [ [ String : String ] ] ) throws -> [ Int ] {
399
+ public func applyChatTemplate( messages: [ Message ] ) throws -> [ Int ] {
363
400
try applyChatTemplate ( messages: messages, addGenerationPrompt: true )
364
401
}
365
402
366
- public func applyChatTemplate( messages: [ [ String : String ] ] , chatTemplate: ChatTemplateArgument ) throws -> [ Int ] {
403
+ public func applyChatTemplate( messages: [ Message ] , tools: [ ToolSpec ] ) throws -> [ Int ] {
404
+ try applyChatTemplate ( messages: messages, addGenerationPrompt: true , tools: tools)
405
+ }
406
+
407
+ public func applyChatTemplate( messages: [ Message ] , tools: [ ToolSpec ] , additionalContext: [ String : Any ] ) throws
408
+ -> [ Int ]
409
+ {
410
+ try applyChatTemplate (
411
+ messages: messages,
412
+ addGenerationPrompt: true ,
413
+ tools: tools,
414
+ additionalContext: additionalContext
415
+ )
416
+ }
417
+
418
+ public func applyChatTemplate( messages: [ Message ] , chatTemplate: ChatTemplateArgument ) throws -> [ Int ] {
367
419
try applyChatTemplate ( messages: messages, chatTemplate: chatTemplate, addGenerationPrompt: true )
368
420
}
369
421
370
- public func applyChatTemplate( messages: [ [ String : String ] ] , chatTemplate: String ) throws -> [ Int ] {
422
+ public func applyChatTemplate( messages: [ Message ] , chatTemplate: String ) throws -> [ Int ] {
371
423
try applyChatTemplate ( messages: messages, chatTemplate: . literal( chatTemplate) , addGenerationPrompt: true )
372
424
}
373
425
374
426
public func applyChatTemplate(
375
- messages: [ [ String : String ] ] ,
427
+ messages: [ Message ] ,
428
+ chatTemplate: ChatTemplateArgument ? = nil ,
429
+ addGenerationPrompt: Bool = false ,
430
+ truncation: Bool = false ,
431
+ maxLength: Int ? = nil ,
432
+ tools: [ ToolSpec ] ? = nil
433
+ ) throws -> [ Int ] {
434
+ try applyChatTemplate ( messages: messages, chatTemplate: chatTemplate, addGenerationPrompt: addGenerationPrompt, truncation: truncation, maxLength: maxLength, tools: tools, additionalContext: nil )
435
+ }
436
+
437
+ public func applyChatTemplate(
438
+ messages: [ Message ] ,
376
439
chatTemplate: ChatTemplateArgument ? = nil ,
377
440
addGenerationPrompt: Bool = false ,
378
441
truncation: Bool = false ,
@@ -382,8 +445,8 @@ public class PreTrainedTokenizer: Tokenizer {
382
445
/// giving the name, description and argument types for the tool. See the
383
446
/// [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use)
384
447
/// for more information.
385
- /// Note: tool calling is not supported yet, it will be available in a future update.
386
- tools : [ [ String : Any ] ] ? = nil
448
+ tools : [ ToolSpec ] ? = nil ,
449
+ additionalContext : [ String : Any ] ? = nil
387
450
) throws -> [ Int ] {
388
451
var selectedChatTemplate : String ?
389
452
if let chatTemplate, case . literal( let template) = chatTemplate {
@@ -425,10 +488,21 @@ public class PreTrainedTokenizer: Tokenizer {
425
488
let template = try Template ( selectedChatTemplate)
426
489
var context : [ String : Any ] = [
427
490
" messages " : messages,
428
- " add_generation_prompt " : addGenerationPrompt
429
- // TODO: Add `tools` entry when support is added in Jinja
430
- // "tools": tools
491
+ " add_generation_prompt " : addGenerationPrompt,
431
492
]
493
+ if let tools {
494
+ context [ " tools " ] = tools
495
+ }
496
+ if let additionalContext {
497
+ /*
498
+ Additional keys and values to be added to the context provided to the prompt templating engine.
499
+ For example, the app could set "tools_in_user_message" to false for Llama 3.1 and 3.2 if a system message is provided.
500
+ The default value is true in the Llama 3.1 and 3.2 chat templates, but these models will perform better if the tools are included in a system message.
501
+ */
502
+ for (key, value) in additionalContext {
503
+ context [ key] = value
504
+ }
505
+ }
432
506
433
507
// TODO: maybe keep NSString here
434
508
for (key, value) in tokenizerConfig. dictionary as [ String : Any ] {
0 commit comments