@@ -9,6 +9,9 @@ import Foundation
9
9
import Hub
10
10
import Jinja
11
11
12
+ public typealias Message = [ String : Any ]
13
+ public typealias ToolSpec = [ String : Any ]
14
+
12
15
enum TokenizerError : Error {
13
16
case missingConfig
14
17
case missingTokenizerClassInConfig
@@ -133,22 +136,26 @@ public protocol Tokenizer {
133
136
var unknownTokenId : Int ? { get }
134
137
135
138
/// The appropriate chat template is selected from the tokenizer config
136
- func applyChatTemplate( messages: [ [ String : String ] ] ) throws -> [ Int ]
139
+ func applyChatTemplate( messages: [ Message ] ) throws -> [ Int ]
140
+
141
+ /// The appropriate chat template is selected from the tokenizer config
142
+ func applyChatTemplate( messages: [ Message ] , tools: [ ToolSpec ] ) throws -> [ Int ]
137
143
138
144
/// The chat template is provided as a string literal or specified by name
139
- func applyChatTemplate( messages: [ [ String : String ] ] , chatTemplate: ChatTemplateArgument ) throws -> [ Int ]
145
+ func applyChatTemplate( messages: [ Message ] , chatTemplate: ChatTemplateArgument ) throws -> [ Int ]
140
146
141
147
/// The chat template is provided as a string literal
142
- func applyChatTemplate( messages: [ [ String : String ] ] , chatTemplate: String ) throws -> [ Int ]
148
+ func applyChatTemplate( messages: [ Message ] , chatTemplate: String ) throws -> [ Int ]
143
149
144
150
func applyChatTemplate(
145
- messages: [ [ String : String ] ] ,
151
+ messages: [ Message ] ,
146
152
/// A chat template can optionally be provided or specified by name when several templates are included in the tokenizer config. Normally this is not necessary.
147
153
chatTemplate: ChatTemplateArgument ? ,
148
154
addGenerationPrompt: Bool ,
149
155
truncation: Bool ,
150
156
maxLength: Int ? ,
151
- tools: [ [ String : Any ] ] ?
157
+ tools: [ ToolSpec ] ? ,
158
+ additionalContext: [ String : Any ] ?
152
159
) throws -> [ Int ]
153
160
}
154
161
@@ -356,20 +363,35 @@ public class PreTrainedTokenizer: Tokenizer {
356
363
model. convertIdToToken ( id)
357
364
}
358
365
359
- public func applyChatTemplate( messages: [ [ String : String ] ] ) throws -> [ Int ] {
366
+ public func applyChatTemplate( messages: [ Message ] ) throws -> [ Int ] {
360
367
try applyChatTemplate ( messages: messages, addGenerationPrompt: true )
361
368
}
362
369
363
- public func applyChatTemplate( messages: [ [ String : String ] ] , chatTemplate: ChatTemplateArgument ) throws -> [ Int ] {
370
+ public func applyChatTemplate( messages: [ Message ] , tools: [ ToolSpec ] ) throws -> [ Int ] {
371
+ try applyChatTemplate ( messages: messages, addGenerationPrompt: true , tools: tools)
372
+ }
373
+
374
+ public func applyChatTemplate( messages: [ Message ] , tools: [ ToolSpec ] , additionalContext: [ String : Any ] ) throws
375
+ -> [ Int ]
376
+ {
377
+ try applyChatTemplate (
378
+ messages: messages,
379
+ addGenerationPrompt: true ,
380
+ tools: tools,
381
+ additionalContext: additionalContext
382
+ )
383
+ }
384
+
385
+ public func applyChatTemplate( messages: [ Message ] , chatTemplate: ChatTemplateArgument ) throws -> [ Int ] {
364
386
try applyChatTemplate ( messages: messages, chatTemplate: chatTemplate, addGenerationPrompt: true )
365
387
}
366
388
367
- public func applyChatTemplate( messages: [ [ String : String ] ] , chatTemplate: String ) throws -> [ Int ] {
389
+ public func applyChatTemplate( messages: [ Message ] , chatTemplate: String ) throws -> [ Int ] {
368
390
try applyChatTemplate ( messages: messages, chatTemplate: . literal( chatTemplate) , addGenerationPrompt: true )
369
391
}
370
392
371
393
public func applyChatTemplate(
372
- messages: [ [ String : String ] ] ,
394
+ messages: [ Message ] ,
373
395
chatTemplate: ChatTemplateArgument ? = nil ,
374
396
addGenerationPrompt: Bool = false ,
375
397
truncation: Bool = false ,
@@ -379,8 +401,8 @@ public class PreTrainedTokenizer: Tokenizer {
379
401
/// giving the name, description and argument types for the tool. See the
380
402
/// [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use)
381
403
/// for more information.
382
- /// Note: tool calling is not supported yet, it will be available in a future update.
383
- tools : [ [ String : Any ] ] ? = nil
404
+ tools : [ ToolSpec ] ? = nil ,
405
+ additionalContext : [ String : Any ] ? = nil
384
406
) throws -> [ Int ] {
385
407
var selectedChatTemplate : String ?
386
408
if let chatTemplate, case . literal( let template) = chatTemplate {
@@ -425,9 +447,20 @@ public class PreTrainedTokenizer: Tokenizer {
425
447
var context : [ String : Any ] = [
426
448
" messages " : messages,
427
449
" add_generation_prompt " : addGenerationPrompt,
428
- // TODO: Add `tools` entry when support is added in Jinja
429
- // "tools": tools
430
450
]
451
+ if let tools {
452
+ context [ " tools " ] = tools
453
+ }
454
+ if let additionalContext {
455
+ /*
456
+ Additional keys and values to be added to the context provided to the prompt templating engine.
457
+ For example, the app could set "tools_in_user_message" to false for Llama 3.1 and 3.2 if a system message is provided.
458
+ The default value is true in the Llama 3.1 and 3.2 chat templates, but these models will perform better if the tools are included in a system message.
459
+ */
460
+ for (key, value) in additionalContext {
461
+ context [ key] = value
462
+ }
463
+ }
431
464
432
465
// TODO: maybe keep NSString here
433
466
for (key, value) in tokenizerConfig. dictionary as [ String : Any ] {
0 commit comments