Skip to content

Commit f7277e2

Browse files
committed
Enable tools
1 parent 8a83416 commit f7277e2

File tree

3 files changed

+56
-17
lines changed

3 files changed

+56
-17
lines changed

Sources/Models/LanguageModel.swift

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,9 @@ public extension LanguageModel {
190190
var tokenizer: Tokenizer {
191191
get async throws {
192192
guard _tokenizer == nil else { return _tokenizer! }
193-
guard let tokenizerConfig = try await tokenizerConfig else { throw "Cannot retrieve Tokenizer configuration" }
193+
guard let tokenizerConfig = try await tokenizerConfig else {
194+
throw TokenizerError.tokenizerConfigNotFound
195+
}
194196
let tokenizerData = try await tokenizerData
195197
_tokenizer = try AutoTokenizer.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
196198
return _tokenizer!
@@ -212,4 +214,6 @@ extension LanguageModel: TextGenerationModel {
212214
}
213215
}
214216

215-
extension String: Error {}
217+
public enum TokenizerError: Error {
218+
case tokenizerConfigNotFound
219+
}

Sources/Tokenizers/Tokenizer.swift

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ import Hub
99
import Foundation
1010
import Jinja
1111

12+
public typealias Message = [String: Any]
13+
public typealias ToolSpec = [String: Any]
14+
1215
enum TokenizerError: Error {
1316
case missingConfig
1417
case missingTokenizerClassInConfig
@@ -141,22 +144,26 @@ public protocol Tokenizer {
141144
var unknownTokenId: Int? { get }
142145

143146
/// The appropriate chat template is selected from the tokenizer config
144-
func applyChatTemplate(messages: [[String: String]]) throws -> [Int]
147+
func applyChatTemplate(messages: [Message]) throws -> [Int]
148+
149+
/// The appropriate chat template is selected from the tokenizer config
150+
func applyChatTemplate(messages: [Message], tools: [ToolSpec]) throws -> [Int]
145151

146152
/// The chat template is provided as a string literal or specified by name
147-
func applyChatTemplate(messages: [[String: String]], chatTemplate: ChatTemplateArgument) throws -> [Int]
153+
func applyChatTemplate(messages: [Message], chatTemplate: ChatTemplateArgument) throws -> [Int]
148154

149155
/// The chat template is provided as a string literal
150-
func applyChatTemplate(messages: [[String: String]], chatTemplate: String) throws -> [Int]
156+
func applyChatTemplate(messages: [Message], chatTemplate: String) throws -> [Int]
151157

152158
func applyChatTemplate(
153-
messages: [[String: String]],
159+
messages: [Message],
154160
/// A chat template can optionally be provided or specified by name when several templates are included in the tokenizer config. Normally this is not necessary.
155161
chatTemplate: ChatTemplateArgument?,
156162
addGenerationPrompt: Bool,
157163
truncation: Bool,
158164
maxLength: Int?,
159-
tools: [[String: Any]]?
165+
tools: [ToolSpec]?,
166+
additionalContext: [String: Any]?
160167
) throws -> [Int]
161168
}
162169

@@ -358,20 +365,35 @@ public class PreTrainedTokenizer: Tokenizer {
358365
model.convertIdToToken(id)
359366
}
360367

361-
public func applyChatTemplate(messages: [[String: String]]) throws -> [Int] {
368+
public func applyChatTemplate(messages: [Message]) throws -> [Int] {
362369
try applyChatTemplate(messages: messages, addGenerationPrompt: true)
363370
}
364371

365-
public func applyChatTemplate(messages: [[String: String]], chatTemplate: ChatTemplateArgument) throws -> [Int] {
372+
public func applyChatTemplate(messages: [Message], tools: [ToolSpec]) throws -> [Int] {
373+
try applyChatTemplate(messages: messages, addGenerationPrompt: true, tools: tools)
374+
}
375+
376+
public func applyChatTemplate(messages: [Message], tools: [ToolSpec], additionalContext: [String: Any]) throws
377+
-> [Int]
378+
{
379+
try applyChatTemplate(
380+
messages: messages,
381+
addGenerationPrompt: true,
382+
tools: tools,
383+
additionalContext: additionalContext
384+
)
385+
}
386+
387+
public func applyChatTemplate(messages: [Message], chatTemplate: ChatTemplateArgument) throws -> [Int] {
366388
try applyChatTemplate(messages: messages, chatTemplate: chatTemplate, addGenerationPrompt: true)
367389
}
368390

369-
public func applyChatTemplate(messages: [[String: String]], chatTemplate: String) throws -> [Int] {
391+
public func applyChatTemplate(messages: [Message], chatTemplate: String) throws -> [Int] {
370392
try applyChatTemplate(messages: messages, chatTemplate: .literal(chatTemplate), addGenerationPrompt: true)
371393
}
372394

373395
public func applyChatTemplate(
374-
messages: [[String: String]],
396+
messages: [Message],
375397
chatTemplate: ChatTemplateArgument? = nil,
376398
addGenerationPrompt: Bool = false,
377399
truncation: Bool = false,
@@ -381,8 +403,8 @@ public class PreTrainedTokenizer: Tokenizer {
381403
/// giving the name, description and argument types for the tool. See the
382404
/// [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use)
383405
/// for more information.
384-
/// Note: tool calling is not supported yet, it will be available in a future update.
385-
tools: [[String: Any]]? = nil
406+
tools: [ToolSpec]? = nil,
407+
additionalContext: [String: Any]? = nil
386408
) throws -> [Int] {
387409
var selectedChatTemplate: String?
388410
if let chatTemplate, case .literal(let template) = chatTemplate {
@@ -424,10 +446,21 @@ public class PreTrainedTokenizer: Tokenizer {
424446
let template = try Template(selectedChatTemplate)
425447
var context: [String: Any] = [
426448
"messages": messages,
427-
"add_generation_prompt": addGenerationPrompt
428-
// TODO: Add `tools` entry when support is added in Jinja
429-
// "tools": tools
449+
"add_generation_prompt": addGenerationPrompt,
430450
]
451+
if let tools {
452+
context["tools"] = tools
453+
}
454+
if let additionalContext {
455+
/*
456+
Additional keys and values to be added to the context provided to the prompt templating engine.
457+
For example, the app could set "tools_in_user_message" to false for Llama 3.1 and 3.2 if a system message is provided.
458+
The default value is true in the Llama 3.1 and 3.2 chat templates, but these models will perform better if the tools are included in a system message.
459+
*/
460+
for (key, value) in additionalContext {
461+
context[key] = value
462+
}
463+
}
431464

432465
// TODO: maybe keep NSString here
433466
for (key, value) in tokenizerConfig.dictionary as [String : Any] {

Tests/TokenizersTests/TokenizerTests.swift

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,9 @@ class TokenizerTester {
239239
get async {
240240
guard _tokenizer == nil else { return _tokenizer! }
241241
do {
242-
guard let tokenizerConfig = try await configuration!.tokenizerConfig else { throw "Cannot retrieve Tokenizer configuration" }
242+
guard let tokenizerConfig = try await configuration!.tokenizerConfig else {
243+
throw TokenizerError.tokenizerConfigNotFound
244+
}
243245
let tokenizerData = try await configuration!.tokenizerData
244246
_tokenizer = try AutoTokenizer.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
245247
} catch {

0 commit comments

Comments
 (0)