Skip to content

Commit 6b941a9

Browse files
committed
Enable tools
1 parent fd16c00 commit 6b941a9

File tree

3 files changed

+51
-16
lines changed

3 files changed

+51
-16
lines changed

Sources/Models/LanguageModel.swift

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ extension LanguageModel {
195195
get async throws {
196196
guard _tokenizer == nil else { return _tokenizer! }
197197
guard let tokenizerConfig = try await tokenizerConfig else {
198-
throw "Cannot retrieve Tokenizer configuration"
198+
throw TokenizerError.tokenizerConfigNotFound
199199
}
200200
let tokenizerData = try await tokenizerData
201201
_tokenizer = try AutoTokenizer.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
@@ -218,4 +218,6 @@ extension LanguageModel: TextGenerationModel {
218218
}
219219
}
220220

221-
extension String: Error {}
221+
public enum TokenizerError: Error {
222+
case tokenizerConfigNotFound
223+
}

Sources/Tokenizers/Tokenizer.swift

Lines changed: 46 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ import Foundation
99
import Hub
1010
import Jinja
1111

12+
public typealias Message = [String: Any]
13+
public typealias ToolSpec = [String: Any]
14+
1215
enum TokenizerError: Error {
1316
case missingConfig
1417
case missingTokenizerClassInConfig
@@ -133,22 +136,26 @@ public protocol Tokenizer {
133136
var unknownTokenId: Int? { get }
134137

135138
/// The appropriate chat template is selected from the tokenizer config
136-
func applyChatTemplate(messages: [[String: String]]) throws -> [Int]
139+
func applyChatTemplate(messages: [Message]) throws -> [Int]
140+
141+
/// The appropriate chat template is selected from the tokenizer config
142+
func applyChatTemplate(messages: [Message], tools: [ToolSpec]) throws -> [Int]
137143

138144
/// The chat template is provided as a string literal or specified by name
139-
func applyChatTemplate(messages: [[String: String]], chatTemplate: ChatTemplateArgument) throws -> [Int]
145+
func applyChatTemplate(messages: [Message], chatTemplate: ChatTemplateArgument) throws -> [Int]
140146

141147
/// The chat template is provided as a string literal
142-
func applyChatTemplate(messages: [[String: String]], chatTemplate: String) throws -> [Int]
148+
func applyChatTemplate(messages: [Message], chatTemplate: String) throws -> [Int]
143149

144150
func applyChatTemplate(
145-
messages: [[String: String]],
151+
messages: [Message],
146152
/// A chat template can optionally be provided or specified by name when several templates are included in the tokenizer config. Normally this is not necessary.
147153
chatTemplate: ChatTemplateArgument?,
148154
addGenerationPrompt: Bool,
149155
truncation: Bool,
150156
maxLength: Int?,
151-
tools: [[String: Any]]?
157+
tools: [ToolSpec]?,
158+
additionalContext: [String: Any]?
152159
) throws -> [Int]
153160
}
154161

@@ -356,20 +363,35 @@ public class PreTrainedTokenizer: Tokenizer {
356363
model.convertIdToToken(id)
357364
}
358365

359-
public func applyChatTemplate(messages: [[String: String]]) throws -> [Int] {
366+
public func applyChatTemplate(messages: [Message]) throws -> [Int] {
360367
try applyChatTemplate(messages: messages, addGenerationPrompt: true)
361368
}
362369

363-
public func applyChatTemplate(messages: [[String: String]], chatTemplate: ChatTemplateArgument) throws -> [Int] {
370+
public func applyChatTemplate(messages: [Message], tools: [ToolSpec]) throws -> [Int] {
371+
try applyChatTemplate(messages: messages, addGenerationPrompt: true, tools: tools)
372+
}
373+
374+
public func applyChatTemplate(messages: [Message], tools: [ToolSpec], additionalContext: [String: Any]) throws
375+
-> [Int]
376+
{
377+
try applyChatTemplate(
378+
messages: messages,
379+
addGenerationPrompt: true,
380+
tools: tools,
381+
additionalContext: additionalContext
382+
)
383+
}
384+
385+
public func applyChatTemplate(messages: [Message], chatTemplate: ChatTemplateArgument) throws -> [Int] {
364386
try applyChatTemplate(messages: messages, chatTemplate: chatTemplate, addGenerationPrompt: true)
365387
}
366388

367-
public func applyChatTemplate(messages: [[String: String]], chatTemplate: String) throws -> [Int] {
389+
public func applyChatTemplate(messages: [Message], chatTemplate: String) throws -> [Int] {
368390
try applyChatTemplate(messages: messages, chatTemplate: .literal(chatTemplate), addGenerationPrompt: true)
369391
}
370392

371393
public func applyChatTemplate(
372-
messages: [[String: String]],
394+
messages: [Message],
373395
chatTemplate: ChatTemplateArgument? = nil,
374396
addGenerationPrompt: Bool = false,
375397
truncation: Bool = false,
@@ -379,8 +401,8 @@ public class PreTrainedTokenizer: Tokenizer {
379401
/// giving the name, description and argument types for the tool. See the
380402
/// [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use)
381403
/// for more information.
382-
/// Note: tool calling is not supported yet, it will be available in a future update.
383-
tools: [[String: Any]]? = nil
404+
tools: [ToolSpec]? = nil,
405+
additionalContext: [String: Any]? = nil
384406
) throws -> [Int] {
385407
var selectedChatTemplate: String?
386408
if let chatTemplate, case .literal(let template) = chatTemplate {
@@ -425,9 +447,20 @@ public class PreTrainedTokenizer: Tokenizer {
425447
var context: [String: Any] = [
426448
"messages": messages,
427449
"add_generation_prompt": addGenerationPrompt,
428-
// TODO: Add `tools` entry when support is added in Jinja
429-
// "tools": tools
430450
]
451+
if let tools {
452+
context["tools"] = tools
453+
}
454+
if let additionalContext {
455+
/*
456+
Additional keys and values to be added to the context provided to the prompt templating engine.
457+
For example, the app could set "tools_in_user_message" to false for Llama 3.1 and 3.2 if a system message is provided.
458+
The default value is true in the Llama 3.1 and 3.2 chat templates, but these models will perform better if the tools are included in a system message.
459+
*/
460+
for (key, value) in additionalContext {
461+
context[key] = value
462+
}
463+
}
431464

432465
// TODO: maybe keep NSString here
433466
for (key, value) in tokenizerConfig.dictionary as [String: Any] {

Tests/TokenizersTests/TokenizerTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ class TokenizerTester {
257257
guard _tokenizer == nil else { return _tokenizer! }
258258
do {
259259
guard let tokenizerConfig = try await configuration!.tokenizerConfig else {
260-
throw "Cannot retrieve Tokenizer configuration"
260+
throw TokenizerError.tokenizerConfigNotFound
261261
}
262262
let tokenizerData = try await configuration!.tokenizerData
263263
_tokenizer = try AutoTokenizer.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)

0 commit comments

Comments
 (0)