Skip to content

Commit be1f482

Browse files
committed
Enable tools
1 parent 92b5072 commit be1f482

File tree

3 files changed

+30
-16
lines changed

3 files changed

+30
-16
lines changed

Sources/Models/LanguageModel.swift

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ public extension LanguageModel {
195195
get async throws {
196196
guard _tokenizer == nil else { return _tokenizer! }
197197
guard let tokenizerConfig = try await tokenizerConfig else {
198-
throw "Cannot retrieve Tokenizer configuration"
198+
throw TokenizerError.tokenizerConfigNotFound
199199
}
200200
let tokenizerData = try await tokenizerData
201201
_tokenizer = try AutoTokenizer.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
@@ -218,4 +218,6 @@ extension LanguageModel: TextGenerationModel {
218218
}
219219
}
220220

221-
extension String: Error {}
221+
public enum TokenizerError: Error {
222+
case tokenizerConfigNotFound
223+
}

Sources/Tokenizers/Tokenizer.swift

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ import Hub
99
import Foundation
1010
import Jinja
1111

12+
public typealias Message = [String: Any]
13+
public typealias ToolSpec = [String: Any]
14+
1215
enum TokenizerError: Error {
1316
case missingConfig
1417
case missingTokenizerClassInConfig
@@ -134,22 +137,25 @@ public protocol Tokenizer {
134137
var unknownTokenId: Int? { get }
135138

136139
/// The appropriate chat template is selected from the tokenizer config
137-
func applyChatTemplate(messages: [[String: String]]) throws -> [Int]
140+
func applyChatTemplate(messages: [Message]) throws -> [Int]
141+
142+
/// The appropriate chat template is selected from the tokenizer config
143+
func applyChatTemplate(messages: [Message], tools: [ToolSpec]) throws -> [Int]
138144

139145
/// The chat template is provided as a string literal or specified by name
140-
func applyChatTemplate(messages: [[String: String]], chatTemplate: ChatTemplateArgument) throws -> [Int]
146+
func applyChatTemplate(messages: [Message], chatTemplate: ChatTemplateArgument) throws -> [Int]
141147

142148
/// The chat template is provided as a string literal
143-
func applyChatTemplate(messages: [[String: String]], chatTemplate: String) throws -> [Int]
149+
func applyChatTemplate(messages: [Message], chatTemplate: String) throws -> [Int]
144150

145151
func applyChatTemplate(
146-
messages: [[String: String]],
152+
messages: [Message],
147153
/// A chat template can optionally be provided or specified by name when several templates are included in the tokenizer config. Normally this is not necessary.
148154
chatTemplate: ChatTemplateArgument?,
149155
addGenerationPrompt: Bool,
150156
truncation: Bool,
151157
maxLength: Int?,
152-
tools: [[String: Any]]?
158+
tools: [ToolSpec]?
153159
) throws -> [Int]
154160
}
155161

@@ -358,20 +364,24 @@ public class PreTrainedTokenizer: Tokenizer {
358364
model.convertIdToToken(id)
359365
}
360366

361-
public func applyChatTemplate(messages: [[String: String]]) throws -> [Int] {
367+
public func applyChatTemplate(messages: [Message]) throws -> [Int] {
362368
try applyChatTemplate(messages: messages, addGenerationPrompt: true)
363369
}
364370

365-
public func applyChatTemplate(messages: [[String: String]], chatTemplate: ChatTemplateArgument) throws -> [Int] {
371+
public func applyChatTemplate(messages: [Message], tools: [ToolSpec]) throws -> [Int] {
372+
try applyChatTemplate(messages: messages, addGenerationPrompt: true, tools: tools)
373+
}
374+
375+
public func applyChatTemplate(messages: [Message], chatTemplate: ChatTemplateArgument) throws -> [Int] {
366376
try applyChatTemplate(messages: messages, chatTemplate: chatTemplate, addGenerationPrompt: true)
367377
}
368378

369-
public func applyChatTemplate(messages: [[String: String]], chatTemplate: String) throws -> [Int] {
379+
public func applyChatTemplate(messages: [Message], chatTemplate: String) throws -> [Int] {
370380
try applyChatTemplate(messages: messages, chatTemplate: .literal(chatTemplate), addGenerationPrompt: true)
371381
}
372382

373383
public func applyChatTemplate(
374-
messages: [[String: String]],
384+
messages: [Message],
375385
chatTemplate: ChatTemplateArgument? = nil,
376386
addGenerationPrompt: Bool = false,
377387
truncation: Bool = false,
@@ -381,8 +391,7 @@ public class PreTrainedTokenizer: Tokenizer {
381391
/// giving the name, description and argument types for the tool. See the
382392
/// [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use)
383393
/// for more information.
384-
/// Note: tool calling is not supported yet, it will be available in a future update.
385-
tools: [[String: Any]]? = nil
394+
tools: [ToolSpec]? = nil
386395
) throws -> [Int] {
387396
var selectedChatTemplate: String?
388397
if let chatTemplate, case .literal(let template) = chatTemplate {
@@ -429,9 +438,12 @@ public class PreTrainedTokenizer: Tokenizer {
429438
var context: [String: Any] = [
430439
"messages": messages,
431440
"add_generation_prompt": addGenerationPrompt,
432-
// TODO: Add `tools` entry when support is added in Jinja
433-
// "tools": tools
434441
]
442+
if let tools {
443+
context["tools"] = tools
444+
// Performance might be better if the tools prompt is included in a system message rather than a user message, but then the system message must be present.
445+
context["tools_in_user_message"] = false // Default is true in Llama 3.1 and 3.2 template
446+
}
435447

436448
// TODO: maybe keep NSString here
437449
for (key, value) in tokenizerConfig.dictionary as [String: Any] {

Tests/TokenizersTests/TokenizerTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ class TokenizerTester {
157157
guard _tokenizer == nil else { return _tokenizer! }
158158
do {
159159
guard let tokenizerConfig = try await configuration!.tokenizerConfig else {
160-
throw "Cannot retrieve Tokenizer configuration"
160+
throw TokenizerError.tokenizerConfigNotFound
161161
}
162162
let tokenizerData = try await configuration!.tokenizerData
163163
_tokenizer = try AutoTokenizer.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)

0 commit comments

Comments
 (0)