@@ -9,6 +9,9 @@ import Hub
9
9
import Foundation
10
10
import Jinja
11
11
12
+ public typealias Message = [ String : Any ]
13
+ public typealias ToolSpec = [ String : Any ]
14
+
12
15
enum TokenizerError : Error {
13
16
case missingConfig
14
17
case missingTokenizerClassInConfig
@@ -141,22 +144,26 @@ public protocol Tokenizer {
141
144
var unknownTokenId : Int ? { get }
142
145
143
146
/// The appropriate chat template is selected from the tokenizer config
144
- func applyChatTemplate( messages: [ [ String : String ] ] ) throws -> [ Int ]
147
+ func applyChatTemplate( messages: [ Message ] ) throws -> [ Int ]
148
+
149
+ /// The appropriate chat template is selected from the tokenizer config
150
+ func applyChatTemplate( messages: [ Message ] , tools: [ ToolSpec ] ) throws -> [ Int ]
145
151
146
152
/// The chat template is provided as a string literal or specified by name
147
- func applyChatTemplate( messages: [ [ String : String ] ] , chatTemplate: ChatTemplateArgument ) throws -> [ Int ]
153
+ func applyChatTemplate( messages: [ Message ] , chatTemplate: ChatTemplateArgument ) throws -> [ Int ]
148
154
149
155
/// The chat template is provided as a string literal
150
- func applyChatTemplate( messages: [ [ String : String ] ] , chatTemplate: String ) throws -> [ Int ]
156
+ func applyChatTemplate( messages: [ Message ] , chatTemplate: String ) throws -> [ Int ]
151
157
152
158
func applyChatTemplate(
153
- messages: [ [ String : String ] ] ,
159
+ messages: [ Message ] ,
154
160
/// A chat template can optionally be provided or specified by name when several templates are included in the tokenizer config. Normally this is not necessary.
155
161
chatTemplate: ChatTemplateArgument ? ,
156
162
addGenerationPrompt: Bool ,
157
163
truncation: Bool ,
158
164
maxLength: Int ? ,
159
- tools: [ [ String : Any ] ] ?
165
+ tools: [ ToolSpec ] ? ,
166
+ additionalContext: [ String : Any ] ?
160
167
) throws -> [ Int ]
161
168
}
162
169
@@ -358,20 +365,35 @@ public class PreTrainedTokenizer: Tokenizer {
358
365
model. convertIdToToken ( id)
359
366
}
360
367
361
- public func applyChatTemplate( messages: [ [ String : String ] ] ) throws -> [ Int ] {
368
+ public func applyChatTemplate( messages: [ Message ] ) throws -> [ Int ] {
362
369
try applyChatTemplate ( messages: messages, addGenerationPrompt: true )
363
370
}
364
371
365
- public func applyChatTemplate( messages: [ [ String : String ] ] , chatTemplate: ChatTemplateArgument ) throws -> [ Int ] {
372
+ public func applyChatTemplate( messages: [ Message ] , tools: [ ToolSpec ] ) throws -> [ Int ] {
373
+ try applyChatTemplate ( messages: messages, addGenerationPrompt: true , tools: tools)
374
+ }
375
+
376
+ public func applyChatTemplate( messages: [ Message ] , tools: [ ToolSpec ] , additionalContext: [ String : Any ] ) throws
377
+ -> [ Int ]
378
+ {
379
+ try applyChatTemplate (
380
+ messages: messages,
381
+ addGenerationPrompt: true ,
382
+ tools: tools,
383
+ additionalContext: additionalContext
384
+ )
385
+ }
386
+
387
+ public func applyChatTemplate( messages: [ Message ] , chatTemplate: ChatTemplateArgument ) throws -> [ Int ] {
366
388
try applyChatTemplate ( messages: messages, chatTemplate: chatTemplate, addGenerationPrompt: true )
367
389
}
368
390
369
- public func applyChatTemplate( messages: [ [ String : String ] ] , chatTemplate: String ) throws -> [ Int ] {
391
+ public func applyChatTemplate( messages: [ Message ] , chatTemplate: String ) throws -> [ Int ] {
370
392
try applyChatTemplate ( messages: messages, chatTemplate: . literal( chatTemplate) , addGenerationPrompt: true )
371
393
}
372
394
373
395
public func applyChatTemplate(
374
- messages: [ [ String : String ] ] ,
396
+ messages: [ Message ] ,
375
397
chatTemplate: ChatTemplateArgument ? = nil ,
376
398
addGenerationPrompt: Bool = false ,
377
399
truncation: Bool = false ,
@@ -381,8 +403,8 @@ public class PreTrainedTokenizer: Tokenizer {
381
403
/// giving the name, description and argument types for the tool. See the
382
404
/// [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use)
383
405
/// for more information.
384
- /// Note: tool calling is not supported yet, it will be available in a future update.
385
- tools : [ [ String : Any ] ] ? = nil
406
+ tools : [ ToolSpec ] ? = nil ,
407
+ additionalContext : [ String : Any ] ? = nil
386
408
) throws -> [ Int ] {
387
409
var selectedChatTemplate : String ?
388
410
if let chatTemplate, case . literal( let template) = chatTemplate {
@@ -424,10 +446,21 @@ public class PreTrainedTokenizer: Tokenizer {
424
446
let template = try Template ( selectedChatTemplate)
425
447
var context : [ String : Any ] = [
426
448
" messages " : messages,
427
- " add_generation_prompt " : addGenerationPrompt
428
- // TODO: Add `tools` entry when support is added in Jinja
429
- // "tools": tools
449
+ " add_generation_prompt " : addGenerationPrompt,
430
450
]
451
+ if let tools {
452
+ context [ " tools " ] = tools
453
+ }
454
+ if let additionalContext {
455
+ /*
456
+ Additional keys and values to be added to the context provided to the prompt templating engine.
457
+ For example, the app could set "tools_in_user_message" to false for Llama 3.1 and 3.2 if a system message is provided.
458
+ The default value is true in the Llama 3.1 and 3.2 chat templates, but these models will perform better if the tools are included in a system message.
459
+ */
460
+ for (key, value) in additionalContext {
461
+ context [ key] = value
462
+ }
463
+ }
431
464
432
465
// TODO: maybe keep NSString here
433
466
for (key, value) in tokenizerConfig. dictionary as [ String : Any ] {
0 commit comments